Training a Convolutional Neural Network (CNN) on the MNIST Dataset


In our previous article, we introduced the basics of Convolutional Neural Networks (CNNs) and their applications in image recognition. In this article, we’ll take a hands-on approach and train a CNN on the MNIST dataset using TensorFlow.


Table of Contents

  1. What is the MNIST Dataset?
  2. Preparing the Data
  3. Defining the CNN Model
  4. Compiling the Model
  5. Training the Model
  6. Evaluating the Model
  7. Complete Code Example

What is the MNIST Dataset?

The MNIST dataset is a collection of 70,000 images of handwritten digits (0-9) with a resolution of 28x28 pixels. The dataset is divided into 60,000 training images and 10,000 testing images. It is widely used as a benchmark for image classification tasks.


Preparing the Data

Before we can train our CNN, we need to prepare the data. We’ll use the TensorFlow library to load the MNIST dataset and preprocess the images.

import tensorflow as tf
from tensorflow.keras.datasets import mnist

# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize the pixel values to be between 0 and 1
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# Reshape the images to have a channel dimension
x_train = x_train.reshape((-1, 28, 28, 1))
x_test = x_test.reshape((-1, 28, 28, 1))

In this code:

  • We load the MNIST dataset, which consists of images and their corresponding labels.
  • We normalize the pixel values to be between 0 and 1, which helps the model converge faster.
  • The images are reshaped to include a channel dimension (1 channel for grayscale).

Defining the CNN Model

Next, we’ll define our CNN model using the TensorFlow Keras API. Our model will consist of two convolutional layers followed by two fully connected layers.

# Define the CNN model
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

Explanation of the Layers:

  1. Conv2D(32, (3, 3)): This layer applies 32 filters of size 3x3 over the input image. The activation function is ReLU, which helps introduce non-linearity.
  2. MaxPooling2D(2, 2): This pooling layer reduces the spatial dimensions by taking the maximum value in 2x2 patches, reducing computational complexity.
  3. Conv2D(64, (3, 3)): The second convolutional layer applies 64 filters of size 3x3.
  4. Flatten(): This layer flattens the 2D feature maps into a 1D vector for the fully connected layers.
  5. Dense(64): A fully connected layer with 64 neurons, followed by a ReLU activation.
  6. Dense(10): The output layer uses the softmax activation function to output probabilities for each of the 10 classes (digits 0-9).

Compiling the Model

Before we can train the model, we need to compile it with a loss function, optimizer, and evaluation metric.

# Compile the model
model.compile(loss='sparse_categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])
  • Loss Function: We use sparse_categorical_crossentropy because we are dealing with multi-class classification (digits 0-9).
  • Optimizer: Adam is used as it is a widely-used optimizer that adapts the learning rate for better performance.
  • Metrics: We use accuracy to evaluate the model’s performance during training.

Training the Model

Now we can train the model using the training data.

# Train the model
model.fit(x_train, y_train, epochs=10, batch_size=128, validation_data=(x_test, y_test))
  • Epochs: The model is trained for 10 epochs, meaning the model sees the entire dataset 10 times.
  • Batch Size: We process the data in batches of 128 samples to make the training more efficient.
  • Validation Data: We pass the test data to monitor validation accuracy during training.

Evaluating the Model

After training the model, we can evaluate its performance on the testing data.

# Evaluate the model
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'Test accuracy: {test_acc:.2f}')

After training, the model typically achieves an accuracy of around 98% on the MNIST test dataset, which is quite good for such a simple architecture.


Complete Code Example

Here’s the complete code example for training a CNN on the MNIST dataset:

import tensorflow as tf
from tensorflow.keras.datasets import mnist

# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize the pixel values to be between 0 and 1
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# Reshape the images to have a channel dimension
x_train = x_train.reshape((-1, 28, 28, 1))
x_test = x_test.reshape((-1, 28, 28, 1))

# Define the CNN model
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

# Compile the model
model.compile(loss='sparse_categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

# Train the model
model.fit(x_train, y_train, epochs=10, batch_size=128, validation_data=(x_test, y_test))

# Evaluate the model
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'Test accuracy: {test_acc:.2f}')

This complete code can be executed in a Python environment with TensorFlow installed. The model is simple yet effective for the MNIST dataset, making it a great introduction to CNNs.

© 2024 Dominic Kneup