Transfer Learning - How to Use Pre-trained Models


Transfer learning is a powerful technique in machine learning where a model pre-trained on a large dataset is reused and adapted for a different but related task. Instead of training a model from scratch, transfer learning allows you to leverage the learned knowledge of pre-trained models such as ResNet for image classification or BERT for natural language processing (NLP). This not only reduces training time but often results in higher performance, especially when data is limited.

In this article, we’ll explore how to use transfer learning, fine-tune pre-trained models, and apply them to new tasks with practical examples in TensorFlow using ResNet for image classification and BERT for text classification.

Here’s the correct table of contents for the article titled “Transfer Learning - How to Use Pre-trained Models”:


Table of Contents

  1. What is Transfer Learning?
  2. Why Use Transfer Learning?
  3. Transfer Learning with ResNet for Image Classification
  4. Fine-Tuning a Pre-trained Model
  5. Transfer Learning with BERT for Text Classification
  6. Best Practices for Transfer Learning
  7. Conclusion

1. What is Transfer Learning?

Transfer learning involves using a model that has already been trained on one task (such as image classification on ImageNet or text classification on large corpora) and fine-tuning it for a new, related task. This process can be broken down into two main approaches:

  1. Feature Extraction: Use the pre-trained model as a feature extractor and only train the final classifier layer on your new dataset.
  2. Fine-tuning: Start with the pre-trained weights and fine-tune the entire model (or part of the model) on your dataset.

2. Why Use Transfer Learning?

Transfer learning is particularly useful when:

  • Limited Data: You have a small dataset that would not be enough to train a deep neural network from scratch.
  • Faster Training: Since much of the model is already trained, the fine-tuning process is faster than training from scratch.
  • High-Performance Models: Pre-trained models are often trained on large, high-quality datasets like ImageNet or Wikipedia text, which improves their performance on related tasks.

3. Transfer Learning with ResNet for Image Classification

3.1 Loading a Pre-trained Model (ResNet)

ResNet is a popular convolutional neural network (CNN) architecture that won the 2015 ImageNet competition. In TensorFlow, we can load a pre-trained ResNet model that has been trained on the ImageNet dataset. The top (classifier) layer is excluded because it’s specific to the ImageNet task, which has 1,000 classes, and may not be suitable for your new task.

import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model

# Load pre-trained ResNet50 model without the top (classifier) layer
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Freeze the base model (optional)
base_model.trainable = False

3.2 Adding Custom Classifier Layers

Since we are using ResNet as a feature extractor, we’ll add custom layers on top for classification.

# Add custom layers
x = Flatten()(base_model.output)
x = Dense(128, activation='relu')(x)
output = Dense(10, activation='softmax')(x)  # Assume 10 classes in the new task

# Define the model
model = Model(inputs=base_model.input, outputs=output)

3.3 Compiling and Training the Model

Now that we’ve added custom layers, we can compile and train the model. Since we froze the base model, only the new layers will be trained because the weights of the frozen layers are not updated during training.

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Assume x_train, y_train are your training data
model.fit(x_train, y_train, epochs=10, batch_size=32, validation_data=(x_val, y_val))

Real-World Example:

Consider a scenario where you’re developing a model to classify medical images (e.g., detecting diseases in X-rays). Instead of training a CNN from scratch, you could fine-tune a ResNet model pre-trained on ImageNet, saving both time and computational resources.


4. Fine-Tuning a Pre-trained Model

In some cases, you might want to fine-tune the pre-trained model on your dataset. This involves unfreezing some layers of the pre-trained model and allowing them to be updated during training. Fine-tuning only the last few layers is a common approach because the early layers of a CNN tend to learn more general features (e.g., edges, textures) that are applicable across tasks, while later layers capture more task-specific information.

4.1 Fine-Tuning Selected Layers

# Unfreeze some layers (optional: for fine-tuning)
for layer in base_model.layers[-10:]:
    layer.trainable = True

# Re-compile and continue training
model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),  # Use a lower learning rate
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5, batch_size=32, validation_data=(x_val, y_val))

Real-World Example:

Fine-tuning is especially useful in tasks like object detection or specialized image recognition, where subtle differences in images require that the model adjust its pre-trained weights.


5. Transfer Learning with BERT for Text Classification

In natural language processing (NLP), BERT (Bidirectional Encoder Representations from Transformers) is a pre-trained transformer model that excels in tasks such as text classification, question answering, and named entity recognition.

5.1 Loading Pre-trained BERT

We can load a pre-trained BERT model using the Hugging Face transformers library, which integrates seamlessly with TensorFlow and provides a wide range of pre-trained models for various NLP tasks.

pip install transformers
from transformers import TFBertForSequenceClassification, BertTokenizer
from tensorflow.keras.optimizers import Adam

# Load pre-trained BERT model for sequence classification
model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize input text
inputs = tokenizer("This is an example sentence", return_tensors="tf", padding=True, truncation=True, max_length=128)

5.2 Training the BERT Model

Now, we can train the BERT model on a text classification task (e.g., sentiment analysis).

# Compile the model
optimizer = Adam(learning_rate=2e-5)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(train_dataset, epochs=3, batch_size=32, validation_data=val_dataset)

Real-World Example:

You can use BERT for tasks like sentiment analysis (classifying text as positive or negative) or spam detection in emails. Fine-tuning BERT on your specific dataset allows you to leverage the model’s understanding of language while tailoring it to your particular task.


6. Best Practices for Transfer Learning

  • Freeze or Fine-Tune: Decide whether to freeze the pre-trained layers or fine-tune them based on your dataset’s size and similarity to the original dataset.
  • Use Pre-trained Models for Feature Extraction: If you’re working with a small dataset, use pre-trained models as feature extractors without retraining them.
  • Adjust Learning Rates: Use a smaller learning rate when fine-tuning pre-trained models to avoid large updates that could disrupt pre-trained weights.
  • Monitor Model Performance: During fine-tuning, regularly monitor the model’s performance on a validation set to avoid overfitting, especially when fine-tuning for too many epochs.

Conclusion

Transfer learning offers a powerful way to leverage pre-trained models like ResNet for image classification and BERT for NLP tasks, reducing training time and improving model performance. Whether you’re using a model as a feature extractor or fine-tuning it for your specific use case, transfer learning can help you achieve better results, especially when working with limited data. However, it can also be useful when working with large datasets, as it provides a good starting point and reduces the risk of overfitting.

By understanding the flexibility and power of transfer learning, you can apply this technique to a variety of tasks, from image recognition to text classification, and significantly boost the performance of your models.

© 2024 Dominic Kneup