PyTorch torch.max 在多个维度上
Posted
技术标签:
【中文标题】PyTorch torch.max 在多个维度上【英文标题】:PyTorch torch.max over multiple dimensions 【发布时间】:2020-07-24 04:22:47 【问题描述】:拥有像 :x.shape = [3, 2, 2]
这样的张量。
import torch
x = torch.tensor([
[[-0.3000, -0.2926],[-0.2705, -0.2632]],
[[-0.1821, -0.1747],[-0.1526, -0.1453]],
[[-0.0642, -0.0568],[-0.0347, -0.0274]]
])
我需要将.max()
带到第二和第三维度。我希望像这样的[-0.2632, -0.1453, -0.0274]
作为输出。我尝试使用:x.max(dim=(1,2))
,但这会导致错误。
【问题讨论】:
我更新了我的答案,因为我提到的 PR 现在已合并,并且此功能在每晚版本中可用。请参阅下面的更新答案。 【参考方案1】:现在,您可以这样做了。 PR was merged(8 月 28 日)现在可以在夜间版本中使用。
只需使用torch.amax()
:
import torch
x = torch.tensor([
[[-0.3000, -0.2926],[-0.2705, -0.2632]],
[[-0.1821, -0.1747],[-0.1526, -0.1453]],
[[-0.0642, -0.0568],[-0.0347, -0.0274]]
])
print(torch.amax(x, dim=(1, 2)))
# Output:
# >>> tensor([-0.2632, -0.1453, -0.0274])
原答案
截至今天(2020 年 4 月 11 日),在 PyTorch 中无法在多个维度上执行 .min()
或 .max()
。有一个关于它的open issue,你可以关注它,看看它是否得到实施。您的解决方法是:
import torch
x = torch.tensor([
[[-0.3000, -0.2926],[-0.2705, -0.2632]],
[[-0.1821, -0.1747],[-0.1526, -0.1453]],
[[-0.0642, -0.0568],[-0.0347, -0.0274]]
])
print(x.view(x.size(0), -1).max(dim=-1))
# output:
# >>> values=tensor([-0.2632, -0.1453, -0.0274]),
# >>> indices=tensor([3, 3, 3]))
所以,如果您只需要这些值:x.view(x.size(0), -1).max(dim=-1).values
。
如果x
不是连续张量,则.view()
将失败。在这种情况下,您应该改用.reshape()
。
2020 年 8 月 26 日更新
此功能正在PR#43092 中实现,函数将被称为amin
和amax
。他们将只返回值。这可能很快就会被合并,所以当您阅读本文时,您可能能够在夜间构建中访问这些功能:) 玩得开心。
【讨论】:
谢谢。它有效,但在我的情况下需要使用 reshape insted 视图以避免错误 @iGero 好的,我会在答案中添加这个注释以防万一:) 很高兴它有帮助 我用 pytorch 1.5.0 和 1.6.0 版本试过这个,但是没有方法torch.amax
。你能验证吗?还是我做错了什么?
@zwep 正如我在回答中所说,此功能目前在 nightly release 中可用。因此,如果您想访问 amax,则必须升级到它,或者等到下一个稳定版本,即 1.7.0。
@Berriel 啊抱歉,我不知道哪个版本与夜间发布有关。虽然我不知道在这种情况下你是否可以谈论一个版本【参考方案2】:
虽然solution of Berriel 解决了这个特定问题,但我认为添加一些解释可能会帮助大家了解这里使用的技巧,以便它可以适应(m)任何其他维度。
让我们从检查输入张量x
的形状开始:
In [58]: x.shape
Out[58]: torch.Size([3, 2, 2])
所以,我们有一个形状为 (3, 2, 2)
的 3D 张量。现在,根据 OP 的问题,我们需要计算张量中沿 1st 和 2nd 维度的值的maximum
。在撰写本文时,torch.max()
的 dim
参数仅支持 int
。所以,我们不能使用元组。因此,我们将使用以下技巧,我将其称为,
Flatten & Max Trick:因为我们想要在 1st 和 2nd 维度上计算 max
,我们将进行展平这两个维度都归为一个维度,而第 0th 维度保持不变。这正是正在发生的事情:
In [61]: x.flatten().reshape(x.shape[0], -1).shape
Out[61]: torch.Size([3, 4]) # 2*2 = 4
所以,现在我们将 3D 张量缩小为 2D 张量(即矩阵)。
In [62]: x.flatten().reshape(x.shape[0], -1)
Out[62]:
tensor([[-0.3000, -0.2926, -0.2705, -0.2632],
[-0.1821, -0.1747, -0.1526, -0.1453],
[-0.0642, -0.0568, -0.0347, -0.0274]])
现在,我们可以简单地将max
应用于第一个st 维度(即在这种情况下,第一个维度也是最后一个维度),因为展平的维度位于该维度中。
In [65]: x.flatten().reshape(x.shape[0], -1).max(dim=1) # or: `dim = -1`
Out[65]:
torch.return_types.max(
values=tensor([-0.2632, -0.1453, -0.0274]),
indices=tensor([3, 3, 3]))
我们在结果张量中得到 3 个值,因为矩阵中有 3 行。
现在,另一方面,如果您想在 0th 和 1st 维度上计算 max
,您可以:
In [80]: x.flatten().reshape(-1, x.shape[-1]).shape
Out[80]: torch.Size([6, 2]) # 3*2 = 6
In [79]: x.flatten().reshape(-1, x.shape[-1])
Out[79]:
tensor([[-0.3000, -0.2926],
[-0.2705, -0.2632],
[-0.1821, -0.1747],
[-0.1526, -0.1453],
[-0.0642, -0.0568],
[-0.0347, -0.0274]])
现在,我们可以简单地将max
应用于第 0th 维度,因为这是我们展平的结果。 ((同样,从我们原来的 (3, 2, 2
) 形状来看,在前 2 个维度取 max 之后,我们应该得到两个值作为结果。)
In [82]: x.flatten().reshape(-1, x.shape[-1]).max(dim=0)
Out[82]:
torch.return_types.max(
values=tensor([-0.0347, -0.0274]),
indices=tensor([5, 5]))
类似地,您可以将此方法应用于多维和其他缩减函数,例如min
。
注意:我遵循基于 0 的维度 (0, 1, 2, 3, ...
) 的术语只是为了与 PyTorch 的使用和代码保持一致。
【讨论】:
哦,有点明白了。你能具体说明什么是“扁平化的结果”吗?我将不胜感激,谢谢! Flattening 总是返回一个 1D 大小的张量,该张量由原始形状中各个维度的乘积产生(即,此处为 3*2*2 与张量x
) 【参考方案3】:
如果您只想使用torch.max()
函数来获取二维张量中最大条目的索引,您可以这样做:
max_i_vals, max_i_indices = torch.max(x, 0)
print('max_i_vals, max_i_indices: ', max_i_vals, max_i_indices)
max_j_index = torch.max(max_i_vals, 0)[1]
print('max_j_index: ', max_j_index)
max_index = [max_i_indices[max_j_index], max_j_index]
print('max_index: ', max_index)
在测试中,上面为我打印出来:
max_i_vals: tensor([0.7930, 0.7144, 0.6985, 0.7349, 0.9162, 0.5584, 1.4777, 0.8047, 0.9008, 1.0169, 0.6705, 0.9034, 1.1159, 0.8852, 1.0353], grad_fn=\<MaxBackward0>)
max_i_indices: tensor([ 5, 8, 10, 6, 13, 14, 5, 6, 6, 6, 13, 4, 13, 13, 11])
max_j_index: tensor(6)
max_index: [tensor(5), tensor(6)]
这种方法可以扩展到 3 个维度。虽然不像这篇文章中的其他答案那样视觉上令人愉悦,但这个答案表明只能使用torch.max()
函数来解决问题(尽管我同意在多个维度上对torch.max()
的内置支持将是一个福音)。
跟进 我偶然发现了一个similar question in the PyTorch forums,并且发布者 ptrblck 提供了这行代码作为获取张量 x 中最大条目索引的解决方案:
x = (x==torch.max(x)).nonzero()
这种单线不仅可以处理 N 维张量而不需要对代码进行调整,而且它也比我上面写的方法(至少 2:1 比率)快得多,并且比公认的答案更快(大约 3:2 的比例)根据我的基准。
【讨论】:
以上是关于PyTorch torch.max 在多个维度上的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch中的torch.max()和torch.maximum()的用法详解
pytorch深度学习实践_p9_多分类问题_pytorch手写实现数字辨识
torch.max()函数predic = torch.max(outputs.data, 1)[1].cpu().numpy()