使用 py4j 将矩阵作为 int[][] 数组从 Python 发送到 Java
Posted
技术标签:
【中文标题】使用 py4j 将矩阵作为 int[][] 数组从 Python 发送到 Java【英文标题】:Using py4j to send matrices to from Python to Java as int[][] arrays 【发布时间】:2016-07-26 23:41:48 【问题描述】:我一直在使用 py4j 围绕一个不太友好的 Java 库构建一个用户友好的 Python 库。在大多数情况下,这很容易,py4j 一直是一个很棒的工具。但是,我在 Python 和 Java 之间发送矩阵时遇到了问题。
具体来说,我在 java 中有一个静态函数,它接受一个整数矩阵作为其参数:
public class MyClass
// ...
public static MyObject create(int[][] matrix)
// ...
我希望能够像这样从 Py4j 调用它:
def create_java_object(numpy_matrix):
# <code here checks that numpy_matrix is a (3 x n) integer matrix>
# ...
return java_instance.jvm.my.namespace.MyClass.create(numpy_matrix)
这不起作用,这并不奇怪,如果 numpy_matrix
被转换为普通 python 列表的列表,它也不起作用。我曾期望解决方案是在函数调用之前构造一个 java 数组并传输数据:
def create_java_object(numpy_matrix):
# <code here checks that numpy_matrix is a (3 x n) integer matrix>
# ...
java_matrix = java_instance.new_array(java_instance.jvm.int, 3, n)
for i in range(numpy_matrix.shape[1]):
java_matrix[0][i] = int(numpy_matrix[0, i])
java_matrix[1][i] = int(numpy_matrix[1, i])
java_matrix[2][i] = int(numpy_matrix[2, i])
return java_instance.jvm.my.namespace.MyClass.create(java_matrix)
现在,这段代码可以正常运行。但是,它需要大约 两分钟 才能运行。顺便说一下,我正在使用的矩阵大约是 (3 x ~300,000) 个元素。
在 Py4j 中是否有一种规范的方法可以做到这一点,而不需要大量的时间来转换矩阵?我不介意花一两秒钟,但这太慢了。如果没有为这种通信设置 Py4j,是否有适用于 Python 的 Java 互操作库?
注意:Java 库将int[][]
矩阵视为不可变数组;即,它从不尝试修改它。
【问题讨论】:
【参考方案1】:我为这种特殊情况找到了一个可行的解决方案;虽然它不是非常优雅:
Py4j 支持将 Python bytearray
对象作为 byte[]
数组有效地传递给 Java。我通过修改原始库和我的 Python 代码解决了这个问题。
新的 Java 代码:
public class MyClass
// ...
public static MyObject create(int[][] matrix)
// ...
public static MyObject createFromPy4j(byte[] data)
java.nio.ByteBuffer buf = java.nio.ByteBuffer.wrap(data);
int n = buf.getInt(), m = buf.getInt();
int[][] matrix = new int[n][m];
for (int i = 0; i < n; ++i)
for (int j = 0; j < m; ++j)
matrix[i][j] = buf.getInt();
return MyClass.create(matrix);
新的 Python 代码:
def create_java_object(numpy_matrix):
header = array.array('i', list(numpy_matrix.shape))
body = array.array('i', numpy_matrix.flatten().tolist());
if sys.byteorder != 'big':
header.byteswap()
body.byteswap()
buf = bytearray(header.tostring() + body.tostring())
return java_instance.jvm.my.namespace.MyClass.createFromPy4j(buf)
这需要几秒钟而不是几分钟。
【讨论】:
这只适用于二维矩阵?如何将形状长度存储为第一个值,以便您可以发送任何形状的矩阵?以上是关于使用 py4j 将矩阵作为 int[][] 数组从 Python 发送到 Java的主要内容,如果未能解决你的问题,请参考以下文章