比较 Python、Numpy、Numba 和 C++ 的矩阵乘法
Posted
技术标签:
【中文标题】比较 Python、Numpy、Numba 和 C++ 的矩阵乘法【英文标题】:Comparing Python, Numpy, Numba and C++ for matrix multiplication 【发布时间】:2016-07-31 07:36:59 【问题描述】:在我正在处理的程序中,我需要重复将两个矩阵相乘。由于其中一个矩阵的大小,此操作需要一些时间,我想看看哪种方法最有效。矩阵的维度为(m x n)*(n x p)
,其中m = n = 3
和10^5 < p < 10^6
。
除了我认为使用优化算法的 Numpy 之外,每个测试都包含 matrix multiplication 的简单实现:
以下是我的各种实现:
Python
def dot_py(A,B):
m, n = A.shape
p = B.shape[1]
C = np.zeros((m,p))
for i in range(0,m):
for j in range(0,p):
for k in range(0,n):
C[i,j] += A[i,k]*B[k,j]
return C
Numpy
def dot_np(A,B):
C = np.dot(A,B)
return C
Numba
代码和Python一样,只是在使用前及时编译:
dot_nb = nb.jit(nb.float64[:,:](nb.float64[:,:], nb.float64[:,:]), nopython = True)(dot_py)
到目前为止,每个方法调用都使用timeit
模块进行了10 次计时。保持最佳结果。矩阵是使用np.random.rand(n,m)
创建的。
C++
mat2 dot(const mat2& m1, const mat2& m2)
int m = m1.rows_;
int n = m1.cols_;
int p = m2.cols_;
mat2 m3(m,p);
for (int row = 0; row < m; row++)
for (int col = 0; col < p; col++)
for (int k = 0; k < n; k++)
m3.data_[p*row + col] += m1.data_[n*row + k]*m2.data_[p*k + col];
return m3;
这里,mat2
是我定义的自定义类,dot(const mat2& m1, const mat2& m2)
是该类的友元函数。它使用来自Windows.h
的QPF
和QPC
进行计时,并使用MinGW 和g++
命令编译程序。同样,保留了从 10 次执行中获得的最佳时间。
结果
正如预期的那样,简单的 Python 代码速度较慢,但对于非常小的矩阵,它仍然胜过 Numpy。在最大规模的情况下,Numba 比 Numpy 快约 30%。
我对 C++ 结果感到惊讶,其中乘法所需的时间几乎比 Numba 多一个数量级。事实上,我预计这些将花费类似的时间。
这引出了我的主要问题:这正常吗?如果不正常,为什么 C++ 比 Numba 慢?我刚开始学习 C++,所以我可能做错了什么。如果是这样,我的错误是什么,或者我可以做些什么来提高我的代码效率(除了选择更好的算法)?
编辑 1
这是mat2
类的标题。
#ifndef MAT2_H
#define MAT2_H
#include <iostream>
class mat2
private:
int rows_, cols_;
float* data_;
public:
mat2() // (default) constructor
mat2(int rows, int cols, float value = 0); // constructor
mat2(const mat2& other); // copy constructor
~mat2(); // destructor
// Operators
mat2& operator=(mat2 other); // assignment operator
float operator()(int row, int col) const;
float& operator() (int row, int col);
mat2 operator*(const mat2& other);
// Operations
friend mat2 dot(const mat2& m1, const mat2& m2);
// Other
friend void swap(mat2& first, mat2& second);
friend std::ostream& operator<<(std::ostream& os, const mat2& M);
;
#endif
编辑 2
正如许多人所建议的,使用优化标志是匹配 Numba 的缺失元素。下面是与以前的曲线相比的新曲线。标记为v2
的曲线是通过切换两个内循环获得的,显示出另外30% 到50% 的改进。
【问题讨论】:
这令人惊讶...我无法想象您会看到极大的加速,但您是否尝试过使用编译器优化标志,例如-O3
?基本用法是g++ *.cpp -std=c++11 -O3
另外你是在以任何方式from python调用这个c++函数还是直接调用一个编译好的程序?
@Eric:这是一个希望,但没有理由以这种方式编写代码。有点像期待你的妻子在你之后收拾东西:-)
查找缓存未命中,这可能是您的 C++ 失败的地方之一。
@TylerS 我用-O3
更新了我的问题(见第二次编辑)。这是你要找的吗?
【参考方案1】:
绝对使用-O3
进行优化。这会打开vectorizations,这应该会显着加快您的代码速度。
Numba 应该已经这样做了。
【讨论】:
【参考方案2】:我会推荐什么
如果你想要最大的效率,你应该使用一个专用的线性代数库,经典是BLAS/LAPACK库。有许多实现,例如。 Intel MKL。您编写的内容不会胜过超优化库。
矩阵矩阵乘法将成为dgemm
例程:d 代表双精度,ge 代表通用,mm 代表矩阵矩阵乘法。如果您的问题有额外的结构,则可能会调用更具体的函数来获得额外的加速。
请注意,Numpy dot 已经调用dgemm
!你可能不会做得更好。
为什么你的 c++ 很慢
与可能的算法相比,您的经典、直观的矩阵-矩阵乘法算法速度较慢。编写利用处理器缓存等的代码会产生重要的性能提升。关键是,无数聪明人毕生致力于让矩阵矩阵乘以极快,你应该利用他们的工作,而不是重新发明***。
【讨论】:
感谢您的回答!我知道 Numpy 使用的是dgemm
(实际上我已经查看了 Fortran 代码)。由于这个原因,我预计它会表现得更好。为了简单起见,我使用了 O(n^3) 算法,因为我已经获得了比 Numpy 更好的结果。最终,我的代码将包含更多带有嵌套循环的自定义函数,这些函数在优化的库中不可用,我现在对如何实现它们有了更好的了解。
我认为优化的dgemm
例程优于幼稚的实现,这主要是由于缓存和其他技术利用了处理器的实际工作方式而不是 O(n^3) 位。不过,我真的不是细节方面的专家。【参考方案3】:
在您当前的实现中,很可能编译器无法自动矢量化最内层循环,因为它的大小为 3。此外,m2
也以“跳跃”方式访问。交换循环以便在最内部的循环中迭代 p
将使其工作得更快(col
不会进行“跳跃”数据访问)并且编译器应该能够做得更好(自动向量化)。
for (int row = 0; row < m; row++)
for (int k = 0; k < n; k++)
for (int col = 0; col < p; col++)
m3.data_[p*row + col] += m1.data_[n*row + k] * m2.data_[p*k + col];
在我的机器上,使用 g++ dot.cpp -std=c++11 -O3 -o dot
标志构建的 p=10^6 元素的原始 C++ 实现采用 12ms
,而使用交换循环的上述实现采用 7ms
。
【讨论】:
【参考方案4】:您仍然可以通过改进内存访问来优化这些循环,您的函数可能看起来像(假设矩阵为 1000x1000):
CS = 10
NCHUNKS = 100
def dot_chunked(A,B):
C = np.zeros(1000,1000)
for i in range(NCHUNKS):
for j in range(NCHUNKS):
for k in range(NCHUNKS):
for ii in range(i*CS,(i+1)*CS):
for jj in range(j*CS,(j+1)*CS):
for kk in range(k*CS,(k+1)*CS):
C[ii,jj] += A[ii,kk]*B[kk,jj]
return C
解释:循环 i 和 ii 显然一起执行的方式与 i 之前相同,j 和 k 保持相同,但这次 A 和 B 中大小为 CSxCS 的区域可以保存在缓存中(我猜)并且可以多次使用。
您可以玩 CS 和 NCHUNKS。对我来说 CS=10 和 NCHUNKS=100 效果很好。使用 numba.jit 时,它将代码从 7 秒加速到 850 毫秒(注意我使用 1000x1000,上面的图形以 3x3x10^5 运行,所以它有点另一种场景)。
【讨论】:
以上是关于比较 Python、Numpy、Numba 和 C++ 的矩阵乘法的主要内容,如果未能解决你的问题,请参考以下文章
与 Python+Numba LLVM/JIT 编译的代码相比,Julia 的性能
Python numpy:无法将 datetime64[ns] 转换为 datetime64[D](与 Numba 一起使用)
为啥在迭代 NumPy 数组时 Cython 比 Numba 慢得多?