torch.mul() 、 torch.mm() 及torch.matmul()的区别

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了torch.mul() 、 torch.mm() 及torch.matmul()的区别相关的知识,希望对你有一定的参考价值。

参考技术A 1、torch.mul(a, b)是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵;
2、torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵。
PS:更接地气来说区别就是点乘,和矩阵乘法的区别

torch.bmm()

torch.matmul()

torch.bmm()强制规定维度和大小相同

torch.matmul()没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作

当进行操作的两个tensor都是3D时,两者等同。

torch.bmm()

官网: https://pytorch.org/docs/stable/torch.html#torch.bmm

torch.bmm(input, mat2, out=None) → Tensor

torch.bmm()是tensor中的一个相乘操作,类似于矩阵中的A*B。

参数:

input,mat2:两个要进行相乘的tensor结构,两者必须是3D维度的,每个维度中的大小是相同的。

output:输出结果

并且相乘的两个矩阵,要满足一定的维度要求:input(p,m,n) * mat2(p,n,a) ->output(p,m,a)。这个要求,可以类比于矩阵相乘。前一个矩阵的列等于后面矩阵的行才可以相乘。

例子:

torch.matmul()也是一种类似于矩阵相乘操作的tensor联乘操作。但是它可以利用python 中的广播机制,处理一些维度不同的tensor结构进行相乘操作。这也是该函数与torch.bmm()区别所在。

参数:

input,other:两个要进行操作的tensor结构

output:结果

一些规则约定:

(1)若两个都是1D(向量)的,则返回两个向量的点积

(2)若两个都是2D(矩阵)的,则按照(矩阵相乘)规则返回2D

(3)若input维度1D,other维度2D,则先将1D的维度扩充到2D(1D的维数前面+1),然后得到结果后再将此维度去掉,得到的与input的维度相同。即使作扩充(广播)处理,input的维度也要和other维度做对应关系。

(4)若input是2D,other是1D,则返回两者的点积结果。(个人觉得这块也可以理解成给other添加了维度,然后再去掉此维度,只不过维度是(3, )而不是规则(3)中的( ,4)了,但是可能就是因为内部机制不同,所以官方说的是点积而不是维度的升高和下降)

(5)如果一个维度至少是1D,另外一个大于2D,则返回的是一个批矩阵乘法( a batched matrix multiply)。

(a)若input是1D,other是大于2D的,则类似于规则(3)。

(b)若other是1D,input是大于2D的,则类似于规则(4)。

(c)若input和other都是3D的,则与torch.bmm()函数功能一样。

(d)如果input中某一维度满足可以广播(扩充),那么也是可以进行相乘操作的。例如 input(j,1,n,m)* other (k,m,p) = output(j,k,n,p)。

这个例子中,可以理解为x中dim=1这个维度可以扩充(广播),y中可以添加一个维度,然后在进行批乘操作。

pytorch基本运算:加减乘除对数幂次等

1、加减乘除

  • a + b = torch.add(a, b)
  • a - b = torch.sub(a, b)
  • a * b = torch.mul(a, b)
  • a / b = torch.div(a, b)
import torch

a = torch.rand(3, 4)
b = torch.rand(4)
a
# 输出:
    tensor([[0.6232, 0.5066, 0.8479, 0.6049],
            [0.3548, 0.4675, 0.7123, 0.5700],
            [0.8737, 0.5115, 0.2106, 0.5849]])

b
# 输出:
    tensor([0.3309, 0.3712, 0.0982, 0.2331])
    
# 相加
# b会被广播
a + b
# 输出:
    tensor([[0.9541, 0.8778, 0.9461, 0.8380],
            [0.6857, 0.8387, 0.8105, 0.8030],
            [1.2046, 0.8827, 0.3088, 0.8179]])   
# 等价于上面相加
torch.add(a, b)
# 输出:
    tensor([[0.9541, 0.8778, 0.9461, 0.8380],
            [0.6857, 0.8387, 0.8105, 0.8030],
            [1.2046, 0.8827, 0.3088, 0.8179]])  

# 比较两个是否相等
torch.all(torch.eq(a + b, torch.add(a, b)))
# 输出:
    tensor(True)    

2、矩阵相乘

  • torch.mm(a, b) # 此方法只适用于2维

  • torch.matmul(a, b)

  • a @ b = torch.matmul(a, b) # 推荐使用此方法

  • 用处:

    1. 降维:比如,[4, 784] @ [784, 512] = [4, 512]
    2. 大于2d的数据相乘:最后2个维度的数据相乘:[4, 3, 28, 64] @ [4, 3, 64, 32] = [4, 3, 28, 32]

      前提是:除了最后两个维度满足相乘条件以外,其他维度要满足广播条件,比如此处的前面两个维度只能是[4, 3]和[4, 1]
a = torch.full((2, 2), 3)
a
# 输出
    tensor([[3., 3.],
            [3., 3.]])

b = torch.ones(2, 2)
b
# 输出
    tensor([[1., 1.],
            [1., 1.]])
    
torch.mm(a, b)
# 输出
    tensor([[6., 6.],
            [6., 6.]])

torch.matmul(a, b)
# 输出
    tensor([[6., 6.],
            [6., 6.]])
    
a @ b
# 输出
    tensor([[6., 6.],
            [6., 6.]])    

3、幂次计算

  • pow, sqrt, rsqrt
a = torch.full([2, 2], 3)
a
# 输出
    tensor([[3., 3.],
            [3., 3.]])
    
a.pow(2)
# 输出
    tensor([[9., 9.],
            [9., 9.]])    
    
aa = a ** 2
aa
# 输出
    tensor([[9., 9.],
            [9., 9.]]) 
    
# 平方根
aa.sqrt()
# 输出
    tensor([[3., 3.],
            [3., 3.]])
# 平方根    
aa ** (0.5)
# 输出
    tensor([[3., 3.],
            [3., 3.]])    
# 平方根    
aa.pow(0.5)
# 输出
    tensor([[3., 3.],
            [3., 3.]])    
    
# 平方根的倒数
aa.rsqrt()
# 输出
    tensor([[0.3333, 0.3333],
            [0.3333, 0.3333]])        
tensor([[3., 3.],
        [3., 3.]])

4、自然底数与对数

a = torch.ones(2, 2)
a
# 输出
    tensor([[1., 1.],
            [1., 1.]])
    
# 自认底数e
torch.exp(a)
# 输出
    tensor([[2.7183, 2.7183],
            [2.7183, 2.7183]])

# 对数
# 默认底数是e
# 可以更换为Log2、log10
torch.log(a)
# 输出
tensor([[0., 0.],
        [0., 0.]])    

5、近似值

  • a.floor() # 向下取整:floor,地板
  • a.ceil() # 向上取整:ceil,天花板
  • a.trunc() # 保留整数部分:truncate,截断
  • a.frac() # 保留小数部分:fraction,小数
  • a.round() # 四舍五入:round,大约

6、限幅

  • a.max() # 最大值
  • a.min() # 最小值
  • a.median() # 中位数
  • a.clamp(10) # 将最小值限定为10
    • a.clamp(0, 10) # 将数据限定在[0, 10],两边都是闭区间


以上是关于torch.mul() 、 torch.mm() 及torch.matmul()的区别的主要内容,如果未能解决你的问题,请参考以下文章

用OneFlow实现数据类型自动提升

pytorch基本运算:加减乘除对数幂次等

pytorch基本运算:加减乘除对数幂次等

pytorch矩阵乘法

pytorch 乘法运算汇总与解析

PyTorch中的matmul函数详解