使用 numpy 生成具有 case-when 条件的随机数据

Posted

技术标签:

【中文标题】使用 numpy 生成具有 case-when 条件的随机数据【英文标题】:Use numpy to generate random data with case-when condition 【发布时间】:2020-01-10 05:50:24 【问题描述】:

问题

我想生成要建模的数据。基于以下逻辑:

输入:2个名为zd的numpy数组。

z:一维数组,值 0/1

d:一维数组,值 0/1

返回: y:一维数组。 value:范数随机数。

如果 z == 0 且 d==0, y ~ norm(1,1),

如果 z == 0 且 d == 1, y~ norm(0,1),

如果 z == 1 且 d == 0, y ~ norm(1,1),

如果 z == 1 且 d == 1,则 y ~ norm(2,1)。

我想以一种超级快速、清晰和 Pythonic 的方式来完成。

似乎基础数学和np.where 更快。在这种情况下,我只有 3 个条件(你可以从基础数学部分看得很清楚)。如果我有 10 个或更多条件,在 if-else 中键入它们有时会令人困惑。我想进行数据模拟,这意味着我将在不同的n生成数百万次数据。那么,最好的方法是什么?

我尝试过的:

# generate data
n = 2000
z = np.random.binomial(1,0.5,n)
d = np.random.binomial(1,0.5,n)

dict case-when

def myfun(x):
    return (0,1):np.random.normal(0,1),\
            (0,0):np.random.normal(1,1),\
            (1,0):np.random.normal(1,1),\
            (1,1):np.random.normal(2,1)[x]
%%timeit
y = [myfun(i) for i in zip(z,d)]

输出:

每个循环 16.2 ms ± 139 µs(平均值 ± 标准偏差,7 次运行,每次 100 个循环)

简单if-else

%%timeit
y = np.random.normal([0 if (i == 0) & (j ==1) else 2 if (i == 1) & (j == 1) else 1 for i,j in zip(z,d)],1)

输出:

每个循环 1.38 ms ± 22.1 µs(7 次运行的平均值 ± 标准偏差,每次 1000 个循环)

基础数学

%%timeit
h0 = np.random.normal(0,1,n)
h1 = np.random.normal(1,1,n)
h2 = np.random.normal(2,1,n)
y = (1-z)*d*h0 + (1-d)*h1 + z*d*h2

输出:

每个循环 140 µs ± 135 ns(7 次运行的平均值 ± 标准偏差,每次 10000 个循环)

np.where

%%timeit
h0 = np.random.normal(0,1,n)
h1 = np.random.normal(1,1,n)
h2 = np.random.normal(2,1,n)
y = np.where((d== 0),h1,0) + np.where((z ==1) & (d== 1),h2,0) + np.where((z ==0) & (d== 1),h0,0)

输出:

每个循环 156 µs ± 598 ns(7 次运行的平均值 ± 标准偏差,每次 10000 次循环)

还有其他新方法吗?

【问题讨论】:

我确定您的 if-else 不会做您认为的事情,因此请确保您的所有选项都先做同样的事情。 @AndrasDeak 抱歉,修改了我的帖子 更好,但你也应该仔细看看i == 0 & j ==1 【参考方案1】:

看来你已经完成了这里的工作。结果现在基于权衡。上述所有解决方案都在不同程度上符合标准。

这个代码是用来教别人的吗?也许它一天只执行一次或两次?它是一个更大项目的一部分,需要非常清楚以供其他人维护?如果是这种情况,请选择速度较慢但更易于理解和阅读的选项。

每天执行数千次还是数百万次?资源成本降低了产品的利润?如果是这样,请评论它并使用更快的选项。

似乎basic mathematics 选项是最好的折衷方案,因为它简单、易于理解且执行迅速。

我对每种方法的偏见评论:

dict case-when:慢,需要多次阅读/测试才能完全了解实际发生的情况并确定是否有任何未知的陷阱。 simple if-else:慢,需要多次阅读/测试才能完全了解实际发生的情况并确定是否有任何未知的陷阱。 basic mathematics:如果你有一点数学背景(应该包括大多数程序员),速度很快,很容易理解。 np.where:快速,完成工作,需要多次阅读/测试才能完全了解实际发生的情况,但由于它基于数组,因此不太容易出现问题。

这里是pythonic编写代码的哲学供参考:

美胜于丑。 显式优于隐式。 简单胜于复杂。 复杂胜于复杂。 平面优于嵌套。 稀疏优于密集。 可读性很重要。

使用上述作为标准更容易评估您的代码是否是 Python 的。

【讨论】:

【参考方案2】:

我认为最快的选择是通过使用normal 的数组值参数只生成一次随机数。使用新的随机 API:

import numpy as np

rng = np.random.default_rng()

# generate data
n = 2000
z = rng.binomial(1, 0.5, n)
d = rng.binomial(1, 0.5, n)

def generate_once(z, d):
    """Generate randoms for https://***.com/questions/59676147"""

    # encode mean; scale is always 1 anyway in the example
    means = np.zeros_like(z, dtype=float)
    z_inds = z == 0
    d_inds = d == 0
    means[d_inds] = 1
    means[z_inds & ~d_inds] = 2

    # generate the data
    y = rng.normal(means)
    return y

y = generate_once(z, d)

我并没有尝试与所有其他人竞争,但我希望这具有竞争力。将其视为if-else 的更快变体。将means(通常还包括scales)映射为数组时采取捷径可以减少开销,并且每个正常数字只生成一次应该会减少运行时间。

【讨论】:

以上是关于使用 numpy 生成具有 case-when 条件的随机数据的主要内容,如果未能解决你的问题,请参考以下文章

SQL Select CASE-WHEN - 如何从电话号码中删除格式

Oracle Sql关于case-when,if-then,decode

使用 Numpy 读取使用 C++ 数据类型生成的二进制文件

case-when oracle sql的动态评估

IsNumeric 和 Case-when 计算数值并更改其在结果中的显示方式

numpy:np.random.seed()