第七周.01.Message更新讲解+GCN实例
Posted oldmao_2000
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了第七周.01.Message更新讲解+GCN实例相关的知识,希望对你有一定的参考价值。
本文内容整理自深度之眼《GNN核心能力培养计划》
公式输入请参考: 在线Latex公式
update_all和send_and_recv是图网中非常重要两个函数,用人话来描述就是我们要如何汇聚邻居的信息:
1、汇聚什么?节点还是边?还是节点加边?还是节点减边?
2、如何汇聚?求和?最大?平均?
3、汇聚后更新节点表征需要什么操作?这个不是必须的,可以是做个特征变化啥的。
update_all
官网说明:https://docs.dgl.ai/generated/dgl.DGLGraph.update_all.html
DGLGraph.update_all(message_func, reduce_func, apply_node_func=None, etype=None)
来看下里面的几个参数。
message_func,消息函数(从源节点到目标节点进行操作),可以使用DGL自带的消息函数或者自定义消息函数
reduce_func,产生消息后,对消息进行汇聚aggregate操作,也是可以使用DGL自带的汇聚函数或者自定义汇聚函数
apply_node_func,节点更新函数,经过上面两步后如何更新节点embedding,这个函数只有用户自定义
以上三个函数是update_all中最核心的部分。
send_and_recv
官网说明:https://docs.dgl.ai/generated/dgl.DGLGraph.send_and_recv.html
DGLGraph.send_and_recv(edges, message_func, reduce_func, apply_node_func=None, etype=None, inplace=False)
这个函数和上面的update_all里面一样有三个核心消息函数,这里就不写了,不一样的是send_and_recv函数可以指定边(第一个参数)进行消息操作。
边这个参数可以有以下几种方式:
方式 | 含义 |
---|---|
整型int | 代表单个边的编号 |
整型 Tensor | Tensor 中的每个元素代表一个边的编号,tensor的device类型及数据类型要和Graph的ID类型要一致 |
可迭代的整型 | 每个元素代表一个边的编号 |
(Tensor ,Tensor ) | 用节点的方式来表示边,两个Tensor 分别表示起始和结束节点 |
(可迭代的整型,可迭代的整型) | 同上 |
Built-in Function
上面讲三个核心函数的时候有提到DGL有自带的消息处理函数,我们来看看:
官网地址:https://docs.dgl.ai/api/python/dgl.function.html#dgl-built-in-function
从表中可以看到有三大类:
第一大类是单对象操作,直接copy消息,下划线后面分别代表拷贝的对象:节点、边,后面两个和前面两个是功能一样的
第二大类是双对象操作,下划线前后分别代表两个对象,中间代表操作类型
第三大类是reduce函数,四个。
实例代码
https://docs.dgl.ai/tutorials/models/1_gnn/1_gcn.html
针对官网的GCN代码重新进行讲解,这次重点看上面提到的函数。
原文的模型描述公式没显示出来,这里重新贴下:
For each node
u
u
u:
- Aggregate neighbors’ representations h v h_{v} hv to produce an intermediate representation h ^ u \\hat{h}_u h^u.
- Transform the aggregated representation h ^ u \\hat{h}_{u} h^u with a linear projection followed by a non-linearity: h u = f ( W u h ^ u ) h_{u} = f(W_{u} \\hat{h}_u) hu=f(Wuh^u).
We will implement step 1 with DGL message passing, and step 2 by PyTorch nn.Module
.
GCN implementation with DGL
We first define the message and reduce function as usual. Since the
aggregation on a node
u
u
u only involves summing over the neighbors’
representations
h
v
h_v
hv, we can simply use builtin functions:
具体代码:
import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
gcn_msg = fn.copy_u(u='h', out='m')#update_all的第一个参数,单对象操作,直接拷贝原节点信息作为消息输出
gcn_reduce = fn.sum(msg='m', out='h')#update_all的第二个参数,采用sum作为aggregate方式,吃的上面的输出
class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats):
super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)#out_feats是输出的分类数量
def forward(self, g, feature):
# Creating a local scope so that all the stored ndata and edata
# (such as the `'h'` ndata below) are automatically popped out
# when the scope exits.
with g.local_scope():
g.ndata['h'] = feature# 初始化的特征丢给节点
g.update_all(gcn_msg, gcn_reduce)# 更新节点表征,里面两个函数在上面,考虑一下博文中提出的三个问题
h = g.ndata['h']#将最后g.ndata读取出来作为结果
return self.linear(h)#update_all的第二个参数,采用sum作为aggregate方式,吃的上面的输出
下面定义一个GCN模型,在cora上进行一个分类,模型包含两层GCN layer,输入特征维度是1433,分类数是7
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer1 = GCNLayer(1433, 16)#输入1433,输出中间层为16
self.layer2 = GCNLayer(16, 7)#输入16,最后输出7分类
def forward(self, g, features):
x = F.relu(self.layer1(g, features))
x = self.layer2(g, x)#最后输出层不用做ReLU
return x
net = Net()
print(net)
导入cora数据,划分训练测试数据集
from dgl.data import CoraGraphDataset
def load_cora_data():
dataset = CoraGraphDataset()
g = dataset[0]
features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
test_mask = g.ndata['test_mask']
return g, features, labels, train_mask, test_mask
测试模型效果
def evaluate(model, g, features, labels, mask):
model.eval()
with th.no_grad():
logits = model(g, features)
logits = logits[mask]
labels = labels[mask]
_, indices = th.max(logits, dim=1)
correct = th.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
训练模型
import time
import numpy as np
g, features, labels, train_mask, test_mask = load_cora_data()
# Add edges between each node and itself to preserve old node representations
g.add_edges(g.nodes(), g.nodes())#加selfloop:A'=A+I
optimizer = th.optim.Adam(net.parameters(), lr=1e-2)
dur = []
for epoch in range(50):
if epoch >=3:
t0 = time.time()
net.train()
logits = net(g, features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[train_mask], labels[train_mask])#两步计算交叉熵损失
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >=3:
dur.append(time.time() - t0)
acc = evaluate(net, g, features, labels, test_mask)
print("Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), acc, np.mean(dur)))
结果:
以上是关于第七周.01.Message更新讲解+GCN实例的主要内容,如果未能解决你的问题,请参考以下文章