犰狳:矩阵乘法精度损失
Posted
技术标签:
【中文标题】犰狳:矩阵乘法精度损失【英文标题】:armadillo: matrix multiplication loss of precision 【发布时间】:2017-05-02 04:11:40 【问题描述】:我似乎在执行矩阵乘法时遇到精度损失,想知道如何防止这种情况发生。例如,假设 feat 和 beta 的大小合适,
Y = feat*beta.rows(0,N);
我正在使用的数字的值相当小,大多数数字都小于 1e-3,因此我试图实现的目标可能是不可能的。我还应该注意,这是一个调用 C++ 函数的 MATLAB 函数,因此涉及到 MEX 编译器。当他们到达时,我确实检查了 mex 函数中的数字并且它们是正确的,只有在上述行之后,我才得到非常错误的答案。
编辑:我认为给出程序的完整上下文不会有什么坏处。这是我到目前为止所拥有的。我有一条用评论标记精度损失的行。
EDIT2:以下是相关矩阵的一些示例。 Feat_2d 是 5x4608
0 0 0 0 0
0 0 0 0 0
0 0 0 0 4.8146
0 0 19.0266 0 0
0 0 0 0 0
Beta_2d 为 4609x4。我删除了最后一行作为 feat*Beta_2d 的乘法
-7.1486e-05 -1.6801e-04 1.0970e-05 3.7837e-04
-8.7524e-05 1.8275e-04 -6.7857e-04 2.6267e-04
-9.1812e-05 -6.5495e-05 -1.7687e-03 -3.2168e-04
0e+00 0e+00 0e+00 0e+00
-4.5089e-04 -5.6013e-05 1.4841e-04 2.4912e-04
Y =
6.8995e-310 0e+00 4.7430e-322 1.7802e-306
6.8995e-310 0e+00 4.9407e-322 1.4463e-307
0e+00 0e+00 0e+00 1.4463e-307
0e+00 0e+00 1.2352e-322 1.2016e-306
6.8996e-310 6.8996e-310 6.8995e-310 1.7802e-306
这是来自 EDIT1 的代码
#include <mex.h>
#include <iostream>
#include <armadillo>
using namespace arma;
void predict_bbox_reg(double *beta, int beta_dim[2], double *t_inv, int tinv_dim[2], double mu, double *feat, int feat_dim[2], double *boxes, int box_dim[2])
//convert pointers
//beta
arma::mat beta_2d = mat(beta_dim[0], beta_dim[1]);
for(int i = 0; i<beta_dim[0]; i++)
for(int j = 0; j<beta_dim[1]; j++)
beta_2d(i, j) = beta[i+beta_dim[0]*j];
//t_inv
arma::mat tinv_2d = mat(tinv_dim[0], tinv_dim[1]);
for(int i = 0; i<tinv_dim[0]; i++)
for(int j = 0; j<tinv_dim[1]; j++)
tinv_2d(i, j) = t_inv[i+tinv_dim[0]*j];
//feadoublet_2d
arma::mat feat_2d = mat(feat_dim[0], feat_dim[1]);
for(int i = 0; i<feat_dim[0]; i++)
for(int j = 0; j<feat_dim[1]; j++)
feat_2d(i, j) = feat[i+feat_dim[0]*j];
//boxes
arma::mat box_2d = mat(box_dim[0], box_dim[1]);
for(int i = 0; i<box_dim[0]; i++)
for(int j = 0; j<box_dim[1]; j++)
box_2d(i, j) = boxes[i+box_dim[0]*j];
arma::mat Y = mat(feat_dim[0], beta_dim[1]);
Y = feat_2d*beta_2d.rows(0,beta_dim[0]-2);// this is the precision loss
arma::mat y1 = beta_2d.row(beta_2d.n_rows-1);
Y.each_row() += y1.row(0);
//RETURNS SOMETHING
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
int M = mxGetM(prhs[0]);
int N = mxGetN(prhs[0]);
int beta_dim[2] = M,N;
double *beta = mxGetPr(prhs[0]);
M = mxGetM(prhs[1]);
N = mxGetN(prhs[1]);
int tinv_dim[2] = M,N;
double *t_inv = mxGetPr(prhs[1]);
double mu = *mxGetPr(prhs[2]);
M = mxGetM(prhs[3]);
N = mxGetN(prhs[3]);
int feat_dim[2] = M,N;
double *feat = mxGetPr(prhs[3]);
M = mxGetM(prhs[4]);
N = mxGetN(prhs[4]);
int box_dim[2] = M,N;
double *ex_boxes = mxGetPr(prhs[4]);
predict_bbox_reg(beta, beta_dim, t_inv, tinv_dim,
mu, feat, feat_dim, ex_boxes, box_dim);
//RETURNS results to matlab
【问题讨论】:
我们需要更多信息。你能展示一些你的矩阵的最小例子吗?另外,您如何定义“精度损失”?你怎么知道你实际上正在失去精度?如果你做得对,你就不能把它和电脑做的比较……因为也许你的比较矩阵不好,而不是你自己的。 @AnderBiguri 我做了一些编辑,显示了一些矩阵,似乎有一个问题,其中 Y 导致基本上为零的操作。当我在matlab中做操作时,结果不一样。 【参考方案1】:我在您的代码中看不到任何直接错误,但我怀疑 Armadillo 存在精度问题。我怀疑它可能是您指向 mat 转换的指针。我会使用犰狳在.../mex_interface/armaMex.hpp
文件中提供的功能。
一个简单的矩阵乘法示例(mult_test.cpp):
#include <armadillo>
#include "armaMex.hpp"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
// Convert to Armadillo
mat x1 = conv_to<mat>::from(armaGetPr(prhs[0],true));
mat x2 = conv_to<mat>::from(armaGetPr(prhs[1],true));
mat y(x1.n_rows,x2.n_cols);
// Do your stuff here:
y = x1*x2;
// Convert back to Matlab
plhs[0] = armaCreateMxMatrix(y.n_rows, y.n_cols, mxDOUBLE_CLASS, mxREAL);
armaSetPr(plhs[0], conv_to<mat>::from(y));
return;
.m 文件是
X1 = [1 2 3 ; 4 5 6];
X2 = 1e-7*rand(3,2);
Y = mult_test(X1,X2);
disp('Matlab:')
disp(X1*X2)
disp('Armadillo:')
disp(Y)
给出输出
Matlab:
1.0e-06 *
0.240798243020273 0.410716970485213
0.559953800808707 0.974915983937571
Armadillo:
1.0e-06 *
0.240798243020273 0.410716970485213
0.559953800808707 0.974915983937571
【讨论】:
请问!正如医生所说! Armadillo 受限于原生类型 double 或 long double 作为最大精度!如果这还不够怎么办!我搜索并没有找到如何使用其他类型,例如一些大型 num 库!任何想法!当然,如果这样做,性能会受到影响!非常感谢! 正如你所说,犰狳只支持“float、double、std::complex以上是关于犰狳:矩阵乘法精度损失的主要内容,如果未能解决你的问题,请参考以下文章