Real-World Implementation of Graph Neural Networks (GNNs)


In this article, we will implement a Graph Neural Network (GNN) model using PyTorch Geometric to solve a node classification task on a citation network dataset, such as Cora. The goal is to predict the class of a paper (node) based on its connections (edges) and content (features).

GNNs are ideal for this type of task because they excel at capturing both local and global graph structures, making them highly effective for problems like node classification, link prediction, and graph classification.


Table of Contents

  1. Dataset Overview
  2. Setting Up the Environment
  3. Building the GNN Model
  4. Training the GNN Model
  5. Evaluating the Model
  6. Deploying the GNN Model
  7. Conclusion

1. Dataset Overview

We’ll use the Cora dataset, a well-known citation network where nodes represent scientific papers and edges represent citation links between them. The task is to classify each paper into one of several predefined categories based on its content and citation patterns.

Graph Data: In a citation network, papers are represented as nodes, and the citations between papers are edges. Each node is associated with features (such as word frequencies in the paper abstract) and a label that indicates its category. In the context of GNNs, this data structure allows us to capture both the content of the paper (via node features) and its relationships with other papers (via graph structure).

Dataset Details:

  • Nodes: Papers
  • Edges: Citations between papers
  • Features: Word frequencies from paper content
  • Labels: Paper categories

The data can be represented as an adjacency matrix for graph structure, a node feature matrix for content, and labels for node classification. Luckily, PyTorch Geometric (PyG) provides utilities to load and process datasets like Cora, making preprocessing much easier.

from torch_geometric.datasets import Planetoid

# Load the Cora dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')

# Access graph data
data = dataset[0]
print(data)

The output will contain:

  • data.x: Node features matrix
  • data.edge_index: Graph connectivity in COO format (adjacency list)
  • data.y: Node labels

For custom datasets, you could manually create the adjacency matrix and node feature matrix. Here’s how:

# If working with raw data
import torch
from torch_geometric.data import Data

# Create adjacency list, node features, and labels manually
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
y = torch.tensor([0, 1, 0], dtype=torch.long)

# Create a data object
data = Data(x=x, edge_index=edge_index, y=y)

2. Setting Up the Environment

Before implementing the GNN, we need to install the necessary libraries and set up our development environment.

Required Libraries:

  • PyTorch: The core deep learning framework used for building neural networks.
  • PyTorch Geometric: A library that simplifies the implementation of GNNs by providing tools to process and manipulate graph data structures like adjacency matrices, edge lists, and node features.
  • NetworkX: A powerful library for the creation, manipulation, and study of complex networks (graphs). It can be used for graph visualization or additional graph algorithms.

Installation Commands

pip install torch
pip install torch-geometric
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv
pip install networkx

GPU Support: PyTorch Geometric can take advantage of CUDA-enabled GPUs to accelerate GNN training. To ensure your GNN model leverages the GPU, check for CUDA availability and move your model and data to the GPU:

# Check if CUDA is available and move model to GPU if possible
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = data.to(device)

With this, you can train the model on your GPU, significantly improving the performance for large datasets.


3. Building the GNN Model

Now, we’ll implement a Graph Convolutional Network (GCN) using PyTorch Geometric. The GCN model performs graph convolutions, which aggregate information from a node’s neighbors to compute updated node embeddings.

What is a Graph Convolution?

A graph convolution is an operation that updates each node’s feature vector by combining its own features with those of its neighboring nodes. This allows each node to learn a representation that depends not only on its own data but also on the graph structure around it.

Each layer of a GCN applies this convolution operation, progressively expanding the “receptive field” of each node. The node’s receptive field refers to the set of nodes whose features influence its final representation. With multiple layers, GCNs allow each node to aggregate information from more distant neighbors.

GCN Model Code

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

# Define the GCN model
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        # Define two GCN layers
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        # First GCN layer + ReLU activation
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        
        # Dropout for regularization
        x = F.dropout(x, training=self.training)
        
        # Second GCN layer (output layer)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

# Instantiate the model, define loss function and optimizer
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

Key Components:

  • GCNConv: This is the core graph convolution operation. It aggregates features from neighboring nodes and updates the target node’s feature.
  • Forward Pass: In the forward method, node features are passed through two GCN layers. The first layer applies the graph convolution, followed by ReLU activation and dropout for regularization. The second layer outputs the logits for each node’s class.
  • Layer Dimensions: The first GCN layer reduces the input dimensionality (from the number of node features to 16 hidden units). The second GCN layer reduces the hidden representation to match the number of output classes.

