Open In App

K-means++ Algorithm - ML

Last Updated : 25 May, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

In Clustering we group similar data points together. K-Means Clustering is one of the simplest and most popular clustering algorithms but it has one major drawback — the random initialization of cluster centers often leads to poor clustering results. Some clusters may have no points or multiple centroids may end up in the same cluster.For example consider the images shown below done using K-means clustering.

Here we can see clusters are not forming properly. To solve this problem KMeans++ was introduced which improves the way initial cluster centers are selected and make the clustering results more accurate and faster. This is how the clustering should have been:  

KMeans++ is an improved version of the KMeans algorithm that automatically chooses better starting points instead of selecting them randomly. The key idea behind KMeans++ is that it chooses the initial cluster centers in a smart way to ensure they are spread out which helps the algorithm converge faster and gives better clustering results.

How K-mean++ Algorithm Works

The KMeans++ algorithm works in two steps:

1. Initialization Step:

  • Choose the first cluster center randomly from the data points.
  • For each remaining cluster center select the next center based on the probability that is proportional to the square of the distance between the data point and the closest selected center.

2. Clustering Step:

  • After selecting the initial centers KMeans++ performs clustering the same way as KMeans
  • Assign each data point to the nearest cluster center.
  • Recalculate cluster centers by finding the average of all points in each cluster.
  • Repeat the steps until the cluster centers do not change or a fixed number of iterations is reached.

Note: Although the initialization in K-means++ is computationally more expensive than the standard K-means algorithm, the run-time for convergence to optimum is drastically reduced for K-means++. This is because the centroids that are initially chosen are likely to lie in different clusters already.

Implementation in Python

Let's understand how KMeans++ initializes centroids step by step using the following implementation:

1. Dataset Creation

Four separate Gaussian clusters are generated with different means and covariances to simulate different groupings in the data.

Python
import numpy as np
import matplotlib.pyplot as plt

mean_01 = np.array([0.0, 0.0])
cov_01 = np.array([[1, 0.3], [0.3, 1]])
dist_01 = np.random.multivariate_normal(mean_01, cov_01, 100)

mean_02 = np.array([6.0, 7.0])
cov_02 = np.array([[1.5, 0.3], [0.3, 1]])
dist_02 = np.random.multivariate_normal(mean_02, cov_02, 100)

mean_03 = np.array([7.0, -5.0])
dist_03 = np.random.multivariate_normal(mean_03, cov_01, 100)

mean_04 = np.array([2.0, -7.0])
cov_04 = np.array([[1.2, 0.5], [0.5, 1.3]])
dist_04 = np.random.multivariate_normal(mean_04, cov_01, 100)

data = np.vstack((dist_01, dist_02, dist_03, dist_04))
np.random.shuffle(data)

2. Plotting Helper Function

This function is used to visualize the data points and the selected centroids at each step. All data points are shown in gray.

  • Previously selected centroids are marked in black.
  • The current centroid being added is marked in red.
  • This helps visualize the centroid initialization process step by step.
Python
def plot(data, centroids):
    plt.scatter(data[:, 0], data[:, 1], marker='.', color='gray', label='Data Points')
    if centroids.shape[0] > 1:
        plt.scatter(centroids[:-1, 0], centroids[:-1, 1], color='black', label='Selected Centroids')
    plt.scatter(centroids[-1, 0], centroids[-1, 1], color='red', label='Next Centroid')
    plt.title(f'Select {centroids.shape[0]}th Centroid')
    plt.legend()
    plt.xlim(-5, 12)
    plt.ylim(-10, 15)
    plt.show()

3. Euclidean Distance Function

This is a standard formula to compute the distance between two vectors p1 and p2 in 2D space.

Python
def distance(p1, p2):
    return np.sqrt(np.sum((p1 - p2)**2))

4. K-Means++ Initialization

This function selects initial centroids using the K-Means++ strategy. The first centroid is chosen randomly from the dataset. For the next centroids:

  • It calculates the distance of every point to its nearest existing centroid.
  • Chooses the point farthest from the nearest centroid as the next centroid and ensures centroids are spaced far apart initially, giving better cluster separation.


Python
import sys

def initialize(data, k):
    centroids = []
    centroids.append(data[np.random.randint(data.shape[0])])
    plot(data, np.array(centroids))

    for _ in range(k - 1):
        distances = []
        for point in data:
            min_dist = min([distance(point, c) for c in centroids])
            distances.append(min_dist)
        
        next_centroid = data[np.argmax(distances)]
        centroids.append(next_centroid)
        plot(data, np.array(centroids))
    
    return np.array(centroids)

# Run initialization
centroids = initialize(data, k=4)

Output: 

It shows the dataset with the first randomly selected centroid (in red). No black points are visible since only one centroid is selected.

The second centroid is selected which is the farthest point from the first centroid. The first centroid becomes black and the new centroid is marked in red

The third centroid is selected. The two previously selected centroids are shown in black while the newly selected centroid is in red.

The final centroid is selected completing the initialization. Three previously selected centroids are in black and the last selected centroid is in red.

Applications of k-means++ algorithm

  • Image segmentation: It can be used to segment images into different regions based on their color or texture features. This is useful in computer vision applications, such as object recognition or tracking.
  • Customer segmentation: These are used to group customers into different segments based on their purchasing habits, demographic data, or other characteristics. This is useful in marketing and advertising applications, as it can help businesses target their marketing efforts more effectively.
  • Recommender systems: K-means++ can be used to recommend products or services to users based on their past purchases or preferences. This is useful in e-commerce and online advertising applications.        

Next Article

Similar Reads