用jax计算行向(或轴向)点积的最佳方法是什么?

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了用jax计算行向(或轴向)点积的最佳方法是什么?相关的知识,希望对你有一定的参考价值。

我有两个形状为(N, M)的数值数组。我想计算一个行向的点积,即产生一个形状(N,M)的数组。即产生一个形状(N,)的数组,使第n行是每个数组中第n行的点积。

我知道 numpy 的 inner1d 方法。jax的最佳方法是什么?jax.numpy.inner但这是另一回事。

答案

你可以定义你自己的jit编译版本的 inner1d 在几行jax代码中。

import jax
@jax.jit
def inner1d(X, Y):
  return (X * Y).sum(-1)

测试一下

import jax.numpy as jnp
import numpy as np
from numpy.core import umath_tests


X = np.random.rand(5, 10)
Y = np.random.rand(5, 10)

print(umath_tests.inner1d(X, Y))
print(inner1d(jnp.array(X), jnp.array(Y)))
# [2.23219571 2.1013316  2.70353783 2.14094973 2.62582531]
# [2.2321959 2.1013315 2.703538  2.1409497 2.6258256]
另一答案

你可以试试 jax.numpy.einsum. 这里使用numpy einsum来实现

import numpy as np
from numpy.core.umath_tests import inner1d

arr1 = np.random.randint(0,10,[5,5])
arr2 = np.random.randint(0,10,[5,5])

arr = np.inner1d(arr1,arr2)
arr
array([ 87, 200, 229,  81,  53])
np.einsum('...i,...i->...',arr1,arr2)
array([ 87, 200, 229,  81,  53])

以上是关于用jax计算行向(或轴向)点积的最佳方法是什么?的主要内容,如果未能解决你的问题,请参考以下文章

向量乘积的几何意义

向量点积几何意义是啥 向量点积几何意义介绍

python中巨大矩阵的点积的行和

关于向量点积和叉积的回顾和总结

什么是Julia在所需尺寸上点缀产品的方法

NumPy 点积:取向量积的积(而不是求和)