Multioutput Classification in scikit-learn: clean noise from pictures

In this notebook, we will explore how to create a model able to clean noise from images. It is very common to associate tasks related to images with deep learning, however, we will explore how a relatively simple technique such as regression can lead to very good results.

Multioutput Classification in scikit-learn: clean noise from pictures - preview image
Photo by [Bernard Hermant]( on [Unsplash](
## Outline: - Introduction - Getting the data - Data preprocessing - Train test split - Model training - Model evaluation - Conclusions ## Introduction We will use the MNIST dataset, a very well knows collection of handwritten images of digits, consisting in 70000 labelled images, "dirty" picture are generated by adding random values to each pixel in the image, using the original, clean images as a target, and the "dirty" images as input. For modelling, we will explore a very well known algorithm: the K-Nearest neighbour, used as a classifier. This type of machine learning can be referred to as multioutput-multiclass classification since this classifier's output consists of one label per pixel and each label can have multiple values (pixel intensity goes from 0 to 255) ## Getting the data: As mentioned before the MNIST dataset is used, since it is very common and widely used, it is possible to download it and get started to use it with little to no preprocessing at all. From version (0.20) scikit-learn provides an API to download data from [OpenML]( Let's see how to use it:
from sklearn.datasets import fetch_openml mnist = fetch_openml('mnist_784', version=1) mnist.keys()
As mentioned before, we are only interested in the images, thus we can skip the labels and only load the `data`:
images = mnist['data'] print("Images array has the following shape: {}".format(images.shape)) print("One image has the following shape: {}".format(images[0].shape))
We can, of course, write a small function to display some images and get a better idea about the data, however, if we want to display images we need to do one small thing: one image is represented as a vector containing 784 elements, however, if we want to display it we need to change its shape, to revert it how it was before being turned into an array, hence we will reshape it to square of 28 x 28 pixels (square root of 784 is 28). Let's write a function that nicely displays 6 random pictures:
import numpy as np import matplotlib.pyplot as plt np.random.seed(0) def plot_imgs(images): # pick 4 random indexes of images indexes = np.random.choice(range(0, len(images)), 4) fig, axs = plt.subplots(1, 4, sharey=True, figsize=(12, 4), dpi=200) axs[0].imshow(images[indexes[0]].reshape(28, 28), cmap='binary') axs[1].imshow(images[indexes[1]].reshape(28, 28), cmap='binary') axs[2].imshow(images[indexes[2]].reshape(28, 28), cmap='binary') axs[3].imshow(images[indexes[3]].reshape(28, 28), cmap='binary') for ax in axs: ax.axis('off')
## Data preprocessing: In this part, we will first split our dataset into training and test, and then we will generate for each set our set of training features. Since we already have the cleaned images (y) we will create the training features (X) by generating a random array of noise which will be added to the original images:
y = images y_train = y[:60000] # take images up to the 60000th image for training y_test = y[60000:] # take images from the 60000th image to the last image noise = np.random.randint(0, 500, (len(y_train), 784)) X_train = y_train + noise # use different (unseen) noise for testing noise = np.random.randint(0, 500, (len(y_test), 784)) X_test = y_test + noise
Let's display some images from the training set side to side with its "label". For this, we will use a slightly modified version of our function we written before, creating a 4x2 grid:
def plot_imgs_with_labels(images, labels): # pick 4 random indexes of images indexes = np.random.choice(range(0, len(images)), 4) fig, axs = plt.subplots(2, 4, figsize=(12, 8), sharey=True, dpi=200) axs[0, 0].imshow(images[indexes[0]].reshape(28, 28), cmap='binary') axs[0, 1].imshow(labels[indexes[0]].reshape(28, 28), cmap='binary') axs[0, 2].imshow(images[indexes[1]].reshape(28, 28), cmap='binary') axs[0, 3].imshow(labels[indexes[1]].reshape(28, 28), cmap='binary') axs[1, 0].imshow(images[indexes[2]].reshape(28, 28), cmap='binary') axs[1, 1].imshow(labels[indexes[2]].reshape(28, 28), cmap='binary') axs[1, 2].imshow(images[indexes[3]].reshape(28, 28), cmap='binary') axs[1, 3].imshow(labels[indexes[3]].reshape(28, 28), cmap='binary') for ax in axs: for idx in ax: idx.axis('off') axs[0, 0].set_title('Dirty image') axs[0, 1].set_title('Clean image') axs[0, 2].set_title('Dirty image') axs[0, 3].set_title('Clean image') plot_imgs_with_labels(X_train, y_train)
As we can see the noise added does its job excellently: it is hard to even for a human to understand which number it is. ## Model training As mentioned before we will use the K-Nearest neighbour classifier for this task. It is an instance-based type of learning, meaning that instead of performing explicit generalization (like algorithms such as regression), compares new problem instances with instances seen during the training phase, which have been stored in the model's memory. The classification is computed in a very simple way: using a majority vote of the nearest neighbour of each point. K-Nearest neighbour is a very common algorithm used for performing also tasks such as clustering and segmentation. Scikit-learn API is very user friendly, and it is enough to import the the KNeighborsClassifier and call the fit method:
from sklearn.neighbors import KNeighborsClassifier knn_clf = KNeighborsClassifier(n_jobs=-1), y_train)
We can now use our fitted model to predict some images from our test set, let's adapt our function to display predicted images along with t
def plot_imgs_with_predictions(images, labels, model): # pick 4 random indexes of images indexes = np.random.choice(range(0, len(images)), 12) fig, axs = plt.subplots(4, 6, figsize=(12, 8), sharey=True) axs[0, 0].imshow(images[indexes[0]].reshape(28, 28), cmap='binary') axs[0, 1].imshow(model.predict([images[indexes[0]]]).reshape(28, 28), cmap='binary') axs[0, 2].imshow(labels[indexes[0]].reshape(28, 28), cmap='binary') axs[1, 0].imshow(images[indexes[1]].reshape(28, 28), cmap='binary') axs[1, 1].imshow(model.predict([images[indexes[1]]]).reshape(28, 28), cmap='binary') axs[1, 2].imshow(labels[indexes[1]].reshape(28, 28), cmap='binary') axs[2, 0].imshow(images[indexes[2]].reshape(28, 28), cmap='binary') axs[2, 1].imshow(model.predict([images[indexes[2]]]).reshape(28, 28), cmap='binary') axs[2, 2].imshow(labels[indexes[2]].reshape(28, 28), cmap='binary') axs[3, 0].imshow(images[indexes[3]].reshape(28, 28), cmap='binary') axs[3, 1].imshow(model.predict([images[indexes[3]]]).reshape(28, 28), cmap='binary') axs[3, 2].imshow(labels[indexes[3]].reshape(28, 28), cmap='binary') axs[0, 3].imshow(images[indexes[4]].reshape(28, 28), cmap='binary') axs[0, 4].imshow(model.predict([images[indexes[4]]]).reshape(28, 28), cmap='binary') axs[0, 5].imshow(labels[indexes[4]].reshape(28, 28), cmap='binary') axs[1, 3].imshow(images[indexes[5]].reshape(28, 28), cmap='binary') axs[1, 4].imshow(model.predict([images[indexes[5]]]).reshape(28, 28), cmap='binary') axs[1, 5].imshow(labels[indexes[5]].reshape(28, 28), cmap='binary') axs[2, 3].imshow(images[indexes[5]].reshape(28, 28), cmap='binary') axs[2, 4].imshow(model.predict([images[indexes[6]]]).reshape(28, 28), cmap='binary') axs[2, 5].imshow(labels[indexes[5]].reshape(28, 28), cmap='binary') axs[3, 3].imshow(images[indexes[5]].reshape(28, 28), cmap='binary') axs[3, 4].imshow(model.predict([images[indexes[7]]]).reshape(28, 28), cmap='binary') axs[3, 5].imshow(labels[indexes[5]].reshape(28, 28), cmap='binary') axs[0, 0].set_title('input image') axs[0, 1].set_title('predicted image') axs[0, 2].set_title('true image') axs[0, 3].set_title('input image') axs[0, 4].set_title('predicted image') axs[0, 5].set_title('true image') for ax in axs: for idx in ax: idx.axis('off') plot_imgs_with_predictions(X_test, y_test, knn_clf)
We can see that our classifier does a pretty decent job at reconstructing damaged images, however, some images are messed up: pixels are not reconstructed correctly and some numbers have been reconstructed as another number. It is possible to try a different model, however not all scikit-learn model handles multioutput natively, that is why we need to wrap our new classifier with the MultiOutputClassifier, this workaround consists in fitting one classifier for each target. ## Model evaluation In order to evaluate the performance of this model, we need to find a metric that makes sense. Normally we could use a metric such as the mean average error, it would return an error in the same unit of measure like the one we are trying to predict (if we are predicting house prices, and the MAE is 6 it means we are on average off by 6$) however, in this case, a custom metrics (or a variation of an existing one) need to be created. A very simple approach in evaluation could be to take the true image and subtract value from the predicted image. If the classifier worked perfectly the difference for each pixel would be 0 and the difference would return an array of 0s:
perfect_prediction = y_test[0] perfect_diff = [y_test[0]] - perfect_prediction plt.imshow(perfect_diff.reshape(28, 28), cmap='binary') plt.title('Perfect image prediction')
Let's apply the same logic to an image:
img = np.random.randint(0, len(y_test)) # pick a random image - everytime you run it changes - fig, axs = plt.subplots(1, 4, figsize=(12, 3), sharey=True) axs[0].imshow(y_test[img].reshape(28, 28), cmap='binary') axs[0].set_title('Test image') axs[1].imshow(X_test[img].reshape(28, 28), cmap='binary') axs[1].set_title('model input') predicted = knn_clf.predict([X_test[img]]) axs[2].imshow(predicted.reshape(28, 28), cmap='binary') axs[2].set_title('model output') diff = y_test[img] - predicted axs[3].imshow(np.maximum(diff, 0).reshape(28, 28), cmap='binary') axs[3].set_title('Difference')
You can run the notebook yourself to pick another image from the test set! From this visualization, it is possible to see where our model failed in reconstructing our image. It is possible to further flatten the `predicted` array, and in this way, we can apply a function to calculate a score. This way gives us the possibility to have performance metrics, if we would try another model, we would have a way to understand which model performed better. Another good strategy would be to use a metric such as the root mean squared error since when removing noise from our digits being off by a small number (meaning our predicted image will be slightly darker or lighter in a specific part) is not such a big deal.
from sklearn.metrics import mean_squared_error mse = mean_squared_error(y_test[0], predicted.flatten()) # flattening the array to 1D print("Root mean squared error is: {:.2f}".format(np.sqrt(mse)))
## Conclusions A simple algorithm such a K-Nearest neighbour can still lead to appreciable results. Even before applying more advanced techniques, it is advised (and helps understand the problem better) to start with simple algorithms. The task performed in this tutorial shows how the line between classification and regression is sometimes blurry. Predicting pixel intensity seems to be more a regression task than a classification one, however since pixels can take values between 0 and 255 this opens rooms for discussion.

Run this article as a notebook

Deepnote is a new kind of data science notebook. Jupyter-compatible and with real-time collaboration.

Sign-up for the waitlist below, or find out more here.

To be continued...