PyTorch 中的连接张量
Posted
技术标签:
【中文标题】PyTorch 中的连接张量【英文标题】:Concat tensors in PyTorch 【发布时间】:2019-07-10 16:20:06 【问题描述】:我有一个名为 data
的张量,形状为 [128, 4, 150, 150]
,其中 128 是批量大小,4 是通道数,最后两个维度是高度和宽度。我有另一个名为fake
的张量,形状为[128, 1, 150, 150]
。
我想从data
的第二维中删除最后一个list/array
;数据的形状现在是[128, 3, 150, 150]
;并将其与fake
连接起来,将连接的输出维度设为[128, 4, 150, 150]
。
换句话说,我想将data
的前三个维度与fake
连接起来,得到一个4 维张量。
我正在使用 PyTorch,遇到了 torch.cat()
和 torch.stack()
函数
这是我编写的示例代码:
fake_combined = []
for j in range(batch_size):
fake_combined.append(torch.stack((data[j][0].to(device), data[j][1].to(device), data[j][2].to(device), fake[j][0].to(device))))
fake_combined = torch.tensor(fake_combined, dtype=torch.float32)
fake_combined = fake_combined.to(device)
但我在行中遇到错误:
fake_combined = torch.tensor(fake_combined, dtype=torch.float32)
错误是:
ValueError: only one element tensors can be converted to Python scalars
另外,如果我打印fake_combined
的形状,我得到的输出是[128,]
而不是[128, 4, 150, 150]
当我打印fake_combined[0]
的形状时,我得到的输出为[4, 150, 150]
,这与预期的一样。
所以我的问题是,为什么我不能使用torch.tensor()
将列表转换为张量。我错过了什么吗?有没有更好的方法来做我打算做的事情?
任何帮助将不胜感激!谢谢!
【问题讨论】:
【参考方案1】:@rollthedice32 的回答非常好。出于教育目的,这里使用torch.cat
a = torch.rand(128, 4, 150, 150)
b = torch.rand(128, 1, 150, 150)
# Cut out last dimension
a = a[:, :3, :, :]
# Concatenate in 2nd dimension
result = torch.cat([a, b], dim=1)
print(result.shape)
# => torch.Size([128, 4, 150, 150])
【讨论】:
【参考方案2】:您也可以只分配给该特定维度。
orig = torch.randint(low=0, high=10, size=(2,3,2,2))
fake = torch.randint(low=111, high=119, size=(2,1,2,2))
orig[:,[2],:,:] = fake
原来的之前
tensor([[[[0, 1],
[8, 0]],
[[4, 9],
[6, 1]],
[[8, 2],
[7, 6]]],
[[[1, 1],
[8, 5]],
[[5, 0],
[8, 6]],
[[5, 5],
[2, 8]]]])
假的
tensor([[[[117, 115],
[114, 111]]],
[[[115, 115],
[118, 115]]]])
原版之后
tensor([[[[ 0, 1],
[ 8, 0]],
[[ 4, 9],
[ 6, 1]],
[[117, 115],
[114, 111]]],
[[[ 1, 1],
[ 8, 5]],
[[ 5, 0],
[ 8, 6]],
[[115, 115],
[118, 115]]]])
希望这会有所帮助! :)
【讨论】:
以上是关于PyTorch 中的连接张量的主要内容,如果未能解决你的问题,请参考以下文章