pytorch 常用函数参数详解

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch 常用函数参数详解相关的知识,希望对你有一定的参考价值。

参考技术A 1、torch.cat(inputs, dim=0) -> Tensor 

参考链接:

[Pytorch] 详解 torch.cat()

Pytorch学习笔记(一):torch.cat()模块的详解

函数作用:cat 是 concatnate 的意思:拼接,联系在一起。在给定维度上对输入的 Tensor 序列进行拼接操作。torch.cat 可以看作是 torch.split 和 torch.chunk 的反操作

参数:

inputs(sequence of Tensors):可以是任意相同类型的 Tensor 的 python 序列

dim(int, optional):defaults=0

dim=0: 按列进行拼接 

dim=1: 按行进行拼接

dim=-1: 如果行和列数都相同则按行进行拼接,否则按照行数或列数相等的维度进行拼接

假设 a 和 b 都是 Tensor,且 a 的维度为 [2, 3],b 的维度为 [2, 4],则

torch.cat((a, b), dim=1) 的维度为 [2, 7]

2、torch.nn.CrossEntropyLoss()

函数作用:CrossEntropy 是交叉熵的意思,故而 CrossEntropyLoss 的作用是计算交叉熵。CrossEntropyLoss 函数是将 torch.nn.Softmax 和 torch.nn.NLLLoss 两个函数组合在一起使用,故而传入的预测值不需要先进行 torch.nnSoftmax 操作。

参数:

input(N, C):N 是 batch_size,C 则是类别数,即在定义模型输出时,输出节点个数要定义为 [N, C]。其中特别注意的是 target 的数据类型需要是浮点数,即 float32

target(N):N 是 batch_size,故 target 需要是 1D 张量。其中特别注意的是 target 的数据类型需要是 long,即 int64

例子:

loss = nn.CrossEntropyLoss()

input = torch.randn(3, 5, requires_grad=True, dtype=torch.float32)

target = torch.empty(3, dtype=torch.long).random_(5)

output = loss(input, target)

output

输出为:

tensor(1.6916, grad_fn=<NllLossBackward>)

以上是关于pytorch 常用函数参数详解的主要内容,如果未能解决你的问题,请参考以下文章

pytorch常用损失函数

pytorch nn.Linear()详解

PyTorch常用函数:torch.ge; torch.gt; torch.le; equal; eq

Pytorch常用的交叉熵损失函数CrossEntropyLoss()详解

DB2常用函数详解:字符串函数

Pytorch冻结部分层的参数