实现一个 java UDF 并从 pyspark 调用它
Posted
技术标签:
【中文标题】实现一个 java UDF 并从 pyspark 调用它【英文标题】:Implement a java UDF and call it from pyspark 【发布时间】:2016-07-10 07:59:29 【问题描述】:我需要创建一个用于 pyspark python 的 UDF,它使用 java 对象进行内部计算。
如果它是一个简单的 python,我会这样做:
def f(x):
return 7
fudf = pyspark.sql.functions.udf(f,pyspark.sql.types.IntegerType())
并使用以下方法调用它:
df = sqlContext.range(0,5)
df2 = df.withColumn("a",fudf(df.id)).show()
但是,我需要的函数的实现是在 java 中而不是在 python 中。我需要以某种方式包装它,以便我可以从 python 中以类似的方式调用它。
我的第一个尝试是实现 java 对象,然后将其包装在 pyspark 中的 python 中并将其转换为 UDF。失败并出现序列化错误。
Java 代码:
package com.test1.test2;
public class TestClass1
Integer internalVal;
public TestClass1(Integer val1)
internalVal = val1;
public Integer do_something(Integer val)
return internalVal;
pyspark 代码:
from py4j.java_gateway import java_import
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
java_import(sc._gateway.jvm, "com.test1.test2.TestClass1")
a = sc._gateway.jvm.com.test1.test2.TestClass1(7)
audf = udf(a,IntegerType())
错误:
---------------------------------------------------------------------------
Py4JError Traceback (most recent call last)
<ipython-input-2-9756772ab14f> in <module>()
4 java_import(sc._gateway.jvm, "com.test1.test2.TestClass1")
5 a = sc._gateway.jvm.com.test1.test2.TestClass1(7)
----> 6 audf = udf(a,IntegerType())
/usr/local/spark/python/pyspark/sql/functions.py in udf(f, returnType)
1595 [Row(slen=5), Row(slen=3)]
1596 """
-> 1597 return UserDefinedFunction(f, returnType)
1598
1599 blacklist = ['map', 'since', 'ignore_unicode_prefix']
/usr/local/spark/python/pyspark/sql/functions.py in __init__(self, func, returnType, name)
1556 self.returnType = returnType
1557 self._broadcast = None
-> 1558 self._judf = self._create_judf(name)
1559
1560 def _create_judf(self, name):
/usr/local/spark/python/pyspark/sql/functions.py in _create_judf(self, name)
1565 command = (func, None, ser, ser)
1566 sc = SparkContext.getOrCreate()
-> 1567 pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
1568 ctx = SQLContext.getOrCreate(sc)
1569 jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
/usr/local/spark/python/pyspark/rdd.py in _prepare_for_python_RDD(sc, command, obj)
2297 # the serialized command will be compressed by broadcast
2298 ser = CloudPickleSerializer()
-> 2299 pickled_command = ser.dumps(command)
2300 if len(pickled_command) > (1 << 20): # 1M
2301 # The broadcast will have same life cycle as created PythonRDD
/usr/local/spark/python/pyspark/serializers.py in dumps(self, obj)
426
427 def dumps(self, obj):
--> 428 return cloudpickle.dumps(obj, 2)
429
430
/usr/local/spark/python/pyspark/cloudpickle.py in dumps(obj, protocol)
644
645 cp = CloudPickler(file,protocol)
--> 646 cp.dump(obj)
647
648 return file.getvalue()
/usr/local/spark/python/pyspark/cloudpickle.py in dump(self, obj)
105 self.inject_addons()
106 try:
--> 107 return Pickler.dump(self, obj)
108 except RuntimeError as e:
109 if 'recursion' in e.args[0]:
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in dump(self, obj)
222 if self.proto >= 2:
223 self.write(PROTO + chr(self.proto))
--> 224 self.save(obj)
225 self.write(STOP)
226
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
284 f = self.dispatch.get(t)
285 if f:
--> 286 f(self, obj) # Call unbound method with explicit self
287 return
288
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj)
566 write(MARK)
567 for element in obj:
--> 568 save(element)
569
570 if id(obj) in memo:
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
284 f = self.dispatch.get(t)
285 if f:
--> 286 f(self, obj) # Call unbound method with explicit self
287 return
288
/usr/local/spark/python/pyspark/cloudpickle.py in save_function(self, obj, name)
191 if islambda(obj) or obj.__code__.co_filename == '<stdin>' or themodule is None:
192 #print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule)
--> 193 self.save_function_tuple(obj)
194 return
195 else:
/usr/local/spark/python/pyspark/cloudpickle.py in save_function_tuple(self, func)
234 # create a skeleton function object and memoize it
235 save(_make_skel_func)
--> 236 save((code, closure, base_globals))
237 write(pickle.REDUCE)
238 self.memoize(func)
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
284 f = self.dispatch.get(t)
285 if f:
--> 286 f(self, obj) # Call unbound method with explicit self
287 return
288
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj)
552 if n <= 3 and proto >= 2:
553 for element in obj:
--> 554 save(element)
555 # Subtle. Same as in the big comment below.
556 if id(obj) in memo:
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
284 f = self.dispatch.get(t)
285 if f:
--> 286 f(self, obj) # Call unbound method with explicit self
287 return
288
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_list(self, obj)
604
605 self.memoize(obj)
--> 606 self._batch_appends(iter(obj))
607
608 dispatch[ListType] = save_list
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in _batch_appends(self, items)
637 write(MARK)
638 for x in tmp:
--> 639 save(x)
640 write(APPENDS)
641 elif n:
/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj)
304 reduce = getattr(obj, "__reduce_ex__", None)
305 if reduce:
--> 306 rv = reduce(self.proto)
307 else:
308 reduce = getattr(obj, "__reduce__", None)
/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
811 answer = self.gateway_client.send_command(command)
812 return_value = get_return_value(
--> 813 answer, self.gateway_client, self.target_id, self.name)
814
815 for temp_arg in temp_args:
/usr/local/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
43 def deco(*a, **kw):
44 try:
---> 45 return f(*a, **kw)
46 except py4j.protocol.Py4JJavaError as e:
47 s = e.java_exception.toString()
/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
310 raise Py4JError(
311 "An error occurred while calling 012. Trace:\n3\n".
--> 312 format(target_id, ".", name, value))
313 else:
314 raise Py4JError(
Py4JError: An error occurred while calling o18.__getnewargs__. Trace:
py4j.Py4JException: Method __getnewargs__([]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:335)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:344)
at py4j.Gateway.invoke(Gateway.java:252)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:209)
at java.lang.Thread.run(Thread.java:745)
编辑:我也尝试使 java 类可序列化,但无济于事。
我的第二次尝试是在 java 中定义 UDF,但失败了,因为我不确定如何正确包装它:
java代码: 包 com.test1.test2;
import org.apache.spark.sql.api.java.UDF1;
public class TestClassUdf implements UDF1<Integer, Integer>
Integer retval;
public TestClassUdf(Integer val)
retval = val;
@Override
public Integer call(Integer arg0) throws Exception
return retval;
但我将如何使用它? 我试过了:
from py4j.java_gateway import java_import
java_import(sc._gateway.jvm, "com.test1.test2.TestClassUdf")
a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7)
dfint = sqlContext.range(0,15)
df = dfint.withColumn("a",a(dfint.id))
但我明白了:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-5-514811090b5f> in <module>()
3 a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7)
4 dfint = sqlContext.range(0,15)
----> 5 df = dfint.withColumn("a",a(dfint.id))
TypeError: 'JavaObject' object is not callable
我尝试使用 a.call 而不是 a:
df = dfint.withColumn("a",a.call(dfint.id))
但得到: -------------------------------------------------- ------------------------- TypeError Traceback(最近一次调用最后一次) 在 () 3 a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7) 4 dfint = sqlContext.range(0,15) ----> 5 df = dfint.withColumn("a",a.call(dfint.id))
/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
796 def __call__(self, *args):
797 if self.converters is not None and len(self.converters) > 0:
--> 798 (new_args, temp_args) = self._get_args(args)
799 else:
800 new_args = args
/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in _get_args(self, args)
783 for converter in self.gateway_client.converters:
784 if converter.can_convert(arg):
--> 785 temp_arg = converter.convert(arg, self.gateway_client)
786 temp_args.append(temp_arg)
787 new_args.append(temp_arg)
/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_collections.py in convert(self, object, gateway_client)
510 HashMap = JavaClass("java.util.HashMap", gateway_client)
511 java_map = HashMap()
--> 512 for key in object.keys():
513 java_map[key] = object[key]
514 return java_map
TypeError: 'Column' object is not callable
如有任何帮助,将不胜感激。
【问题讨论】:
【参考方案1】:在another question (and answer) of your own 的帮助下,我得到了这个关于 UDAF 的帮助。
Spark 提供了一个 udf()
方法来包装 Scala FunctionN
,因此我们可以在 Scala 中包装 Java 函数并使用它。您的 Java 方法需要是静态的或在 implements Serializable
的类上。
package com.example
import org.apache.spark.sql.UserDefinedFunction
import org.apache.spark.sql.functions.udf
class MyUdf extends Serializable
def getUdf: UserDefinedFunction = udf(() => MyJavaClass.MyJavaMethod())
在 PySpark 中的使用:
def my_udf():
from pyspark.sql.column import Column, _to_java_column, _to_seq
pcls = "com.example.MyUdf"
jc = sc._jvm.java.lang.Thread.currentThread() \
.getContextClassLoader().loadClass(pcls).newInstance().getUdf().apply
return Column(jc(_to_seq(sc, [], _to_java_column)))
rdd1 = sc.parallelize(['c1': 'a', 'c1': 'b', 'c1': 'c'])
df1 = rdd1.toDF()
df2 = df1.withColumn('mycol', my_udf())
与您的其他问答中的 UDAF 一样,我们可以使用 return Column(jc(_to_seq(sc, ["col1", "col2"], _to_java_column)))
将列传递给它
【讨论】:
【参考方案2】:根据https://dzone.com/articles/pyspark-java-udf-integration-1,您可以使用Java 定义UDF1
public class AddNumber implements UDF1<Long, Long>
@Override
public Long call(Long num) throws Exception
return (num + 5);
然后在使用 --package <your-jar>
将 jar 添加到您的 pyspark 之后
您可以在 pyspark 中将其用作:
from pyspark.sql import functions as F
from pyspark.sql.types import LongType
>>> df = spark.createDataFrame([float(i) for i in range(100)], FloatType()).toDF("a")
>>> spark.udf.registerJavaFunction("addNumber", "com.example.spark.AddNumber", LongType())
>>> df.withColumn("b", F.expr("addNumber(a)")).show(5)
+---+---+
| a| b|
+---+---+
|0.0| 5|
|1.0| 6|
|2.0| 7|
|3.0| 8|
|4.0| 8|
+---+---+
only showing top 5 rows
【讨论】:
以上是关于实现一个 java UDF 并从 pyspark 调用它的主要内容,如果未能解决你的问题,请参考以下文章
Pyspark:从 Python 到 Pyspark 实现 lambda 函数和 udf
当我在 pyspark EMR 5.x 中运行用 Java 编写的 hive UDF 时出错