Business

Graph Neural Networks Part 3: How GraphSAGE Handles Changing Graph Structure

In the previous parts of this series, we looked at Graph Convolutional Networks (GCNs) and Graph Attention Networks (GATs). Both architectures work fine, but they also have some limitations! A big one is that for large graphs, calculating the node representations with GCNs and GATs will become v-e-r-y slow. Another limitation is that if the graph structure changes, GCNs and GATs will not be able to generalize. So if nodes are added to the graph, a GCN or GAT cannot make predictions for it. Luckily, these issues can be solved!

In this post, I will explain Graphsage and how it solves common problems of GCNs and GATs. We will train GraphSAGE and use it for graph predictions to compare performance with GCNs and GATs.

New to GNNs? You can start with post 1 about GCNs (also containing the initial setup for running the code samples), and post 2 about GATs


Two Key Problems with GCNs and GATs

I shortly touched upon it in the introduction, but let’s dive a bit deeper. What are the problems with the previous GNN models?

Problem 1. They don’t generalize

GCNs and GATs struggle with generalizing to unseen graphs. The graph structure needs to be the same as the training data. This is known as transductive learning, where the model trains and makes predictions on the same fixed graph. It is actually overfitting to specific graph topologies. In reality, graphs will change: Nodes and edges can be added or removed, and this happens often in real world scenarios. We want our GNNs to be capable of learning patterns that generalize to unseen nodes, or to entirely new graphs (this is called inductive learning).

Problem 2. They have scalability issues

Training GCNs and GATs on large-scale graphs is computationally expensive. GCNs require repeated neighbor aggregation, which grows exponentially with graph size, while GATs involve (multihead) attention mechanisms that scale poorly with increasing nodes.
In big production recommendation systems that have large graphs with millions of users and products, GCNs and GATs are impractical and slow.

Let’s take a look at GraphSAGE to fix these issues.

GraphSAGE (SAmple and aggreGatE)

GraphSAGE makes training much faster and scalable. It does this by sampling only a subset of neighbors. For super large graphs it’s computationally impossible to process all neighbors of a node (except if you have limitless time, which we all don’t…), like with traditional GCNs. Another important step of GraphSAGE is combining the features of the sampled neighbors with an aggregation function
We will walk through all the steps of GraphSAGE below.

1. Sampling Neighbors

With tabular data, sampling is easy. It’s something you do in every common machine learning project when creating train, test, and validation sets. With graphs, you cannot select random nodes. This can result in disconnected graphs, nodes without neighbors, etcetera:

Randomly selecting nodes, but some are disconnected. Image by author.

What you can do with graphs, is selecting a random fixed-size subset of neighbors. For example in a social network, you can sample 3 friends for each user (instead of all friends):

Randomly selecting three rows in the table, all neighbors selected in the GCN, three neighbors selected in GraphSAGE. Image by author.

2. Aggregate Information

After the neighbor selection from the previous part, GraphSAGE combines their features into one single representation. There are multiple ways to do this (multiple aggregation functions). The most common types and the ones explained in the paper are mean aggregationLSTM, and pooling

With mean aggregation, the average is computed over all sampled neighbors’ features (very simple and often effective). In a formula:

LSTM aggregation uses an LSTM (type of neural network) to process neighbor features sequentially. It can capture more complex relationships, and is more powerful than mean aggregation. 

The third type, pool aggregation, applies a non-linear function to extract key features (think about max-pooling in a neural network, where you also take the maximum value of some values).

3. Update Node Representation

After sampling and aggregation, the node combines its previous features with the aggregated neighbor features. Nodes will learn from their neighbors but also keep their own identity, just like we saw before with GCNs and GATs. Information can flow across the graph effectively. 

This is the formula for this step:

The aggregation of step 2 is done over all neighbors, and then the feature representation of the node is concatenated. This vector is multiplied by the weight matrix, and passed through non-linearity (for example ReLU). As a final step, normalization can be applied.

4. Repeat for Multiple Layers

The first three steps can be repeated multiple times, when this happens, information can flow from distant neighbors. In the image below you see a node with three neighbors selected in the first layer (direct neighbors), and two neighbors selected in the second layer (neighbors of neighbors). 

Selected node with selected neighbors, three in the first layer, two in the second layer. Interesting to note is that one of the neighbors of the nodes in the first step is the selected node, so that one can also be selected when two neighbors are selected in the second step (just a bit harder to visualize). Image by author.

To summarize, the key strengths of GraphSAGE are its scalability (sampling makes it efficient for massive graphs); flexibility, you can use it for Inductive learning (works well when used for predicting on unseen nodes and graphs); aggregation helps with generalization because it smooths out noisy features; and the multi-layers allow the model to learn from far-away nodes.

