Python源码示例:sklearn.metrics.pairwise_distances_argmin()
示例1
def test_birch_predict():
# Test the predict method predicts the nearest centroid.
rng = np.random.RandomState(0)
X = generate_clustered_data(n_clusters=3, n_features=3,
n_samples_per_cluster=10)
# n_samples * n_samples_per_cluster
shuffle_indices = np.arange(30)
rng.shuffle(shuffle_indices)
X_shuffle = X[shuffle_indices, :]
brc = Birch(n_clusters=4, threshold=1.)
brc.fit(X_shuffle)
centroids = brc.subcluster_centers_
assert_array_equal(brc.labels_, brc.predict(X_shuffle))
nearest_centroid = pairwise_distances_argmin(X_shuffle, centroids)
assert_almost_equal(v_measure_score(nearest_centroid, brc.labels_), 1.0)
示例2
def test_birch_predict():
# Test the predict method predicts the nearest centroid.
rng = np.random.RandomState(0)
X = generate_clustered_data(n_clusters=3, n_features=3,
n_samples_per_cluster=10)
# n_samples * n_samples_per_cluster
shuffle_indices = np.arange(30)
rng.shuffle(shuffle_indices)
X_shuffle = X[shuffle_indices, :]
brc = Birch(n_clusters=4, threshold=1.)
brc.fit(X_shuffle)
centroids = brc.subcluster_centers_
assert_array_equal(brc.labels_, brc.predict(X_shuffle))
nearest_centroid = pairwise_distances_argmin(X_shuffle, centroids)
assert_almost_equal(v_measure_score(nearest_centroid, brc.labels_), 1.0)
示例3
def find_clusters(x, n_clusters, current_split):
current_split_suffled = list(x_split[current_split])[:]
shuffle(current_split_suffled)
current_split_suffled = np.array(current_split_suffled)
centroids = np.array(current_split_suffled[:n_clusters])
while True:
# assign labels based on closest centroid
#print centroids
#print "len train", len(x_split[current_split])
labels = pairwise_distances_argmin(x_split[current_split], centroids)
#print "len labels", len(labels)
# find new centroids as the average of examples
new_centroids = np.array([x_split[current_split][labels == i].mean(0) for i in range(n_clusters)])
# check for convergence
if np.all(centroids == new_centroids):
break
centroids = new_centroids
return centroids, labels
示例4
def predict(self, X, set_outliers=True):
import sklearn.metrics as sk_met
y = sk_met.pairwise_distances_argmin(X, self.cluster_centers_[:, None])
if set_outliers:
y[((X > self.max) | (X < self.min))[:, 0]] = -1
return y
示例5
def run_pruning_for_conv2d_layer(self, pruning_factor: float, layer: layers.Conv2D, layer_weight_mtx) -> List[int]:
_, _, _, nb_channels = layer_weight_mtx.shape
# Initialize KMeans
nb_of_clusters, _ = self._calculate_number_of_channels_to_keep(pruning_factor, nb_channels)
kmeans = cluster.KMeans(nb_of_clusters, "k-means++")
# Fit with the flattened weight matrix
# (height, width, input_channels, output_channels) -> (output_channels, flattened features)
layer_weight_mtx_reshaped = layer_weight_mtx.transpose(3, 0, 1, 2).reshape(nb_channels, -1)
# Apply some fuzz to the weights, to avoid duplicates
self._apply_fuzz(layer_weight_mtx_reshaped)
kmeans.fit(layer_weight_mtx_reshaped)
# If a cluster has only a single member, then that should not be pruned
# so that point will always be the closest to the cluster center
closest_point_to_cluster_center_indices = metrics.pairwise_distances_argmin(kmeans.cluster_centers_,
layer_weight_mtx_reshaped)
# Compute filter indices which can be pruned
channel_indices = set(np.arange(len(layer_weight_mtx_reshaped)))
channel_indices_to_keep = set(closest_point_to_cluster_center_indices)
channel_indices_to_prune = list(channel_indices.difference(channel_indices_to_keep))
channel_indices_to_keep = list(channel_indices_to_keep)
if len(channel_indices_to_keep) > nb_of_clusters:
print("Number of selected channels for pruning is less than expected")
diff = len(channel_indices_to_keep) - nb_of_clusters
print("Randomly adding {0} channels for pruning".format(diff))
np.random.shuffle(channel_indices_to_keep)
for i in range(diff):
channel_indices_to_prune.append(channel_indices_to_keep.pop(i))
elif len(channel_indices_to_keep) < nb_of_clusters:
print("Number of selected channels for pruning is greater than expected. Leaving too few channels.")
diff = nb_of_clusters - len(channel_indices_to_keep)
print("Discarding {0} pruneable channels".format(diff))
for i in range(diff):
channel_indices_to_keep.append(channel_indices_to_prune.pop(i))
if len(channel_indices_to_keep) != nb_of_clusters:
raise ValueError(
"Number of clusters {0} is not equal with the selected "
"pruneable channels {1}".format(nb_of_clusters, len(channel_indices_to_prune)))
return channel_indices_to_prune