PyTorch torch.max 在多个维度上

Posted

技术标签:

【中文标题】PyTorch torch.max 在多个维度上【英文标题】:PyTorch torch.max over multiple dimensions 【发布时间】:2020-07-24 04:22:47 【问题描述】:

拥有像 :x.shape = [3, 2, 2] 这样的张量。

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

我需要将.max() 带到第二和第三维度。我希望像这样的[-0.2632, -0.1453, -0.0274] 作为输出。我尝试使用:x.max(dim=(1,2)),但这会导致错误。

【问题讨论】:

我更新了我的答案,因为我提到的 PR 现在已合并,并且此功能在每晚版本中可用。请参阅下面的更新答案。 【参考方案1】:

现在,您可以这样做了。 PR was merged(8 月 28 日)现在可以在夜间版本中使用。

只需使用torch.amax():

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

print(torch.amax(x, dim=(1, 2)))

# Output:
# >>> tensor([-0.2632, -0.1453, -0.0274])

原答案

截至今天(2020 年 4 月 11 日),在 PyTorch 中无法在多个维度上执行 .min().max()。有一个关于它的open issue,你可以关注它,看看它是否得到实施。您的解决方法是:

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

print(x.view(x.size(0), -1).max(dim=-1))

# output:
# >>> values=tensor([-0.2632, -0.1453, -0.0274]),
# >>> indices=tensor([3, 3, 3]))

所以,如果您只需要这些值:x.view(x.size(0), -1).max(dim=-1).values

如果x 不是连续张量,则.view() 将失败。在这种情况下,您应该改用.reshape()


2020 年 8 月 26 日更新

此功能正在PR#43092 中实现,函数将被称为aminamax。他们将只返回值。这可能很快就会被合并,所以当您阅读本文时,您可能能够在夜间构建中访问这些功能:) 玩得开心。

【讨论】:

谢谢。它有效,但在我的情况下需要使用 reshape insted 视图以避免错误 @iGero 好的,我会在答案中添加这个注释以防万一:) 很高兴它有帮助 我用 pytorch 1.5.0 和 1.6.0 版本试过这个,但是没有方法torch.amax。你能验证吗?还是我做错了什么? @zwep 正如我在回答中所说,此功能目前在 nightly release 中可用。因此,如果您想访问 amax,则必须升级到它,或者等到下一个稳定版本,即 1.7.0。 @Berriel 啊抱歉,我不知道哪个版本与夜间发布有关。虽然我不知道在这种情况下你是否可以谈论一个版本【参考方案2】:

虽然solution of Berriel 解决了这个特定问题,但我认为添加一些解释可能会帮助大家了解这里使用的技巧,以便它可以适应(m)任何其他维度。

让我们从检查输入张量x的形状开始:

In [58]: x.shape   
Out[58]: torch.Size([3, 2, 2])

所以,我们有一个形状为 (3, 2, 2) 的 3D 张量。现在,根据 OP 的问题,我们需要计算张量中沿 1st 和 2nd 维度的值的maximum。在撰写本文时,torch.max()dim 参数仅支持 int。所以,我们不能使用元组。因此,我们将使用以下技巧,我将其称为,

Flatten & Max Trick:因为我们想要在 1st 和 2nd 维度上计算 max,我们将进行展平这两个维度都归为一个维度,而第 0th 维度保持不变。这正是正在发生的事情:

In [61]: x.flatten().reshape(x.shape[0], -1).shape   
Out[61]: torch.Size([3, 4])   # 2*2 = 4

所以,现在我们将 3D 张量缩小为 2D 张量(即矩阵)。

In [62]: x.flatten().reshape(x.shape[0], -1) 
Out[62]:
tensor([[-0.3000, -0.2926, -0.2705, -0.2632],
        [-0.1821, -0.1747, -0.1526, -0.1453],
        [-0.0642, -0.0568, -0.0347, -0.0274]])

现在,我们可以简单地将max 应用于第一个st 维度(即在这种情况下,第一个维度也是最后一个维度),因为展平的维度位于该维度中。