Cool! And the best thing, GraphSAGE is implemented in PyG, so we can use it easily in PyTorch.

Predicting with GraphSAGE

In the previous posts, we implemented an MLP, GCN, and GAT on the Cora dataset (CC BY-SA). To refresh your mind a bit, Cora is a dataset with scientific publications where you have to predict the subject of each paper, with seven classes in total. This dataset is relatively small, so it might be not the best set for testing GraphSAGE. We will do this anyway, just to be able to compare. Let’s see how well GraphSAGE performs.

Interesting parts of the code I like to highlight related to GraphSAGE:

  • The NeighborLoader that performs selecting the neighbors for each layer:
from torch_geometric.loader import NeighborLoader

# 10 neighbors sampled in the first layer, 10 in the second layer
num_neighbors = [10, 10]

# sample data from the train set
train_loader = NeighborLoader(
    data,
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    input_nodes=data.train_mask,
)
  • The aggregation type is implemented in the SAGEConv layer. The default is mean, you can change this to max or lstm:
from torch_geometric.nn import SAGEConv

SAGEConv(in_c, out_c, aggr='mean')
  • Another important difference is that GraphSAGE is trained in mini batches, and GCN and GAT on the full dataset. This touches the essence of GraphSAGE, because the neighbor sampling of GraphSAGE makes it possible to train in mini batches, we don’t need the full graph anymore. GCNs and GATs do need the complete graph for correct feature propagation and calculation of attention scores, so that’s why we train GCNs and GATs on the full graph.
  • The rest of the code is similar as before, except that we have one class where all different models are instantiated based on the model_type (GCN, GAT, or SAGE). This makes it easy to compare or make small changes.

This is the complete script, we train 100 epochs and repeat the experiment 10 times to calculate average accuracy and standard deviation for each model:

import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, GCNConv, GATConv
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader

# dataset_name can be 'Cora', 'CiteSeer', 'PubMed'
dataset_name = 'Cora'
hidden_dim = 64
num_layers = 2
num_neighbors = [10, 10]
batch_size = 128
num_epochs = 100
model_types = ['GCN', 'GAT', 'SAGE']

dataset = Planetoid(root='data', name=dataset_name)
data = dataset[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)

class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, model_type='SAGE', gat_heads=8):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.model_type = model_type
        self.gat_heads = gat_heads

        def get_conv(in_c, out_c, is_final=False):
            if model_type == 'GCN':
                return GCNConv(in_c, out_c)
            elif model_type == 'GAT':
                heads = 1 if is_final else gat_heads
                concat = False if is_final else True
                return GATConv(in_c, out_c, heads=heads, concat=concat)
            else:
                return SAGEConv(in_c, out_c, aggr='mean')

        if model_type == 'GAT':
            self.convs.append(get_conv(in_channels, hidden_channels))
            in_dim = hidden_channels * gat_heads
            for _ in range(num_layers - 2):
                self.convs.append(get_conv(in_dim, hidden_channels))
                in_dim = hidden_channels * gat_heads
            self.convs.append(get_conv(in_dim, out_channels, is_final=True))
        else:
            self.convs.append(get_conv(in_channels, hidden_channels))
            for _ in range(num_layers - 2):
                self.convs.append(get_conv(hidden_channels, hidden_channels))
            self.convs.append(get_conv(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index))
        x = self.convs[-1](x, edge_index)
        return x

@torch.no_grad()
def test(model):
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs

results = {}

for model_type in model_types:
    print(f'Training {model_type}')
    results[model_type] = []

    for i in range(10):
        model = GNN(dataset.num_features, hidden_dim, dataset.num_classes, num_layers, model_type, gat_heads=8).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

        if model_type == 'SAGE':
            train_loader = NeighborLoader(
                data,
                num_neighbors=num_neighbors,
                batch_size=batch_size,
                input_nodes=data.train_mask,
            )

            def train():
                model.train()
                total_loss = 0
                for batch in train_loader:
                    batch = batch.to(device)
                    optimizer.zero_grad()
                    out = model(batch.x, batch.edge_index)
                    loss = F.cross_entropy(out, batch.y[:out.size(0)])
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                return total_loss / len(train_loader)

        else:
            def train():
                model.train()
                optimizer.zero_grad()
                out = model(data.x, data.edge_index)
                loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
                loss.backward()
                optimizer.step()
                return loss.item()

        best_val_acc = 0
        best_test_acc = 0
        for epoch in range(1, num_epochs + 1):
            loss = train()
            train_acc, val_acc, test_acc = test(model)
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_test_acc = test_acc
            if epoch % 10 == 0:
                print(f'Epoch {epoch:02d} | Loss: {loss:.4f} | Train: {train_acc:.4f} | Val: {val_acc:.4f} | Test: {test_acc:.4f}')

        results[model_type].append([best_val_acc, best_test_acc])

