打印张量的所有内容
Posted
技术标签:
【中文标题】打印张量的所有内容【英文标题】:Printing all the contents of a tensor 【发布时间】:2019-03-11 10:23:39 【问题描述】:我遇到了this PyTorch 教程(在 neural_networks_tutorial.py 中),他们在其中构建了一个简单的神经网络并运行推理。我想打印整个输入张量的内容以进行调试。当我尝试打印张量时得到的结果是这样的,而不是整个张量:
我看到一个类似的link 用于 numpy,但不确定 PyTorch 的工作原理。我可以将它转换为 numpy 并且可以查看它,但希望避免额外的开销。有没有办法打印整个张量?
【问题讨论】:
【参考方案1】:我来这里实际上是在寻找如何在控制台的一行中打印整行张量的答案,所以我想我会添加这个。
tensor([[1.1573e+04, 6.0693e+02, 1.2436e+03, 2.7277e+04, 1.6673e+08, 2.0462e+00, 9.8891e-01],
[2.0237e+04, 5.9074e+02, 1.7208e+03, 2.7449e+04, 2.1301e+08, 2.0678e+00, 1.0011e+00],
[2.7456e+04, 6.1106e+02, 1.4897e+03, 2.7332e+04, 1.7310e+08, 2.0448e+00, 9.6041e-01],
[1.7732e+04, 6.0232e+02, 1.2608e+03, 2.7371e+04, 1.8106e+08, 1.9594e+00, 1.0040e+00],
...,
[1.1167e+04, 5.9867e+02, 1.3440e+03, 2.7263e+04, 2.3160e+08, 2.0190e+00, 1.0075e+00],
[1.6003e+04, 5.9590e+02, 1.2319e+03, 2.7368e+04, 1.7155e+08, 2.0171e+00, 1.0202e+00],
[1.5499e+04, 6.1471e+02, 9.4877e+02, 2.7395e+04, 1.8146e+08, 1.9016e+00, 9.5884e-01],
[3.3886e+04, 6.0689e+02, 1.0777e+03, 2.7259e+04, 2.1599e+08, 2.0179e+00, 1.0201e+00]], dtype=torch.float64)
我是这样做的
torch.set_printoptions(linewidth=200)
【讨论】:
【参考方案2】:为避免截断并控制打印张量数据的数量,请使用与 numpy 的 numpy.set_printoptions(threshold=10_000)
相同的 API。
例子:
x = torch.rand(1000, 2, 2)
print(x) # prints the truncated tensor
torch.set_printoptions(threshold=10_000)
print(x) # prints the whole tensor
如果您的张量非常大,请将threshold
值调整为更高的数字。
另一种选择是:
torch.set_printoptions(profile="full")
print(x) # prints the whole tensor
torch.set_printoptions(profile="default") # reset
print(x) # prints the truncated tensor
所有可用的set_printoptions
参数都记录在here。
【讨论】:
【参考方案3】:虽然我不建议这样做,但如果你愿意,那么
In [18]: torch.set_printoptions(edgeitems=1)
In [19]: a
Out[19]:
tensor([[-0.7698, ..., -0.1949],
...,
[-0.7321, ..., 0.8537]])
In [20]: torch.set_printoptions(edgeitems=3)
In [21]: a
Out[21]:
tensor([[-0.7698, 1.3383, 0.5649, ..., 1.3567, 0.6896, -0.1949],
[-0.5761, -0.9789, -0.2058, ..., -0.5843, 2.6311, -0.0008],
[ 1.3152, 1.8851, -0.9761, ..., 0.8639, -0.6237, 0.5646],
...,
[ 0.2851, 0.5504, -0.9471, ..., 0.0688, -0.7777, 0.1661],
[ 2.9616, -0.8685, -1.5467, ..., -1.4646, 1.1098, -1.0873],
[-0.7321, 0.7610, 0.3182, ..., 2.5859, -0.9709, 0.8537]])
【讨论】:
以上是关于打印张量的所有内容的主要内容,如果未能解决你的问题,请参考以下文章
为啥结果打印 b'hello,Python!' ,当我使用张量流? [复制]