The K-Means Clustering Algorithm in Java - Baeldung
The K-Means Clustering Algorithm in Java - Baeldung
(/)
Java
(/)
(/)
(/)
Artificial (/)
Intelligence (https://www.baeldung.com/category/artificial-intelligence)
Get started with Spring and Spring Boot, through the Learn Spring
course:
>> CHECK OUT THE COURSE (/ls-course-start)
1. Overview
Clustering is an umbrella term for a class of unsupervised algorithms to discover groups of things,
people, or ideas that are closely related to each other.
In this apparently simple one-liner definition, we saw a few buzzwords. What exactly is clustering?
What is an unsupervised algorithm?
In this tutorial, we’re going to, first, shed some lights on these concepts. Then, we’ll see how they can
manifest themselves in Java.
2. Unsupervised
(/)
Algorithms
Before we use most
Startlearning algorithms, we should
Here (/start-here) somehow feedGuides
Courses some sample data to
▼
them and
About ▼ ▼
allow the algorithm to learn from those data. In Machine Learning terminology, we call that sample
(/)
dataset training data. Also, the whole process is known as the training process.
Anyway, we can classify learning algorithms based on the amount of supervision they need during
the training process. The two main types of learning algorithms in this category are:
Supervised Learning: In supervised algorithms, the training data should include the actual
solution for each point. For example, if we’re about to train our spam filtering algorithm, we feed
both the sample emails and their label, i.e. spam or not-spam, to the algorithm. Mathematically
speaking, we’re going to infer the f(x) from a training set including both xs and ys.
Unsupervised Learning: When there are no labels in training data, then the algorithm is an
unsupervised one. For example, we have plenty of data about musicians and we’re going
discover groups of similar musicians in the data.
3. Clustering (/)
supervised algorithms, we’re not training clustering algorithms with examples of known labels.
(/)
Instead, clustering tries to find structures within a training set where no point of the data is the label.
(/)
(/)
(/wp-content/uploads/2019/08/Date-6.png)
K-Means begins with k randomly placed centroids. Centroids, as their name suggests, are the center
points of the clusters. For example, here we’re adding four random centroids:
(/wp-content/uploads/2019/08/Date-7.png)
Then we assign each existing data point to its nearest centroid:
https://w w w .baeldung.com/java-k-means-clustering-algorithm 6/38
3/10/24, 6:35 AM The K-Means Clustering Algorithm in Java | Baeldung
(/)
(/)
(/wp-content/uploads/2019/08/Date-8.png)
After the assignment, we move the centroids to the average location of points assigned to it.
Remember, centroids are supposed to be the center points of clusters:
(/)
(/)
(/wp-content/uploads/2019/08/Date-10.png)
The current iteration concludes each time we’re done relocating the centroids. We repeat these
iterations until the assignment between multiple consecutive iterations stops changing:
(/)
(/)
(/wp-content/uploads/2019/08/Date-copy.png)
When the algorithm terminates, those four clusters are found as expected. Now that we know how
K-Means works, let’s implement it in Java.
When modeling different training datasets, we need a data structure to represent model attributes
(/) values. For example, a musician can have a genre attribute with a value like
and their corresponding
Rock. We usually use the term feature to refer to the combination of an attribute and its value.
Start Here (/start-here) Courses Guides
▼ About ▼ ▼
(/)
To prepare a dataset for a particular learning algorithm, we usually use a common set of numerical
attributes that can be used to compare different items. For example, if we let our users tag each artist
with a genre, then at the end of the day, we can count how many times each artist is tagged with a
specific genre:
(/wp-content/uploads/2019/08/Screen-Shot-1398-04-29-at-22.30.58.png)
The feature vector for an artist like Linkin Park is [rock -> 7890, nu-metal -> 700, alternative -> 520,
pop -> 3]. So if we could find a way to represent attributes as numerical values, then we can simply
compare two different items, e.g. artists, by comparing their corresponding vector entries.
https://w w w .baeldung.com/java-k-means-clustering-algorithm 10/38
3/10/24, 6:35 AM The K-Means Clustering Algorithm in Java | Baeldung
Since numeric vectors are such versatile data structures, we’re going to represent features using
(/)
them. Here’s how we implement feature vectors in Java:
(/wp-
content/uploads/2019/08/4febdae84cbc320c19dd13eac5060a984fd438d8.svg)
Let’s implement this function in Java. First, the abstraction:
In addition to Euclidean distance, there are other approaches to compute the distance or similarity
between different items(/) like the Pearson Correlation Coefficient (/cs/correlation-coefficient). This
abstraction makes it easy to switch between different distance metrics.
Start Here (/start-here)
Let’s see the implementation Courses
for Euclidean distance: Guides
▼ About ▼ ▼
(/)
public class EuclideanDistance implements Distance {
@Override
public double calculate(Map<String, Double> f1, Map<String, Double> f2) {
double sum = 0;
for (String key : f1.keySet()) {
Double v1 = f1.get(key);
Double v2 = f2.get(key);
return Math.sqrt(sum);
}
}
First, we calculate the sum of squared differences between corresponding entries. Then, by applying
the sqrt function, we compute the actual Euclidean distance.
(/)
(/)
Now that we have a few necessary abstractions in place, it’s time to write our K-Means
implementation. Here’s a quick look at our method signature:
Although each centroid can contain totally random coordinates, it’s a good practice to generate
(/)
random coordinates between the minimum and maximum possible values for each attribute.
Generating random centroids without considering the range of possible values would cause the
algorithm to converge more(/start-here)
Start Here slowly. Courses Guides
▼ About ▼ ▼
First, we should
(/) compute the minimum and maximum value for each attribute, and then, generate
the random values between each pair of them:
for (/)
(Record record : records) {
record.getFeatures().forEach((key, value) -> {
// compares the value with the current max and choose the bigger value between them
maxs.compute(key, (k1, max) -> max == null || value > max ? value : max);
// compare the value with the current min and choose the smaller value between them
mins.compute(key, (k1, min) -> min == null || value < min ? value : min);
});
}
centroids.add(new Centroid(coordinates));
}
return centroids;
}
3.7. Assignment
(/)
First off, given a Record, we should find the centroid nearest to it:
Start Here (/start-here) Courses ▼ Guides ▼ About ▼
(/)
return nearest;
}
list.add(record);
return list;
});
}
assigned (/)records:
(/)
This simple one-liner iterates through all centroids, relocates them, and returns the new centroids.
(/)
List<Centroid> centroids = randomCentroids(records, k);
Map<Centroid, List<Record>> clusters = new HashMap<>();
Map<Centroid, List<Record>> lastState = new HashMap<>();
// in each iteration we should find the nearest centroid for each record
for (Record record : records) {
Centroid centroid = nearestCentroid(record, centroids, distance);
assignToCluster(clusters, record, centroid);
}
return lastState;
}
4. Example: Discovering
(/)
Similar Artists on Last.fm
Last.fm builds a Start
detailed
Hereprofile of each user’s musical
(/start-here) Coursestaste by recording
▼ Guidesdetails of what About
the user▼ ▼
listens to. In this section, we’re going to find clusters of similar artists. To build a dataset appropriate for
(/) ll use three APIs from Last.fm:
this task, we’
1. API to get a collection of top artists (https://www.last.fm/api/show/chart.getTopArtists) on
Last.fm.
2. Another API to find popular tags (https://www.last.fm/api/show/chart.getTopTags). Each user
can tag an artist with something, e.g. rock. So, Last.fm maintains a database of those tags and
their frequencies.
3. And an API to get the top tags for an artist (https://www.last.fm/api/show/artist.getTopTags),
ordered by popularity. Since there are many such tags, we’ll only keep those tags that are among
the top global tags.
(/)
(/)
@GET("/2.0/?method=chart.gettopartists&format=json&limit=50")
Call<Artists> topArtists(@Query("page") int page);
@GET("/2.0/?method=artist.gettoptags&format=json&limit=20&autocorrect=1")
Call<Tags> topTagsFor(@Query("artist") String artist);
@GET("/2.0/?method=chart.gettoptags&format=json&limit=100")
Call<TopTags> topTags();
return artists;
}
Finally, we can build a dataset of artists along with their tag frequencies:
return records;
}
System.out.println();
System.out.println();
});
If we run this code, then it would visualize the clusters as text output:
------------------------------
(/) CLUSTER -----------------------------------
Centroid {classic rock=65.58333333333333, rock=64.41666666666667, british=20.333333333333332, ... }
David Bowie, Led Zeppelin, Pink Floyd, System of a Down, Queen, blink-182, The Rolling Stones,
Metallica, Start Here (/start-here) Courses ▼ Guides ▼ About ▼
Fleetwood Mac, The Beatles, Elton John, The Clash
(/)
------------------------------ CLUSTER -----------------------------------
Centroid {Hip-Hop=97.21428571428571, rap=64.85714285714286, hip hop=29.285714285714285, ... }
Kanye West, Post Malone, Childish Gambino, Lil Nas X, A$AP Rocky, Lizzo, xxxtentacion,
Travi$ Scott, Tyler, the Creator, Eminem, Frank Ocean, Kendrick Lamar, Nicki Minaj, Drake
(/)
------------------------------ CLUSTER -----------------------------------
Centroid {rock=87.38888888888889, alternative=72.11111111111111, alternative rock=49.16666666, ...
}
Start
Weezer, The White Here (/start-here)
Stripes, Courses
Nirvana, Foo Fighters, Guides
Maroon ▼5, Oasis, Panic! at the
▼ About ▼
Disco, Gorillaz,
Green Day, The Cure, Fall Out Boy, OneRepublic, Paramore, Coldplay, Radiohead, Linkin Park,
(/)
Red Hot Chili Peppers, Muse
Since centroid coordinations are sorted by the average tag frequency, we can easily spot the
dominant genre in each cluster. For example, the last cluster is a cluster of a good old rock-bands, or
the second one is filled with rap stars.
Although this clustering makes sense, for the most part, it’s not perfect since the data is merely
collected from user behavior.
5. Visualization
A few moments ago, our algorithm visualized the cluster of artists in a terminal-friendly way. If we
convert our cluster configuration to JSON and feed it to D3.js, then with a few lines of JavaScript, we’ll
have a nice human-friendly Radial Tidy-Tree (https://observablehq.com/@d3/radial-tidy-tree?
collection=@d3/d3-hierarchy):
(/)
(/)
(/)
(/)
(/)
(/wp-content/uploads/2019/08/Screen-Shot-1398-05-04-at-12.09.40.png)
We have to convert our Map<Centroid, List<Record>> to a JSON with a similar schema like this d3.js
example (https:/Start
/raw.githubusercontent.com/d3/d3-hierarchy/v1.1.8/test/data/flare.json).
Here (/start-here) Courses Guides ▼ About ▼ ▼
(/)
6. Number of Clusters
One of the fundamental properties of K-Means is the fact that we should define the number of
clusters in advance. So far, we used a static value for k, but determining this value can be a
challenging problem. There are two common ways to calculate the number of clusters:
1. Domain Knowledge
2. Mathematical Heuristics
If we’re lucky enough that we know so much about the domain, then we might be able to simply
guess the right number. Otherwise, we can apply a few heuristics like Elbow Method or Silhouette
Method to get a sense on the number of clusters.
Before going any further, we should know that these heuristics, although useful, are just heuristics
and may not provide clear-cut answers.
One way to perform this distance calculation is to use the Sum of Squared Errors. Sum of squared
errors or SSE is equal (/)
to the sum of squared differences between a centroid and all its members:
Start Here
public static double (/start-here)List<Record>>
sse(Map<Centroid, Courses
clustered,
▼ Guides
Distance distance)
▼ { About ▼
double sum = 0;
(/)
for (Map.Entry<Centroid, List<Record>> entry : clustered.entrySet()) {
Centroid centroid = entry.getKey();
for (Record record : entry.getValue()) {
double d = distance.calculate(centroid.getCoordinates(), record.getFeatures());
sum += Math.pow(d, 2);
}
}
return sum;
}
Then, we can run the K-Means algorithm for different values of k and calculate the SSE for each of
them:
At the end of the day, it’s possible to find an appropriate k by plotting the number of clusters against
the SSE:
(/)
(/)
(/)
(/)
(/wp-content/uploads/2019/08/Screen-Shot-1398-05-04-at-17.01.36.png)
Usually, as the number of clusters increases, the distance between cluster members decreases.
However, we can’t choose any arbitrary large values for k, since having multiple clusters with just one
member defeats the whole purpose of clustering.
The idea behind the elbow method is to find an appropriate value for k in a way that the SSE
decreases dramatically around that value. For example, k=9 can be a good candidate here.
7. Conclusion (/)
In this tutorial, first, weHere
Start covered a few importantCourses
(/start-here) concepts in Machine Learning. Then we About
Guides
▼
got ▼ ▼
aquatinted with the mechanics of the K-Means clustering algorithm. Finally, we wrote a simple
(/)
implementation for K-Means, tested our algorithm with a real-world dataset from Last.fm, and
visualized the clustering result in a nice graphical way.
As usual, the sample code is available on our GitHub
(https://github.com/eugenp/tutorials/tree/master/algorithms-modules/algorithms-miscellaneous-
3) project, so make sure to check it out!
Get started with Spring and Spring Boot, through the Learn Spring
course:
>> CHECK OUT THE COURSE (/ls-course-end)
(/)
(/)
guide)
(/)
(/)
COURSES
ALL COURSES (/ALL-COURSES)
ALL BULK COURSES (/ALL-BULK-COURSES)
ALL BULK TEAM COURSES (/ALL-BULK-TEAM-COURSES)
THE COURSES PLATFORM (HTTPS://COURSES.BAELDUNG.COM)
SERIES
JAVA “BACK TO BASICS” TUTORIAL (/JAVA-TUTORIAL)
JACKSON JSON TUTORIAL (/JACKSON)
APACHE HTTPCLIENT TUTORIAL (/HTTPCLIENT-GUIDE)
(/)
ABOUT
ABOUT BAELDUNG (/ABOUT)
THE FULL ARCHIVE (/FULL_ARCHIVE)
EDITORS (/EDITORS)
JOBS (/TAG/ACTIVE-JOB/)
OUR PARTNERS (/PARTNERS)
PARTNER W ITH BAELDUNG (/ADVERTISE)