October 23, 2012 17:23 / 17 comments / algorithms python

I'm working on a little photography website for my Dad and thought it would be neat to extract color information from photographs. I tried a couple of different approaches before finding one that works pretty well. This approach uses k-means clustering to cluster the pixels in groups based on their color. The center of those resulting clusters are then the "dominant" colors. k-means is a great fit for this problem because it is (usually) fast. It has the caveat of requiring you to specify up-front how many clusters you want -- I found that it works well when I specified around 3.

A warning

I'm no expert on data-mining -- almost all my experience comes from reading Toby Segaran's excellent book Programming Collective Intelligence. In one of the first chapters Toby covers clustering algorithms, including a nice treatment of k-means, so if you want to really learn from an expert I'd suggest picking up a copy. You won't be disappointed.

How it works

The way I understand it to work is you start with a bunch of data points. For simplicity let's say they're numbers on a number-line. You want to group the numbers into "k" clusters, so pick "k" points randomly from the data to use as your "clusters".

Now loop over every point in the data and calculate its distance to each of the "k" clusters. Find the nearest cluster and associate that point with the cluster. When you've looped over all the points they should all be assigned to one of the "k" clusters. Now, for each cluster recalculate its center by averaging the distances of all the associated points and start over.

When the centers stop moving very much you can stop looping. You will end up with something like this -- the points are colored based on what "cluster" they are in and the dark-black circles indicate the centers of each cluster.

K-Means

Applying it to photographs

The neat thing about this algorithm is, since it relies only on a simple distance calculation, you can extend it out to multi-dimensional data. Color is often represented using 3 channels, Red, Green, and Blue. So what I did was treat all the pixels in the image like points on a 3-dimensional space. That's all there was to it!

I made a few optimizations along the way:

  1. resize the image down to 200x200 or so using PIL
  2. instead of storing "duplicate" points, store a count with each -- saves on calculations

Looking at some results

Akira motorcycles

The results:                                 

Akira motorcycles 2

The results:                                 

Akira 3

The results:                                 

Akira 4

The results:                                 

The source code

Below is the source code. It requires PIL to resize the image down to 200x200 and to extract the colors/counts. The "colorz" function is the one that returns the actual color codes for a filename.

from collections import namedtuple
from math import sqrt
import random
try:
    import Image
except ImportError:
    from PIL import Image

Point = namedtuple('Point', ('coords', 'n', 'ct'))
Cluster = namedtuple('Cluster', ('points', 'center', 'n'))

def get_points(img):
    points = []
    w, h = img.size
    for count, color in img.getcolors(w * h):
        points.append(Point(color, 3, count))
    return points

rtoh = lambda rgb: '#%s' % ''.join(('%02x' % p for p in rgb))

def colorz(filename, n=3):
    img = Image.open(filename)
    img.thumbnail((200, 200))
    w, h = img.size

    points = get_points(img)
    clusters = kmeans(points, n, 1)
    rgbs = [map(int, c.center.coords) for c in clusters]
    return map(rtoh, rgbs)

def euclidean(p1, p2):
    return sqrt(sum([
        (p1.coords[i] - p2.coords[i]) ** 2 for i in range(p1.n)
    ]))

def calculate_center(points, n):
    vals = [0.0 for i in range(n)]
    plen = 0
    for p in points:
        plen += p.ct
        for i in range(n):
            vals[i] += (p.coords[i] * p.ct)
    return Point([(v / plen) for v in vals], n, 1)

def kmeans(points, k, min_diff):
    clusters = [Cluster([p], p, p.n) for p in random.sample(points, k)]

    while 1:
        plists = [[] for i in range(k)]

        for p in points:
            smallest_distance = float('Inf')
            for i in range(k):
                distance = euclidean(p, clusters[i].center)
                if distance < smallest_distance:
                    smallest_distance = distance
                    idx = i
            plists[idx].append(p)

        diff = 0
        for i in range(k):
            old = clusters[i]
            center = calculate_center(plists[i], old.n)
            new = Cluster(plists[i], center, old.n)
            clusters[i] = new
            diff = max(diff, euclidean(old.center, new.center))

        if diff < min_diff:
            break

    return clusters

