NumPy:对角元素的点积
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了NumPy:对角元素的点积相关的知识,希望对你有一定的参考价值。
考虑x
,一个n x 3
矢量。
- 是否有可能使用
numpy
或tensorflow
的内置方法或任何Python库来获取n x 1
阶的向量,使得每一行都是3 x 1
阶的向量?也就是说,如果x
是[[1,2,3],[4,5,6],[7,8,9],[10,11,12]] T,那么形式的矢量[[1] ,2,3] T,[4,5,6] T,[7,8,9] T,[10,11,12] T] T没有for
循环或引入新的轴,比如np.newaxis
? - 这背后的动机是只获得
x
及其转置的点积的对角元素。当然,我们可以做像np.diag(x.dot(x.T))
这样的事情。但是,如果n
非常大,比如202933,那么可以听到CPU的风扇患有喘息声。如何实际避免做所有元素的点积,只做虚拟点积的对角线而没有迭代?
答案
让我们来看看x
乘以自己的转置结果中每个元素的公式。我不想尝试强制Stack Overflow UI允许我使用张量表示法,所以我们从概念上看。
结果的第i行,第j列的每个元素是x
中第i行和x.T
中第j列的点积。现在x.T
中的第j列只是x
中的第j行,而对角线是i和j相同的地方。所以你想要的是x
平方元素行的总和:
d = (x * x).sum(axis=1)
为了解决问题的第一部分,numpy中的转置操作很少复制您的数据,因此即使是最大的数组,x.T
或np.transpose(x)
也是常量操作。原因是numpy数组被存储为一个数据块以及一些元数据,如维度,每个维度中元素之间的跨距和数据大小。转置数组只需要修改数组对象中的少量元数据,例如沿每个维度和步幅的大小,而不是复制整个数据集。
耗时的部分是执行乘法。简单地让对象x
和x.T
几乎没有成本:它们都使用相同的数据缓冲区。
以上是关于NumPy:对角元素的点积的主要内容,如果未能解决你的问题,请参考以下文章
numpy使用np.dot函数或者@操作符计算两个numpy数组的点积数量积(dot productscalar product)