如何对 pyspark 数据框列进行向量操作?
Posted
技术标签:
【中文标题】如何对 pyspark 数据框列进行向量操作?【英文标题】:How to do vector operations on pyspark dataframe columns? 【发布时间】:2020-07-06 17:14:12 【问题描述】:我有一个这样的数据框:
id | vector1 | id2 | vector2
其中 id 是整数,向量是 SparseVector 类型。
对于每一行,我想添加一个余弦相似度的列,这将由
vector1.dot(vector2)/(sqrt(vector1.dot(vector1)*sqrt(vector2.dot(vector2))
但我不知道如何使用它来将其放入新列中。我试过制作一个udf,但似乎无法弄清楚
【问题讨论】:
【参考方案1】:使用 scala 的解决方案
spark repo 中有一个实用程序对象org.apache.spark.ml.linalg.BLAS,它使用com.github.fommil.netlib.BLAS
进行点积。但是该对象对于 spark 内部提交者是包私有的,要在此处使用它,我们需要在当前项目中复制该实用程序,如下所示 -
package utils
import com.github.fommil.netlib.F2jBLAS, BLAS => NetlibBLAS
import com.github.fommil.netlib.BLAS.getInstance => NativeBLAS
import org.apache.spark.ml.linalg.DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector
/**
* Utility object org.apache.spark.ml.linalg.BLAS is package private in spark repo,
* copying it here org.apache.spark.ml.linalg.BLAS to use the utility
* BLAS routines for MLlib's vectors and matrices.
*/
object BLAS extends Serializable
@transient private var _f2jBLAS: NetlibBLAS = _
@transient private var _nativeBLAS: NetlibBLAS = _
// For level-1 routines, we use Java implementation.
private def f2jBLAS: NetlibBLAS =
if (_f2jBLAS == null)
_f2jBLAS = new F2jBLAS
_f2jBLAS
/**
* dot(x, y)
*/
def dot(x: Vector, y: Vector): Double =
require(x.size == y.size,
"BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" +
" x.size = " + x.size + ", y.size = " + y.size)
(x, y) match
case (dx: DenseVector, dy: DenseVector) =>
dot(dx, dy)
case (sx: SparseVector, dy: DenseVector) =>
dot(sx, dy)
case (dx: DenseVector, sy: SparseVector) =>
dot(sy, dx)
case (sx: SparseVector, sy: SparseVector) =>
dot(sx, sy)
case _ =>
throw new IllegalArgumentException(s"dot doesn't support ($x.getClass, $y.getClass).")
/**
* dot(x, y)
*/
private def dot(x: DenseVector, y: DenseVector): Double =
val n = x.size
f2jBLAS.ddot(n, x.values, 1, y.values, 1)
/**
* dot(x, y)
*/
private def dot(x: SparseVector, y: DenseVector): Double =
val xValues = x.values
val xIndices = x.indices
val yValues = y.values
val nnz = xIndices.length
var sum = 0.0
var k = 0
while (k < nnz)
sum += xValues(k) * yValues(xIndices(k))
k += 1
sum
/**
* dot(x, y)
*/
private def dot(x: SparseVector, y: SparseVector): Double =
val xValues = x.values
val xIndices = x.indices
val yValues = y.values
val yIndices = y.indices
val nnzx = xIndices.length
val nnzy = yIndices.length
var kx = 0
var ky = 0
var sum = 0.0
// y catching x
while (kx < nnzx && ky < nnzy)
val ix = xIndices(kx)
while (ky < nnzy && yIndices(ky) < ix)
ky += 1
if (ky < nnzy && yIndices(ky) == ix)
sum += xValues(kx) * yValues(ky)
ky += 1
kx += 1
sum
使用上述实用程序查找余弦相似度
val df = Seq(
(0, Vectors.dense(0.0, 10.0, 0.5), 1, Vectors.dense(0.0, 10.0, 0.5)),
(1, Vectors.dense(0.0, 10.0, 0.2), 2, Vectors.dense(0.0, 10.0, 0.2))
).toDF("id", "vector1", "id2", "vector2")
df.show(false)
df.printSchema()
/**
* +---+--------------+---+--------------+
* |id |vector1 |id2|vector2 |
* +---+--------------+---+--------------+
* |0 |[0.0,10.0,0.5]|1 |[0.0,10.0,0.5]|
* |1 |[0.0,10.0,0.2]|2 |[0.0,10.0,0.2]|
* +---+--------------+---+--------------+
*
* root
* |-- id: integer (nullable = false)
* |-- vector1: vector (nullable = true)
* |-- id2: integer (nullable = false)
* |-- vector2: vector (nullable = true)
*/
// vector1.dot(vector2)/(sqrt(vector1.dot(vector1)*sqrt(vector2.dot(vector2))
val cosine_similarity = udf((vector1: Vector, vector2: Vector) => utils.BLAS.dot(vector1, vector2) /
(Math.sqrt(utils.BLAS.dot(vector1, vector1))* Math.sqrt(utils.BLAS.dot(vector2, vector2)))
)
df.withColumn("cosine", cosine_similarity($"vector1", $"vector2"))
.show(false)
/**
* +---+--------------+---+--------------+------------------+
* |id |vector1 |id2|vector2 |cosine |
* +---+--------------+---+--------------+------------------+
* |0 |[0.0,10.0,0.5]|1 |[0.0,10.0,0.5]|0.9999999999999999|
* |1 |[0.0,10.0,0.2]|2 |[0.0,10.0,0.2]|1.0000000000000002|
* +---+--------------+---+--------------+------------------+
*/
【讨论】:
谢谢。这是我一直在尝试做的,但我在 python 中仍然遇到了麻烦。你知道 Python 中 $ 运算符的等价物吗?我不知道如何将正确的变量传递给 udf以上是关于如何对 pyspark 数据框列进行向量操作?的主要内容,如果未能解决你的问题,请参考以下文章