实现一个 java UDF 并从 pyspark 调用它

Posted

技术标签:

【中文标题】实现一个 java UDF 并从 pyspark 调用它【英文标题】:Implement a java UDF and call it from pyspark 【发布时间】:2016-07-10 07:59:29 【问题描述】:

我需要创建一个用于 pyspark python 的 UDF,它使用 java 对象进行内部计算。

如果它是一个简单的 python,我会这样做:

def f(x):
    return 7
fudf = pyspark.sql.functions.udf(f,pyspark.sql.types.IntegerType())

并使用以下方法调用它:

df = sqlContext.range(0,5)
df2 = df.withColumn("a",fudf(df.id)).show()

但是,我需要的函数的实现是在 java 中而不是在 python 中。我需要以某种方式包装它,以便我可以从 python 中以类似的方式调用它。

我的第一个尝试是实现 java 对象,然后将其包装在 pyspark 中的 python 中并将其转换为 UDF。失败并出现序列化错误。

Java 代码:

package com.test1.test2;

public class TestClass1 
    Integer internalVal;
    public TestClass1(Integer val1) 
        internalVal = val1;
    
    public Integer do_something(Integer val) 
        return internalVal;
        

pyspark 代码:

from py4j.java_gateway import java_import
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
java_import(sc._gateway.jvm, "com.test1.test2.TestClass1")
a = sc._gateway.jvm.com.test1.test2.TestClass1(7)
audf = udf(a,IntegerType())

错误:

---------------------------------------------------------------------------
Py4JError                                 Traceback (most recent call last)
<ipython-input-2-9756772ab14f> in <module>()
      4 java_import(sc._gateway.jvm, "com.test1.test2.TestClass1")
      5 a = sc._gateway.jvm.com.test1.test2.TestClass1(7)
----> 6 audf = udf(a,IntegerType())

