PySpark Dataframe 根据函数返回值创建新列

Posted

技术标签:

【中文标题】PySpark Dataframe 根据函数返回值创建新列【英文标题】:PySpark Dataframe create new column based on function return value 【发布时间】:2018-11-22 00:52:28 【问题描述】:

我有一个数据框,我想根据函数返回的值添加一个新列。此函数的参数是来自同一数据帧的四列。 This 一个和this 一个与我想要的有点相似,但没有回答我的问题。

这是我的数据框(列多于这四列)

 + ------ + ------ + ------ + ------ +
 | lat1   | lng1   | lat2   | lng2   |
 + ------ + ------ + ------ + ------ +
 | -32.92 | 151.80 | -32.89 | 151.71 |
 | -32.92 | 151.80 | -32.89 | 151.71 |
 | -32.92 | 151.80 | -32.89 | 151.71 |
 | -32.92 | 151.80 | -32.89 | 151.71 |
 | -32.92 | 151.80 | -32.89 | 151.71 |
 + ------ + ------ + ------ + ------ +

我想添加另一列“距离”,它是两个位置点之间的总距离(纬度/经度)。我有一个函数,它接受四个位置点作为参数并将差值作为浮点数返回。

def get_distance(lat_1, lng_1, lat_2, lng_2): 
  d_lat = lat_2 - lat_1
  d_lng = lng_2 - lng_1 

  temp = (  
  math.sin(d_lat / 2) ** 2 
    + math.cos(lat_1) 
    * math.cos(lat_2) 
    * math.sin(d_lng / 2) ** 2
  )

  return 6367.0 * (2 * math.asin(math.sqrt(temp))) 

这是我的尝试导致错误,我也不确定这种方法,它基于我已经提到的其他问题。

udf_func = udf(lambda lat_1, lng_1, lat_2, lng_2: get_distance(lat_1, lng_1, lat_2, lng_2), returnType=FloatType())

df1 = df.withColumn('difference', udf_func(df.lat1, df_lng1, df.lat2, df.lng2))
df_subset1.show()

这是错误堆栈跟踪

An error occurred while calling o1300.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 50.0 failed 4 times, most recent failure: Lost task 0.3 in stage 50.0 (TID 341, data05.dac.local, executor 255): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 171, in main
    process()
  File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 166, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 103, in <lambda>
    func = lambda _, it: map(mapper, it)
  File "<string>", line 1, in <lambda>
  File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 70, in <lambda>
    return lambda *a: f(*a)
  File "<stdin>", line 2, in <lambda>
  File "<stdin>", line 5, in get_distance
TypeError: unsupported operand type(s) for -: 'unicode' and 'unicode'
    at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRDD.scala:193)
    at org.apache.spark.api.python.PythonRunner$$anon$1.<init>(PythonRDD.scala:234)
    at org.apache.spark.api.python.PythonRunner.compute(PythonRDD.scala:152)
    at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$doExecute$1.apply(BatchEvalPythonExec.scala:144)
    at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$doExecute$1.apply(BatchEvalPythonExec.scala:87)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:797)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:797)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
    at org.apache.spark.scheduler.Task.run(Task.scala:99)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:322)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
    at java.lang.Thread.run(Thread.java:745)
Driver stacktrace:
    at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1435)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1423)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1422)
    at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
    at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1422)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
    at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
    at scala.Option.foreach(Option.scala:257)
    at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:802)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1650)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1605)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1594)
    at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
    at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:628)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:1928)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:1941)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:1954)
    at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:336)
    at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
    at org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$execute$1$1.apply(Dataset.scala:2386)
    at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:57)
    at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$withNewExecutionId(Dataset.scala:2788)
    at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$execute$1(Dataset.scala:2385)
    at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collect(Dataset.scala:2392)
    at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2128)
    at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2127)
    at org.apache.spark.sql.Dataset.withTypedCallback(Dataset.scala:2818)
    at org.apache.spark.sql.Dataset.head(Dataset.scala:2127)
    at org.apache.spark.sql.Dataset.take(Dataset.scala:2342)
    at org.apache.spark.sql.Dataset.showString(Dataset.scala:248)
    at sun.reflect.GeneratedMethodAccessor94.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:280)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:214)
    at java.lang.Thread.run(Thread.java:745)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 171, in main
    process()
  File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 166, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 103, in <lambda>
    func = lambda _, it: map(mapper, it)
  File "<string>", line 1, in <lambda>
  File "/hadoop/log/yarn/local/usercache/baiga/appcache/application_1541820416317_0349/container_e122_1541820416317_0349_01_000269/pyspark.zip/pyspark/worker.py", line 70, in <lambda>
    return lambda *a: f(*a)
  File "<stdin>", line 2, in <lambda>
  File "<stdin>", line 5, in get_distance
