解决multivariate_normal中:output parameter (typecode ‘d‘) according to the casting rule ‘‘same_kind‘‘(代

Posted 沉迷单车的追风少年

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了解决multivariate_normal中:output parameter (typecode ‘d‘) according to the casting rule ‘‘same_kind‘‘(代相关的知识,希望对你有一定的参考价值。

项目场景:

sketch rnn


问题描述:

完整报错:

 File "/root/DiffusionModel/Pytorch-Sketch-RNN-master/sketch_rnn.py", line 416, in sample_bivariate_normal
    x = np.random.multivariate_normal(mean, cov, 1)
  File "mtrand.pyx", line 4114, in numpy.random.mtrand.RandomState.multivariate_normal
TypeError: ufunc 'add' output (typecode 'O') could not be coerced to provided output parameter (typecode 'd') according to the casting rule ''same_kind''

报错代码段:

def sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, greedy=False):

    if greedy:
        return mu_x, mu_y
    mean = [mu_x, mu_y]
    sigma_x *= np.sqrt(hp.temperature)
    sigma_y *= np.sqrt(hp.temperature)
    cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y], \\
           [rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]]

    x = np.random.multivariate_normal(mean, cov, 1)
    return x[0][0], x[0][1]

问题出在这一行:

x = np.random.multivariate_normal(mean, cov, 1)

原因分析:

np.random.multivariate_normal函数中使用了np.add操作,所以找到的资料大多是在np.add()中添加参数casting='unsafe'

但是multivariate_normal这个函数中是没有casting参数的。


解决方案:

把张量取出来运算即可。

def sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, greedy=False):

    mu_x = mu_x.item()
    mu_y = mu_y.item()
    sigma_x = sigma_x.item()
    sigma_y = sigma_y.item()

    if greedy:
        return mu_x, mu_y
    mean = [mu_x, mu_y]
    sigma_x *= np.sqrt(hp.temperature)
    sigma_y *= np.sqrt(hp.temperature)
    cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y], \\
           [rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]]

    x = np.random.multivariate_normal(mean, cov, 1)
    return x[0][0], x[0][1]

以上是关于解决multivariate_normal中:output parameter (typecode ‘d‘) according to the casting rule ‘‘same_kind‘‘(代的主要内容,如果未能解决你的问题,请参考以下文章

Numpy之高斯分布 multivariate_normal

numpy.random.multivariate_normal()函数解析

np.random.multivariate_normal方法浅析

关于协方差最小化 scipy.stats.multivariate_normal.logpdf

调用 scipy.stats.multivariate_normal 后,pylab.plot“无法将浮点 NaN 转换为整数”

multivariate_normal 多元正态分布