/usr/local/spark/python/pyspark/sql/functions.py in udf(f, returnType)
   1595     [Row(slen=5), Row(slen=3)]
   1596     """
-> 1597     return UserDefinedFunction(f, returnType)
   1598 
   1599 blacklist = ['map', 'since', 'ignore_unicode_prefix']

/usr/local/spark/python/pyspark/sql/functions.py in __init__(self, func, returnType, name)
   1556         self.returnType = returnType
   1557         self._broadcast = None
-> 1558         self._judf = self._create_judf(name)
   1559 
   1560     def _create_judf(self, name):

/usr/local/spark/python/pyspark/sql/functions.py in _create_judf(self, name)
   1565         command = (func, None, ser, ser)
   1566         sc = SparkContext.getOrCreate()
-> 1567         pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
   1568         ctx = SQLContext.getOrCreate(sc)
   1569         jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())

/usr/local/spark/python/pyspark/rdd.py in _prepare_for_python_RDD(sc, command, obj)
   2297     # the serialized command will be compressed by broadcast
   2298     ser = CloudPickleSerializer()
-> 2299     pickled_command = ser.dumps(command)
   2300     if len(pickled_command) > (1 << 20):  # 1M
   2301         # The broadcast will have same life cycle as created PythonRDD

/usr/local/spark/python/pyspark/serializers.py in dumps(self, obj)
    426 
    427     def dumps(self, obj):
--> 428         return cloudpickle.dumps(obj, 2)
    429 
    430 

/usr/local/spark/python/pyspark/cloudpickle.py in dumps(obj, protocol)
    644 
    645     cp = CloudPickler(file,protocol)
--> 646     cp.dump(obj)
    647 
    648     return file.getvalue()

/usr/local/spark/python/pyspark/cloudpickle.py in dump(self, obj)
    105         self.inject_addons()
    106         try:
--> 107             return Pickler.dump(self, obj)
    108         except RuntimeError as e:
    109             if 'recursion' in e.args[0]:

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in dump(self, obj)
    222         if self.proto >= 2:
    223             self.write(PROTO + chr(self.proto))
--> 224         self.save(obj)
    225         self.write(STOP)
    226 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
    284         f = self.dispatch.get(t)
    285         if f:
--> 286             f(self, obj) # Call unbound method with explicit self
    287             return
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj)
    566         write(MARK)
    567         for element in obj:
--> 568             save(element)
    569 
    570         if id(obj) in memo:

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
    284         f = self.dispatch.get(t)
    285         if f:
--> 286             f(self, obj) # Call unbound method with explicit self
    287             return
    288 

/usr/local/spark/python/pyspark/cloudpickle.py in save_function(self, obj, name)
    191         if islambda(obj) or obj.__code__.co_filename == '<stdin>' or themodule is None:
    192             #print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule)
--> 193             self.save_function_tuple(obj)
    194             return
    195         else:

/usr/local/spark/python/pyspark/cloudpickle.py in save_function_tuple(self, func)
    234         # create a skeleton function object and memoize it
    235         save(_make_skel_func)
--> 236         save((code, closure, base_globals))
    237         write(pickle.REDUCE)
    238         self.memoize(func)

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
    284         f = self.dispatch.get(t)
    285         if f:
--> 286             f(self, obj) # Call unbound method with explicit self
    287             return
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj)
    552         if n <= 3 and proto >= 2:
    553             for element in obj:
--> 554                 save(element)
    555             # Subtle.  Same as in the big comment below.
    556             if id(obj) in memo:

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
    284         f = self.dispatch.get(t)
    285         if f:
--> 286             f(self, obj) # Call unbound method with explicit self
    287             return
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_list(self, obj)
    604 
    605         self.memoize(obj)
--> 606         self._batch_appends(iter(obj))
    607 
    608     dispatch[ListType] = save_list

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in _batch_appends(self, items)
    637                 write(MARK)
    638                 for x in tmp:
--> 639                     save(x)
    640                 write(APPENDS)
    641             elif n:

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
    304             reduce = getattr(obj, "__reduce_ex__", None)
    305             if reduce:
--> 306                 rv = reduce(self.proto)
    307             else:
    308                 reduce = getattr(obj, "__reduce__", None)

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
    811         answer = self.gateway_client.send_command(command)
    812         return_value = get_return_value(
--> 813             answer, self.gateway_client, self.target_id, self.name)
    814 
    815         for temp_arg in temp_args:

/usr/local/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
     43     def deco(*a, **kw):
     44         try:
---> 45             return f(*a, **kw)
     46         except py4j.protocol.Py4JJavaError as e:
     47             s = e.java_exception.toString()

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    310                 raise Py4JError(
    311                     "An error occurred while calling 012. Trace:\n3\n".
--> 312                     format(target_id, ".", name, value))
    313         else:
    314             raise Py4JError(

Py4JError: An error occurred while calling o18.__getnewargs__. Trace:
py4j.Py4JException: Method __getnewargs__([]) does not exist
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:335)
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:344)
    at py4j.Gateway.invoke(Gateway.java:252)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:209)
    at java.lang.Thread.run(Thread.java:745)

编辑:我也尝试使 java 类可序列化,但无济于事。

我的第二次尝试是在 java 中定义 UDF,但失败了,因为我不确定如何正确包装它:

java代码: 包 com.test1.test2;

import org.apache.spark.sql.api.java.UDF1;

public class TestClassUdf implements UDF1<Integer, Integer> 

    Integer retval;

    public TestClassUdf(Integer val) 
        retval = val;
    

    @Override
    public Integer call(Integer arg0) throws Exception 
        return retval;
       

但我将如何使用它? 我试过了:

from py4j.java_gateway import java_import
java_import(sc._gateway.jvm, "com.test1.test2.TestClassUdf")
a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7)
dfint = sqlContext.range(0,15)
df = dfint.withColumn("a",a(dfint.id))

但我明白了:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-5-514811090b5f> in <module>()
      3 a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7)
      4 dfint = sqlContext.range(0,15)
----> 5 df = dfint.withColumn("a",a(dfint.id))

TypeError: 'JavaObject' object is not callable

我尝试使用 a.call 而不是 a:

df = dfint.withColumn("a",a.call(dfint.id))

但得到: -------------------------------------------------- ------------------------- TypeError Traceback(最近一次调用最后一次) 在 () 3 a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7) 4 dfint = sqlContext.range(0,15) ----> 5 df = dfint.withColumn("a",a.call(dfint.id))

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
    796     def __call__(self, *args):
    797         if self.converters is not None and len(self.converters) > 0:
--> 798             (new_args, temp_args) = self._get_args(args)
    799         else:
    800             new_args = args

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in _get_args(self, args)
    783                 for converter in self.gateway_client.converters:
    784                     if converter.can_convert(arg):
--> 785                         temp_arg = converter.convert(arg, self.gateway_client)
    786                         temp_args.append(temp_arg)
    787                         new_args.append(temp_arg)

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_collections.py in convert(self, object, gateway_client)
    510         HashMap = JavaClass("java.util.HashMap", gateway_client)
    511         java_map = HashMap()
--> 512         for key in object.keys():
    513             java_map[key] = object[key]
    514         return java_map

TypeError: 'Column' object is not callable

如有任何帮助,将不胜感激。

【问题讨论】:

【参考方案1】:

在another question (and answer) of your own 的帮助下,我得到了这个关于 UDAF 的帮助。

Spark 提供了一个 udf() 方法来包装 Scala FunctionN,因此我们可以在 Scala 中包装 Java 函数并使用它。您的 Java 方法需要是静态的或在 implements Serializable 的类上。

package com.example

import org.apache.spark.sql.UserDefinedFunction
import org.apache.spark.sql.functions.udf

class MyUdf extends Serializable 
  def getUdf: UserDefinedFunction = udf(() => MyJavaClass.MyJavaMethod())

在 PySpark 中的使用:

def my_udf():
    from pyspark.sql.column import Column, _to_java_column, _to_seq
    pcls = "com.example.MyUdf"
    jc = sc._jvm.java.lang.Thread.currentThread() \
        .getContextClassLoader().loadClass(pcls).newInstance().getUdf().apply
    return Column(jc(_to_seq(sc, [], _to_java_column)))

rdd1 = sc.parallelize(['c1': 'a', 'c1': 'b', 'c1': 'c'])
df1 = rdd1.toDF()
df2 = df1.withColumn('mycol', my_udf())

与您的其他问答中的 UDAF 一样,我们可以使用 return Column(jc(_to_seq(sc, ["col1", "col2"], _to_java_column))) 将列传递给它

【讨论】:

【参考方案2】:

根据https://dzone.com/articles/pyspark-java-udf-integration-1,您可以使用Java 定义UDF1

public class AddNumber implements UDF1<Long, Long> 

@Override
public Long call(Long num) throws Exception 
      return (num + 5);
   


然后在使用 --package &lt;your-jar&gt; 将 jar 添加到您的 pyspark 之后

您可以在 pyspark 中将其用作:

from pyspark.sql import functions as F
from pyspark.sql.types import LongType


>>> df = spark.createDataFrame([float(i) for i in range(100)], FloatType()).toDF("a")
>>> spark.udf.registerJavaFunction("addNumber", "com.example.spark.AddNumber", LongType())
>>> df.withColumn("b", F.expr("addNumber(a)")).show(5)
+---+---+
|  a|  b|
+---+---+
|0.0|  5|
|1.0|  6|
|2.0|  7|
|3.0|  8|
|4.0|  8|
+---+---+
only showing top 5 rows

【讨论】:

以上是关于实现一个 java UDF 并从 pyspark 调用它的主要内容,如果未能解决你的问题,请参考以下文章

Pyspark:从 Python 到 Pyspark 实现 lambda 函数和 udf

当我在 pyspark EMR 5.x 中运行用 Java 编写的 hive UDF 时出错

如何使用 Hive 上下文中的 Pyspark 调用用 Java 编写的 Hive UDF

pyspark 中的 Pandas UDF

pyspark中未定义的函数UDF?

使用 Scala 类作为带有 pyspark 的 UDF