使用 Pyspark 计算 Spark 数据帧每列中非 NaN 条目的数量
Posted
技术标签:
【中文标题】使用 Pyspark 计算 Spark 数据帧每列中非 NaN 条目的数量【英文标题】:Count number of non-NaN entries in each column of Spark dataframe with Pyspark 【发布时间】:2016-02-27 07:28:30 【问题描述】:我在 Hive 中加载了一个非常大的数据集。它由大约 190 万行和 1450 列组成。我需要确定每列的“覆盖率”,即每列具有非 NaN 值的行的比例。
这是我的代码:
from pyspark import SparkContext
from pyspark.sql import HiveContext
import string as string
sc = SparkContext(appName="compute_coverages") ## Create the context
sqlContext = HiveContext(sc)
df = sqlContext.sql("select * from data_table")
nrows_tot = df.count()
covgs=sc.parallelize(df.columns)
.map(lambda x: str(x))
.map(lambda x: (x, float(df.select(x).dropna().count()) / float(nrows_tot) * 100.))
在 pyspark shell 中尝试这个,如果我然后执行 covgs.take(10),它会返回一个相当大的错误堆栈。它说在文件/usr/lib64/python2.6/pickle.py
中保存有问题。这是错误的最后一部分:
py4j.protocol.Py4JError: An error occurred while calling o37.__getnewargs__. Trace:
py4j.Py4JException: Method __getnewargs__([]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:333)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:342)
at py4j.Gateway.invoke(Gateway.java:252)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:207)
at java.lang.Thread.run(Thread.java:745)
如果有比我尝试的方法更好的方法来实现这一点,我愿意接受建议。不过,我不能使用 pandas,因为它目前在我工作的集群上不可用,而且我无权安装它。
【问题讨论】:
【参考方案1】:让我们从一个虚拟数据开始:
from pyspark.sql import Row
row = Row("v", "x", "y", "z")
df = sc.parallelize([
row(0.0, 1, 2, 3.0), row(None, 3, 4, 5.0),
row(None, None, 6, 7.0), row(float("Nan"), 8, 9, float("NaN"))
]).toDF()
## +----+----+---+---+
## | v| x| y| z|
## +----+----+---+---+
## | 0.0| 1| 2|3.0|
## |null| 3| 4|5.0|
## |null|null| 6|7.0|
## | NaN| 8| 9|NaN|
## +----+----+---+---+
您只需要一个简单的聚合:
from pyspark.sql.functions import col, count, isnan, lit, sum
def count_not_null(c, nan_as_null=False):
"""Use conversion between boolean and integer
- False -> 0
- True -> 1
"""
pred = col(c).isNotNull() & (~isnan(c) if nan_as_null else lit(True))
return sum(pred.cast("integer")).alias(c)
df.agg(*[count_not_null(c) for c in df.columns]).show()
## +---+---+---+---+
## | v| x| y| z|
## +---+---+---+---+
## | 2| 3| 4| 4|
## +---+---+---+---+
或者如果你想对待NaN
一个NULL
:
df.agg(*[count_not_null(c, True) for c in df.columns]).show()
## +---+---+---+---+
## | v| x| y| z|
## +---+---+---+---+
## | 1| 3| 4| 3|
## +---+---+---+---
您还可以利用 SQL NULL
语义来实现相同的结果,而无需创建自定义函数:
df.agg(*[
count(c).alias(c) # vertical (column-wise) operations in SQL ignore NULLs
for c in df.columns
]).show()
## +---+---+---+
## | x| y| z|
## +---+---+---+
## | 1| 2| 3|
## +---+---+---+
但这不适用于NaNs
。
如果你喜欢分数:
exprs = [(count_not_null(c) / count("*")).alias(c) for c in df.columns]
df.agg(*exprs).show()
## +------------------+------------------+---+
## | x| y| z|
## +------------------+------------------+---+
## |0.3333333333333333|0.6666666666666666|1.0|
## +------------------+------------------+---+
或
# COUNT(*) is equivalent to COUNT(1) so NULLs won't be an issue
df.select(*[(count(c) / count("*")).alias(c) for c in df.columns]).show()
## +------------------+------------------+---+
## | x| y| z|
## +------------------+------------------+---+
## |0.3333333333333333|0.6666666666666666|1.0|
## +------------------+------------------+---+
Scala 等价物:
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.col, isnan, sum
type JDouble = java.lang.Double
val df = Seq[(JDouble, JDouble, JDouble, JDouble)](
(0.0, 1, 2, 3.0), (null, 3, 4, 5.0),
(null, null, 6, 7.0), (java.lang.Double.NaN, 8, 9, java.lang.Double.NaN)
).toDF()
def count_not_null(c: Column, nanAsNull: Boolean = false) =
val pred = c.isNotNull and (if (nanAsNull) not(isnan(c)) else lit(true))
sum(pred.cast("integer"))
df.select(df.columns map (c => count_not_null(col(c)).alias(c)): _*).show
// +---+---+---+---+
// | _1| _2| _3| _4|
// +---+---+---+---+
// | 2| 3| 4| 4|
// +---+---+---+---+
df.select(df.columns map (c => count_not_null(col(c), true).alias(c)): _*).show
// +---+---+---+---+
// | _1| _2| _3| _4|
// +---+---+---+---+
// | 1| 3| 4| 3|
// +---+---+---+---+
【讨论】:
return sum(col(c).isNotNull().cast("integer")).alias(c) 这里它会自动知道要访问哪个数据帧吗?是因为我们从那个特定的数据框中获取了列名吗? @Roshini 列仅在定义绑定的特定 SQL 表达式的范围内有意义。换句话说,给定select
的上下文定义了列的解析方式。
如果nan
计数大于阈值,如何选择列?
TypeError: Column is not iterable
第一次尝试。【参考方案2】:
您可以使用isNotNull()
:
df.where(df[YOUR_COLUMN].isNotNull()).select(YOUR_COLUMN).show()
【讨论】:
为什么投反对票?这是非常优雅的,至少与上面的 spark sql 代码一样 Pythonic(这也很出色,但在许多更简单的上下文中,这段代码就可以了)。投赞成票。为所有人点赞! Nulls 和 nans 有不同的功能以上是关于使用 Pyspark 计算 Spark 数据帧每列中非 NaN 条目的数量的主要内容,如果未能解决你的问题,请参考以下文章