torch_geometric笔记:max_pool 与max_pool_x
Posted UQI-LIUWJ
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了torch_geometric笔记:max_pool 与max_pool_x相关的知识,希望对你有一定的参考价值。
1 max_pool
1.1 函数介绍
torch_geometric.nn.max_pool(
cluster,
data,
transform=None)
对由torch_geometricy .data给出的图形进行池化和粗化。
数据对象根据集群cluster中定义的集群。同一集群中的所有节点将表示为一个节点。最终节点特征由同一簇内所有节点的特征最大值定义,节点位置平均,边的index定义为同一簇内所有节点的边index的并集。
1.2 参数说明
cluster (LongTensor) | 簇向量,每一个维度表示了一个点属于哪个簇 |
data (Data) | torch_geometric的data 对象 |
transform (callable, optional) | 一个函数/转换,接受粗化和池化的torch_geometry .data。数据对象,并返回转换后的版本。 |
返回torch_geometric的data 对象
1.3 举例说明
假如我们一开始的data 为:
Batch(x=[9893, 1], edge_index=[2, 34637], y=[9893, 1], batch=[9893], ptr=[2])
from torch_geometric.nn import max_pool
cluster = graclus(data.edge_index, num_nodes=x.shape[0])
cluster
#tensor([ 0, 1, 1, ..., 9890, 9891, 9892])
#第i维表示第i个点在以第几个点为中心点的簇中
data_c = max_pool(
cluster,
data)
data_c
#Batch(x=[5863, 1], edge_index=[2, 21983], batch=[5863])
#分成了5863 个cluster
1.3.1 mini-batch 的max_pool
在mini_batch的话,需要这样写:
data_c = max_pool(
cluster,
Data(
x=data.x,
batch=data.batch,
edge_index=data.edge_index))
data
2 max_pool_x
对一个cluster中中的x的特征进行最大池化操作
max_pool_x(
cluster,
x,
batch,
size: Optional[int] = None)
注意和max_pool的区别
max_pool 返回的是data,max_pool_x返回的是Tensor
max_pool 相当于max_pool_x的基础上,再对图的边进行了修改合并操作
以上是关于torch_geometric笔记:max_pool 与max_pool_x的主要内容,如果未能解决你的问题,请参考以下文章
pytorch 笔记:torch_geometric 创建一张图
torch_geometric 笔记: 数据集Cora &简易 GNN
torch_geometric笔记:max_pool 与max_pool_x