Pytorch 常用函数
Posted king-lps
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch 常用函数相关的知识,希望对你有一定的参考价值。
1. torch.
renorm
(input, p, dim, maxnorm, out=None) → Tensor
Returns a tensor where each sub-tensor of input
along dimension dim
is normalized such that the p-norm of the sub-tensor is lower than the value maxnorm。
解释:返回一个张量,包含规范化后的各个子张量,使得沿着dim
维划分的各子张量的p范数小于maxnorm
。
>>> x = torch.Tensor([[1,2,3]])
>>> torch.renorm(x,2,0,1) tensor([[ 0.2673, 0.5345, 0.8018]])
2. torch. scatter_
(dim, index, src) → Tensor
将src
中的所有值按照index
确定的索引写入本tensor中。其中索引是根据给定的dimension,dim按照gather()
描述的规则来确定。
注意,index的值必须是在0到(self.size(dim)-1)之间,
参数:
- input (Tensor)-源tensor
- dim (int)-索引的轴向
- index (LongTensor)-散射元素的索引指数
- src (Tensor or float)-散射的源元素
1 >>> x = torch.rand(2, 5) 2 >>> x 3 0.4319 0.6500 0.4080 0.8760 0.2355 4 0.2609 0.4711 0.8486 0.8573 0.1029 5 [torch.FloatTensor of size 2x5]
6 >>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) #将 x 按照格式写入新的Tensor里 7 0.4319 0.4711 0.8486 0.8760 0.2355 8 0.0000 0.6500 0.0000 0.8573 0.0000 9 0.2609 0.0000 0.4080 0.0000 0.1029 10 [torch.FloatTensor of size 3x5]
11 >>> z = torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23) 12 >>> z 13 0.0000 0.0000 1.2300 0.0000 14 0.0000 0.0000 0.0000 1.2300 15 [torch.FloatTensor of size 2x4]
3. torch.gather(input, dim, index, out=None) → Tensor
沿给定轴dim
,将输入索引张量index
指定位置的值进行聚合。
参数:
- input (Tensor) – 源张量
- dim (int) – 索引的轴
- index (LongTensor) – 聚合元素的下标
- out (Tensor, optional) – 目标张量
>>> t = torch.Tensor([[1,2],[3,4]]) >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])) 1 1 4 3 [torch.FloatTensor of size 2x2]
or:
>>> s=torch.randn(3,6) >>> s tensor([[-0.4857, -0.0982, -0.6532, -1.0273, -0.9205, -0.7440], [-0.6890, -0.3474, -1.4337, -0.3511, -0.2443, -0.6398], [ 1.2902, 1.1210, 1.7374, 0.0902, -0.4524, -0.6898]]) >>> s.gather(1,torch.LongTensor([[1,2,1],[1,2,3],[1,2,3]])) tensor([[-0.0982, -0.6532, -0.0982], [-0.3474, -1.4337, -0.3511], [ 1.1210, 1.7374, 0.0902]])
以上是关于Pytorch 常用函数的主要内容,如果未能解决你的问题,请参考以下文章
Pytorch常用损失函数nn.BCEloss();nn.BCEWithLogitsLoss();nn.CrossEntropyLoss();nn.L1Loss(); nn.MSELoss();(代码