如何在 PySpark 的分组对象中插入一列?

Posted

技术标签:

【中文标题】如何在 PySpark 的分组对象中插入一列?【英文标题】:How to interpolate a column within a grouped object in PySpark? 【发布时间】:2019-02-10 01:50:29 【问题描述】:

如何在分组数据中插入 PySpark 数据框?

例如:

我有一个包含以下列的 PySpark 数据框:

+--------+-------------------+--------+
|webID   |timestamp          |counts  |
+--------+-------------------+--------+
|John    |2018-02-01 03:00:00|60      |
|John    |2018-02-01 03:03:00|66      |
|John    |2018-02-01 03:05:00|70      |
|John    |2018-02-01 03:08:00|76      |
|Mo      |2017-06-04 01:05:00|10      |
|Mo      |2017-06-04 01:07:00|20      |
|Mo      |2017-06-04 01:10:00|35      |
|Mo      |2017-06-04 01:11:00|40      |
+--------+----------------- -+--------+

我需要每分钟将 John 和 Mo 的计数数据插入一个数据点,在他们自己的时间间隔内。我愿意接受任何简单的线性插值 - 但请注意,我的真实数据是每隔几秒,我想插值到每一秒。

所以结果应该是:

+--------+-------------------+--------+
|webID   |timestamp          |counts  |
+--------+-------------------+--------+
|John    |2018-02-01 03:00:00|60      |
|John    |2018-02-01 03:01:00|62      |
|John    |2018-02-01 03:02:00|64      |
|John    |2018-02-01 03:03:00|66      |
|John    |2018-02-01 03:04:00|68      |
|John    |2018-02-01 03:05:00|70      |
|John    |2018-02-01 03:06:00|72      |
|John    |2018-02-01 03:07:00|74      |
|John    |2018-02-01 03:08:00|76      |
|Mo      |2017-06-04 01:05:00|10      |
|Mo      |2017-06-04 01:06:00|15      |
|Mo      |2017-06-04 01:07:00|20      |
|Mo      |2017-06-04 01:08:00|25      |
|Mo      |2017-06-04 01:09:00|30      |
|Mo      |2017-06-04 01:10:00|35      |
|Mo      |2017-06-04 01:11:00|40      |
+--------+----------------- -+--------+

需要将新行添加到我的原始数据框中。 正在寻找 PySpark 解决方案。

【问题讨论】:

【参考方案1】:

如果您使用 Python,完成任务的最短方法是重用现有的 Pandas 函数,GROUPED_MAPudf:

from operator import attrgetter
from pyspark.sql.types import StructType
from pyspark.sql.functions import pandas_udf, PandasUDFType

def resample(schema, freq, timestamp_col = "timestamp",**kwargs):
    @pandas_udf(
        StructType(sorted(schema, key=attrgetter("name"))), 
        PandasUDFType.GROUPED_MAP)
    def _(pdf):
        pdf.set_index(timestamp_col, inplace=True)
        pdf = pdf.resample(freq).interpolate()
        pdf.ffill(inplace=True)
        pdf.reset_index(drop=False, inplace=True)
        pdf.sort_index(axis=1, inplace=True)
        return pdf
    return _

应用于您的数据:

from pyspark.sql.functions import to_timestamp

df = spark.createDataFrame([
    ("John",   "2018-02-01 03:00:00", 60),  
    ("John",   "2018-02-01 03:03:00", 66),  
    ("John",   "2018-02-01 03:05:00", 70),  
    ("John",   "2018-02-01 03:08:00", 76),  
    ("Mo",     "2017-06-04 01:05:00", 10),  
    ("Mo",     "2017-06-04 01:07:00", 20),  
    ("Mo",     "2017-06-04 01:10:00", 35),  
    ("Mo",     "2017-06-04 01:11:00", 40),
], ("webID", "timestamp", "counts")).withColumn(
  "timestamp", to_timestamp("timestamp")
)

df.groupBy("webID").apply(resample(df.schema, "60S")).show()

它产生了