Layer Explanation:

  • First GCN Layer: The input to this layer is the node feature matrix (x) and the edge list (edge_index). The output is a new feature matrix with 16 hidden units for each node.
  • Second GCN Layer: This layer reduces the 16-dimensional hidden feature vector for each node down to the number of classes, which allows for classification.

4. Training the GNN Model

We’ll split the dataset into training, validation, and test sets, then define a training loop that runs over several epochs, monitoring performance on the validation set. Additionally, to ensure we can handle overfitting and obtain a well-generalized model, we will use techniques such as early stopping and hyperparameter tuning.

Data Splitting

The Cora dataset already includes a predefined train/validation/test split:

data.train_mask  # Boolean mask for training nodes
data.val_mask    # Mask for validation nodes
data.test_mask   # Mask for test nodes

It’s important to monitor the model’s performance on the validation set during training to avoid overfitting.

Training Loop

We’ll define the training loop that includes both forward and backward passes to update the model’s parameters. The training loop also includes validation performance monitoring to fine-tune hyperparameters like learning rate and weight decay.

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def test():
    model.eval()
    logits, accs = model(data), []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

# Training the GNN model over multiple epochs
for epoch in range(200):
    loss = train()
    train_acc, val_acc, test_acc = test()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, '
          f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

Monitoring Overfitting: Early stopping based on the validation accuracy can prevent overfitting. If the validation accuracy stops improving over a set number of epochs, the training process should stop.

early_stopping_counter = 0
best_val_acc = 0.0

for epoch in range(200):
    loss = train()
    train_acc, val_acc, test_acc = test()

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
    
    if early_stopping_counter > 10:  # Stop if no improvement after 10 epochs
        print(f'Early stopping at epoch {epoch}')
        break

This simple early stopping mechanism helps avoid unnecessary overfitting and reduces computation.


5. Evaluating the Model

After training the model, we’ll evaluate it using the test set and compute performance metrics such as accuracy, precision, recall, and F1-score. Additionally, we will visualize the learned node embeddings to better understand the model’s performance.

Accuracy and Precision

Accuracy, precision, recall, and F1-score are common metrics for classification tasks. We’ve already computed accuracy in the training loop. To go further, we can compute additional metrics using scikit-learn.

from sklearn.metrics import classification_report

# Get predictions on test set
model.eval()
out = model(data)
pred = out[data.test_mask].max(1)[1]
print(classification_report(data.y[data.test_mask].cpu(), pred.cpu()))

Visualizing Node Embeddings

Visualizing high-dimensional embeddings can help understand the model’s performance and the learned structure of the graph. We’ll use t-SNE to project the learned node embeddings into a 2D space and visualize the clustering of nodes.

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

# Extract embeddings from the last hidden layer
model.eval()
embeddings = model.conv1(data.x, data.edge_index).detach().cpu()

# Apply t-SNE
tsne = TSNE(n_components=2)
embeddings_2d = tsne.fit_transform(embeddings)

# Visualize node embeddings in 2D
plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=data.y.cpu(), cmap='jet')
plt.colorbar()
plt.show()

This visualization helps assess whether the model has learned meaningful representations by observing the clusters formed by different node classes.


6. Deploying the GNN Model

Once the model is trained and evaluated, we can save the model and use it to make predictions on new graph data in the future.

Saving the Model

torch.save(model.state_dict(), 'gcn_model.pth')

By saving the model’s state dictionary, we can later load it and use it for inference without needing to retrain the model.

Loading the Model and Inference

To make predictions on new data, we can load the saved model and perform inference on a new graph or updated dataset.

# Load the saved model
model = GCN()
model.load_state_dict(torch.load('gcn_model.pth'))
model.eval()

# Inference on new data
new_data = ...  # Load or preprocess new graph data
output = model(new_data)
predictions = output.max(1)[1]

This simple code snippet demonstrates how to use a trained model to infer new data points, which is useful for real-world applications like predicting new paper categories in a citation network or detecting fraud in a financial transaction network.


7. Conclusion

In this article, we implemented a GNN model using PyTorch Geometric to perform node classification on the Cora citation network dataset. We covered every step, from loading the dataset, setting up the model, training it, and evaluating its performance, to deploying the model for future inference.

To further improve this implementation, you can experiment with more advanced GNN architectures such as Graph Attention Networks (GATs) or try applying the model to different types of graph data (e.g., social networks or molecular graphs). You could also explore hyperparameter optimization techniques and apply them to fine-tune the model’s performance.

© 2024 Dominic Kneup