Playing with it in the browser

I ported the code over to JavaScript -- let me tell you, its pretty rough, but it works and is fast. If you'd like to take a look at a live example, check out:

http://charlesleifer.com/static/colors/ -- you can view the source to see the js version, but basically it is just using the HTML5 canvas and its getImageData method.

Thanks for reading

Thanks for reading, I hope you found this post interesting. I am sure this is not the only approach so if you have other ideas please feel free to leave a comment or contact me directly.

Links

Comments (17)

Stephen | oct 2012, at 12:21am

Neat. I can see using something like this to power those ambient color lights behind some TVs.

Jim Bumgardner | oct 2012, at 03:24pm

I've been using this method for a few years. A neat extension to it enables you to determine a "foreground" and "background color. The cluster with the most pixels near the center of the image is "foreground". The cluster with the most pixels near the edges is "background".

Ryan O'Hern | oct 2012, at 03:54pm

An interesting extension to this would be to weight more highly data points for colors that are "eye grabbing". For example, in the motorbike picture, you have captured the most common three shades in the image, but the blue lights dominate the viewer's perception (at least for me). If I were choosing dominant colors to do a styling template for something like...Warhammer pieces (this is the only thing that comes to mind, strangely) not choosing that bright blue color would be strange, nay? At least as an accent. Perhaps you could capture a fourth data point exclusive to this feature to give exactly that: accent color.

You may know that k-means can mess up and give non-optimal clustering due to local optimums. You might want to add an additional step to run k-means N times and then choose the trial with the min cost value averaged over all data points (where cost is distance to the matched centroid) as your result.

Cool post and simple implementation. I like.

David | oct 2012, at 03:54pm

Would be really nice to see in a gallery. With each image the page adapts the colors around it =)

Adam | oct 2012, at 04:44pm

Consider while True: instead of while 1:. More readable, and doesn't incur a type conversion every iteration.

Mathew Sanders | oct 2012, at 04:47pm

@Ryan I wonder if using HSL instead of RGB as the colour model and somehow shifting the center of the cluster to have more weight on the saturation axis (we tend to perceive saturated colours to 'pop' out more).

@Charles this is a great project! I've often asked developers if we could do something like this to automatically tag images but I've always been told that it's too expensive :( super impressed with your javascript implementation!

Samantha K | oct 2012, at 07:52pm

Can this algorithm be speed up? It runs over 10 seconds on my computer for each image. I've several hundred thousand of images to process. I could wait 1 second per image but 10 is a bit too much.

Phillip Cloud | oct 2012, at 08:13pm

You should try scikits-learn. They have a nicely encapsulated kmeans algorithm. It's very fast.

Marjan | oct 2012, at 08:20pm

This was the coolest thing i have ever seen in making a real use of clustering algorithms. Well done

Hugh Brown | oct 2012, at 08:57pm

def points(img): w, h = img.size return [Point(color, 3, count) for count, color in img.getcolors(w * h)]

ai_ja_nai | oct 2012, at 03:36am

Nice application. But, due to the resize to 200x200 to reduce search space, don't you end to smooth the dataset too much, resulting in a lot of approximations in pixel counts?

Matt W | oct 2012, at 04:19am

We were using similar techniques in the 1990s on analytical microscope data to try to unravel the information held in multivariate signals.

Cushen | oct 2012, at 06:04am

This is a similar application and might be of interest:

http://mkweb.bcgsc.ca/color_summarizer/

Mark Ransom | oct 2012, at 10:27pm

@Samantha, one easy way to speed it up would be to remove the sqrt() function from the euclidean distance calculation. When comparing distances it doesn't matter if you compare the actual distance or the square of the distance.

pit | oct 2012, at 04:19am

Is there online version of this or similar script?

Tim Wilder | oct 2012, at 02:05pm

Great entry. It was succinct, informative, and pretty interesting. Thank you for posting!

Shashank | oct 2012, at 04:33am

Nice post. I guess this is the same method used in Windows7 to shade the taskbar icons on mouse hover.


Commenting has been closed, but please feel free to contact me