如何用其他 pytorch 函数替换 torch.norm?

Posted

技术标签:

【中文标题】如何用其他 pytorch 函数替换 torch.norm?【英文标题】:How to replace torch.norm with other pytorch function? 【发布时间】:2021-07-31 01:20:45 【问题描述】:

我想用另一个 Pytorch 函数替换 torch.norm 函数。

我可以在x不是矩阵的情况下替换torch.norm,如下代码所示。

import torch
x = torch.randn(9)
out1 = torch.norm(x)
out2 = sum(abs(x)**2)**(1./2)
out1 == out2
>> tensor(True)

但是当 x 是矩阵时,我不知道如何替换它。 特别是,我想在 dim=1 and keepdim=True 的情况下替换它。

x = torch.randn([3, 136, 64, 64])
out1 = torch.norm(x, dim=1, keepdim=True)
out2 = ???
out1 == out2

背景:

我正在将 Pytorch 模型转换为 CoreML,但在 torch.norm 函数中定义的 _VF.frobenius_norm 运算符未使用 CoreMLTools 实现。 (torch.norm里面的实现可以在here找到。)

有些人遇到了这个问题,但 CoreMLTools 仍然不受支持(您可以查看此issue)。 所以我想在没有torch.norm中使用的运算符的情况下替换它。

我尝试过torch.linalg.norm()numpy.linalg.norm,但它们不受支持。

我创建了一个简单的协作笔记本来重现这一点。 请使用以下 colab 对其进行测试。 https://colab.research.google.com/drive/11o6rTxHzEgZ_Rc7nFZHd3TvPugybB88h?usp=sharing

【问题讨论】:

【参考方案1】:

您可以尝试以下方法:

import torch
x = torch.randn([3, 136, 64, 64])
out1 = torch.norm(x, dim=1, keepdim=True)
out2 = torch.square(x).sum(dim=1, keepdim=True).sqrt()

请注意,out1 == out2 不会完全给出所有 True,因为精度有小误差。您可以检查错误在1e-7 的顺序为float32

在这里,范数是直接使用其数学定义计算的。您可以查看 Wolfram MathWorld 的this reference 了解更多详情。

【讨论】:

感谢您的回答!但它只适用于 pytorch。不幸的是,torch.diag 没有在 coremltools 中实现。我更新了 colab,并在 Answer from GoodDeeds 部分添加了您的代码。 colab.research.google.com/drive/… 没有torch.diag 是否可以实现这个? @tommy19970714 不幸的是,我现在无法运行 colab,但我已经用一个更好的版本更新了我的答案,这样可以避免大量不必要的计算和使用 torch.diag。请检查它是否有效。 @tommy19970714 再看一遍,我进一步简化了它。 torch.square(x) 有点新,如果不支持可以换成torch.mul(x, x) 我检查了你的最新代码。它工作得很好! e-06有错误,但没问题。非常感谢! @tommy19970714 很高兴知道。几秒钟前我刚刚进行了最后一次编辑,以使代码更惯用。

以上是关于如何用其他 pytorch 函数替换 torch.norm?的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch中的数学函数

pytorch如何导入ctc库

pytorch----nn.Modulenn.functionalnn.Sequentialnn.optim

使用 PyTorch 张量将对角线屏蔽为特定值

pytorch 常用函数参数详解

Pytorch 中 torch.cat() 函数解析