如何从 libtorch 输出中删除乘数并显示最终结果?
Posted
技术标签:
【中文标题】如何从 libtorch 输出中删除乘数并显示最终结果?【英文标题】:How to remove the multiplier from the libtorch output and display the final result? 【发布时间】:2020-12-13 00:08:53 【问题描述】:当我尝试在屏幕上显示/打印一些张量时,我遇到了类似下面的情况,而不是得到最终结果,似乎 libtorch 显示带有乘数的张量(即0.01*
和类似的东西)见下文):
offsets.shape: [1, 4, 46, 85]
probs.shape: [46, 85]
offsets: (1,1,.,.) =
0.01 *
0.1006 1.2322
-2.9587 -2.2280
(1,2,.,.) =
0.01 *
1.3772 1.3971
-1.2813 -0.8563
(1,3,.,.) =
0.01 *
6.2367 9.2561
3.5719 5.4744
(1,4,.,.) =
0.2901 0.2963
0.2618 0.2771
[ CPUFloatType1,4,2,2 ]
probs: 0.0001 *
1.4593 1.0351
6.6782 4.9104
[ CPUFloatType2,2 ]
如何禁用此行为并获得最终输出?我试图将其显式转换为浮点数,希望这会导致最终输出被存储/显示,但这也不起作用。
【问题讨论】:
【参考方案1】:根据libtorch输出张量的源码,在repository中搜索“*”字符串后,发现这个“pretty-print”是在aten/src/ATen/core/Formatting.cpp翻译单元中完成的.此处添加了比例和星号:
static void printScale(std::ostream & stream, double scale)
FormatGuard guard(stream);
stream << defaultfloat << scale << " *" << std::endl;
之后,张量的所有坐标都除以scale
:
if(scale != 1)
printScale(stream, scale);
double* tensor_p = tensor.data_ptr<double>();
for(int64_t i = 0; i < tensor.size(0); i++)
stream << std::setw(sz) << tensor_p[i]/scale << std::endl;
基于这个翻译单元,这是完全不可配置的。
我猜你有两个选择:
-
调整函数并以最低限度编辑现有函数以满足您的要求。
删除(或添加
#ifdef
)Formatting.cpp 中张量的 <<
运算符重载并提供您自己的实现。但是,在构建 libtorch 时,您必须将其链接到包含该方法实现的目标。
但是,这两个选项都需要您更改第 3 方代码,我认为这很糟糕。
【讨论】:
非常感谢。这确实很糟糕! :( 如果只是在您的用户代码中打印张量,您可以随时提供自己的实用程序来打印它...如果您希望 libtorch 以这种形式打印它们,例如当记录……嗯,这可能是不可能的。我只能建议在例如编写一些后处理器。 python 可以根据需要格式化日志。以上是关于如何从 libtorch 输出中删除乘数并显示最终结果?的主要内容,如果未能解决你的问题,请参考以下文章
Build Libtorch from Source Code for x86
Build Libtorch from Source Code for x86
Build Libtorch from Source Code for x86