In [65]: x.flatten().reshape(x.shape[0], -1).max(dim=1)    # or: `dim = -1`
Out[65]: 
torch.return_types.max(
values=tensor([-0.2632, -0.1453, -0.0274]),
indices=tensor([3, 3, 3]))

我们在结果张量中得到 3 个值,因为矩阵中有 3 行。


现在,另一方面,如果您想在 0th 和 1st 维度上计算 max,您可以:

In [80]: x.flatten().reshape(-1, x.shape[-1]).shape 
Out[80]: torch.Size([6, 2])    # 3*2 = 6

In [79]: x.flatten().reshape(-1, x.shape[-1]) 
Out[79]: 
tensor([[-0.3000, -0.2926],
        [-0.2705, -0.2632],
        [-0.1821, -0.1747],
        [-0.1526, -0.1453],
        [-0.0642, -0.0568],
        [-0.0347, -0.0274]])

现在,我们可以简单地将max 应用于第 0th 维度,因为这是我们展平的结果。 ((同样,从我们原来的 (3, 2, 2) 形状来看,在前 2 个维度取 max 之后,我们应该得到两个值作为结果。)

In [82]: x.flatten().reshape(-1, x.shape[-1]).max(dim=0) 
Out[82]: 
torch.return_types.max(
values=tensor([-0.0347, -0.0274]),
indices=tensor([5, 5]))

类似地,您可以将此方法应用于多维和其他缩减函数,例如min


注意:我遵循基于 0 的维度 (0, 1, 2, 3, ...) 的术语只是为了与 PyTorch 的使用和代码保持一致。

【讨论】:

哦,有点明白了。你能具体说明什么是“扁平化的结果”吗?我将不胜感激,谢谢! Flattening 总是返回一个 1D 大小的张量,该张量由原始形状中各个维度的乘积产生(即,此处为 3*2*2 与张量 x【参考方案3】:

如果您只想使用torch.max() 函数来获取二维张量中最大条目的索引,您可以这样做:

max_i_vals, max_i_indices = torch.max(x, 0)
print('max_i_vals, max_i_indices: ', max_i_vals, max_i_indices)
max_j_index = torch.max(max_i_vals, 0)[1]
print('max_j_index: ', max_j_index)
max_index = [max_i_indices[max_j_index], max_j_index]
print('max_index: ', max_index)

在测试中,上面为我打印出来:

max_i_vals: tensor([0.7930, 0.7144, 0.6985, 0.7349, 0.9162, 0.5584, 1.4777, 0.8047, 0.9008, 1.0169, 0.6705, 0.9034, 1.1159, 0.8852, 1.0353], grad_fn=\<MaxBackward0>)   
max_i_indices: tensor([ 5,  8, 10,  6, 13, 14,  5,  6,  6,  6, 13,  4, 13, 13, 11])  
max_j_index:  tensor(6)  
max_index:  [tensor(5), tensor(6)]

这种方法可以扩展到 3 个维度。虽然不像这篇文章中的其他答案那样视觉上令人愉悦,但这个答案表明只能使用torch.max() 函数来解决问题(尽管我同意在多个维度上对torch.max() 的内置支持将是一个福音)。

跟进 我偶然发现了一个similar question in the PyTorch forums,并且发布者 ptrblck 提供了这行代码作为获取张量 x 中最大条目索引的解决方案:

x = (x==torch.max(x)).nonzero()

这种单线不仅可以处理 N 维张量而不需要对代码进行调整,而且它也比我上面写的方法(至少 2:1 比率)快得多,并且比公认的答案更快(大约 3:2 的比例)根据我的基准。

【讨论】:

以上是关于PyTorch torch.max 在多个维度上的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch中的torch.max()和torch.maximum()的用法详解

pytorch深度学习实践_p9_多分类问题_pytorch手写实现数字辨识

[基于Pytorch的MNIST识别04]模型调试

torch.argmax与torch.max详解

torch.max()函数predic = torch.max(outputs.data, 1)[1].cpu().numpy()

Pytorch实现TripletAttention