如何在 PyTorch 中高效实现非全连接线性层?

Posted

技术标签:

【中文标题】如何在 PyTorch 中高效实现非全连接线性层?【英文标题】:How to efficiently implement a non-fully connected Linear Layer in PyTorch? 【发布时间】:2022-01-13 02:08:09 【问题描述】:

我制作了一个我正在尝试实现的缩小版本的示例图:

所以顶部的两个输入节点只与顶部的三个输出节点完全连接,同样的设计适用于底部的两个节点。到目前为止,我已经提出了两种在 PyTorch 中实现这一点的方法,但都不是最优的。

首先是创建一个包含许多较小线性层的 nn.ModuleList,并在前向传递期间通过它们迭代输入。对于图表的示例,它看起来像这样:

class Module(nn.Module):
  def __init__(self):
    self.layers = nn.Module([nn.Linear(2, 3) for i in range(2)])
  
  def forward(self, input):
    output = torch.zeros(2, 3)
    for i in range(2):
      output[i, :] = self.layers[i](input.view(2, 2)[i, :])
    return output.flatten()

所以这完成了图中的网络,主要问题是它非常慢。我认为这是因为 PyTorch 必须按顺序处理 for 循环,而不能并行处理输入张量。

要“矢量化”模块以便 PyTorch 可以更快地运行它,我有这个实现:

class Module(nn.Module):
  def __init__(self):
    self.layer = nn.Linear(4, 6)
    self.mask = # create mask of ones and zeros to "block" certain layer connections
  
  def forward(self, input):
    prune.custom_from_mask(self.layer, name='weight', mask=self.mask)
    return self.layer(input)

这也完成了图的网络,通过使用权重修剪来确保全连接层中的某些权重始终为零(例如,连接顶部输入节点和底部输出节点的权重将始终为零,因此它有效地“断开连接”)。这个模块比前一个模块快得多,因为没有 for 循环。现在的问题是这个模块占用了更多的内存。这可能是因为即使大多数层的权重为零,PyTorch 仍然将网络视为存在。这种实现基本上保留了比它需要的更多的权重。

以前有没有人遇到过这个问题并提出过有效的解决方案?

【问题讨论】:

【参考方案1】:

如果权重共享没问题,那么一维卷积应该可以解决问题:

class Module(nn.Module):
  def __init__(self):
    self.layers = nn.Conv1d(in_channels=2, out_channels=3, kernel_size=1)
    self._n_splits = 2

  
  def forward(self, input):
    
    B, C = input.shape
    output = self.layers(input.view(B, C//self._n_splits, -1))
    return output.view(B, C)

如果权重共享不行,那么您可以使用组卷积:self.layers = nn.Conv1d(in_channels=4, out_channels=4, kernel_size=1, stride=1, groups=2)。但是,我不确定这是否可以实现任意数量的通道拆分,您可以查看文档:https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html

一维卷积是输入所有通道上的全连接层。组卷积会将通道分成组并对它们执行单独的卷积操作(这是您想要的)。

实现将类似于:

class Module(nn.Module):
  def __init__(self):
    self.layers = nn.Conv1d(in_channels=2, out_channels=4, kernel_size=1, groups=2)

  
  def forward(self, input):
    
    B, C = input.shape
    output = self.layers(input.unsqueeze(-1))
    return output.squeeze()

编辑:

如果您需要奇数个输出通道,您可以组合两个组转换。

class Module(nn.Module):
  def __init__(self):
    self.layers = nn.Sequence(
         nn.Conv1d(in_channels=2, out_channels=4, kernel_size=1, groups=2),
         nn.Conv1d(in_channels=4, out_channels=3, kernel_size=1, groups=3))


  def forward(self, input):
    
    B, C = input.shape
    output = self.layers(input.unsqueeze(-1))
    return output.squeeze()

这将有效地定义图中所需的输入通道,并允许您使用任意数量的输出通道。请注意,如果第二个卷积具有groups=1,则您将允许混合通道并有效地使第一组卷积层无用。

从理论上讲,这两个卷积之间不需要激活函数。我们将它们组合成一个线性问题。但是,添加激活函数可能会提高性能。

【讨论】:

是的,不应该共享权重。我认为你的第二个解决方案接近我可以使用的东西,但它可能需要一些调整。例如,如果我尝试创建您的 Conv1d 层,则会收到“out_channels must be divisible by groups”错误。 这意味着如果您有 2 个组,则需要偶数个输出通道。请注意,如果您想要奇数个输出通道(无论出于何种原因),您可以组合两组卷积。 nn.Sequential(nn.Conv1d(in_channels=2, out_channels=4, kernel=1, groups=2), nn.Conv1d(in_channels=4, out_channels=3, kernel=1, groups=3)) 这将在一半的输入通道(和一半的输出通道)上进行一组卷积,并在通道方面进行另一组卷积,从 4 个输出通道转换为 3 个输出通道。

以上是关于如何在 PyTorch 中高效实现非全连接线性层?的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch全连接网络:激活函数对一维拟合问题的影响探讨;网络加深后带来的loss不降问题

如何使用 PyTorch 中的单个全连接层直接将输入连接到输出?

pytorch nn.Linear(对输入数据做线性变换:y=Ax+b)(全连接层?)

动手学习pytorch——多层感知机

Torch:为啥在相同数据大小下卷积层甚至比全连接线性层慢

神经网络之全连接层(线性层)