This post is still under construction; I am adding sutff as I get the time to.
My foray into graph neural networks
what does a graphical data look like?
show code
# %%capture
# import os
# import torch
# os.environ['TORCH'] = torch.__version__
# os.environ['PYTHONWARNINGS'] = "ignore"
# !pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install git+https://github.com/pyg-team/pytorch_geometric.git
# !pip install torch_geometric
# !pip install tabulate
# !pip install torchsummary
Collecting torchsummary
Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
show code
import torch
import torch_geometric as tg
show code
edge_index = torch.tensor([[0 , 1 , 1 , 2 ],
[1 , 0 , 2 , 1 ]], dtype= torch.long )
x = torch.tensor([[- 1 ], [0 ], [1 ]], dtype= torch.float )
data = tg.data.Data(x= x, edge_index= edge_index)
data
Data(x=[3, 1], edge_index=[2, 4])
The nodes define what is connected. This is a graph with 3 nodes -1, 0 and 1
The edges define the connections too: 0 -> 1
, 1 -> 0
, 1 -> 2
, and 2 -> 1
Graphs shold have a few things
x: the data, node feature matrix with shape [num_nodes, num_node_features]
edge_index: Graph connectivity in coordinate format with shape [2, num_edges]
and type torch.long
edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
(note that edges can have features too)
y: Target to train against (may have arbitrary shape), e.g., node-level targets of shape [num_nodes, *]
or graph-level targets of shape [1, *]
pos: Node position matrix with shape [num_nodes, num_dimensions]
given a graph data type in torch, you can do a few things:
validate the graph
show code
data.validate(raise_on_error= True )
Count how many nodes and edges
show code
data.num_edge_features # no features for the edges
show code
data.num_node_features # just one feature for each node
is the graph directed?
Let’s look at a real-life dataset
show code
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root= '/tmp/ENZYMES' , name= 'ENZYMES' )
dataset
There are 600 graphs in this dataset. Each graph is a protein structure; 600 protein structures. We can look at one of them…
show code
Data(edge_index=[2, 90], x=[24, 3], y=[1])
show code
Data(edge_index=[2, 168], x=[37, 3], y=[1])
show code
a_protein_graph = dataset[1 ]
a_protein_graph
Data(edge_index=[2, 102], x=[23, 3], y=[1])
This protein graph has 23 nodes and 3 node-level features. There are no edge attributes. y
refers to the class of protein, and there are 102 connections.
show code
b_protein_graph = dataset[0 ]
b_protein_graph
Data(edge_index=[2, 168], x=[37, 3], y=[1])
Another one.
We can split all the 600 graphs into 2, so that we train on some, and test on another.
show code
train_graphs = dataset[:500 ]
test_graphs = dataset[500 :]
train_graphs, test_graphs
(ENZYMES(500), ENZYMES(100))
Training a graph neural net
Let’s use another dataset, Cora , a less complex dataset:
show code
from torch_geometric.datasets import Planetoid
import torchsummary
dataset = Planetoid(root= '/tmp/Cora' , name= 'Cora' )
dataset
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!
show code
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
show code
dataset.num_classes, data.num_node_features
So, this dataset has 2708 nodes, and 1433
node-level features. It has 10556
connectivities.
show code
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
show code
class SimpleGNN(torch.nn.Module):
def __init__ (self , nnode_features, num_classes, num_layers= 3 ):
super ().__init__ ()
self .conv1 = tg.nn.GCNConv(in_channels= nnode_features, out_channels= 32 )
#self.mlp = tg.nn.models.MLP(in_channels=16, out_channels=16, hidden_channels=4, num_layers=3, dropout=0.4, act='relu')
# self.layers = torch.nn.ModuleList()
# self.layers.append(torch.nn.Linear(in_features=8, out_features=4))
# self.layers.append(torch.nn.ReLU())
# for _ in range(num_layers - 2):
# self.layers.append(torch.nn.Linear(in_features=4, out_features=4))
# self.layers.append(torch.nn.ReLU())
# self.layers.append(torch.nn.Linear(in_features=4, out_features=4))
self .conv2 = tg.nn.GCNConv(in_channels= 32 , out_channels= num_classes)
def forward(self , data):
x, edge_index = data.x, data.edge_index
x = self .conv1(x, edge_index)
x = F.relu(x)
# x = self.mlp(x)
# x = F.relu(x)
# for layer in self.layers:
# x = layer(x)
# x = F.relu(x)
x = self .conv2(x, edge_index)
return (F.log_softmax(x, dim= 1 ))
show code
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )
model = SimpleGNN(nnode_features= dataset.num_node_features, num_classes= dataset.num_classes, num_layers= 1 ).to(device)
print (tg.nn.summary(model, dataset[0 ].to(device)))
+------------------+--------------------------+----------------+----------+
| Layer | Input Shape | Output Shape | #Param |
|------------------+--------------------------+----------------+----------|
| SimpleGNN | [2708, 2708] | [2708, 7] | 46,119 |
| ├─(conv1)GCNConv | [2708, 1433], [2, 10556] | [2708, 32] | 45,888 |
| ├─(conv2)GCNConv | [2708, 32], [2, 10556] | [2708, 7] | 231 |
+------------------+--------------------------+----------------+----------+
show code
lls = ([], [], [], [])
train_loss_tally, train_acc_tally, test_loss_tally, test_acc_tally = lls
data = dataset[0 ].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr= 0.0001 , weight_decay= 5e-3 )
model.train()
for epoch in range (5000 ):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
train_loss_tally.append(loss.item())
pred = out.argmax(dim= 1 )
correct = (pred[data.train_mask] == data.y[data.train_mask]).sum ()
acc = int (correct) / int (data.train_mask.sum ())
train_acc_tally.append(acc)
optimizer.step()
model.eval ()
loss = F.nll_loss(out[data.test_mask], data.y[data.test_mask])
test_loss_tally.append(loss.item())
pred = out.argmax(dim= 1 )
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum ()
acc = int (correct) / int (data.test_mask.sum ())
test_acc_tally.append(acc)
show code
import matplotlib.pyplot as plt
plt.plot(train_acc_tally, label = 'train' )
plt.plot(test_acc_tally, label = 'test' )
plt.legend(title= "acc" ,loc= 4 , fontsize= 'small' , fancybox= True )
show code
plt.plot(train_loss_tally, label = 'train' )
plt.plot(test_loss_tally, label = 'test' )
plt.legend(title= "loss" ,loc= 4 , fontsize= 'small' , fancybox= True )
show code
model.eval ()
pred = model(data).argmax(dim= 1 )
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum ()
acc = int (correct) / int (data.test_mask.sum ())
print (f'Accuracy: { acc:.4f} ' )
Graph attention network
show code
class SimpleGAT(torch.nn.Module):
def __init__ (self , nnode_features, num_classes, num_layers= 3 ):
super ().__init__ ()
self .conv1 = tg.nn.conv.GATConv(in_channels= nnode_features, out_channels= 32 , heads= 12 , concat= False , dropout= 0.2 , add_self_loops= True )
self .conv2 = tg.nn.conv.GATConv(in_channels= 32 , out_channels= num_classes, heads= 12 , concat= False , add_self_loops= True )
def forward(self , data):
x, edge_index = data.x, data.edge_index
x = self .conv1(x, edge_index)
x = F.relu(x)
x = self .conv2(x, edge_index)
return (F.log_softmax(x, dim= 1 ))
show code
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )
model = SimpleGAT(nnode_features= dataset.num_node_features, num_classes= dataset.num_classes, num_layers= 1 ).to(device)
print (tg.nn.summary(model, dataset[0 ].to(device)))
+------------------+--------------------------+----------------+----------+
| Layer | Input Shape | Output Shape | #Param |
|------------------+--------------------------+----------------+----------|
| SimpleGAT | [2708, 2708] | [2708, 7] | 553,935 |
| ├─(conv1)GATConv | [2708, 1433], [2, 10556] | [2708, 32] | 551,072 |
| ├─(conv2)GATConv | [2708, 32], [2, 10556] | [2708, 7] | 2,863 |
+------------------+--------------------------+----------------+----------+
show code
lls = ([], [], [], [])
train_loss_tally, train_acc_tally, test_loss_tally, test_acc_tally = lls
data = dataset[0 ].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr= 0.00001 , weight_decay= 5e-3 )
model.train()
for epoch in range (5000 ):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
train_loss_tally.append(loss.item())
pred = out.argmax(dim= 1 )
correct = (pred[data.train_mask] == data.y[data.train_mask]).sum ()
acc = int (correct) / int (data.train_mask.sum ())
train_acc_tally.append(acc)
optimizer.step()
model.eval ()
loss = F.nll_loss(out[data.test_mask], data.y[data.test_mask])
test_loss_tally.append(loss.item())
pred = out.argmax(dim= 1 )
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum ()
acc = int (correct) / int (data.test_mask.sum ())
test_acc_tally.append(acc)
show code
import matplotlib.pyplot as plt
plt.plot(train_acc_tally, label = 'train' )
plt.plot(test_acc_tally, label = 'test' )
plt.legend(title= "acc" ,loc= 4 , fontsize= 'small' , fancybox= True )
show code
plt.plot(train_loss_tally, label = 'train' )
plt.plot(test_loss_tally, label = 'test' )
plt.legend(title= "loss" ,loc= 4 , fontsize= 'small' , fancybox= True )
show code
model.eval ()
pred = model(data).argmax(dim= 1 )
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum ()
acc = int (correct) / int (data.test_mask.sum ())
print (f'Accuracy: { acc:.4f} ' )