如何逐行读取 Matlab mex 函数的输入矩阵?
Posted
技术标签:
【中文标题】如何逐行读取 Matlab mex 函数的输入矩阵?【英文标题】:How to read input matrix of a Matlab mex function row-wise? 【发布时间】:2021-11-05 03:19:46 【问题描述】:我需要创建一个 Matlab mex 函数,该函数将采用输入矩阵并返回矩阵对角线。
输入:
1 2 3
4 5 6
预期输出:
1 2 3 0 0 0
0 0 0 4 5 6
我的问题是,由于 Matlab 是按列而不是按行读取矩阵,所以我的 mex 函数给出了错误的输出。
当前输出:
1 4 0 0 0 0
0 0 2 5 0 0
0 0 0 0 3 6
您将如何更改我的代码以逐行读取输入矩阵,以便获得正确的输出?
我的代码如下:
#include <matrix.h>
#include <mex.h>
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
mxArray *a_in, *b_out;
const mwSize *dims;
double *a, *b;
int rows, cols;
// Input array:
a_in = mxDuplicateArray(prhs[0]);
// Get dimensions of input array:
dims = mxGetDimensions(prhs[0]);
rows = (int) dims[0];
cols = (int) dims[1];
// Output array:
if(rows == cols)
b_out = plhs[0] = mxCreateDoubleMatrix(rows, rows*cols, mxREAL);
else
b_out = plhs[0] = mxCreateDoubleMatrix(cols, rows*cols, mxREAL);
// Access the contents of the input and output arrays:
a = mxGetPr(a_in);
b = mxGetPr(b_out);
// Compute exdiag function of the input array
int count = 0;
for (int i = 0; i < rows; i++)
for(int j = 0; j<cols;j++)
if(rows == cols)
b[rows*count+count/rows] = a[j + rows * i];
count++;
else if(rows < cols)
b[cols*count+count/rows] = a[j + cols * i];
count++;
else if(rows>cols)
b[cols*count+count/rows] = a[j + cols * i];
count++;
【问题讨论】:
你不能只是转置输入吗? 【参考方案1】:在循环中,i
是行索引,j
是列索引。你做a[j + rows * i]
,混合了两个索引。 MATLAB 按列存储数据,因此您需要执行a[i + rows * j]
才能正确读取输入矩阵。
为了索引输出,您希望行保持i
,并且希望列保持i * cols + j
:
b[i + rows * (i * cols + j)] = a[i + rows * j];
请注意,您不需要执行a_in = mxDuplicateArray(prhs[0])
,因为您没有写入a_in
。您可以直接访问prhs[0]
矩阵,如果您想要别名,也可以使用a_in = prhs[0]
。
此外,如果数组非常大,将数组大小转换为 int
会导致问题。数组大小和索引最好使用mwSize
和mwIndex
。
最后,你应该经常检查输入数组的类型,如果你得到一个不是双精度的数组,你很可能会导致读取越界错误。
这是我的代码:
#include <matrix.h>
#include <mex.h>
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
mwSize const* dims;
double *a, *b;
mwSize rows, cols;
if (!mxIsDouble(prhs[0]))
mexErrMsgTxt("Input must be doubles");
// Get dimensions of input array:
dims = mxGetDimensions(prhs[0]);
rows = dims[0];
cols = dims[1];
// Output array:
plhs[0] = mxCreateDoubleMatrix(rows, rows*cols, mxREAL);
// Access the contents of the input and output arrays:
a = mxGetPr(prhs[0]);
b = mxGetPr(plhs[0]);
// Compute exdiag function of the input array
for (mwIndex i = 0; i < rows; i++)
for (mwIndex j = 0; j < cols; j++)
b[i + rows * (i * cols + j)] = a[i + rows * j];
【讨论】:
以上是关于如何逐行读取 Matlab mex 函数的输入矩阵?的主要内容,如果未能解决你的问题,请参考以下文章