Fine-Tuning CNNs on Custom Datasets Using Transfer Learning
Convolutional Neural Networks (CNNs) have proven to be highly effective for image-related tasks such as classification, object detection, and segmentation. However, training these models from scratch often requires vast amounts of labeled data and computational resources. Transfer learning allows us to leverage pre-trained CNN models (e.g., ResNet, VGG) and fine-tune them on custom datasets, drastically reducing training time and improving performance, especially when data is limited.
In this guide, we will walk through the process of fine-tuning pre-trained CNNs using TensorFlow, explaining how to adapt models like ResNet and VGG to fit specific custom datasets.
Table of Contents
- What is Transfer Learning?
- Pre-trained Models for Fine-Tuning
- Steps to Fine-Tune CNNs on Custom Datasets
- Fine-Tuning Pre-Trained Layers
- Best Practices for Fine-Tuning CNNs
- Real-World Example: Fine-Tuning ResNet50 on a Custom Dataset
- Evaluating the Fine-Tuned Model
- Common Issues and Solutions
- Conclusion
1. What is Transfer Learning?
Transfer learning is a machine learning technique where a model trained on a large dataset (e.g., ImageNet) is repurposed for a different, smaller dataset. Instead of starting from scratch, transfer learning leverages learned features from the source dataset and adapts them to the target dataset by fine-tuning specific layers.
Why Use Transfer Learning?
- Saves Time: Training large CNNs from scratch can take days or weeks, while fine-tuning can be done in hours or minutes.
- Improves Accuracy: Pre-trained models have learned general features like edges, textures, and shapes that can be useful across different image classification tasks.
- Works with Limited Data: Transfer learning enables good performance even when you have only a small custom dataset by transferring knowledge from larger, pre-trained models.
2. Pre-trained Models for Fine-Tuning
Some commonly used pre-trained CNN architectures for transfer learning are:
2.1 ResNet
ResNet (Residual Network) is known for its deep architecture and skip connections that help mitigate the vanishing gradient problem in deep networks. Pre-trained ResNet models are ideal for fine-tuning on custom image datasets due to their powerful feature extraction.
2.2 VGG
VGG is a simpler CNN model known for its deep layers of convolutional and pooling operations. While computationally heavier than ResNet, VGG models are also widely used for fine-tuning due to their straightforward architecture and high performance.
3. Steps to Fine-Tune CNNs on Custom Datasets
3.1 Step 1: Load a Pre-Trained Model
The first step in fine-tuning is loading the pre-trained model (e.g., ResNet or VGG) with weights pre-trained on a large dataset, such as ImageNet. This can be done easily using TensorFlow’s tf.keras.applications
module.
Example (ResNet50):
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
# Load ResNet50 model without the top (classification) layer
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
# Freeze the base model
base_model.trainable = False
Explanation:
- include_top=False: This ensures that the model’s top layer (used for the original task, like ImageNet classification) is removed, which allows you to customize it for your own dataset.
- input_shape=(224, 224, 3): ResNet50’s default input shape is 224x224 pixels with 3 color channels (RGB). If your custom dataset has different image sizes, you may need to resize your images to this input shape during preprocessing.
3.2 Step 2: Add Custom Layers
Once the pre-trained model is loaded, you will need to add custom classification layers that are specific to your task. For example, if you are performing binary classification, you might add a dense layer with a sigmoid activation function.
# Add custom layers on top of the pre-trained model
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(1, activation='sigmoid') # For binary classification
])
Explanation:
- GlobalAveragePooling2D(): Reduces the spatial dimensions of the output from the pre-trained model.
- Dense layers: Custom fully connected layers added to the model for classification.
- Dropout: Regularization technique to prevent overfitting.
- Dense(1, activation=‘sigmoid’): The output layer has 1 unit because this is a binary classification problem. For multi-class classification, you would use a softmax activation with the number of units equal to the number of classes in the problem.
3.3 Step 3: Compile the Model
Before training the model, compile it by specifying the loss function, optimizer, and metrics to be tracked. Since you are fine-tuning, you can start with a small learning rate.
# Compile the model with a low learning rate for fine-tuning
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
loss='binary_crossentropy',
metrics=['accuracy'])
Explanation:
- learning_rate=1e-4: Fine-tuning typically requires a smaller learning rate to avoid overwriting the learned weights too quickly.
- binary_crossentropy: Used for binary classification. For multi-class classification, use categorical_crossentropy or sparse_categorical_crossentropy depending on the format of your labels.
3.4 Step 4: Train the Custom Layers
At this stage, only the newly added custom layers will be trained, while the pre-trained layers are frozen. This allows the model to learn task-specific features without disturbing the pre-trained weights.
# Train the model on the custom dataset
model.fit(train_dataset, epochs=10, validation_data=val_dataset)
Explanation:
- train_dataset: The dataset containing your custom images and labels.
- val_dataset: The validation dataset to monitor the model’s performance.
4. Fine-Tuning Pre-Trained Layers
4.1 Step 5: Unfreeze Some Layers for Fine-Tuning
After training the new layers, you can unfreeze certain layers of the pre-trained model to fine-tune them for the custom dataset. It is common to unfreeze only the last few layers while keeping the early layers frozen since early layers typically learn general features like edges, textures, and shapes. The number of layers to unfreeze depends on the complexity of your task and the size of your custom dataset. For smaller datasets, unfreezing too many layers may lead to overfitting, while larger datasets allow for more layers to be fine-tuned.
# Unfreeze the last few layers of the pre-trained model
base_model.trainable = True
for layer in base_model.layers[:-10]:
layer.trainable = False # Freeze all layers except the last 10
Explanation:
- layer.trainable = False: Freezes all but the last 10 layers of the model. Fine-tuning only the top layers helps retain general features learned from the original dataset while adapting specific features to the custom dataset. The choice to unfreeze 10 layers is arbitrary and should be tuned based on the complexity of the problem.
4.2 Step 6: Recompile the Model
Since you have unfrozen some layers, it’s important to recompile the model before continuing training.
# Recompile the model with a lower learning rate
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
loss='binary_crossentropy',
metrics=['accuracy'])
Explanation:
- learning_rate=1e-5: Even smaller learning rate for fine-tuning to ensure the pre-trained weights are adjusted carefully.
4.3 Step 7: Fine-Tune the Model
Now, fine-tune the model by training it with both the pre-trained layers and the new custom layers.
# Fine-tune the model with unfrozen layers
model.fit(train_dataset, epochs=5, validation_data=val_dataset)
5. Best Practices for Fine-Tuning CNNs
- Use Smaller Learning Rates: Fine-tuning requires a smaller learning rate to carefully adjust pre-trained weights without overwriting them.
- Freeze Early Layers: Early layers capture general features (e.g., edges, textures, shapes) that are useful across tasks, so freezing them helps retain this useful information.
- Monitor Validation Loss: Watch for overfitting when fine-tuning, especially if the custom dataset is small.
- Data Augmentation: When working with limited data, augmenting the dataset (e.g., flipping, rotating, scaling) can improve generalization.
6. Real-World Example: Fine-Tuning ResNet50 on a Custom Dataset
Let’s say you have a custom dataset of dog and cat images, and you want to fine-tune a pre-trained ResNet50 model for binary classification. Here’s how you would proceed:
Custom Dataset Details:
- Dataset size: 5,000 images (2,500 dogs, 2,500 cats).
- Class balance: Balanced dataset with equal samples from each class.
- Data Augmentation: Random flipping, rotation, and brightness adjustments
to increase dataset variety.
-
Load ResNet50 pre-trained on ImageNet:
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3)) base_model.trainable = False # Freeze the model
-
Add custom classification layers:
model = tf.keras.Sequential([ base_model, tf.keras.layers.GlobalAveragePooling2D(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dropout(0.5), tf.keras.layers.Dense(1, activation='sigmoid') ])
-
Compile and train the model:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), loss='binary_crossentropy', metrics=['accuracy']) model.fit(train_dataset, epochs=10, validation_data=val_dataset)
-
Unfreeze last layers for fine-tuning:
base_model.trainable = True for layer in base_model.layers[:-10]: layer.trainable = False model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), loss='binary_crossentropy', metrics=['accuracy']) model.fit(train_dataset, epochs=5, validation_data=val_dataset)
7. Evaluating the Fine-Tuned Model
Once fine-tuning is complete, it is important to evaluate the model’s performance on the test set. Common evaluation metrics include:
- Accuracy: Measures the proportion of correct predictions.
- Precision: Measures how many of the predicted positives were actually positive.
- Recall: Measures how many of the actual positives were correctly identified.
- F1 Score: The harmonic mean of precision and recall, useful for imbalanced datasets.
- ROC-AUC: Measures the ability of the model to distinguish between classes at different thresholds.
# Evaluate the model on the test dataset
test_loss, test_acc = model.evaluate(test_dataset)
print(f'Test Accuracy: {test_acc:.2f}')
8. Common Issues and Solutions
Overfitting:
Fine-tuning can lead to overfitting, especially if the custom dataset is small. Solutions include:
- Using dropout to reduce over-reliance on specific neurons.
- Data augmentation to create more diverse training data.
- Early stopping to halt training when validation performance starts to degrade.
Underfitting:
If the model is underfitting, you may need to:
- Unfreeze more layers for fine-tuning.
- Train for more epochs to allow the model to learn more from the dataset.
Vanishing Gradients:
In deep CNNs, gradients can vanish during training. This is mitigated by using skip connections (e.g., ResNet) and smaller learning rates during fine-tuning.
9. Conclusion
Fine-tuning CNNs on custom datasets using pre-trained models like ResNet and VGG is an efficient way to achieve high performance on image-related tasks without needing vast amounts of data. By leveraging the powerful feature extraction capabilities of these models and customizing them for specific tasks, you can drastically reduce training time while still achieving great results.
Understanding the balance between freezing and unfreezing layers, using appropriate learning rates, and applying techniques like data augmentation is key to successful fine-tuning in transfer learning.