torch_geometric笔记:数据集 ENZYMES &Minibatches

Posted UQI-LIUWJ

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了torch_geometric笔记:数据集 ENZYMES &Minibatches相关的知识,希望对你有一定的参考价值。

         Pytorch Geometric中包含大量的常见基准数据集。在初始化数据集的时候,框架会自动下载数据集的原始文件,并将其处理为Data对象。例如要下载ENZYMES数据集(由600个graph划分为6个类别)

1 下载数据集

from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='', name='ENZYMES')

dataset
#ENZYMES(600)

type(dataset)
#torch_geometric.datasets.tu_dataset.TUDataset

len(dataset)
#600
#说明600张图

dataset.num_classes
#6
#图一共有6各不同的类

dataset.num_node_features
#3 每一个节点有三个特征

data = dataset[0]
data
#Data(edge_index=[2, 168], x=[37, 3], y=[1])
#第一张图有168条有向边,37个节点,每个节点3个特征,整张图有一个类别

data.is_undirected()
#True

2 Mini-batches

        神经网络通常以batch的方式进行训练,geometric在mini-batch实现了并行化,这种组合允许在一个batch中使用不同数量的边和节点。

        在torch_geometric.data.DataLoader中,已经包含了此过程。

        这种mini-batch的操作本质上来说是将一个batch的graph看成是一个大的graph,由此,无论batch size是多少,其将所有的操作都统一在一个大图上进行操作。

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

dataset = TUDataset(root='', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch in loader:
    print(batch,batch.num_graphs)

'''
Batch(edge_index=[2, 3890], x=[1075, 21], y=[32], batch=[1075], ptr=[33]) 32
Batch(edge_index=[2, 4284], x=[1157, 21], y=[32], batch=[1157], ptr=[33]) 32
Batch(edge_index=[2, 4098], x=[1086, 21], y=[32], batch=[1086], ptr=[33]) 32
Batch(edge_index=[2, 3668], x=[916, 21], y=[32], batch=[916], ptr=[33]) 32
Batch(edge_index=[2, 4062], x=[1074, 21], y=[32], batch=[1074], ptr=[33]) 32
Batch(edge_index=[2, 4086], x=[1096, 21], y=[32], batch=[1096], ptr=[33]) 32
Batch(edge_index=[2, 3954], x=[1005, 21], y=[32], batch=[1005], ptr=[33]) 32
Batch(edge_index=[2, 4170], x=[1064, 21], y=[32], batch=[1064], ptr=[33]) 32
Batch(edge_index=[2, 4258], x=[1149, 21], y=[32], batch=[1149], ptr=[33]) 32
Batch(edge_index=[2, 3836], x=[997, 21], y=[32], batch=[997], ptr=[33]) 32
Batch(edge_index=[2, 3886], x=[1016, 21], y=[32], batch=[1016], ptr=[33]) 32
Batch(edge_index=[2, 4066], x=[1042, 21], y=[32], batch=[1042], ptr=[33]) 32
Batch(edge_index=[2, 3946], x=[1046, 21], y=[32], batch=[1046], ptr=[33]) 32
Batch(edge_index=[2, 3656], x=[927, 21], y=[32], batch=[927], ptr=[33]) 32
Batch(edge_index=[2, 4110], x=[1034, 21], y=[32], batch=[1034], ptr=[33]) 32
Batch(edge_index=[2, 3824], x=[1002, 21], y=[32], batch=[1002], ptr=[33]) 32
Batch(edge_index=[2, 4178], x=[1116, 21], y=[32], batch=[1116], ptr=[33]) 32
Batch(edge_index=[2, 3736], x=[974, 21], y=[32], batch=[974], ptr=[33]) 32
Batch(edge_index=[2, 2856], x=[804, 21], y=[24], batch=[804], ptr=[25]) 24
'''

以  Batch(edge_index=[2, 3890], x=[1075, 21], y=[32], batch=[1075], ptr=[33])  为例:

  • edge_index=[2, 3890]——这个batch一共3890条边
  • x=[1075, 21]——整个batch的节点特征矩阵,这个batch一共2075个点,至于这个21,我不太明白,是因为不同的图有不同的特征,所以拼起来一共21个不同的特征吗?欢迎大家在评论区指正!
  • y=[32]——32个图,32维特征
  • batch=[1075]——batch是一个列向量,它将每个节点映射到该batch中的对应的graph:

       

2.1 自己的图列表 &DataLoader

不难发现,这种下载的数据集,可以看成是图的集合

 

那么如果我门自己设计了一些图,集合成一个列表,我们可以直接用这个列表构造DataLoader(注:这里的DataLoader是torch_geometric.loader的DataLoader)

 

 

 

以上是关于torch_geometric笔记:数据集 ENZYMES &Minibatches的主要内容,如果未能解决你的问题,请参考以下文章

torch_geometric笔记:nn. graclus (图点分类)

torch_geometric 笔记:TORCH_GEOMETRIC.UTILS(更新中)

torch_geometric笔记:max_pool 与max_pool_x

pytorch 笔记:torch_geometric 创建一张图

torch_geometric 笔记:nn.ChebNet

torch_geometric 笔记:global_mean_pool