NumPy:对角元素的点积

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了NumPy:对角元素的点积相关的知识,希望对你有一定的参考价值。

考虑x,一个n x 3矢量。

  1. 是否有可能使用numpytensorflow的内置方法或任何Python库来获取n x 1阶的向量,使得每一行都是3 x 1阶的向量?也就是说,如果x是[[1,2,3],[4,5,6],[7,8,9],[10,11,12]] T,那么形式的矢量[[1] ,2,3] T,[4,5,6] T,[7,8,9] T,[10,11,12] T] T没有for循环或引入新的轴,比如np.newaxis
  2. 这背后的动机是只获得x及其转置的点积的对角元素。当然,我们可以做像np.diag(x.dot(x.T))这样的事情。但是,如果n非常大,比如202933,那么可以听到CPU的风扇患有喘息声。如何实际避免做所有元素的点积,只做虚拟点积的对角线而没有迭代?
答案

让我们来看看x乘以自己的转置结果中每个元素的公式。我不想尝试强制Stack Overflow UI允许我使用张量表示法,所以我们从概念上看。

结果的第i行,第j列的每个元素是x中第i行和x.T中第j列的点积。现在x.T中的第j列只是x中的第j行,而对角线是i和j相同的地方。所以你想要的是x平方元素行的总和:

d = (x * x).sum(axis=1)

为了解决问题的第一部分,numpy中的转置操作很少复制您的数据,因此即使是最大的数组,x.Tnp.transpose(x)也是常量操作。原因是numpy数组被存储为一个数据块以及一些元数据,如维度,每个维度中元素之间的跨距和数据大小。转置数组只需要修改数组对象中的少量元数据,例如沿每个维度和步幅的大小,而不是复制整个数据集。

耗时的部分是执行乘法。简单地让对象xx.T几乎没有成本:它们都使用相同的数据缓冲区。

以上是关于NumPy:对角元素的点积的主要内容,如果未能解决你的问题,请参考以下文章

numpy使用np.dot函数或者@操作符计算两个numpy数组的点积数量积(dot productscalar product)

python中巨大矩阵的点积的行和

Numpy:点积和 dot() 矩阵相乘

用jax计算行向(或轴向)点积的最佳方法是什么?

向量与 SIMD 的点积

在 Cython 中调用点积和线性代数运算?