Skip to main content

Command Palette

Search for a command to run...

Attacking a simple Image Classifier from scratch

Assemble and evaluate a simple machine learning model

Updated
12 min read
Attacking a simple Image Classifier from scratch
C

I hold a PhD in Computer Science and have been published in a variety of international peer-reviewed journals.

AI is going to be a problem. I don't know what will cause the first "big issue"; it might be from a courtroom where a defendant is sent to jail based off erroneous AI-generated data, it could be a death in a medical setting.. but, something is going to happen.

Let's take the existing adversarial AI research (there's been plenty) and make it useful.

I'm here to bring you up to speed.

MNIST dataset

The Modified National Institute of Standards and Technology dataset (or, just 'MNIST') is the most popular beginner dataset used for ML research. It's simply a collection of 60,000 images of handwritten digits.

Each digit is saved as a 28x28 pixel greyscale image, like below:

Source: MNIST

This dataset is perfect for starting out. It's both open-source and small. Its size makes it easy to train on our own - no GPUs or cloud rentals are required.

We'll start by training a hand-crafted model that recognizes handwritten digits. By the way, if it's your first foray into training models, don't despair - it's going to be super simple.

I'll also provide the model weights below. This will allow those in a hurry to bypass the model training - but if it's your first time, give it a shot.

Build a MNIST classifier

Don't forget to install dependencies, including tensorflow and tensorflow_datasets using pip

Downloading MNIST

Let's start by downloading MNIST.

import tensorflow as tf
import tensorflow_datasets as tfds

# MNIST download using TFDS; split into training data and test data
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

This small block grabs the MNIST dataset and splits it up into our training data and our test data. You'll remember our initial discussion that training is used to build the model, whereas test data is used to validate the model's accuracy.

Preprocessing MNIST images

Before we can use the data, we need to preprocess it. This takes in the raw images from the MNIST dataset and converts them into something the model can handle.

Don't overlook this step - in particular, expanding the array to have an additional column (of value 1) 28x28x1.

def preprocess(images, labels):
    # Convert the images to float32
    images = tf.cast(images, tf.float32)
    # Normalize the images to [0, 1]
    images = images / 255.0
    # Add a channel dimension, images will have shape (28, 28, 1)
    images = tf.expand_dims(images, -1)
    return images, labels

# Apply the preprocess function to our training and testing data
ds_test = ds_test.map(preprocess)
ds_train = ds_train.map(preprocess)

ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_test = ds_test.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

Building the model!

Okay, we have the data and have prepared our datasets - but we don't have a model yet. Let's build one using Keras (which is just a wrapper around TensorFlow).

## create and tune the model
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

Here, we define a Neural Network (NN) that has three layers. The first, the input layer, is expecting a shape of (28, 28). This matches our dataset of images with the same dimensions.

The second layer is a 'hidden layer'. We've defined 128 nodes whose activation function is a Rectified Linear Unit. It's the most popular activation function because of its simplicity and its effectiveness for deep-learning tasks. A simple way to think about it is that we've defined a wide net of filters (128 to be exact). The filters update during training to either pass along inputs to the next layer or to prevent inputs from moving on. Updating these filters (or, weights) is called backpropagation, and is the heart of ML training. A complete course is outside the scope of what we'll do here, but there are several free excellent resources. Specifically for relu, you can't go wrong with this 2-minute overview: Relu Activation Function.

Finally, the output layer is defined as 10 nodes with a softmax activation function. If you think about what we're doing with this model, we're trying to determine if a given image is a 1, 2, 3, 4, 5, 6, 7, 8, 9, or 0 (for a total of 10 digits). This corresponds to an output node for each of our choices. The 'most activated' output node will be our answer. Note that we're not defining each output node as an answer (such as defining the first node as an image of a 0); rather, the training model will automatically assign an answer for each node based on the labeling within the original training data.

That's a lot of text on NN models - but that's 99% of what we need to discuss for our purposes.

Train the model!!

Finally, we can compile and train the model!

#compile the model 
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
# train the model using our 'training' dataset and validating it with our 'testing' dataset
model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)

That's it! We now have a model that's completely trained. Let's test it out!

Testing our model

Install pyplot using `pip install matplotlib`

import matplotlib.pyplot as plt

# Take 10 examples from the test set
for images, labels in ds_test.take(1):
    # Select 10 images and labels
    test_images = images[:10]
    test_labels = labels[:10]
    predictions = model.predict(test_images)

