K-Means Clustering and Related Algorithms: Ryan P. Adams
K-Means Clustering and Related Algorithms: Ryan P. Adams
K-Means Clustering and Related Algorithms: Ryan P. Adams
Ryan P. Adams
1 Clustering
For most clustering algorithms, the main thing that we need is a notion of distance between our
data. If our data live in some space X , then we need a function that takes two points in X , say, x
and x0 and computes a distance between them. We’ll write such a distance as || x − x0 ||. If X is RD ,
then a natural choice would be the Euclidean (sometimes called `2 ) distance
v
u D
|| x − x ||2 = t ∑ ( xd − xd0 )2 .
0
u
(1)
d =1
1
(a) Initialization (b) Iteration 1 (c) Iteration 2 (d) Iteration 3 (e) Iteration 4
Figure 1: Four iterations of K-Means applied to the lengths and widths of fruit (oranges and
lemons), as measured in centimeters. Here, K = 5 and the initial means were chosen randomly
in the square shown. The colored regions show the Voronoi partitions induced by each cluster
center, and the data are colored according to their association. The cluster centers are shown as ×.
There are many distance metrics that one might come up with, depending on what your data are
and what “similarity” means for the problem you want to solve. For strings or DNA sequences,
one might use edit distance.1 For bit vectors, it might be sensible to use Hamming distance.2 This
choice is important because it will determine whether two objects should want to be in the same
group or not.
You also need to decide is how many groups you want. This is the number K that gives the
K-means algorithm its name. Choosing K can be a bit of an art, because it depends on what
kind of structure you’re looking for. Sometimes you might know how many clusters there are in
advance. Other times, you might want to “oversegment” or “undersegment” the data, depending
on whether you’d like to end up with many clusters or just a few. In the compression view of
clustering, this boils down to asking whether you’d like a more compressed representation that
loses information (smaller K), or a less compressed representation that keeps more information
about your data (larger K). If you’re using K-Means for learning a feature representation, it’s
usually a good idea to use a larger K. If you want to interpret the groups, then perhaps you want
to go with a smaller K.
Our data are N points in X . Let’s denote the nth of these as xn , so we can write the data as the
set { xn }nN=1 . Clustering algorithms assign every one of these data to one of the K clusters. What
we’re doing is trying to find a good (ideally, the best) assignment of the data to the clusters. We
represent these assignments by giving every one of the N data a binary responsibility vector r n .
This vector is all zeros except in one component, which corresponds to the cluster it is assigned
to. That is, if xn is assigned to cluster k, then rnk = 1 and all of the other entries in r n are zero. This
is an example of one-hot coding in which an integer between 1 and K is encoded as a length-K
binary vector that is zero everywhere except for one place.
2
Algorithm 1 K-Means Clustering (Lloyd’s Algorithm) Note: written for clarity, not efficiency.
1: Input: Data vectors { xn }nN=1 , number of clusters K
2: for n ← 1 . . . N do . Initialize all of the responsibilities.
3: r n ← [0, 0, · · · , 0] . Zero out the responsibilities.
4: k0 ← RandomInteger(1, K ) . Make one of them randomly one to initialize.
5: rnk0 = 1
6: end for
7: repeat
8: for k ← 1 . . . K do . Loop over the clusters.
9: Nk ← ∑nN=1 rnk . Compute the number assigned to cluster k.
10: µk ← N1k ∑nN=1 rnk xn . Compute the mean of the kth cluster.
11: end for
12: for n ← 1 . . . N do . Loop over the data.
13: r n ← [0, 0, · · · , 0] . Zero out the responsibilities.
14: k0 ← arg mink || xn − µk ||2 . Find the closest mean.
15: rnk0 = 1
16: end for
17: until none of the r n change
18: Return assignments {r n }nN=1 for each datum, and cluster means {µk }kK=1 .
resented by an average, this approach is called K-Means. The K-Means procedure is among the
most popular machine learning algorithms, due to its simplicity and interpretability. Pseudocode
for K-Means is shown in Algorithm 1. K-means is an iterative algorithm that loops until it con-
verges to a (locally optimal) solution. Within each loop, it makes two kinds of updates: it loops
over the responsibility vectors r n and changes them to point to the closest cluster, and it loops
over the mean vectors µk and changes them to be the mean of the data that currently belong to it.
There are K of these mean vectors (hence the name of the algorithm) and you can think of them as
“prototypes” that describe each of the clusters. The basic idea is to find a prototype that describes
a group in the data and to use the r n to assign the data to the best one. In the compression view
of K-Means, you can think of replacing your actual datum xn with its prototype and then trying
to find a situation in which that doesn’t seem so bad, i.e., that compression will not lose too much
information if the prototype accurately reflects the group. When updating the assignments, we
tend to use the squared distance, rather than the actual distance, as this doesn’t change the answer
and we avoid the square root in Eq. 1.
Figure 1 shows several iterations of the K-Means clustering algorithm applied to two-dimensional
data. These are the lengths and widths of fruit (oranges and lemons) purchased by Iain Murray.3
These distances are in centimeters. I initialized the centers randomly within the square shown,
and used K = 5.
3 http://homepages.inf.ed.ac.uk/imurray2/teaching/oranges_and_lemons/
3
(a) Cluster Centers (b) Cluster 1 (c) Cluster 2 (d) Cluster 3 (e) Cluster 4
(f) Cluster 5 (g) Cluster 6 (h) Cluster 7 (i) Cluster 8 (j) Cluster 9 (k) Cluster 10
(l) Cluster 11 (m) Cluster 12 (n) Cluster 13 (o) Cluster 14 (p) Cluster 15 (q) Cluster 16
Figure 2: This is the result of K-Means clustering applied to the MNIST digits data. (a) The 16
cluster centers. (b-q) 25 data examples are shown for each of the 16 clusters. The clusters roughly
grab digits with similar stroke patterns.
4
(a) Cluster Centers (b) Cluster 1 (c) Cluster 2 (d) Cluster 3 (e) Cluster 4
(f) Cluster 5 (g) Cluster 6 (h) Cluster 7 (i) Cluster 8 (j) Cluster 9 (k) Cluster 10
(l) Cluster 11 (m) Cluster 12 (n) Cluster 13 (o) Cluster 14 (p) Cluster 15 (q) Cluster 16
Figure 3: This is the result of K-Means clustering applied to the CIFAR-100 image data. (a) The 16
cluster centers. (b-q) 25 data examples are shown for each of the 16 clusters. The clusters primarily
pick up on low-frequency color variations.
5
(a) Cluster Centers (b) Cluster 1 (c) Cluster 2 (d) Cluster 3 (e) Cluster 4
(f) Cluster 5 (g) Cluster 6 (h) Cluster 7 (i) Cluster 8 (j) Cluster 9 (k) Cluster 10
(l) Cluster 11 (m) Cluster 12 (n) Cluster 13 (o) Cluster 14 (p) Cluster 15 (q) Cluster 16
Figure 4: This is the result of K-Means clustering applied to a preprocessed variant of the Labeled
Faces in the Wild image data. (a) The 16 cluster centers. (b-q) 25 data examples are shown for each
of the 16 clusters. The clusters capture a combination of face shape, background, and illumination.
most common words, excluding stop words9 such as “the” and ”and”. These counts were treated
as the features directly, leading to a very simple notion of distance. I used K = 12 and initialized
with K-Means++ before applying Lloyd’s algorithm.
3 Derivation
Where does this algorithm come from and what does it do? As we’ll see with many machine
learning algorithms, we begin by defining a loss function which specifies what solutions are good
and bad. This loss function takes as arguments the two sets of parameters that we introduced in
the previous sections, the responsibilities r n and the means µk . Given these parameters and the
data vectors xn , what does it mean to be in a good configuration versus a bad one? One intuition
is that good settings of r n and µk will be those in which as many of the data as possible can be
near their assigned µk . This fits well with the compression view of K-Means: if each of the xn was
replaced by the appropriate µk , then better solutions will have this error be very small on average.
9 http://en.wikipedia.org/wiki/Stop_words
6
Cluster 1 Cluster 2 Cluster 3 Cluster 4 Cluster 5 Cluster 6
education south war war art light
united population german government century energy
american north british law architecture atoms
public major united political style theory
world west president power painting stars
social mi power united period chemical
government km government party sculpture elements
century sq army world form electrons
schools deg germany century artists hydrogen
countries river congress military forms carbon
Table 1: This is the result of a simple application of K-Means clustering to a set of Grolier encyclo-
pedia articles. Shown above are the words with the highest “mean counts” in each of the cluster
centers, with clusters as columns. Even with this very simple approach, K-Means identifies groups
of words that seem conceptually related.
Here we’re continuing with the assumption that the distance is Euclidean. This function sums
up the squared distances between each example and the prototype it belongs to. The K-Means
algorithm minimizes this via coordinate descent, i.e., it alternates between 1) minimizing each of
the r n , and 2) minimizing each of the µk .
Minimizing the r n If we look at the sum in Eq. 2, we see that the r n only appears in one of the
outer sums, because it only affects one of the data examples. There are only K possible values
for r n and so we can minimize it (holding everything else fixed) by choosing rnk = 1 for the cluster
that has the smallest distance:
(
1 if k = arg mink0 || xn − µk0 ||22
rnk = (3)
0 otherwise.
7
Minimizing the µk Having fixed everything else, we note that each µk only depends on one of
the parts of the inner sums. We can think about the objective function written in terms of only one
of these:
N
J (µk ) = ∑ rnk ||xn − µk ||22 (4)
n =1
N
= ∑ rnk (xn − µk )T (xn − µk ). (5)
n =1
Here I’ve written out the squared Euclidean distance as a quadratic form, because it makes the
calculus a bit easier to see. To minimize, we differentiate this objective with respect to µk , set the
resulting gradient to zero, and then solve for µk .
N
∇ µk J ( µ k ) = ∇ µk ∑ rnk (xn − µk )T (xn − µk ) (6)
n =1
N
= ∑ rnk ∇µ (xn − µk )T (xn − µk ).
k
(7)
n =1
4 Practical Considerations
4.1 Hardness and Initialization
The objective function in Eq. 2 is highly non-convex, with many local minima. Some of them are
very easy to see. For example, you could clearly permute the indices of the clusters and wind up
in a “different” solution that was just as good. Coordinate descent, as described here, therefore
only finds a local minimum of the objective. Strictly speaking, “K-Means” is not an algorithm, but
a problem specified by finding a configuration of the r n that minimizes Eq. 2 – note that the µk
are completely determined by the r n for a given data set. The iterative algorithm described here
is often called Lloyd’s algorithm (Lloyd, 1982), but it represents just one way to optimize the K-
Means objective. It turns out that finding (one of) the globally optimal solutions to the K-Means
problem is NP-hard (Aloise et al., 2009), even if there are only two clusters.
10 http://en.wikipedia.org/wiki/Voronoi_diagram
8
Algorithm 2 K-Means++ Note: written for clarity, not efficiency.
1: Input: Data vectors { xn }nN=1 , number of clusters K
2: n ← RandomInteger(1, N ) . Choose a datum at random.
3: µ1 ← x n . Make this random datum the first cluster center.
4: for k ← 2 . . . K do . Loop over the rest of the centers.
5: for n ← 1 . . . N do . Loop over the data.
6: dn ← mink0 <k || xn − µk0 ||2 . Compute the distance to the closest center.
7: end for
8: for n ← 1 . . . N do . Loop over the data again.
9: pn ← d2n / ∑n0 d2n0 . Compute a distribution proportional to d2n .
10: end for
11: n ← Discrete( p1 , p2 , . . . , p N ) . Draw a datum from this distribution.
12: µk ← xn . Make this datum the next center.
13: end for
14: Return cluster means {µk }kK=1 .
When faced with highly non-convex optimization problems, a common strategy is to use ran-
dom restarts to try to find a good solution. That is, running Algorithm 1 several times (e.g., 10 or
20), with different random seeds so that the initial r n might land in better places. Then one looks
at the final value of the objective in Eq. 2 to choose the best solution. Another practical strategy
for larger data sets is to do these restarts with a smaller subset of the data in order to find some
reasonable cluster centers before running the full iteration.
More recently, an algorithm has been proposed that is a bounded-error approximation to the
solution of K-Means (Arthur and Vassilvitskii, 2007). This algorithm, called K-Means++, is shown
in pseudocode in Algorithm 2, and can be an excellent alternative to the simple random initializa-
tion shown in Algorithm 1. In fact, Arthur and Vassilvitskii (2007) show that K-Means++ can do
well even without using Lloyd’s algorithm at all.
9
Algorithm 3 Data Standardization Note: written for clarity, not efficiency.
1: Input: Data vectors { xn }nN=1
2: for d ← 1 . . . D do . Loop over dimensions.
3: m1 ← 0 . For storing the total of the values.
4: m2 ← 0 . For storing the total of the squared values.
5: for n ← 1 . . . N do . Loop over the data.
6: m1 ← m1 + xn,d
7: m2 ← m2 + xn,d 2
8: end for
9: µ ← m1 /N . Compute sample mean.
10: σ2 ← (m2 /N ) − µ2 . Compute sample variance.
11: for n ← 1 . . . N do . Loop over the data again to modify.
0 ← (x
12: xn,d n,d − µ ) /σ . Shift by mean, scale by standard deviation.
13: end for
14: end for
15: Return transformed data { x0n }nN=1
Z ∞ D
E || x − µ||22 = ∑ ( xd0 − µd0 )2 dxd
N ( xd | 0, 1) (11)
−∞ 0
d =1
" # Z ∞
= ∑ ( x d0 − µ d0 ) 2
+
−∞
N ( xd | 0, 1)( xd − µd )2 dxd (12)
d0 6=d
" #
= ∑
0
( x d 0 − µ d 0 )2 + 1 + µ2d . (13)
d 6=d
10
1 import numpy as np
2 def standardize(data):
3 ’’’Take an NxD numpy matrix as input and return a standardized version of it.’’’
4 mean = np.mean(data, axis=0)
5 std = np.std(data, axis=0)
6 return (data - mean)/std
Dn,m = ( xn − ym )( xn − ym )T (14)
= x n xT
n − 2xn yT
m + y m yT
m, (15)
where xn ∈ RD is the nth row of X and ym ∈ RD is the mth row of Y. These are row vectors, so
this is a sum of inner products. Using broadcasting, which numpy does naturally and can be done
in Matlab using bsxfun, this gives us a rapid way to offload distance calculations to BLAS (or
equivalent) without looping in our high-level code.
11
1.5
1
Gap Statistic
0.5
−0.5
1 2 3 4 5 6 7 8 9 10
Number of Clusters
Figure 5: The gap statistic computed for the fruit data. Here, it seems reasonable to choose K = 5.
This used 100 reference data sets, each from a Gaussian MLE fit for the null model.
fit to the same statistic fit to synthetic data that are known not to have clusters. When we have
the right number of clusters, we would expect to have less dispersion than random. Let’s imagine
that we have run K-Means and we now have a set of responsibilities {r n }nN=1 and cluster cen-
ters {µk }kK=1 . We define the within-cluster dispersion to be the sum of squared distances between
all pairs in a given cluster:
N N
Dk = ∑ ∑
0
rn,k rn0 ,k || xn − xn0 ||2 . (16)
n =1 n =1
The dispersion for a size-K clustering is the normalized sum of the within-cluster dispersion over
all of the clusters:
K K N N
D 1
WK = ∑ 2Nkk = ∑ 2Nk ∑ ∑
0
rn,k rn0 ,k || xn − xn0 ||2 (17)
k =1 k =1 n =1 n =1
N
Nk = ∑ rn,k . (18)
n =1
This is basically a measure of “tightness” of the fit normalized by the size of each cluster. It will be
smaller when the clustered points group together closely. The gap statistic uses a null distribution
for the data we have clustered, from which we generate reference data. You can think of this as a
distribution on the same space, with similar coarse statistics, but in which there aren’t clusters.
For example, if our original data were in the unit hypercube [0, 1] D , we might generate reference
data uniformly from the cube. If we standardized data on RD so that each feature has zero sample
mean and unit sample variance, when it would be natural to make the null distribution N (0, I D ).
To compute the gap statistic, we now generate several sets of reference data, cluster each of them,
and compute their dispersions. This gives us an idea (with error bars) as to what the expected
12
cow humpback whale german shepherd mole
walks hairless furry furry
quadrapedal toughskin meatteeth small
vegetation big walks fast
ground swims fast active
big strong quadrapedal newworld
ox blue whale siamese cat hamster
pig seal wolf rat
sheep walrus chihuahua squirrel
buffalo dolphin dalmatian mouse
horse killer whale weasel skunk
Table 2: This table shows the result of applying K-Medoids to binary features associated with 50
animals, using Hamming distance. Here K = 4. The bold animals along the top are the medoids
for each cluster. The top five most common features are shown next, followed by five other non-
medoid animals from the same cluster.
dispersion would be based on the coarse properties of the data space, for a given K. We can then
compute the gap statistic as follows:
Here the first term is the expected log of the dispersion under the null distribution – something
we can compute by averaging the log dispersions of our reference data. We subtract from this
the dispersion of our actual data. We can then look for the K which maximizes Gap N (K ). We
choose the smallest one that appears to be statistically significant. Figure 5 shows a boxplot of the
gap statistic for the fruit data. Here the null distribution was an MLE Gaussian fit to the data. I
generated 100 sets of reference data.
13
Algorithm 4 K-Medoids Note: written for clarity, not efficiency.
1: Input: Data vectors { xn }nN=1 , number of clusters K
2: for n ← 1 . . . N do . Initialize all of the responsibilities.
3: r n ← [0, 0, · · · , 0] . Zero out the responsibilities.
4: k0 ← RandomInteger(1, K ) . Make one of them randomly one to initialize.
5: rnk0 = 1
6: end for
7: repeat
8: for k ← 1 . . . K do . Loop over the clusters.
9: for n ← 1 . . . N do . Loop over data.
10: if rn,k = 1 then
11: Jn ← ∑nN0 =1 rn0 ,k || xn − xn0 || . Sum distances to this datum.
12: else
13: Jn ← ∞ . Infinite cost for data not in this cluster.
14: end if
15: end for
16: n? ← arg minn Jn . Pick the one that minimizes the sum of distances.
17: µ k ← x n? . Make the minimizing one the cluster center.
18: end for
19: for n ← 1 . . . N do . Loop over the data.
20: r n ← [0, 0, · · · , 0] . Zero out the responsibilities.
21: k0 ← arg mink || xn − µk ||2 . Find the closest medoid.
22: rnk0 = 1
23: end for
24: until none of the r n change
25: Return assignments {r n }nN=1 for each datum, and cluster medoids {µk }kK=1 .
are properties such as “swims” or “furry”. There are 50 animals and 85 features. I used K = 4 and
Hamming distance. I initialized with K-Means++. The table has four columns with the medoid
at the top, the most common five features within the group, and five other animals in the cluster,
excluding the medoid.
7 Advanced Topics
These topics are outside the scope of this class, but are good things to look into next if you find
clustering to be an interesting topic.
14
(a) Cluster Centers (b) Cluster 1 (c) Cluster 2 (d) Cluster 3 (e) Cluster 4
(f) Cluster 5 (g) Cluster 6 (h) Cluster 7 (i) Cluster 8 (j) Cluster 9 (k) Cluster 10
(l) Cluster 11 (m) Cluster 12 (n) Cluster 13 (o) Cluster 14 (p) Cluster 15 (q) Cluster 16
Figure 6: This is the result of K-Medoids clustering applied to the CIFAR-100 image data. (a) The
16 cluster medoids. (b-q) 25 data examples are shown for each of the 16 clusters.
Spectral Clustering Clusters in data are often more complicated than simple isotropic blobs.
Our human visual system likes to group things in more complicated ways. Spectral clustering
(see, e.g., Ng et al. (2002); Von Luxburg (2007)) constructs a graph over the data first and then
performs operations on that graph using the Laplacian. The idea is that data which are close
together tend to be in the same group, even if they are not close to some single prototype.
Affinity Propagation One powerful way to think about clustering, which we’ll see later in the
semester, is to frame the groups in terms of latent variables in a probabilistic model. Inference
in many probabilistic models can be performed efficiently using “message passing” algorithms
which take advantage of an underlying graph structure. The algorithm of affinity propagation (Frey
and Dueck, 2007) is a nice way to perform K-Medoids clustering efficiently with such a message
passing procedure.
Biclustering The clustering algorithms we’ve looked at here have operated on the data instances.
That is, we’ve thought about finding partitions of the rows of an N × D matrix. Many data are not
well represented by “instances” and “features”, but are interactions between items. For example,
we could imagine our data to be a matrix of outcomes between sports teams, or protein-protein
interaction data. Biclustering (Hartigan, 1972) is an algorithm for simultaneously grouping both
rows and columns of a matrix, in effect discovering blocks. Variants of this technique have become
immensely important to biological data analysis.
15
References
Adam Coates, Andrew Y. Ng, and Honglak Lee. An analysis of single-layer networks in unsuper-
vised feature learning. In International Conference on Artificial Intelligence and Statistics, pages
215–223, 2011. URL http://www.stanford.edu/˜acoates/papers/coatesleeng_
aistats_2011.pdf.
Stuart Lloyd. Least squares quantization in PCM. IEEE Transactions on Information Theory, 28
(2):129–137, 1982. URL http://www.nt.tuwien.ac.at/fileadmin/courses/389075/
Least_Squares_Quantization_in_PCM.pdf.
Daniel Aloise, Amit Deshpande, Pierre Hansen, and Preyas Popat. NP-hardness of Euclidean
sum-of-squares clustering. Machine Learning, 75(2):245–248, 2009. URL http://link.
springer.com/article/10.1007%2Fs10994-009-5103-0.
David Arthur and Sergei Vassilvitskii. k-means++: The advantages of careful seeding. In Proceed-
ings of the Eighteenth Annual ACM-SIAM Symposium on Discrete Algorithms, pages 1027–1035. So-
ciety for Industrial and Applied Mathematics, 2007. URL http://ilpubs.stanford.edu:
8090/778/1/2006-13.pdf.
Allan D. Gordon. Null models in cluster validation. In From data to knowledge, pages 32–44.
Springer, 1996.
Greg Hamerly and Charles Elkan. Learning the k in k-means. Advances in neural information
processing systems, 16:281, 2004.
Robert Tibshirani, Guenther Walther, and Trevor Hastie. Estimating the number of clusters in a
data set via the gap statistic. Journal of the Royal Statistical Society: Series B (Statistical Methodol-
ogy), 63(2):411–423, 2001. URL http://www.stanford.edu/˜hastie/Papers/gap.pdf.
Andrew Y Ng, Michael I Jordan, Yair Weiss, et al. On spectral clustering: Analysis and an
algorithm. Advances in neural information processing systems, 2:849–856, 2002. URL http:
//machinelearning.wustl.edu/mlpapers/paper_files/nips02-AA35.pdf.
Ulrike Von Luxburg. A tutorial on spectral clustering. Statistics and computing, 17(4):395–
416, 2007. URL http://web.mit.edu/˜wingated/www/introductions/tutorial_
on_spectral_clustering.pdf.
Brendan J Frey and Delbert Dueck. Clustering by passing messages between data points. science,
315(5814):972–976, 2007.
John A Hartigan. Direct clustering of a data matrix. Journal of the american statistical association, 67
(337):123–129, 1972.
16