TypeError: unsupported operand type(s) for -: 'unicode' and 'unicode'
    at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRDD.scala:193)
    at org.apache.spark.api.python.PythonRunner$$anon$1.<init>(PythonRDD.scala:234)
    at org.apache.spark.api.python.PythonRunner.compute(PythonRDD.scala:152)
    at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$doExecute$1.apply(BatchEvalPythonExec.scala:144)
    at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$doExecute$1.apply(BatchEvalPythonExec.scala:87)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:797)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:797)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
    at org.apache.spark.scheduler.Task.run(Task.scala:99)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:322)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
    ... 1 more

请指导。

【问题讨论】:

【参考方案1】:

让我重写它,以便人们能够理解上下文。有两个步骤-

1. 最初创建的DataFrame 的列采用String 格式,因此无法对其进行计算。因此,第一步,我们必须将所有 4 列转换为Float

2.在此DataFrame 上应用UDF 以创建新列distance

import math
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType
df = sqlContext.createDataFrame([('-32.92','151.80','-32.89','151.71'),('-32.92','151.80','-32.89','151.71'),
                                 ('-32.92','151.80','-32.89','151.71'),('-32.92','151.80','-32.89','151.71'),
                                 ('-32.92','151.80','-32.89','151.71'),], ("lat1", "lng1", "lat2","lng2"))
print('Original Schema - columns imported as "String"')
df.printSchema()       #All colums are Strings.    
# Converting String based numbers into float.
df = df.withColumn('lat1', df.lat1.cast("float"))\
       .withColumn('lng1', df.lng1.cast("float"))\
       .withColumn('lat2', df.lat2.cast("float"))\
       .withColumn('lng2', df.lng2.cast("float"))

print('Schema after converting "String" to "Float"')
df.printSchema()       #All columns are float now.   
df.show()    
#Function defined by user, to calculate distance between two points on the globe.
def get_distance(lat_1, lng_1, lat_2, lng_2): 
    d_lat = lat_2 - lat_1
    d_lng = lng_2 - lng_1 

    temp = (  
    math.sin(d_lat / 2) ** 2 
      + math.cos(lat_1) 
      * math.cos(lat_2) 
      * math.sin(d_lng / 2) ** 2
    )
    return 6367.0 * (2 * math.asin(math.sqrt(temp))) 

udf_func = udf(get_distance,FloatType()) #Creating a 'User Defined Function' to calculate distance between two points.
df = df.withColumn("distance",udf_func(df.lat1, df.lng1, df.lat2, df.lng2)) #Creating column "distance" based on function 'get_distance'
df.show()

输出:

Original Schema - columns imported as "String"
root
 |-- lat1: string (nullable = true)
 |-- lng1: string (nullable = true)
 |-- lat2: string (nullable = true)
 |-- lng2: string (nullable = true)

Schema after converting "String" to "Float"
root
 |-- lat1: float (nullable = true)
 |-- lng1: float (nullable = true)
 |-- lat2: float (nullable = true)
 |-- lng2: float (nullable = true)

+------+-----+------+------+
|  lat1| lng1|  lat2|  lng2|
+------+-----+------+------+
|-32.92|151.8|-32.89|151.71|
|-32.92|151.8|-32.89|151.71|
|-32.92|151.8|-32.89|151.71|
|-32.92|151.8|-32.89|151.71|
|-32.92|151.8|-32.89|151.71|
+------+-----+------+------+

+------+-----+------+------+---------+
|  lat1| lng1|  lat2|  lng2| distance|
+------+-----+------+------+---------+
|-32.92|151.8|-32.89|151.71|196.45587|
|-32.92|151.8|-32.89|151.71|196.45587|
|-32.92|151.8|-32.89|151.71|196.45587|
|-32.92|151.8|-32.89|151.71|196.45587|
|-32.92|151.8|-32.89|151.71|196.45587|
+------+-----+------+------+---------+

代码现在完美运行。

【讨论】:

虽然此代码可能会回答问题,但提供有关它如何和/或为什么解决问题的额外上下文将提高​​答案的长期价值。 先生,我已经进行了更改。现在应该很清楚了。 非常感谢您的努力 - 问题仅在于列的类型。字符串列必须转换为浮点数。【参考方案2】:

关于 unicode 的堆栈跟踪部分表明该列的类型是 StringType,因为您不能减去两个字符串。您可以使用df.printSchema() 进行检查。

如果您在任何计算之前将所有经纬度转换为浮点数(例如float(lat1)),您的 udf 应该可以正常执行。

【讨论】:

以上是关于PySpark Dataframe 根据函数返回值创建新列的主要内容,如果未能解决你的问题,请参考以下文章

PySpark:是不是有可能根据非 Null 值创建动态数量的 DataFrame

Pyspark Dataframe Imputations - 根据指定条件用列平均值替换未知和缺失值

如何使用滚动窗口函数计算 Pyspark Dataframe 中等于某个值的相邻值的数量?

如何根据来自其他 pyspark 数据帧的日期值过滤第二个 pyspark 数据帧?

PySpark DataFrame基础操作

用列表 Pyspark Dataframe 中的值替换 NA