如何用其他 pytorch 函数替换 torch.norm?
Posted
技术标签:
【中文标题】如何用其他 pytorch 函数替换 torch.norm?【英文标题】:How to replace torch.norm with other pytorch function? 【发布时间】:2021-07-31 01:20:45 【问题描述】:我想用另一个 Pytorch 函数替换 torch.norm
函数。
我可以在x
不是矩阵的情况下替换torch.norm
,如下代码所示。
import torch
x = torch.randn(9)
out1 = torch.norm(x)
out2 = sum(abs(x)**2)**(1./2)
out1 == out2
>> tensor(True)
但是当 x 是矩阵时,我不知道如何替换它。
特别是,我想在 dim=1 and keepdim=True
的情况下替换它。
x = torch.randn([3, 136, 64, 64])
out1 = torch.norm(x, dim=1, keepdim=True)
out2 = ???
out1 == out2
背景:
我正在将 Pytorch 模型转换为 CoreML,但在 torch.norm
函数中定义的 _VF.frobenius_norm
运算符未使用 CoreMLTools 实现。
(torch.norm
里面的实现可以在here找到。)
有些人遇到了这个问题,但 CoreMLTools 仍然不受支持(您可以查看此issue)。
所以我想在没有torch.norm
中使用的运算符的情况下替换它。
我尝试过torch.linalg.norm()
和numpy.linalg.norm
,但它们不受支持。
我创建了一个简单的协作笔记本来重现这一点。 请使用以下 colab 对其进行测试。 https://colab.research.google.com/drive/11o6rTxHzEgZ_Rc7nFZHd3TvPugybB88h?usp=sharing
【问题讨论】:
【参考方案1】:您可以尝试以下方法:
import torch
x = torch.randn([3, 136, 64, 64])
out1 = torch.norm(x, dim=1, keepdim=True)
out2 = torch.square(x).sum(dim=1, keepdim=True).sqrt()
请注意,out1 == out2
不会完全给出所有 True
,因为精度有小误差。您可以检查错误在1e-7
的顺序为float32
。
在这里,范数是直接使用其数学定义计算的。您可以查看 Wolfram MathWorld 的this reference 了解更多详情。
【讨论】:
感谢您的回答!但它只适用于 pytorch。不幸的是,torch.diag
没有在 coremltools 中实现。我更新了 colab,并在 Answer from GoodDeeds
部分添加了您的代码。 colab.research.google.com/drive/… 没有torch.diag
是否可以实现这个?
@tommy19970714 不幸的是,我现在无法运行 colab,但我已经用一个更好的版本更新了我的答案,这样可以避免大量不必要的计算和使用 torch.diag
。请检查它是否有效。
@tommy19970714 再看一遍,我进一步简化了它。 torch.square(x)
有点新,如果不支持可以换成torch.mul(x, x)
。
我检查了你的最新代码。它工作得很好! e-06有错误,但没问题。非常感谢!
@tommy19970714 很高兴知道。几秒钟前我刚刚进行了最后一次编辑,以使代码更惯用。以上是关于如何用其他 pytorch 函数替换 torch.norm?的主要内容,如果未能解决你的问题,请参考以下文章