论文阅读和分析:《DeepGCNs: Can GCNs Go as Deep as CNNs?》

Posted KPer_Yang


篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了论文阅读和分析:《DeepGCNs: Can GCNs Go as Deep as CNNs?》相关的知识,希望对你有一定的参考价值。






torch_geometric.nn — pytorch_geometric documentation (pytorch-geometric.readthedocs.io)

The implemented skip connections includes the pre-activation residual connection ("res+"), the residual connection ("res"), the dense connection ("dense") and no connections ("plain").

  • Res+ ("res+"):

Normalization → Activation → Dropout → GraphConv → Res \\textNormalization\\to\\textActivation\\to\\textDropout\\to \\textGraphConv\\to\\textRes NormalizationActivationDropoutGraphConvRes

  • Res (:obj:"res") / Dense (:obj:"dense") / Plain(:obj:"plain"):

GraphConv → Normalization → Activation → Res/Dense/Plain → Dropout \\textGraphConv\\to\\textNormalization\\to\\textActivation\\to \\textRes/Dense/Plain\\to\\textDropout GraphConvNormalizationActivationRes/Dense/PlainDropout


from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
from torch.utils.checkpoint import checkpoint

class DeepGCNLayer(torch.nn.Module):
        conv (torch.nn.Module, optional): the GCN operator.
            (default: :obj:`None`)
        norm (torch.nn.Module): the normalization layer. (default: :obj:`None`)
        act (torch.nn.Module): the activation layer. (default: :obj:`None`)
        block (string, optional): The skip connection operation to use
            (:obj:`"res+"`, :obj:`"res"`, :obj:`"dense"` or :obj:`"plain"`).
            (default: :obj:`"res+"`)
        dropout (float, optional): Whether to apply or dropout.
            (default: :obj:`0.`)
        ckpt_grad (bool, optional): If set to :obj:`True`, will checkpoint this
            part of the model. Checkpointing works by trading compute for
            memory, since intermediate activations do not need to be kept in
            memory. Set this to :obj:`True` in case you encounter out-of-memory
            errors while going deep. (default: :obj:`False`)
    def __init__(
        conv: Optional[Module] = None,
        norm: Optional[Module] = None,
        act: Optional[Module] = None,
        block: str = 'res+',
        dropout: float = 0.,
        ckpt_grad: bool = False,

        self.conv = conv
        self.norm = norm
        self.act = act
        self.block = block.lower()
        assert self.block in ['res+', 'res', 'dense', 'plain']
        self.dropout = dropout
        self.ckpt_grad = ckpt_grad

    def reset_parameters(self):

    def forward(self, *args, **kwargs) -> Tensor:
        args = list(args)
        x = args.pop(0)

        if self.block == 'res+':
            h = x
            if self.norm is not None:
                h = self.norm(h)
            if self.act is not None:
                h = self.act(h)
            h = F.dropout(h, p=self.dropout, training=self.training)
            if self.conv is not None and self.ckpt_grad and h.requires_grad:
                # checkpoint不保存中间变量,而是在后向更新的时候重新计算一遍。
                h = checkpoint(self.conv, h, *args, **kwargs)
                h = self.conv(h, *args, **kwargs)

            return x + h

            if self.conv is not None and self.ckpt_grad and x.requires_grad:
                h = checkpoint(self.conv, x, *args, **kwargs)
                h = self.conv(x, *args, **kwargs)
            if self.norm is not None:
                h = self.norm(h)
            if self.act is not None:
                h = self.act(h)

            if self.block == 'res':
                h = x + h
            elif self.block == 'dense':
                h = torch.cat([x, h], dim=-1)
            elif self.block == 'plain':

            return F.dropout(h, p=self.dropout, training=self.training)

    def __repr__(self) -> str:
        return f'self.__class__.__name__(block=self.block)'

注释:pytoch checkpoint

torch.utils.checkpoint — PyTorch 1.13 documentation

torch.utils.checkpoint 简介 和 简易使用_ONE_SIX_MIX的博客-CSDN博客

pytorch 的 checkpoint 是一种用时间换显存的技术,一般训练模式下,pytorch 每次运算后会保留一些中间变量用于求导,而使用 checkpoint 的函数,则不会保留中间变量,中间变量会在求导时再计算一次,因此减少了显存占用,这个 checkpoint 用的好的话,训练时相比不使用 checkpoint 的模型可以增加 30% 的批量大小。


[2006.07739] DeeperGCN: All You Need to Train Deeper GCNs (arxiv.org)

[1904.03751] DeepGCNs: Can GCNs Go as Deep as CNNs? (arxiv.org)

torch_geometric.nn — pytorch_geometric documentation (pytorch-geometric.readthedocs.io)

torch.utils.checkpoint — PyTorch 1.13 documentation

torch.utils.checkpoint 简介 和 简易使用_ONE_SIX_MIX的博客-CSDN博客

以上是关于论文阅读和分析:《DeepGCNs: Can GCNs Go as Deep as CNNs?》的主要内容,如果未能解决你的问题,请参考以下文章

论文阅读和分析:Hybrid Mathematical Symbol Recognition using Support Vector Machines

论文阅读和分析:A Tree-Structured Decoder for Image-to-Markup Generation

[论文阅读] (26) 基于Excel可视化分析的论文实验图表绘制总结——以电影市场为例

[论文阅读] (26) 基于Excel可视化分析的论文实验图表绘制总结——以电影市场为例

Java GC算法——日志解读与分析(GC参数基础配置分析)

Java GC算法——日志解读与分析(GC参数基础配置分析)