+------+-------------------+-----+
|counts|          timestamp|webID|
+------+-------------------+-----+
|    60|2018-02-01 03:00:00| John|
|    62|2018-02-01 03:01:00| John|
|    64|2018-02-01 03:02:00| John|
|    66|2018-02-01 03:03:00| John|
|    68|2018-02-01 03:04:00| John|
|    70|2018-02-01 03:05:00| John|
|    72|2018-02-01 03:06:00| John|
|    74|2018-02-01 03:07:00| John|
|    76|2018-02-01 03:08:00| John|
|    10|2017-06-04 01:05:00|   Mo|
|    15|2017-06-04 01:06:00|   Mo|
|    20|2017-06-04 01:07:00|   Mo|
|    25|2017-06-04 01:08:00|   Mo|
|    30|2017-06-04 01:09:00|   Mo|
|    35|2017-06-04 01:10:00|   Mo|
|    40|2017-06-04 01:11:00|   Mo|
+------+-------------------+-----+

这是在假设单个webID 的输入和插值数据都可以放入单个节点的内存中的假设下工作的(通常其他精确的非迭代解决方案必须做出类似的假设)。如果不是这种情况,您可以通过重叠窗口轻松近似

partial = (df
    .groupBy("webID", window("timestamp", "5 minutes", "3 minutes")["start"])
    .apply(resample(df.schema, "60S")))

并汇总最终结果

from pyspark.sql.functions import mean

(partial
    .groupBy("webID", "timestamp")
    .agg(mean("counts")
    .alias("counts"))
    # Order by key and timestamp, only for consistent presentation
    .orderBy("webId", "timestamp")
    .show())

这当然要昂贵得多(有两次洗牌,有些值将被计算多次),但如果重叠不足以包括下一次观察,也会留下间隙。

+-----+-------------------+------+
|webID|          timestamp|counts|
+-----+-------------------+------+
| John|2018-02-01 03:00:00|  60.0|
| John|2018-02-01 03:01:00|  62.0|
| John|2018-02-01 03:02:00|  64.0|
| John|2018-02-01 03:03:00|  66.0|
| John|2018-02-01 03:04:00|  68.0|
| John|2018-02-01 03:05:00|  70.0|
| John|2018-02-01 03:08:00|  76.0|
|   Mo|2017-06-04 01:05:00|  10.0|
|   Mo|2017-06-04 01:06:00|  15.0|
|   Mo|2017-06-04 01:07:00|  20.0|
|   Mo|2017-06-04 01:08:00|  25.0|
|   Mo|2017-06-04 01:09:00|  30.0|
|   Mo|2017-06-04 01:10:00|  35.0|
|   Mo|2017-06-04 01:11:00|  40.0|
+-----+-------------------+------+

【讨论】:

仅供参考 - 秒的单位是大写的“S”。否则这有效。谢谢! 感谢代码,它适用于 60S,但我收到 1S 的错误。你知道为什么吗?【参考方案2】:

解决此问题的本机 pyspark 实现(无 udf)是:

import pyspark.sql.functions as F
resample_interval = 1  # Resample interval size in seconds

df_interpolated = (
  df_data
  # Get timestamp and Counts of previous measurement via window function
  .selectExpr(
    "webID",
    "LAG(Timestamp) OVER (PARTITION BY webID ORDER BY Timestamp ASC) as PreviousTimestamp",
    "Timestamp as NextTimestamp",
    "LAG(Counts) OVER (PARTITION BY webID ORDER BY Timestamp ASC) as PreviousCounts",
    "Counts as NextCounts",
  )
  # To determine resample interval round up start and round down end timeinterval to nearest interval boundary
  .withColumn("PreviousTimestampRoundUp", F.expr(f"to_timestamp(ceil(unix_timestamp(PreviousTimestamp)/resample_interval)*resample_interval)"))
  .withColumn("NextTimestampRoundDown", F.expr(f"to_timestamp(floor(unix_timestamp(NextTimestamp)/resample_interval)*resample_interval)"))
  # Make sure we don't get any negative intervals (whole interval is within resample interval)
  .filter("PreviousTimestampRoundUp<=NextTimestampRoundDown")
  # Create resampled time axis by creating all "interval" timestamps between previous and next timestamp
  .withColumn("Timestamp", F.expr(f"explode(sequence(PreviousTimestampRoundUp, NextTimestampRoundDown, interval resample_interval second)) as Timestamp"))
  # Sequence has inclusive boundaries for both start and stop. Filter out duplicate Counts if original timestamp is exactly a boundary.
  .filter("Timestamp<NextTimestamp")
  # Interpolate Counts between previous and next
  .selectExpr(
    "webID",
    "Timestamp", 
    """(unix_timestamp(Timestamp)-unix_timestamp(PreviousTimestamp))
        /(unix_timestamp(NextTimestamp)-unix_timestamp(PreviousTimestamp))
        *(NextCounts-PreviousCounts) 
        +PreviousCounts
        as Counts"""
  )
)

