日常dgl库搭建GNN进行节点分类与边分类任务示例
Posted 囚生CY
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了日常dgl库搭建GNN进行节点分类与边分类任务示例相关的知识,希望对你有一定的参考价值。
序言
之前的笔记【学习笔记】图神经网络库 DGL 入门教程(backend pytorch) 写得比较详尽,但是教程中的代码写得比较零散,这里抽空把两个最常见的任务,节点分类和边分类的代码整合了一下,加了一些注释便于理解,已备后查。
目录
1 节点分类代码示例
节点分类利用了dgl
的内置数据集CiteseerGraphDataset
,下载速度很快,默认会下载到C:\\Users\\用户名\\.dgl
目录下,该数据集是一个图分类数据集,里面包含了许多张图,这里只取第一张图,它的节点共有6个类别,做节点分类任务示例:
值得注意的是CiteseerGraphDataset
数据集给每个节点添加了train_mask
, val_mask
, test_mask
三个特征,这些mask的其实就是通过取值为零一值的mask来把数据集划分为训练集,验证集,测试集三部分,在模型训练部分的代码中可以看到,计算损失函数值时只使用了train_mask没有掩盖到的数据,训练中计算验证集的精确度也使用了valid_mask没有掩盖到的数据,最后在测试集上进行最终评估时使用了task_mask,这种手段在图数据集难以划分时是非常实用的。
# -*- coding: UTF-8 -*-
import dgl
import torch
import numpy as np
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
# Load data.
dataset = dgl.data.CiteseerGraphDataset()
graph = dataset[0] # num_nodes: 3327 | num_edges: 9228
# Contruct a two-layer GNN model.
class SAGE(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats):
super().__init__()
self.conv1 = dglnn.SAGEConv(in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
self.conv2 = dglnn.SAGEConv(in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')
def forward(self, graph, inputs): # inputs are features of nodes
h = self.conv1(graph, inputs)
h = F.relu(h)
h = self.conv2(graph, h)
return h
node_features = graph.ndata['feat'] # Node feature: shape(3327, 3703)
node_labels = graph.ndata['label'] # Node labels: shape(3327, )
train_mask = graph.ndata['train_mask'] # Train mask: shape(3327, ), used to drop some nodes
valid_mask = graph.ndata['val_mask'] # Valid mask: shape(3327, ), used to drop some nodes
test_mask = graph.ndata['test_mask'] # Test mask: shape(3327, ), used to drop some nodes
n_features = node_features.shape[1] # Number of features: 3703
n_labels = int(node_labels.max().item() + 1) # Number of different classes: 6
# Define model metric.
def evaluate(model, graph, features, labels, mask):
model.eval() # Enter the evaluation mode.
with torch.no_grad(): # When we do evaluation, gradient is not needed to be considered.
logits = model(graph, features) # Get the output of the model.
logits = logits[mask] # Predicted possibility.
labels = labels[mask] # True labels.
_, indices = torch.max(logits, dim=1) # Get the index of max possibility.
correct = torch.sum(indices == labels) # Get the number of correct prediction.
return correct.item() * 1.0 / len(labels) # Calculate accuracy.
# Train model.
model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters())
for epoch in range(100):
model.train()
logits = model(graph, node_features)
loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])
acc = evaluate(model, graph, node_features, node_labels, valid_mask)
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item(), acc)
print('Accuracy on test: '.format(evaluate(model, graph, node_features, node_labels, test_mask)))
# Save model.
torch.save(model, 'node_sage.m')
运行结果示例:左侧为损失函数,右侧为模型预测精度。
2 边分类代码示例
确切的说这里是边回归,使用的数据集是随机生成的一张图,边的标签是随机浮点数,所以其实是在训练回归模型。相对来说数据规模比CiteseerGraphDataset
要小很多,所以速度会非常快。输出结果为每个epoch的损失函数值(代码中可见为均方误差)。
# -*- coding: UTF-8 -*-
import dgl
import torch
import numpy as np
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
# 1 Contruct a two-layer GNN model.
class SAGE(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats):
super().__init__()
self.conv1 = dglnn.SAGEConv(in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
self.conv2 = dglnn.SAGEConv(in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')
def forward(self, graph, inputs):
# inputs are features of nodes
h = self.conv1(graph, inputs)
h = F.relu(h)
h = self.conv2(graph, h)
return h
# 2 Generate data randomly.
src = np.random.randint(0, 100, 500)
dst = np.random.randint(0, 100, 500)
edge_pred_graph = dgl.graph((np.concatenate([src, dst]), np.concatenate([dst, src])))
edge_pred_graph.ndata['feature'] = torch.randn(100, 10)
edge_pred_graph.edata['feature'] = torch.randn(1000, 10)
edge_pred_graph.edata['label'] = torch.randn(1000)
edge_pred_graph.edata['train_mask'] = torch.zeros(1000, dtype=torch.bool).bernoulli(0.6)
# 3 Define predictor to compute feature of edge.
# Here gives two predictors `DotProductPredictor` and `MLPPredictor`, but we only apply the former predictor `DotProductPredictor`.
class DotProductPredictor(nn.Module):
# Simply compute the feature of edge by do dot production using the source node and dst
def forward(self, graph, h):
# h contains the node representations computed from the GNN defined
# in the node classification section (Section 5.1).
with graph.local_scope():
graph.ndata['h'] = h
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
return graph.edata['score']
class MLPPredictor(nn.Module):
def __init__(self, in_features, out_classes):
super().__init__()
self.W = nn.Linear(in_features * 2, out_classes)
def apply_edges(self, edges):
h_u = edges.src['h']
h_v = edges.dst['h']
score = self.W(torch.cat([h_u, h_v], 1))
return 'score': score
def forward(self, graph, h):
# h contains the node representations computed from the GNN defined
# in the node classification section (Section 5.1).
with graph.local_scope():
graph.ndata['h'] = h
graph.apply_edges(self.apply_edges)
return graph.edata['score']
# 4 Define model.
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super().__init__()
self.sage = SAGE(in_features, hidden_features, out_features)
self.pred = DotProductPredictor()
def forward(self, g, x):
h = self.sage(g, x)
return self.pred(g, h)
node_features = edge_pred_graph.ndata['feature']
edge_label = edge_pred_graph.edata['label'] # This is not label, but a value only. In this case we just do regression.
train_mask = edge_pred_graph.edata['train_mask']
# Train model.
model = Model(10, 20, 5)
opt = torch.optim.Adam(model.parameters())
for epoch in range(1000):
pred = model(edge_pred_graph, node_features)
loss = ((pred[train_mask] - edge_label[train_mask]) ** 2).mean()
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
# Save model.
torch.save(model, 'edge_sage.m')
以上是关于日常dgl库搭建GNN进行节点分类与边分类任务示例的主要内容,如果未能解决你的问题,请参考以下文章
比较图神经网络PyTorch Geometric 与 Deep Graph Library,帮助团队选出适合的GNN库
学习笔记图神经网络库 DGL 入门教程(backend pytorch)