PySpark Dataframe 根据函数返回值创建新列
Posted
技术标签:
【中文标题】PySpark Dataframe 根据函数返回值创建新列【英文标题】:PySpark Dataframe create new column based on function return value 【发布时间】:2018-11-22 00:52:28 【问题描述】:我有一个数据框,我想根据函数返回的值添加一个新列。此函数的参数是来自同一数据帧的四列。 This 一个和this 一个与我想要的有点相似,但没有回答我的问题。
这是我的数据框(列多于这四列)
+ ------ + ------ + ------ + ------ +
| lat1 | lng1 | lat2 | lng2 |
+ ------ + ------ + ------ + ------ +
| -32.92 | 151.80 | -32.89 | 151.71 |
| -32.92 | 151.80 | -32.89 | 151.71 |
| -32.92 | 151.80 | -32.89 | 151.71 |
| -32.92 | 151.80 | -32.89 | 151.71 |
| -32.92 | 151.80 | -32.89 | 151.71 |
+ ------ + ------ + ------ + ------ +
我想添加另一列“距离”,它是两个位置点之间的总距离(纬度/经度)。我有一个函数,它接受四个位置点作为参数并将差值作为浮点数返回。
def get_distance(lat_1, lng_1, lat_2, lng_2):
d_lat = lat_2 - lat_1
d_lng = lng_2 - lng_1
temp = (
math.sin(d_lat / 2) ** 2
+ math.cos(lat_1)
* math.cos(lat_2)
* math.sin(d_lng / 2) ** 2
)
return 6367.0 * (2 * math.asin(math.sqrt(temp)))
这是我的尝试导致错误,我也不确定这种方法,它基于我已经提到的其他问题。
udf_func = udf(lambda lat_1, lng_1, lat_2, lng_2: get_distance(lat_1, lng_1, lat_2, lng_2), returnType=FloatType())
df1 = df.withColumn('difference', udf_func(df.lat1, df_lng1, df.lat2, df.lng2))
df_subset1.show()
这是错误堆栈跟踪
An error occurred while calling o1300.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 50.0 failed 4 times, most recent failure: Lost task 0.3 in stage 50.0 (TID 341, data05.dac.local, executor 255): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 171, in main
process()
File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 166, in process
serializer.dump_stream(func(split_index, iterator), outfile)
File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 103, in <lambda>
func = lambda _, it: map(mapper, it)
File "<string>", line 1, in <lambda>
File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 70, in <lambda>
return lambda *a: f(*a)
File "<stdin>", line 2, in <lambda>
File "<stdin>", line 5, in get_distance
TypeError: unsupported operand type(s) for -: 'unicode' and 'unicode'
at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRDD.scala:193)
at org.apache.spark.api.python.PythonRunner$$anon$1.<init>(PythonRDD.scala:234)
at org.apache.spark.api.python.PythonRunner.compute(PythonRDD.scala:152)
at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$doExecute$1.apply(BatchEvalPythonExec.scala:144)
at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$doExecute$1.apply(BatchEvalPythonExec.scala:87)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:797)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:797)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:99)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:322)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1435)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1423)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1422)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1422)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
at scala.Option.foreach(Option.scala:257)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:802)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1650)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1605)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1594)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:628)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:1928)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:1941)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:1954)
at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:336)
at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
at org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$execute$1$1.apply(Dataset.scala:2386)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:57)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$withNewExecutionId(Dataset.scala:2788)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$execute$1(Dataset.scala:2385)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collect(Dataset.scala:2392)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2128)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2127)
at org.apache.spark.sql.Dataset.withTypedCallback(Dataset.scala:2818)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2127)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2342)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:248)
at sun.reflect.GeneratedMethodAccessor94.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:280)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:214)
at java.lang.Thread.run(Thread.java:745)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 171, in main
process()
File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 166, in process
serializer.dump_stream(func(split_index, iterator), outfile)
File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 103, in <lambda>
func = lambda _, it: map(mapper, it)
File "<string>", line 1, in <lambda>
File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 70, in <lambda>
return lambda *a: f(*a)
File "<stdin>", line 2, in <lambda>
File "<stdin>", line 5, in get_distance
TypeError: unsupported operand type(s) for -: 'unicode' and 'unicode'
at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRDD.scala:193)
at org.apache.spark.api.python.PythonRunner$$anon$1.<init>(PythonRDD.scala:234)
at org.apache.spark.api.python.PythonRunner.compute(PythonRDD.scala:152)
at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$doExecute$1.apply(BatchEvalPythonExec.scala:144)
at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$doExecute$1.apply(BatchEvalPythonExec.scala:87)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:797)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:797)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:99)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:322)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
... 1 more
请指导。
【问题讨论】:
【参考方案1】:让我重写它,以便人们能够理解上下文。有两个步骤-
1. 最初创建的DataFrame
的列采用String
格式,因此无法对其进行计算。因此,第一步,我们必须将所有 4 列转换为Float
。
2.在此DataFrame
上应用UDF
以创建新列distance
。
import math
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType
df = sqlContext.createDataFrame([('-32.92','151.80','-32.89','151.71'),('-32.92','151.80','-32.89','151.71'),
('-32.92','151.80','-32.89','151.71'),('-32.92','151.80','-32.89','151.71'),
('-32.92','151.80','-32.89','151.71'),], ("lat1", "lng1", "lat2","lng2"))
print('Original Schema - columns imported as "String"')
df.printSchema() #All colums are Strings.
# Converting String based numbers into float.
df = df.withColumn('lat1', df.lat1.cast("float"))\
.withColumn('lng1', df.lng1.cast("float"))\
.withColumn('lat2', df.lat2.cast("float"))\
.withColumn('lng2', df.lng2.cast("float"))
print('Schema after converting "String" to "Float"')
df.printSchema() #All columns are float now.
df.show()
#Function defined by user, to calculate distance between two points on the globe.
def get_distance(lat_1, lng_1, lat_2, lng_2):
d_lat = lat_2 - lat_1
d_lng = lng_2 - lng_1
temp = (
math.sin(d_lat / 2) ** 2
+ math.cos(lat_1)
* math.cos(lat_2)
* math.sin(d_lng / 2) ** 2
)
return 6367.0 * (2 * math.asin(math.sqrt(temp)))
udf_func = udf(get_distance,FloatType()) #Creating a 'User Defined Function' to calculate distance between two points.
df = df.withColumn("distance",udf_func(df.lat1, df.lng1, df.lat2, df.lng2)) #Creating column "distance" based on function 'get_distance'
df.show()
输出:
Original Schema - columns imported as "String"
root
|-- lat1: string (nullable = true)
|-- lng1: string (nullable = true)
|-- lat2: string (nullable = true)
|-- lng2: string (nullable = true)
Schema after converting "String" to "Float"
root
|-- lat1: float (nullable = true)
|-- lng1: float (nullable = true)
|-- lat2: float (nullable = true)
|-- lng2: float (nullable = true)
+------+-----+------+------+
| lat1| lng1| lat2| lng2|
+------+-----+------+------+
|-32.92|151.8|-32.89|151.71|
|-32.92|151.8|-32.89|151.71|
|-32.92|151.8|-32.89|151.71|
|-32.92|151.8|-32.89|151.71|
|-32.92|151.8|-32.89|151.71|
+------+-----+------+------+
+------+-----+------+------+---------+
| lat1| lng1| lat2| lng2| distance|
+------+-----+------+------+---------+
|-32.92|151.8|-32.89|151.71|196.45587|
|-32.92|151.8|-32.89|151.71|196.45587|
|-32.92|151.8|-32.89|151.71|196.45587|
|-32.92|151.8|-32.89|151.71|196.45587|
|-32.92|151.8|-32.89|151.71|196.45587|
+------+-----+------+------+---------+
代码现在完美运行。
【讨论】:
虽然此代码可能会回答问题,但提供有关它如何和/或为什么解决问题的额外上下文将提高答案的长期价值。 先生,我已经进行了更改。现在应该很清楚了。 非常感谢您的努力 - 问题仅在于列的类型。字符串列必须转换为浮点数。【参考方案2】:关于 unicode 的堆栈跟踪部分表明该列的类型是 StringType,因为您不能减去两个字符串。您可以使用df.printSchema()
进行检查。
如果您在任何计算之前将所有经纬度转换为浮点数(例如float(lat1)
),您的 udf 应该可以正常执行。
【讨论】:
以上是关于PySpark Dataframe 根据函数返回值创建新列的主要内容,如果未能解决你的问题,请参考以下文章
PySpark:是不是有可能根据非 Null 值创建动态数量的 DataFrame
Pyspark Dataframe Imputations - 根据指定条件用列平均值替换未知和缺失值
如何使用滚动窗口函数计算 Pyspark Dataframe 中等于某个值的相邻值的数量?