Building a nearest neighbour classifier in Python

The nearest neighbour classifier is a very simple algorithm for image classification. While not used much in practice, it is simple to implement and it helps to gain a deeper understanding of the problems in image classification.

Just like other classifiers, if we give the nearest neighbour classifier an image, it’ll try and find its closest match. We “train” the classifier by giving it a large collection of images that it allowed to search through to find matches. The bigger the number of images that we give it, the closer the match that the classifier is able to find, and conversely the longer it will take to find that match.

How do you compare images? An image is basically a large matrix filled with pixels. Colours are made up of different combinations of the three primary colours; correspondingly, each pixel is made up of a red, green and blue component. A 32x32 image will be represented as three 32x32 matrices; one for red, green and blue.

Once this is understood it is easy to find ways to compare images. One way is for an image pair, we could find the differences between the two in the red, green and blue layers, and then sum up their absolute values. This is called $L^1$ distance and is probably the simplest method to compare two images.

The Nearest Neighbour classifier is not a particulary efficient or useful classifier. It will do better than chance, but will achieve nowhere near the performance of other classifiers.

Setup Link to heading

We are following along with the Stanford tutorial. We will use a dataset called CIFAR-10; it comprises of 60,000 labelled tiny images (32x32 pixels) across ten classes, split into training and test batches. Each class has 5,000 training images available.

The first thing to do is download the dataset and unzip it to a folder. CIFAR-10 comes in ‘pickled’ form, meaning that it is stored as a byte stream. This not a convenient format for us and we will have to convert the byte stream into something useable (depickling) before we can use it.

First let’s import the libraries we’ll want:

import pickle
import os
import sys
import glob
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

The data comes in a number of batches - one for the test set and a few for the training set. Each batch file contains a dictionary with the following:

  • data: a 10000x3072 numpy array of uint8s, where each row of the array stores a 32x32 colour image. The first 1024 entries contain the red channel values, the next 1024 green, and the final 1024 blue.
  • labels: a list of 10000 numbers in the range 0-9. The number at index i indicates the label of the ith image in the array data.

We will create a helper function for use on unpickling our batches, which will return a dictonary for each batch. Dictonaries aren’t that helpful for us to work with, so we’ll change the data into a numpy array instead.

We’ll create three functions to unpack the data:

  • unpickle: a function for unpickling each batch
  • load_CIFAR_batch: a function for extracting the data and labels from each batch
  • load_CIFAR10: a function for loading the data and labels of the entire dataset (separating train and test batches)

Each row of the returned data will contain one image. This format will be hard for us to visualise - so we’ll reshape each row to 32x32 blocks to allow us to visualise the images.

def unpickle(file):
    Unpickles a file stream
    :param file: file path
    :return: dict with keys {batch_label, labels, data, filename}
    import pickle
    with open(file, 'rb') as fo: 
        dict = pickle.load(fo, encoding='latin1')
        return dict

def load_CIFAR_batch(file): 
    Load data and labels from a batch file
    :param file: file path
        X: ndarray of shape (number of images in batch, (3x32x32)) 
        Y: ndarray of shape (number of images in batch) 
        label_names: vector of label_names
    file_dict = unpickle(file)
    X = file_dict['data']
    Y = file_dict['labels']
    return X, Y

def load_CIFAR10(root):
    Load the CIFAR10 dataset. 
    Currently handles multiple train batches, but only one test batch.
    :param root: path to the root folder containing the CIFAR10 dataset
      x_train: ndarray of shape (number of train images, 32, 32, 3) holding RGB values of images
      y_train: ndarray of shape (number of train images, ) holding labels for the training set 
      x_test:  ndarray of shape (number of test images, 32, 32, 3) holding RGB values of images
      y_test:  ndarray of shape (number of test images, ) holding labels for the testing set 
    train_batch_list = glob.glob('data_batch*')
    test_batch = glob.glob('test_batch*')
    # Training set 
    x_train = ''
    y_train = ''
    for file in train_batch_list:
        x_batch, y_batch = load_CIFAR_batch(file)
        if (x_train == ''):
            x_train = x_batch
            y_train = y_batch
            x_train = np.concatenate((x_train, x_batch))
            y_train = np.concatenate((y_train, y_batch))
    # Change x_train from n_image * col format to a n_image * 3x32x32 format
    x_train = x_train.reshape((x_train.shape[0], 3, 32, 32)).transpose(0,2,3,1)
    # Test Set 
    x_test, y_test = load_CIFAR_batch(test_batch[0])
    # change x_test from n_image * col format to a n_image * 3x32x32 format
    x_test = x_test.reshape((x_test.shape[0], 3, 32, 32)).transpose(0,2,3,1)
    return x_train, y_train, x_test, y_test

We can go ahead and load up our data now.

# Get filenames of batches 
path_data = '/Users/tomroth/Documents/deeplearning_courses/cs231n_exercises/cifar-10-batches-py/'
x_train, y_train, x_test, y_test = load_CIFAR10(path_data)

Let’s take a look at an image to see if it loaded correctly.


It’s worth noting that if we change x_train to a float type, then the image will turn out negative, which can really mess with you if you’re not expecting it.

x_train_float = x_train.astype(float)

Building the nearest neighbour classifier Link to heading

We are now ready to build our classifier. We do this through creating a class that contains:

  • a constructor (the __init__ method)
  • a train method
  • a predict method

After creating this it is easy to use the classifier. We will create a NearestNeighbour object and then train and predict with the object’s train and predict methods.

class NearestNeighbour(object):
    def __init__(self):
    def train(self, x_train, y_train):
        """ x_train is shape NxD, where N is the number of images rows and D=3x32x32. 
            Asssumes x_train is flattened out - a 2d array with one row per image 
            This function remembers the training data, and that is all. """
        self.x_train = x_train
        self.y_train = y_train
    def predict(self, x_test):
        Compare each image in the test set with every image in the training set.
        Asssumes x_test is flattened out - a 2d array with one row per image 
        y_predicted_classes = []
        for i in x_test:
            differences = abs(self.x_train - i)
            differences_rowsums = np.sum(differences, axis = 1)  # 1D array
            closest_image_index = differences_rowsums.argmin()
            # track progress
            progress = len(y_predicted_classes)
            if (progress % 100 == 0):
                print('Progress: %f' % (progress / x_test.shape[0]))
        return y_predicted_classes

The Nearest Neighbour classifier requires the data is flattened out - one row per image in a 2D array.

x_train_rows = x_train.reshape((x_train.shape[0], 3*32*32))
x_test_rows = x_test.reshape((x_test.shape[0], 3*32*32))

Now we can train our model and make some predictions for our test set. It takes a while to compute predictions on the test set, so we’ll only test half the observations in the test set to get a general idea of the classification accuracy.

nn = NearestNeighbour()
nn.train(x_train_rows, y_train)
predictions = nn.predict(x_test_rows[0:5000])

We can evaluate our accuracy by comparing our predictions against the labelled test set.

pred_array = np.array(predictions)
sum(pred_array == y_test[0:5000]) / 5000

For me, I scored 25% accuracy using this classifier, which is better than guessing randomly (10%), but not by much! We could improve things a little by generalising to a k-nearest neighbours classifier, which instead of finding the closest image to our test set, instead finds the k closest images and looks for a consensus in their labels. That would certainly improve our accuracy, but in reality there are classifiers that are much more powerful than this one and we really should just use them.