如何创建图神经网络数据集? (pytorch 几何)
Posted
技术标签:
【中文标题】如何创建图神经网络数据集? (pytorch 几何)【英文标题】:How to create a graph neural network dataset? (pytorch geometric) 【发布时间】:2021-06-21 14:33:58 【问题描述】:如何将我自己的数据集转换为可供 pytorch 几何图形神经网络使用?
所有教程都使用已转换为可供 pytorch 使用的现有数据集。例如,如果我有自己的点云数据集,我如何使用它来训练图神经网络的分类?我自己的分类图像数据集呢?
【问题讨论】:
【参考方案1】:您需要如何转换数据取决于您的模型所期望的格式。
图神经网络通常期望(的一个子集):
节点特征 边缘 边缘属性 节点目标取决于问题。您可以在PyTorch Geometric 中使用这些值的张量创建一个对象(并根据需要扩展属性),并使用Data
对象,如下所示:
data = Data(x=x, edge_index=edge_index, y=y)
data.train_idx = torch.tensor([...], dtype=torch.long)
data.test_mask = torch.tensor([...], dtype=torch.bool)
【讨论】:
【参考方案2】:就像文档中提到的那样。 pytorch-geometric
我真的需要使用这些数据集接口吗? 不!就像在常规 PyTorch 中一样,您不必使用数据集,例如,当您想要动态创建合成数据而不将它们显式保存到磁盘时。在这种情况下,只需传递一个包含 torch_geometric.data.Data 对象的常规 python 列表并将它们传递给 torch_geometric.loader.DataLoader
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)
【讨论】:
【参考方案3】:from torch_geometric.data import Dataset, Data
class MyCustomDataset(Dataset):
def __init__():
self.filename = .. # List of raw files, in your case point cloud
super(MyCustomDataset, self).__init()
@property
def raw_file_names(self):
return self.filename
@property
def processed_file_names(self):
""" return list of files should be in processed dir, if found - skip processing."""
processed_filename = []
return processed_filename
def download(self):
pass
def process(self):
for file in self.raw_paths:
self._process_one_step(file)
def _process_one_step(self, path):
out_path = (self.processed_dir, "some_unique_filename.pt")
# read your point cloud here,
# convert point cloud to Data object
data = Data(x=node_features,
edge_index=edge_index,
edge_attr=edge_attr,
y=label #you can add more arguments as you like
)
torch.save(data, out_path)
return
def __len__(self):
return len(self.processed_file_names)
def __getitem__(self, idx):
data = torch.load(os.path.join(self.processed_dir, self.processed_file_names[idx]))
return data
这将以正确的格式创建数据。然后您可以使用torch_geometric.data.Dataloader
创建一个数据加载器,然后训练您的网络。
【讨论】:
以上是关于如何创建图神经网络数据集? (pytorch 几何)的主要内容,如果未能解决你的问题,请参考以下文章
MNIST数据集上卷积神经网络的简单实现(使用PyTorch)
我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!