我最近写了一篇博文来解释这种方法,并表明与上面的 pandas udf 方法相比,这种方法在大数据集上的扩展性要好得多:https://medium.com/delaware-pro/interpolate-big-data-time-series-in-native-pyspark-d270d4b592a1

【讨论】:

【参考方案3】:

这不是Python 解决方案,但我想下面的Scala 解决方案可以使用Python 中的类似方法来实现。它涉及使用lag Window 函数在每一行中创建一个时间范围,以及一个通过java.time API 将时间范围扩展为per-minute 时间序列和插值计数的列表,然后用 Spark 的explode方法:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
import spark.implicits._

val df = Seq(
  ("John", "2018-02-01 03:00:00", 60),
  ("John", "2018-02-01 03:03:00", 66),
  ("John", "2018-02-01 03:05:00", 70),
  ("Mo", "2017-06-04 01:07:00", 20),
  ("Mo", "2017-06-04 01:10:00", 35),
  ("Mo", "2017-06-04 01:11:00", 40)
).toDF("webID", "timestamp", "count")

val winSpec = Window.partitionBy($"webID").orderBy($"timestamp")

def minuteList(timePattern: String) = udf (ts1: String, ts2: String, c1: Int, c2: Int) =>
  import java.time.LocalDateTime
  import java.time.format.DateTimeFormatter

  val timeFormat = DateTimeFormatter.ofPattern(timePattern)

  val perMinTS = if (ts1 == ts2) Vector(ts1) else 
      val t1 = LocalDateTime.parse(ts1, timeFormat)
      val t2 = LocalDateTime.parse(ts2, timeFormat)
      Iterator.iterate(t1.plusMinutes(1))(_.plusMinutes(1)).takeWhile(! _.isAfter(t2)).
        map(_.format(timeFormat)).
        toVector
    

  val sz = perMinTS.size

  val perMinCount = for  i <- 1 to sz  yield c1 + ((c2 - c1) * i / sz)

  perMinTS zip perMinCount


df.
  withColumn("timestampPrev", when(row_number.over(winSpec) === 1, $"timestamp").
    otherwise(lag($"timestamp", 1).over(winSpec))).
  withColumn("countPrev", when(row_number.over(winSpec) === 1, $"count").
    otherwise(lag($"count", 1).over(winSpec))).
  withColumn("minuteList",
    minuteList("yyyy-MM-dd HH:mm:ss")($"timestampPrev", $"timestamp", $"countPrev", $"count")).
  withColumn("minute", explode($"minuteList")).
  select($"webID", $"minute._1".as("timestamp"), $"minute._2".as("count")).
  show
// +-----+-------------------+-----+
// |webID|          timestamp|count|
// +-----+-------------------+-----+
// | John|2018-02-01 03:00:00|   60|
// | John|2018-02-01 03:01:00|   62|
// | John|2018-02-01 03:02:00|   64|
// | John|2018-02-01 03:03:00|   66|
// | John|2018-02-01 03:04:00|   68|
// | John|2018-02-01 03:05:00|   70|
// |   Mo|2017-06-04 01:07:00|   20|
// |   Mo|2017-06-04 01:08:00|   25|
// |   Mo|2017-06-04 01:09:00|   30|
// |   Mo|2017-06-04 01:10:00|   35|
// |   Mo|2017-06-04 01:11:00|   40|
// +-----+-------------------+-----+

【讨论】:

以上是关于如何在 PySpark 的分组对象中插入一列?的主要内容,如果未能解决你的问题,请参考以下文章

如何在 PySpark 中仅打印 DataFrame 的某一列?

Pyspark数据框:对一列求和,同时对另一列进行分组

PySpark:如何在列中使用 Or 进行分组

如何使用 pyspark.resultiterable.ResultIterable 对象

获取由 PySpark Dataframe 上的另一列分组的列的不同元素

如何使用pyspark将json对象插入postgres表中的列