pytorch-geometric 从入门到不放弃 day3
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch-geometric 从入门到不放弃 day3相关的知识,希望对你有一定的参考价值。
参考技术A 已经学习了data,dataset和dataloader,不如就先实战根据自己的数据集,写好自定义的dataset吧。1、首先将每个图数据预处理成Data需要的形式:
x是所有节点的特征,【num_nodes, embed_dim】,要注意这里所有的节点特征维度需要一致;
edge_index是邻接表,有向图:【【0,1】,【1,2】】;无向图:【【0,1,1,2】,【1,2,0,1】】;
y类别标签;
其他自定义的数据,需要是int或者float类型。
最后分别转换成numpy.array类型,使用numpy.savez()保存成npz文件,分别存放在train/eval/test路径下的graph文件夹里,后面要用。
np.savez(os.path.join(path, data_name, 'graph', file_id+'.npz'), x=x, edge_index=edge_idx, y=y, dtype=object)
2、自定义dataset,主要是__getitem__函数,逻辑是传入上面处理好的文件list,然后getitem函数按照列表下标读取,返回Data类型就好。
class GraphDataset(Dataset):
def __init__(self, root, file_list, treeLenDic, lower = 2, upper = 100000):
super(GraphDataset, self).__init__()
self.root = root
self.file_list = list(filter(lambda id: id.split('.')[0] in treeLenDic.keys() and treeLenDic[id.split('.')[0]] >= lower and treeLenDic[id.split('.')[0]] <= upper, file_list))
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
id = self.file_list[idx]
data = np.load(os.path.join(self.root, id), allow_pickle=True)
return Data(x=torch.tensor(data['x'], dtype=torch.float32),
edge_index=torch.LongTensor(data['edge_index']),
y=torch.LongTensor([int(data['y'])]))
这里对每个图文件的长度做了筛选,要至少有两个节点,那种只有一个点的就不考虑了,TreeLenDic是个字典,graph_id: len.
3. 将Dataset实例化的对象传入DataLoader就可以批量读取数据了
好啦,到这里我数据预处理以及自定义Dataset就搞定了,可以开始学习torch.geometric.nn里面的网络模型啦~
《Java从入门到放弃》文章目录
转眼半个月过去了,不知不觉也写了10篇博客,突然发现所有的目录都没有纯列表的展示,所以特意写一个目录篇,来记录该系列下所有的文章。
当然,因为现在还没有写完,所以先按时间顺序排列,等相关内容都写完后,再按学习顺序来整理。
好了,先整理到这儿,如果大家有什么感兴趣的入门级的内容,可以在评论回复,博主可以优先编写相关内容。
本文出自 “软件思维” 博客,请务必保留此出处http://softi.blog.51cto.com/13093971/1953609
以上是关于pytorch-geometric 从入门到不放弃 day3的主要内容,如果未能解决你的问题,请参考以下文章
10年Web前端工程师自白:Web前端开发如何从入门到不放弃
10年web前端工程师自白:web前端开发如何从入门到不放弃