GraphSAGE_Code解析

Posted Dodo·D·Caster

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了GraphSAGE_Code解析相关的知识,希望对你有一定的参考价值。

GraphSAGE_Code解析

https://blog.csdn.net/weixin_44027006/article/details/116888648

DataCenter类

该类用于加载数据,存储编号和ID的字典,label和数字的字典,并分割成训练集、测试集和验证集

UnsupervisedLoss类

计算损失函数,其中对于无监督的损失函数,其公式为:

J G ( z u ) = − l o g ( σ ( z u T z v ) ) − Q ⋅ E v n ∼ P n ( v ) l o g ( σ ( − z u T z v n ) ) J \\mathcalG (z_u) = - log(\\sigma (z_u^T z_v)) - Q \\cdot E_v_n \\sim P_n(v)log(\\sigma (-z_u^T z_v_n)) JG(zu)=log(σ(zuTzv))QEvnPn(v)log(σ(zuTzvn))

需要生成正样本和负样本

  • 正样本采用随机游走的方式生成
  • 负样本生成方式是用训练集中的节点减去n阶邻居后在剩余的邻居里面随机取样
# Q * Exception(negative score) 计算负例样本的Loss,即Loss函数的后一项
indexs = [list(x) for x in zip(*nps)]                                               # [[源节点,...,源节点],[采样得到的负节点1,...,采样得到的负节点n]]
node_indexs = [node2index[x] for x in indexs[0]]                                    # 获得源节点的编号
neighb_indexs = [node2index[x] for x in indexs[1]]                                  # 获得负样本节点的编号
neg_score = F.cosine_similarity(embeddings[node_indexs], embeddings[neighb_indexs]) # 计算余弦相似性
neg_score = self.Q * torch.mean(torch.log(torch.sigmoid(-neg_score)), 0)            # 计算损失的后一项

# multiple positive score 计算正例样本的Loss,即Loss函数的前一项
indexs = [list(x) for x in zip(*pps)]
node_indexs = [node2index[x] for x in indexs[0]]
neighb_indexs = [node2index[x] for x in indexs[1]]
pos_score = F.cosine_similarity(embeddings[node_indexs], embeddings[neighb_indexs])
pos_score = torch.log(torch.sigmoid(pos_score))                                     # 计算损失的前一项

nodes_score.append(torch.mean(- pos_score - neg_score).view(1, -1))                  # 把每个节点的损失加入到列表中 view(1, -1) 中-1表示自动判断维度

Classification类

初始化一个 input_size*类别数 的线性全连接层并运用log_softmax进行分类,返回所有类别(标签)的值, 最大的值即预测的类别

logists = torch.log_softmax(self.fc1(x), 1)

SageLayer类

即公式中的:

h v k ← σ ( W ⋅ C O N C A T ( h v k − 1 , h N ( v ) k ) h^k_v \\leftarrow \\sigma (W \\cdot CONCAT( h_v^k-1 , h^k_ N(v) ) hvkσ(WCONCAT(hvk1,hN(v)k)

如果聚合器不是gcn,则连接自身信息和聚合后的信息,并乘权重矩阵后传入RELU激活函数

如何是gcn,则直接乘权重矩阵后传入RELU激活函数

if not self.gcn:
		combined = torch.cat([self_feats, aggregate_feats], dim=1)
else:
		combined = aggregate_feats
combined = F.relu(self.weight.mm(combined.t())).t()

GraphSage类

即公式中的:

h N ( v ) k ← A G G R E G A T E k ( h u k − 1 , ∀ u ∈ N ( v ) ) h^k_ N(v) \\leftarrow AGGREGATE_k (h^k-1_u, \\forall u \\in N(v)) hN(v)kAGGREGATEk(huk1,uN(v))

采样:

  • 如果邻居个数大于num_sample,则随机采样num_sample个邻居
  • 如果邻居个数小于num_sample,则全部采样

聚合

  • 先建立mask矩阵(row表示源节点,col表示所有节点的邻居)方便矩阵运算
    • mask[i, j] = 0 表示节点 j 不是节点 i 的邻居
    • mask[i, j] = 1 表示节点 j 是节点 i 的邻居
mask = torch.zeros(len(samp_neighs), len(unique_nodes))
column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]
row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
mask[row_indices, column_indices] = 1
  • 如果是MEAN聚合方式,则聚合后的特征为该节点所有邻居特征的均值
    • 实现方式为先对mask矩阵中的值除以邻居个数,然后和嵌入矩阵做矩阵乘法
num_neigh = mask.sum(1, keepdim=True)                                       # 计算每个源节点有多少个邻居节点
mask = mask.div(num_neigh).to(embed_matrix.device)
aggregate_feats = mask.mm(embed_matrix)
  • 如果是MAX聚合方式,则聚合后的特征为该节点所有邻居特征的最大值
indexs = [x.nonzero() for x in mask == 1]
aggregate_feats = []
for feat in [embed_matrix[x.squeeze()] for x in indexs]:                    # np.squeeze()函数可以删除数组形状中的单维度条目,即把shape中为1的维度去掉,但是对非单维的维度不起作用
		if len(feat.size()) == 1:
				aggregate_feats.append(feat.view(1, -1))
		else:
				aggregate_feats.append(torch.max(feat, 0)[0].view(1, -1))
aggregate_feats = torch.cat(aggregate_feats, 0)

Evaluate函数

输入当前最大的 max_val_f1_score,计算最新的 val_f1_score,如果大于 max_f1_score 则计算test_f1_score

  • 分类器返回的结果中,所有标签中值最大的那个即预测的类别
  • 输入真实的标签集合和预测的标签集合可以计算 f1 score
vali_f1 = f1_score(labels_val, predicts.cpu().data, average="micro")

get_gnn_embeddings函数

从每一个batch中append训练好的嵌入,最终汇总成完整的嵌入

for index in range(batches):
		nodes_batch = nodes[index * batchSize : (index + 1) * batchSize]
    embs_batch = gnn_model(nodes_batch)
    assert len(embs_batch) == len(nodes_batch)
    embs.append(embs_batch)
 
assert len(embs) == batches
embs = torch.cat(embs, 0)

train_classification函数

分类器训练过程,包含了每一次迭代,每一个batch要做的事情

apply_model

节点嵌入训练过程,包含了不同的方法,每一个epoch要做的事情

main函数

设定参数,训练模型

以上是关于GraphSAGE_Code解析的主要内容,如果未能解决你的问题,请参考以下文章

损失函数Center Loss 代码解析

Unet项目解析: 模型编译-优化函数损失函数指标列表

PyTorch 中自定义后向函数的损失 - 简单 MSE 示例中的爆炸损失

收藏 | YOLOv4损失函数全面解析

收藏 | YOLOv4损失函数全面解析

YOLOv5全面解析教程③:更快更好的边界框回归损失