numpy中dot, multiply, *区别
Posted bitcarmanlee
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了numpy中dot, multiply, *区别相关的知识,希望对你有一定的参考价值。
1.dot
首先看下dot源码中的注释部分
def dot(a, b, out=None):
"""
dot(a, b, out=None)
Dot product of two arrays. Specifically,
- If both `a` and `b` are 1-D arrays, it is inner product of vectors
(without complex conjugation).
- If both `a` and `b` are 2-D arrays, it is matrix multiplication,
but using :func:`matmul` or ``a @ b`` is preferred.
- If either `a` or `b` is 0-D (scalar), it is equivalent to :func:`multiply`
and using ``numpy.multiply(a, b)`` or ``a * b`` is preferred.
- If `a` is an N-D array and `b` is a 1-D array, it is a sum product over
the last axis of `a` and `b`.
- If `a` is an N-D array and `b` is an M-D array (where ``M>=2``), it is a
sum product over the last axis of `a` and the second-to-last axis of `b`::
dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
.....
关注一下最常用的两种情况:
If both
aand
bare 1-D arrays, it is inner product of vectors
这就是两个向量dot,最后得到的两个向量的内积。
If both
aand
bare 2-D arrays, it is matrix multiplication, but using :func:
matmulor ``a @ b`` is preferred.
2-D arrays指的就是矩阵了。根据上面的解释不难看出,如果是两个矩阵dot,执行的就是矩阵相乘运算。
写段代码测试下
def demo2():
a1 = np.arange(1, 5)
a2 = a1[::-1]
print(a1)
print(a2)
# 两个向量dot为内积
print(a1.dot(a2))
print(np.dot(a1, a2))
print("\\n\\n")
b1 = np.arange(1, 5).reshape(2, 2)
b2 = np.arange(5, 9).reshape(2, 2)
b3 = np.arange(9, 15).reshape(3, 2)
print(b1)
print(b2)
print(b3)
print(np.dot(b1, b2))
# 会报错, 不满足矩阵相乘条件
# print(np.dot(b1, b3))
代码执行的结果
[1 2 3 4]
[4 3 2 1]
20
20
[[1 2]
[3 4]]
[[5 6]
[7 8]]
[[ 9 10]
[11 12]
[13 14]]
[[19 22]
[43 50]]
2.multiply
同样的看一下multiply对应源码的注释部分。
def multiply(x1, x2, *args, **kwargs): # real signature unknown; NOTE: unreliably restored from __doc__
"""
multiply(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj])
Multiply arguments element-wise.
Parameters
----------
x1, x2 : array_like
Input arrays to be multiplied. If ``x1.shape != x2.shape``, they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, None, or tuple of ndarray and None, optional
.....
明白multiply方法的关键就是上面的一句注释:
Multiply arguments element-wise.
说人话就是:按对应的元素相乘。
def demo3():
a1 = np.arange(1, 5)
a2 = a1[::-1]
print(a1)
print(a2)
print(np.multiply(a1, a2))
print("\\n\\n")
b1 = np.arange(1, 5).reshape(2, 2)
b2 = np.arange(5, 9).reshape(2, 2)
print(b1)
print(b2)
print(np.multiply(b1, b2))
运行得到结果
[1 2 3 4]
[4 3 2 1]
[4 6 6 4]
[[1 2]
[3 4]]
[[5 6]
[7 8]]
[[ 5 12]
[21 32]]
参考对应的代码,应该就很容易理解了。
3. *运算符
乘法运算符,最后得到的结果,跟multiply方法得到的结果是一样的。
def demo4():
a1 = np.arange(1, 5)
a2 = a1[::-1]
print(a1)
print(a2)
print(a1 * a2)
print("\\n\\n")
b1 = np.arange(1, 5).reshape(2, 2)
b2 = np.arange(5, 9).reshape(2, 2)
print(b1)
print(b2)
print(b1 * b2)
最终结果
[1 2 3 4]
[4 3 2 1]
[4 6 6 4]
[[1 2]
[3 4]]
[[5 6]
[7 8]]
[[ 5 12]
[21 32]]
以上是关于numpy中dot, multiply, *区别的主要内容,如果未能解决你的问题,请参考以下文章
python中np.multiply()np.dot()和星号(*)三种乘法运算的区别