如何在 Batch、PyTorch 上填充零

Posted

技术标签:

【中文标题】如何在 Batch、PyTorch 上填充零【英文标题】:How to pad zeros on Batch, PyTorch 【发布时间】:2021-06-25 03:49:32 【问题描述】:

有没有更好的方法来做到这一点?如何在不创建新张量对象的情况下用零填充张量?我需要输入始终是相同的batchsize,所以我想用零填充小于batchsize 的输入。就像在 NLP 中当序列长度较短时填充零一样,但这是批处理的填充。

目前,我创建了一个新张量,但正因为如此,我的 GPU 将出现内存不足。我不想将批处理大小减少一半来处理这个操作。

import torch
from torch import nn

class MyModel(nn.Module):
    def __init__(self, batchsize=16):
        super().__init__()
        self.batchsize = batchsize
    
    def forward(self, x):
        b, d = x.shape
        
        print(x.shape) # torch.Size([7, 32])

        if b != self.batchsize: # 2. I need batches to be of size 16, if batch isn't 16, I want to pad the rest to zero
            new_x = torch.zeros(self.batchsize,d) # 3. so I create a new tensor, but this is bad as it increase the GPU memory required greatly
            new_x[0:b,:] = x
            x = new_x
            b = self.batchsize
        
        print(x.shape) # torch.Size([16, 32])

        return x

model = MyModel()
x = torch.randn((7, 32)) # 1. shape's batch is 7, because this is last batch, and I dont want to "drop_last"
y = model(x)
print(y.shape)

【问题讨论】:

【参考方案1】:

你可以像这样填充额外的元素:

import torch.nn.functional as F

n = self.batchsize - b

new_x = F.pad(x, (0,0,n,0)) # pad the start of 2d tensors
new_x = F.pad(x, (0,0,0,n)) # pad the end of 2d tensors
new_x = F.pad(x, (0,0,0,0,0,n)) # pad the end of 3d tensors

【讨论】:

以上是关于如何在 Batch、PyTorch 上填充零的主要内容,如果未能解决你的问题,请参考以下文章

批处理如何在 pytorch 的 seq2seq 模型中工作?

Pytorch 四种边界填充方式(Padding)

Pytorch 四种边界填充方式(Padding)

如何在忽略类中使用 pytorch 闪电精度?

Pytorch:与反向池化和复制填充类似的过程?

PyTorch学习系列——加载数据并生成batch数据