Around IT in 256 seconds

#93: K-means clustering: machine learning algorithm to easily split observations into multiple buckets

January 11, 2023 | 3 Minute Read

K-means clustering is an algorithm for partitioning data into multiple, non-overlapping buckets. For example, if you have a bunch of points in two-dimensional space, this algorithm can easily find concentrated clusters of points. To be honest, that’s quite a simple task for humans. Just plot all the points on a piece of paper and find areas with higher density. For example, most of the points are located on the top-left of the plane, some at the bottom and a few at the centre-right. However, this is not that straightforward once you can no longer rely on graphical representation. For instance, when your data points live 3-, 4- or 100-dimensional space. Turns out, this is not that uncommon. Let me clarify.

The example, simple flat space is pretty common and useful. Let’s say you’re a city mayor and can afford to build 3 hospitals. The K-means algorithm examines where every single citizen lives. It then splits all citizens into groups, most likely, but not necessarily by city districts. These groups are called clusters. Moreover, the algorithm even tells what is the middle point of each group. That point is called the centroid and, on average, it’s the closest point to everyone within that group.

OK, but what about the promised hundred dimensions? Let’s think about a different example. A bank has a bunch of customers and they have a very detailed profile of each customer. It includes age, gender, marital status, income, credit score, and place of living… you name it. It’s the XXI century, assume every big company knows more about you than yourself. Anyway, each piece of information can be treated as a dimension. On the income dimension, some customers are on the left, and some are on the right. The same goes for age, credit score, etc. Sometimes you must get creative, for example, -1 for men and +1 for women on gender dimension.

At this point (no pun intended), every customer is a data point in a multi-dimensional space. Something that’s impossible to draw and visualize. But you can still calculate the distance between two points using abstract formula. Without diving into math, it’s enough to say that the more dimensions are similar to each other, the closer the points are. If you now feed these data points into K-means algorithm, it will split customers into distinct clusters. You must decide on the number of clusters in advance. But you don’t know what kind of clusters you’ll get. For this domain, we can assume it’ll easily distinguish, for example, between high-income households living in the suburbs, students renting in cheaper locations and blue-collar workers in their 40s.

Of course, this is a gross oversimplification. In real life, we can expect tens of clusters with more subtle characteristics. Such insight into data will help every organization target specific groups with better advertising and products. Of course.

OK, but how does K-means algorithm work? It’s actually pretty simple to implement:

  1. First, randomly assign a cluster number to each data point (from 1 to K)
  2. Then, find the centroid of each cluster. Centroid is, on average, the closest point to each observation within that cluster
  3. Finally, reassign points to their closest centroid.

Now, repeat the process of finding the centroid and reassigning. The algorithm should converge rather quickly and stabilize. However, it’s not guaranteed to produce an optimal solution.

That’s it, thanks for listening, bye!

More materials

Tags: data-science, k-means, machine-learning

Be the first to listen to new episodes!

To get exclusive content: