RcppArmadillo:对角矩阵乘法很慢

Posted

技术标签:

【中文标题】RcppArmadillo:对角矩阵乘法很慢【英文标题】:RcppArmadillo: diagonal matrix multiplication is very slow 【发布时间】:2019-11-02 23:15:00 【问题描述】:

x 为向量,M 为矩阵。

在 R 中,我可以做到

D <- diag(exp(x))
crossprod(M, D%M)

在 RcppArmadillo 中,我有以下慢得多的内容。

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export]]
arma::mat multiple_mnv(const arma::vec& x, const arma::mat& M) 
  arma::colvec diagonal(x.size())
  for (int i = 0; i < x.size(); i++)
  
    diagonal(i) = exp(x[i]);
  
  arma::mat D = diagmat(diagonal);
  return     M.t()*D*M;

为什么这么慢?如何加快速度?

【问题讨论】:

你能提供一个minimal reproducible example吗?你是如何对其进行基准测试的? 尝试将return M.t()*D*M 更改为return M.t()*diagmat(diagonal)*M。后一个表达式为犰狳提供了更多关于diagonal 是什么的信息,因此它可能会进行更多优化。请注意,diagonal 的生成是多余的,因为这个表达式也应该有效:return M.t()*diagmat(arma::exp(x))*M 【参考方案1】:

欢迎来到 Stack Overflow manju。对于未来的问题,请注意,minimal reproducible example 是预期的,实际上是您的最大利益;它帮助别人帮助你。下面是一个示例,说明如何提供示例数据供其他人使用:

## Set seed for reproducibility
set.seed(123)
## Generate data
x <- rnorm(10)
M <- matrix(rnorm(100), nrow = 10, ncol = 10)
## Output code for others to copy your objects
dput(x)
dput(M)

这是我将使用的数据,以表明您的 C++ 代码实际上并不比 R 慢。我使用了您的 C++ 代码(添加了缺少的分号):

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export]]
arma::mat foo(const arma::vec& x, const arma::mat& M) 
    arma::colvec diagonal(x.size());
    for ( int i = 0; i < x.size(); i++ )
    
        diagonal(i) = exp(x[i]);
    
    arma::mat D = diagmat(diagonal);
    return M.t() * D * M;

还请注意,我必须对返回对象的类型和函数参数的类型做出一些自己的选择(这是一个最小的可重现示例可以帮助您的地方之一:如果这些选择影响我的结果?)然后我创建一个 R 函数来执行 foo() 所做的事情:

bar <- function(v, M) 
    D <- diag(exp(v))
    return(crossprod(M, D %*% M))

还请注意,我必须更正您的拼写错误,将 D%M 更改为 D %*% M。让我们仔细检查一下它们是否给出了相同的结果:

all.equal(foo(x, M), bar(x, M))
# [1] TRUE

现在让我们来看看它们有多快:

library(microbenchmark)
bench <- microbenchmark(cpp = foo(x, M), R = foo(x, M), times = 1e5)
bench
# Unit: microseconds
#  expr    min     lq     mean median     uq      max
#   cpp 22.185 23.015 27.00436 23.204 23.461 31143.30
#     R 22.126 23.028 25.48256 23.216 23.475 29628.86

那些对我来说看起来几乎一样!我们还可以查看时间的密度图(丢弃极值异常值以使事情更清晰):

cpp_times <- with(bench, time[expr == "cpp"])
R_times   <- with(bench, time[expr == "R"])
cpp_time_dens <- density(cpp_times[cpp_times < quantile(cpp_times, 0.95)])
R_time_dens   <- density(R_times[R_times < quantile(R_times, 0.95)])
plot(cpp_time_dens, col = "blue", xlab = "Time (in nanoseconds)", ylab = "",
     main = "Comparing C++ and R execution time")
lines(R_time_dens, col = "red")
legend("topright", col = c("blue", "red"), bty = "n", lty = 1,
       legend = c("C++ function (foo)", "R function (bar)"))

为什么?

正如Dirk Eddelbuettel 在 cmets 中有用地指出的那样,最终 R 和犰狳都会调用 LAPACK 或 BLAS 例程——除非你能给犰狳一个提示,否则你不应该期望有太大的不同如何提高效率。

我们可以让犰狳代码更快吗?

是的!正如 cmets 中的mtall 所指出的,我们可以给犰狳暗示我们正在处理对角矩阵。咱们试试吧;我们将使用以下代码:

// [[Rcpp::export]]
arma::mat baz(const arma::vec& x, const arma::mat& M) 
    return M.t() * diagmat(arma::exp(x)) * M;

并对其进行基准测试:

all.equal(foo(x, M), baz(x, M))
# [1] TRUE
library(microbenchmark)
bench <- microbenchmark(cpp = foo(x, M), R = foo(x, M),
                        cpp2 = baz(x, M), times = 1e5)
bench
# Unit: microseconds
#  expr    min     lq     mean median     uq      max
#   cpp 22.822 23.757 27.57015 24.118 24.632 26600.48
#     R 22.855 23.771 26.44725 24.124 24.638 30619.09
#  cpp2 20.035 21.218 25.49863 21.587 22.123 36745.72

我们看到了一个小但肯定的改进;让我们像以前一样以图形方式看一下:

cpp_times  <- with(bench, time[expr == "cpp"])
cpp2_times <- with(bench, time[expr == "cpp2"])
R_times    <- with(bench, time[expr == "R"])
cpp_time_dens  <- density(cpp_times[cpp_times < quantile(cpp_times, 0.95)])
cpp2_time_dens <- density(cpp2_times[cpp2_times < quantile(cpp2_times, 0.95)])
R_time_dens    <- density(R_times[R_times < quantile(R_times, 0.95)])
xlims <- range(c(cpp_time_dens$x, cpp2_time_dens$x, R_time_dens$x))
ylims <- range(c(cpp_time_dens$y, cpp2_time_dens$y, R_time_dens$y))
ylims <- ylims * c(1, 1.15)
cols  <- c("#0072b2", "#f0e442", "#d55e00")
cols  <- c("#e69f00", "#56b4e9", "#009e73")
labs  <- c("C++ original", "C++ improved", "R")
plot(cpp_time_dens, col = cols[1], xlim = xlims, ylim = ylims,
     xlab = "Time (in nanoseconds)", ylab = "",
     main = "Comparing C++ and R execution time")
lines(cpp2_time_dens, col = cols[2])
lines(R_time_dens, col = cols[3])
legend("topleft", col = cols, bty = "n", lty = 1, legend = labs, horiz = TRUE)

【讨论】:

非常好的和耐心的回答。当然,“tl;dr”最终都在同一个 LAPACK/BLAS 子例程中结束,因此我们期望性能相似(“数据如何到达那里”的模块开销)。 尝试将return M.t()*D*M 更改为return M.t()*diagmat(diagonal)*M。后一个表达式为犰狳提供了更多关于diagonal 是什么的信息,因此它可能会进行更多优化。请注意,diagonal 的生成是多余的,因为这个表达式也应该起作用:return M.t()*diagmat(arma::exp(x))*M @mtall 是的。我还将使用高效的 Armadillo 代码添加基准测试,但我的主要观点是证明鉴于 OP 的代码,它实际上并不慢 @mtall 我在您提到的代码改进中添加了一些内容;谢谢!

以上是关于RcppArmadillo:对角矩阵乘法很慢的主要内容,如果未能解决你的问题,请参考以下文章

numpy 和 tensorflow 中的各种乘法(点乘和矩阵乘)

numpy 和 tensorflow 中的各种乘法(点乘和矩阵乘)

C语言试题129之求一个 3乘3 矩阵对角线元素之和

2×3矩阵乘3×2矩阵要怎么算?

动态规划 - 矩阵链乘法

矩阵乘法和逆