for model_name, model_results in results.items():
    model_results = torch.tensor(model_results)
    print(f'{model_name} Val Accuracy: {model_results[:, 0].mean():.3f} ± {model_results[:, 0].std():.3f}')
    print(f'{model_name} Test Accuracy: {model_results[:, 1].mean():.3f} ± {model_results[:, 1].std():.3f}')

And here are the results:

GCN Val Accuracy: 0.791 ± 0.007
GCN Test Accuracy: 0.806 ± 0.006
GAT Val Accuracy: 0.790 ± 0.007
GAT Test Accuracy: 0.800 ± 0.004
SAGE Val Accuracy: 0.899 ± 0.005
SAGE Test Accuracy: 0.907 ± 0.004

Impressive improvement! Even on this small dataset, GraphSAGE outperforms GAT and GCN easily! I repeated this test for CiteSeer and PubMed datasets, and always GraphSAGE came out best. 

What I like to note here is that GCN is still very useful, it’s one of the most effective baselines (if the graph structure allows it). Also, I didn’t do much hyperparameter tuning, but just went with some standard values (like 8 heads for the GAT multi-head attention). In larger, more complex and noisier graphs, the advantages of GraphSAGE become more clear than in this example. We didn’t do any performance testing, because for these small graphs GraphSAGE isn’t faster than GCN.


Conclusion

GraphSAGE brings us very nice improvements and benefits compared to GATs and GCNs. Inductive learning is possible, GraphSAGE can handle changing graph structures quite well. And we didn’t test it in this post, but neighbor sampling makes it possible to create feature representations for larger graphs with good performance. 

Related

Optimizing Connections: Mathematical Optimization within Graphs

Graph Neural Networks Part 1. Graph Convolutional Networks Explained

Graph Neural Networks Part 2. Graph Attention Networks vs. GCNs

The post Graph Neural Networks Part 3: How GraphSAGE Handles Changing Graph Structure appeared first on Towards Data Science.

Picture of John Doe
John Doe

Sociosqu conubia dis malesuada volutpat feugiat urna tortor vehicula adipiscing cubilia. Pede montes cras porttitor habitasse mollis nostra malesuada volutpat letius.

Related Article

Leave a Reply

Your email address will not be published. Required fields are marked *

We would love to hear from you!

Please record your message.

Record, Listen, Send

Allow access to your microphone

Click "Allow" in the permission dialog. It usually appears under the address bar in the upper left side of the window. We respect your privacy.

Microphone access error

It seems your microphone is disabled in the browser settings. Please go to your browser settings and enable access to your microphone.

Speak now

00:00

Canvas not available.

Reset recording

Are you sure you want to start a new recording? Your current recording will be deleted.

Oops, something went wrong

Error occurred during uploading your audio. Please click the Retry button to try again.

Send your recording

Thank you

Meet Eve: Your AI Training Assistant

Welcome to Enlightening Methodology! We are excited to introduce Eve, our innovative AI-powered assistant designed specifically for our organization. Eve represents a glimpse into the future of artificial intelligence, continuously learning and growing to enhance the user experience across both healthcare and business sectors.

In Healthcare

In the healthcare category, Eve serves as a valuable resource for our clients. She is capable of answering questions about our business and providing "Day in the Life" training scenario examples that illustrate real-world applications of the training methodologies we employ. Eve offers insights into our unique compliance tool, detailing its capabilities and how it enhances operational efficiency while ensuring adherence to all regulatory statues and full HIPAA compliance. Furthermore, Eve can provide clients with compelling reasons why Enlightening Methodology should be their company of choice for Electronic Health Record (EHR) implementations and AI support. While Eve is purposefully designed for our in-house needs and is just a small example of what AI can offer, her continuous growth highlights the vast potential of AI in transforming healthcare practices.

In Business

In the business section, Eve showcases our extensive offerings, including our cutting-edge compliance tool. She provides examples of its functionality, helping organizations understand how it can streamline compliance processes and improve overall efficiency. Eve also explores our cybersecurity solutions powered by AI, demonstrating how these technologies can protect organizations from potential threats while ensuring data integrity and security. While Eve is tailored for internal purposes, she represents only a fraction of the incredible capabilities that AI can provide. With Eve, you gain access to an intelligent assistant that enhances training, compliance, and operational capabilities, making the journey towards AI implementation more accessible. At Enlightening Methodology, we are committed to innovation and continuous improvement. Join us on this exciting journey as we leverage Eve's abilities to drive progress in both healthcare and business, paving the way for a smarter and more efficient future. With Eve by your side, you're not just engaging with AI; you're witnessing the growth potential of technology that is reshaping training, compliance and our world! Welcome to Enlightening Methodology, where innovation meets opportunity!