# Display the images and the model's predictions
plt.figure(figsize=(10, 10))
for i in range(5):
    plt.subplot(1, 5, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(test_images[i].numpy().squeeze(), cmap=plt.cm.binary)
    plt.xlabel(f"Actual: {test_labels[i].numpy()}")
    plt.title(f"Predicted: {np.argmax(predictions[i])}")
plt.tight_layout()
plt.show()

Voila! The Predicted value is the output from our model; the Actual value is from our dataset (MINST).

Okay - so we've built an image recognition model using Keras and a common dataset. Super easy using modern frameworks like TensorFlow and Keras.

Housekeeping

Before we move on to attacks, let's add a little housekeeping code: save the model so we don't have to retrain every time we run our code.

First, take all of our current code and move it to a new function, def train_model(model_path) and add a line to save the model once trained.

It will look something like this:

import tensorflow as tf
import tensorflow_datasets as tfds
import os
import matplotlib.pyplot as plt
import numpy as np

def train_model(model_path):
    # all the code we've written so far; moved into this function
    (ds_train, ds_test), ds_info = tfds.load(
        'mnist',
        split=['train', 'test'],
        shuffle_files=True,
        as_supervised=True,
        with_info=True,
    )

    def preprocess(images, labels):
        # Convert the images to float32
        images = tf.cast(images, tf.float32)
        # Normalize the images to [0, 1]
        images = images / 255.0
        # Add a channel dimension, images will have shape (28, 28, 1)
        images = tf.expand_dims(images, -1)
        return images, labels

    # Apply the preprocess function to our training and testing data
    ds_test = ds_test.map(preprocess)
    ds_train = ds_train.map(preprocess)

    ds_train = ds_train.cache()
    ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
    ds_train = ds_train.batch(128)
    ds_test = ds_test.batch(128)
    ds_train = ds_train.prefetch(tf.data.AUTOTUNE)


    ## create and tune the model
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(
        optimizer=tf.keras.optimizers.Adam(0.001),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )

    model.fit(
        ds_train,
        epochs=6,
        validation_data=ds_test,
    )

    #save the model 
    tf.keras.models.save_model(model, model_path)

Next, let's add the code to load a model if it exists.

def load_model(model_path):
    model = tf.keras.models.load_model(model_path)
    return model

Finally, check if it exists and train a new model if it does not:

model_path = 'mnist-saved-model'
# Check if the model file exists
if not os.path.exists(model_path):
    print(f"The model file {model_path} does not exist. Training now. ")
    # train the model if it doesn't exist yet 
    train_model(model_path)
model = load_model(model_path)

Now our model will be trained and saved to a folder containing a handful of files. I've shared mine below; simply unzip the folder and point your code to the directory (default mnist-saved-model).


Attacking our MNIST classifier model

Instead of thinking about this in terms of attacking some black-box esoteric AI model, I've found the best analogy is we're attacking a specific database. Each database will be drastically different (for example, ChatGPT3.5 vs ChatGPT4), so the fun part of this work comes from the evaluation of each database (aka 'model' or 'algorithm').

💡
Think of it this way: we're attacking a specific database

We're not executing a SQL injection through a WAF. We've already got access to the raw database. So the next question is, how do we execute attacks if we're already at the end goal?

This is where traditional cyber engineers get confused. Our red team objectives are different here. Instead of saying, "Crack a password from this hash", we're saying "Trick the algorithm by using malicious input".

So let's trick the MNIST algorithm we just built.

First, we'll build a wrapper for our MNIST model to take requests over an API so we can build a command-line attack tool. We'll feed it images, it will respond with a value of 0-9.

Second, we'll build a script that talks with the API.

Third, we'll send known good images and test the API and our model.

Finally, we'll build an attack script that will change our input images and look for errors in the output.

(1) Build API wrapper for our model

Building an API to access our model might sound difficult, but it will only take a few lines of Python.

## add the following imports
from http.server import BaseHTTPRequestHandler, HTTPServer
import json
import io


class RequestHandler(BaseHTTPRequestHandler):
    model = load_model('mnist-saved-model')

    def do_POST(self):
        if self.path == '/predict':
            content_length = int(self.headers['Content-Length'])
            post_data = self.rfile.read(content_length)
            print("[-] Recieved request.. ")

            try:
                # Use PIL to open the image and convert it to the expected format
                image = Image.open(io.BytesIO(post_data)).convert('L')
                image = image.resize((28, 28))
                image = np.array(image) / 255.0
                image = image.reshape(1, 28, 28, 1)
                print("[-] Making prediction from submitted image.. ")
                # Make prediction
                prediction = self.model.predict(image)
                predicted_class = np.argmax(prediction, axis=1)
                print(f'This image most likely is a {predicted_class[0]} with a probability of {np.max(prediction)}.')

                # Send response
                self.send_response(200)
                self.send_header('Content-type', 'application/json')
                self.end_headers()
                resp = f'This image most likely is a ' + str(predicted_class[0])  + ' with a probability of {:.3%}'.format(np.max(prediction))
                self.wfile.write(json.dumps(resp).encode())
            except Exception as e:
                self.send_response(500)
                self.end_headers()
                response = {'error': str(e)}
                self.wfile.write(json.dumps(response).encode())
        else:
            self.send_response(404)
            self.end_headers()

def runServer(server_class=HTTPServer, handler_class=RequestHandler, port=42000):
    server_address = ('', port)
    httpd = server_class(server_address, handler_class)
    print(f'Serving HTTP on port {port}...')
    httpd.serve_forever()

runServer()

Now, we can submit files using standard HTTP tools, such as CURL!

curl -X POST --data-binary @test.png http://localhost:42000/predict

"This image most likely is a 5 with a probability of 17.230%"

(2) Build attack script skeleton

Create a new Python file, client.py, which we'll use to modify our images to trick the classifier.

import numpy as np
import matplotlib.pyplot as plt
import requests
from keras.datasets import mnist
from PIL import Image
import io

# The path to the image you want to send
image_path = filename
server_url = 'http://localhost:42000/predict'

# Open the image in binary mode
with open(image_path, 'rb') as image_file:
    # The POST request with the binary data of the image
    image_binary = image_file.read()

#send the OG image
response = requests.post(server_url, data=image_binary)
print(response.text)
$ ./client.py

"This image most likely is a 2 with a probability of 99.897%"

(3) Test known good examples

Let's extract a few test image from MINST and send them through the API to our model. Note that this code replaces our last codeblock.

import numpy as np
import matplotlib.pyplot as plt
import requests
from keras.datasets import mnist
from PIL import Image
import io

# Load the MNIST dataset
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Combine the train and test sets if you want to select from the entire dataset
all_images = np.concatenate((train_images, test_images), axis=0)
# Generate a random index
random_index = np.random.choice(all_images.shape[0])
# Select the image
random_image = all_images[random_index]
# Display the image
plt.imshow(random_image, cmap='gray')
plt.title(f"Random MNIST digit: {random_index}")
plt.axis('off')  # Hide the axis to focus on the image
plt.show()

# Save the image to the filesystem
filename = f"mnist_digit_{random_index}.png"
imageio.imwrite(filename, random_image)
print(f"Image saved as {filename}")

# Open the image in binary mode
with open(filename, 'rb') as image_file:
    # The POST request with the binary data of the image
    image_binary = image_file.read()

#send the OG image
response = requests.post(server_url, data=image_binary)
print(response.text)

We include a pyplot to show the image and save it to disk as a regular .PNG.

(4) Implement the attack script

If you've made it this far, you've hopefully understood that to this point we have done nothing adversarial. We've built a simple ML model using an introductory dataset and wrapped it in a little HTTP API.

But finally.. we've made it to the fun stuff!

In our introductory article, we discussed random noise. Let's implement a routine that takes a MINST image, adds noise, and feeds it to the model over our API.

def add_random_noise(imageIn, noise_level=0.1):
    # Assuming imageIn is a numpy array of shape (height, width, channels)
    # Add random noise to the image
    perturbation = noise_level * np.random.randn(*imageIn.shape)
    perturbed_image = imageIn + perturbation
    # Clip the image pixel values to be between 0 and 1
    perturbed_image = np.clip(perturbed_image, 0.0, 1.0)
    return perturbed_image

Ok - let's break this down.

The first thing to wrap your head around is that an image is represented as an array. We can't simply generate a random number and add it to the array - the mathematical operation of addition has to be two arrays of equal type (e.g., both 3x3 arrays).

We generate the random number array (called pertubation) using randn from numpy, scaling it by a factor between 0 and 1, and instantiating it with the same shape of the image that is passed into our function. This ensures we have the same amount of dimensions for our next step - adding the noise.

The last step simply clips the values to make sure we've stayed within the bounds of our grayscale image to be between the values of 0 and 1.

That's it!

Let's call our function.

# Apply the noise function - play with the noise_level which we can pass in here
perturbed_image_array = add_random_noise(image_array,.05)
# Convert back to an image from the raw array
perturbed_image = Image.fromarray(perturbed_image_array.astype('uint8'), 'L')
perturbed_image_path='perturbed_image.png'
perturbed_image.save(perturbed_image_path)

Finally, let's display the image to the user and send it over to the API!

plt.subplot(1, 2, 1)
plt.axis('off')
plt.title(f"Original")
plt.imshow(image, cmap='gray')  # Use cmap='gray' for grayscale images
plt.subplot(1, 2, 2)
plt.title(f"Modified")
plt.imshow(perturbed_image, cmap='gray')  # Use cmap='gray' for grayscale images
plt.axis('off')  # Turn off axis numbers and ticks
plt.show()

with open(perturbed_image_path, 'rb') as image_file:
    perturbed_image_binary = image_file.read()

#send the perturbed image
response = requests.post(server_url, data=perturbed_image_binary)
print(response.text)
$ ./client.py

"This image most likely is a 8 with a probability of 99.180%"

"This image most likely is a 5 with a probability of 17.175%"

The first thing we'll notice is the amount of change we've made. Given our super-simple dataset of 28x28 images, it's going to be painfully obvious that we've created relatively drastic changes: even though it still looks like an 8, we can tell it's been modified. When we move on to more complex examples, this same effect will be subtle enough to escape notice.

The important concept is that we've tricked the Neural Network into identifying a 5 from what is obviously an 8 to a human observer.

Downloads

Client: https://github.com/cyberaiguy/attacking-mnist/blob/main/client.py

Server: https://github.com/cyberaiguy/attacking-mnist/blob/main/server.py

Model weights: mailto cyberaiguy at cyberaiguy.com