如何从 libtorch 输出中删除乘数并显示最终结果?

Posted

技术标签:

【中文标题】如何从 libtorch 输出中删除乘数并显示最终结果?【英文标题】:How to remove the multiplier from the libtorch output and display the final result? 【发布时间】:2020-12-13 00:08:53 【问题描述】:

当我尝试在屏幕上显示/打印一些张量时,我遇到了类似下面的情况,而不是得到最终结果,似乎 libtorch 显示带有乘数的张量(即0.01* 和类似的东西)见下文):

offsets.shape: [1, 4, 46, 85]
probs.shape: [46, 85]
offsets: (1,1,.,.) =
 0.01 *
  0.1006  1.2322
  -2.9587 -2.2280

(1,2,.,.) =
 0.01 *
  1.3772  1.3971
  -1.2813 -0.8563

(1,3,.,.) =
 0.01 *
  6.2367  9.2561
   3.5719  5.4744

(1,4,.,.) =
  0.2901  0.2963
  0.2618  0.2771
[ CPUFloatType1,4,2,2 ]
probs: 0.0001 *
 1.4593  1.0351
  6.6782  4.9104
[ CPUFloatType2,2 ]

如何禁用此行为并获得最终输出?我试图将其显式转换为浮点数,希望这会导致最终输出被存储/显示,但这也不起作用。

【问题讨论】:

【参考方案1】:

根据libtorch输出张量的源码,在repository中搜索“*”字符串后,发现这个“pretty-print”是在aten/src/ATen/core/Formatting.cpp翻译单元中完成的.此处添加了比例和星号:

static void printScale(std::ostream & stream, double scale) 
  FormatGuard guard(stream);
  stream << defaultfloat << scale << " *" << std::endl;

之后,张量的所有坐标都除以scale

if(scale != 1) 
  printScale(stream, scale);

double* tensor_p = tensor.data_ptr<double>();
for(int64_t i = 0; i < tensor.size(0); i++) 
  stream << std::setw(sz) << tensor_p[i]/scale << std::endl;

基于这个翻译单元,这是完全不可配置的。

我猜你有两个选择:

    调整函数并以最低限度编辑现有函数以满足您的要求。 删除(或添加 #ifdef)Formatting.cpp 中张量的 &lt;&lt; 运算符重载并提供您自己的实现。但是,在构建 libtorch 时,您必须将其链接到包含该方法实现的目标。

但是,这两个选项都需要您更改第 3 方代码,我认为这很糟糕。

【讨论】:

非常感谢。这确实很糟糕! :( 如果只是在您的用户代码中打印张量,您可以随时提供自己的实用程序来打印它...如果您希望 libtorch 以这种形式打印它们,例如当记录……嗯,这可能是不可能的。我只能建议在例如编写一些后处理器。 python 可以根据需要格式化日志。

以上是关于如何从 libtorch 输出中删除乘数并显示最终结果?的主要内容,如果未能解决你的问题,请参考以下文章

Build Libtorch from Source Code for x86

Build Libtorch from Source Code for x86

Build Libtorch from Source Code for x86

如何从PyQt5中显示的excel文件中删除数据并刷新它

如何在 Maskcrnn libtorch 中获取元组对象返回的值

如何从 Maxima tex 的输出中删除 TeX 显示的方程标记?