GraphSAGE_Code解析
Posted Dodo·D·Caster
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了GraphSAGE_Code解析相关的知识,希望对你有一定的参考价值。
GraphSAGE_Code解析
https://blog.csdn.net/weixin_44027006/article/details/116888648
DataCenter类
该类用于加载数据,存储编号和ID的字典,label和数字的字典,并分割成训练集、测试集和验证集
UnsupervisedLoss类
计算损失函数,其中对于无监督的损失函数,其公式为:
J G ( z u ) = − l o g ( σ ( z u T z v ) ) − Q ⋅ E v n ∼ P n ( v ) l o g ( σ ( − z u T z v n ) ) J \\mathcalG (z_u) = - log(\\sigma (z_u^T z_v)) - Q \\cdot E_v_n \\sim P_n(v)log(\\sigma (-z_u^T z_v_n)) JG(zu)=−log(σ(zuTzv))−Q⋅Evn∼Pn(v)log(σ(−zuTzvn))
需要生成正样本和负样本
- 正样本采用随机游走的方式生成
- 负样本生成方式是用训练集中的节点减去n阶邻居后在剩余的邻居里面随机取样
# Q * Exception(negative score) 计算负例样本的Loss,即Loss函数的后一项
indexs = [list(x) for x in zip(*nps)] # [[源节点,...,源节点],[采样得到的负节点1,...,采样得到的负节点n]]
node_indexs = [node2index[x] for x in indexs[0]] # 获得源节点的编号
neighb_indexs = [node2index[x] for x in indexs[1]] # 获得负样本节点的编号
neg_score = F.cosine_similarity(embeddings[node_indexs], embeddings[neighb_indexs]) # 计算余弦相似性
neg_score = self.Q * torch.mean(torch.log(torch.sigmoid(-neg_score)), 0) # 计算损失的后一项
# multiple positive score 计算正例样本的Loss,即Loss函数的前一项
indexs = [list(x) for x in zip(*pps)]
node_indexs = [node2index[x] for x in indexs[0]]
neighb_indexs = [node2index[x] for x in indexs[1]]
pos_score = F.cosine_similarity(embeddings[node_indexs], embeddings[neighb_indexs])
pos_score = torch.log(torch.sigmoid(pos_score)) # 计算损失的前一项
nodes_score.append(torch.mean(- pos_score - neg_score).view(1, -1)) # 把每个节点的损失加入到列表中 view(1, -1) 中-1表示自动判断维度
Classification类
初始化一个 input_size*类别数 的线性全连接层并运用log_softmax进行分类,返回所有类别(标签)的值, 最大的值即预测的类别
logists = torch.log_softmax(self.fc1(x), 1)
SageLayer类
即公式中的:
h v k ← σ ( W ⋅ C O N C A T ( h v k − 1 , h N ( v ) k ) h^k_v \\leftarrow \\sigma (W \\cdot CONCAT( h_v^k-1 , h^k_ N(v) ) hvk←σ(W⋅CONCAT(hvk−1,hN(v)k)
如果聚合器不是gcn,则连接自身信息和聚合后的信息,并乘权重矩阵后传入RELU激活函数
如何是gcn,则直接乘权重矩阵后传入RELU激活函数
if not self.gcn:
combined = torch.cat([self_feats, aggregate_feats], dim=1)
else:
combined = aggregate_feats
combined = F.relu(self.weight.mm(combined.t())).t()
GraphSage类
即公式中的:
h N ( v ) k ← A G G R E G A T E k ( h u k − 1 , ∀ u ∈ N ( v ) ) h^k_ N(v) \\leftarrow AGGREGATE_k (h^k-1_u, \\forall u \\in N(v)) hN(v)k←AGGREGATEk(huk−1,∀u∈N(v))
采样:
- 如果邻居个数大于num_sample,则随机采样num_sample个邻居
- 如果邻居个数小于num_sample,则全部采样
聚合
- 先建立mask矩阵(row表示源节点,col表示所有节点的邻居)方便矩阵运算
- mask[i, j] = 0 表示节点 j 不是节点 i 的邻居
- mask[i, j] = 1 表示节点 j 是节点 i 的邻居
mask = torch.zeros(len(samp_neighs), len(unique_nodes))
column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]
row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
mask[row_indices, column_indices] = 1
- 如果是MEAN聚合方式,则聚合后的特征为该节点所有邻居特征的均值
- 实现方式为先对mask矩阵中的值除以邻居个数,然后和嵌入矩阵做矩阵乘法
num_neigh = mask.sum(1, keepdim=True) # 计算每个源节点有多少个邻居节点
mask = mask.div(num_neigh).to(embed_matrix.device)
aggregate_feats = mask.mm(embed_matrix)
- 如果是MAX聚合方式,则聚合后的特征为该节点所有邻居特征的最大值
indexs = [x.nonzero() for x in mask == 1]
aggregate_feats = []
for feat in [embed_matrix[x.squeeze()] for x in indexs]: # np.squeeze()函数可以删除数组形状中的单维度条目,即把shape中为1的维度去掉,但是对非单维的维度不起作用
if len(feat.size()) == 1:
aggregate_feats.append(feat.view(1, -1))
else:
aggregate_feats.append(torch.max(feat, 0)[0].view(1, -1))
aggregate_feats = torch.cat(aggregate_feats, 0)
Evaluate函数
输入当前最大的 max_val_f1_score,计算最新的 val_f1_score,如果大于 max_f1_score 则计算test_f1_score
- 分类器返回的结果中,所有标签中值最大的那个即预测的类别
- 输入真实的标签集合和预测的标签集合可以计算 f1 score
vali_f1 = f1_score(labels_val, predicts.cpu().data, average="micro")
get_gnn_embeddings函数
从每一个batch中append训练好的嵌入,最终汇总成完整的嵌入
for index in range(batches):
nodes_batch = nodes[index * batchSize : (index + 1) * batchSize]
embs_batch = gnn_model(nodes_batch)
assert len(embs_batch) == len(nodes_batch)
embs.append(embs_batch)
assert len(embs) == batches
embs = torch.cat(embs, 0)
train_classification函数
分类器训练过程,包含了每一次迭代,每一个batch要做的事情
apply_model
节点嵌入训练过程,包含了不同的方法,每一个epoch要做的事情
main函数
设定参数,训练模型
以上是关于GraphSAGE_Code解析的主要内容,如果未能解决你的问题,请参考以下文章