RcppArmadillo:对角矩阵乘法很慢
Posted
技术标签:
【中文标题】RcppArmadillo:对角矩阵乘法很慢【英文标题】:RcppArmadillo: diagonal matrix multiplication is very slow 【发布时间】:2019-11-02 23:15:00 【问题描述】:设x
为向量,M
为矩阵。
在 R 中,我可以做到
D <- diag(exp(x))
crossprod(M, D%M)
在 RcppArmadillo 中,我有以下慢得多的内容。
#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::export]]
arma::mat multiple_mnv(const arma::vec& x, const arma::mat& M)
arma::colvec diagonal(x.size())
for (int i = 0; i < x.size(); i++)
diagonal(i) = exp(x[i]);
arma::mat D = diagmat(diagonal);
return M.t()*D*M;
为什么这么慢?如何加快速度?
【问题讨论】:
你能提供一个minimal reproducible example吗?你是如何对其进行基准测试的? 尝试将return M.t()*D*M
更改为return M.t()*diagmat(diagonal)*M
。后一个表达式为犰狳提供了更多关于diagonal
是什么的信息,因此它可能会进行更多优化。请注意,diagonal
的生成是多余的,因为这个表达式也应该有效:return M.t()*diagmat(arma::exp(x))*M
【参考方案1】:
欢迎来到 Stack Overflow manju。对于未来的问题,请注意,minimal reproducible example 是预期的,实际上是您的最大利益;它帮助别人帮助你。下面是一个示例,说明如何提供示例数据供其他人使用:
## Set seed for reproducibility
set.seed(123)
## Generate data
x <- rnorm(10)
M <- matrix(rnorm(100), nrow = 10, ncol = 10)
## Output code for others to copy your objects
dput(x)
dput(M)
这是我将使用的数据,以表明您的 C++ 代码实际上并不比 R 慢。我使用了您的 C++ 代码(添加了缺少的分号):
#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::export]]
arma::mat foo(const arma::vec& x, const arma::mat& M)
arma::colvec diagonal(x.size());
for ( int i = 0; i < x.size(); i++ )
diagonal(i) = exp(x[i]);
arma::mat D = diagmat(diagonal);
return M.t() * D * M;
还请注意,我必须对返回对象的类型和函数参数的类型做出一些自己的选择(这是一个最小的可重现示例可以帮助您的地方之一:如果这些选择影响我的结果?)然后我创建一个 R 函数来执行 foo()
所做的事情:
bar <- function(v, M)
D <- diag(exp(v))
return(crossprod(M, D %*% M))
还请注意,我必须更正您的拼写错误,将 D%M
更改为 D %*% M
。让我们仔细检查一下它们是否给出了相同的结果:
all.equal(foo(x, M), bar(x, M))
# [1] TRUE
现在让我们来看看它们有多快:
library(microbenchmark)
bench <- microbenchmark(cpp = foo(x, M), R = foo(x, M), times = 1e5)
bench
# Unit: microseconds
# expr min lq mean median uq max
# cpp 22.185 23.015 27.00436 23.204 23.461 31143.30
# R 22.126 23.028 25.48256 23.216 23.475 29628.86
那些对我来说看起来几乎一样!我们还可以查看时间的密度图(丢弃极值异常值以使事情更清晰):
cpp_times <- with(bench, time[expr == "cpp"])
R_times <- with(bench, time[expr == "R"])
cpp_time_dens <- density(cpp_times[cpp_times < quantile(cpp_times, 0.95)])
R_time_dens <- density(R_times[R_times < quantile(R_times, 0.95)])
plot(cpp_time_dens, col = "blue", xlab = "Time (in nanoseconds)", ylab = "",
main = "Comparing C++ and R execution time")
lines(R_time_dens, col = "red")
legend("topright", col = c("blue", "red"), bty = "n", lty = 1,
legend = c("C++ function (foo)", "R function (bar)"))
为什么?
正如Dirk Eddelbuettel 在 cmets 中有用地指出的那样,最终 R 和犰狳都会调用 LAPACK 或 BLAS 例程——除非你能给犰狳一个提示,否则你不应该期望有太大的不同如何提高效率。
我们可以让犰狳代码更快吗?
是的!正如 cmets 中的mtall 所指出的,我们可以给犰狳暗示我们正在处理对角矩阵。咱们试试吧;我们将使用以下代码:
// [[Rcpp::export]]
arma::mat baz(const arma::vec& x, const arma::mat& M)
return M.t() * diagmat(arma::exp(x)) * M;
并对其进行基准测试:
all.equal(foo(x, M), baz(x, M))
# [1] TRUE
library(microbenchmark)
bench <- microbenchmark(cpp = foo(x, M), R = foo(x, M),
cpp2 = baz(x, M), times = 1e5)
bench
# Unit: microseconds
# expr min lq mean median uq max
# cpp 22.822 23.757 27.57015 24.118 24.632 26600.48
# R 22.855 23.771 26.44725 24.124 24.638 30619.09
# cpp2 20.035 21.218 25.49863 21.587 22.123 36745.72
我们看到了一个小但肯定的改进;让我们像以前一样以图形方式看一下:
cpp_times <- with(bench, time[expr == "cpp"])
cpp2_times <- with(bench, time[expr == "cpp2"])
R_times <- with(bench, time[expr == "R"])
cpp_time_dens <- density(cpp_times[cpp_times < quantile(cpp_times, 0.95)])
cpp2_time_dens <- density(cpp2_times[cpp2_times < quantile(cpp2_times, 0.95)])
R_time_dens <- density(R_times[R_times < quantile(R_times, 0.95)])
xlims <- range(c(cpp_time_dens$x, cpp2_time_dens$x, R_time_dens$x))
ylims <- range(c(cpp_time_dens$y, cpp2_time_dens$y, R_time_dens$y))
ylims <- ylims * c(1, 1.15)
cols <- c("#0072b2", "#f0e442", "#d55e00")
cols <- c("#e69f00", "#56b4e9", "#009e73")
labs <- c("C++ original", "C++ improved", "R")
plot(cpp_time_dens, col = cols[1], xlim = xlims, ylim = ylims,
xlab = "Time (in nanoseconds)", ylab = "",
main = "Comparing C++ and R execution time")
lines(cpp2_time_dens, col = cols[2])
lines(R_time_dens, col = cols[3])
legend("topleft", col = cols, bty = "n", lty = 1, legend = labs, horiz = TRUE)
【讨论】:
非常好的和耐心的回答。当然,“tl;dr”最终都在同一个 LAPACK/BLAS 子例程中结束,因此我们期望性能相似(“数据如何到达那里”的模块开销)。 尝试将return M.t()*D*M
更改为return M.t()*diagmat(diagonal)*M
。后一个表达式为犰狳提供了更多关于diagonal
是什么的信息,因此它可能会进行更多优化。请注意,diagonal
的生成是多余的,因为这个表达式也应该起作用:return M.t()*diagmat(arma::exp(x))*M
@mtall 是的。我还将使用高效的 Armadillo 代码添加基准测试,但我的主要观点是证明鉴于 OP 的代码,它实际上并不慢
@mtall 我在您提到的代码改进中添加了一些内容;谢谢!以上是关于RcppArmadillo:对角矩阵乘法很慢的主要内容,如果未能解决你的问题,请参考以下文章
numpy 和 tensorflow 中的各种乘法(点乘和矩阵乘)