pytorch 常用函数参数详解
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch 常用函数参数详解相关的知识,希望对你有一定的参考价值。
参考技术A 1、torch.cat(inputs, dim=0) -> Tensor参考链接:
[Pytorch] 详解 torch.cat()
Pytorch学习笔记(一):torch.cat()模块的详解
函数作用:cat 是 concatnate 的意思:拼接,联系在一起。在给定维度上对输入的 Tensor 序列进行拼接操作。torch.cat 可以看作是 torch.split 和 torch.chunk 的反操作
参数:
inputs(sequence of Tensors):可以是任意相同类型的 Tensor 的 python 序列
dim(int, optional):defaults=0
dim=0: 按列进行拼接
dim=1: 按行进行拼接
dim=-1: 如果行和列数都相同则按行进行拼接,否则按照行数或列数相等的维度进行拼接
假设 a 和 b 都是 Tensor,且 a 的维度为 [2, 3],b 的维度为 [2, 4],则
torch.cat((a, b), dim=1) 的维度为 [2, 7]
2、torch.nn.CrossEntropyLoss()
函数作用:CrossEntropy 是交叉熵的意思,故而 CrossEntropyLoss 的作用是计算交叉熵。CrossEntropyLoss 函数是将 torch.nn.Softmax 和 torch.nn.NLLLoss 两个函数组合在一起使用,故而传入的预测值不需要先进行 torch.nnSoftmax 操作。
参数:
input(N, C):N 是 batch_size,C 则是类别数,即在定义模型输出时,输出节点个数要定义为 [N, C]。其中特别注意的是 target 的数据类型需要是浮点数,即 float32
target(N):N 是 batch_size,故 target 需要是 1D 张量。其中特别注意的是 target 的数据类型需要是 long,即 int64
例子:
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True, dtype=torch.float32)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output
输出为:
tensor(1.6916, grad_fn=<NllLossBackward>)
以上是关于pytorch 常用函数参数详解的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch常用函数:torch.ge; torch.gt; torch.le; equal; eq