解决multivariate_normal中:output parameter (typecode ‘d‘) according to the casting rule ‘‘same_kind‘‘(代
Posted 沉迷单车的追风少年
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了解决multivariate_normal中:output parameter (typecode ‘d‘) according to the casting rule ‘‘same_kind‘‘(代相关的知识,希望对你有一定的参考价值。
项目场景:
sketch rnn
问题描述:
完整报错:
File "/root/DiffusionModel/Pytorch-Sketch-RNN-master/sketch_rnn.py", line 416, in sample_bivariate_normal
x = np.random.multivariate_normal(mean, cov, 1)
File "mtrand.pyx", line 4114, in numpy.random.mtrand.RandomState.multivariate_normal
TypeError: ufunc 'add' output (typecode 'O') could not be coerced to provided output parameter (typecode 'd') according to the casting rule ''same_kind''
报错代码段:
def sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, greedy=False):
if greedy:
return mu_x, mu_y
mean = [mu_x, mu_y]
sigma_x *= np.sqrt(hp.temperature)
sigma_y *= np.sqrt(hp.temperature)
cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y], \\
[rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]]
x = np.random.multivariate_normal(mean, cov, 1)
return x[0][0], x[0][1]
问题出在这一行:
x = np.random.multivariate_normal(mean, cov, 1)
原因分析:
np.random.multivariate_normal函数中使用了np.add操作,所以找到的资料大多是在np.add()中添加参数casting='unsafe'
但是multivariate_normal这个函数中是没有casting参数的。
解决方案:
把张量取出来运算即可。
def sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, greedy=False):
mu_x = mu_x.item()
mu_y = mu_y.item()
sigma_x = sigma_x.item()
sigma_y = sigma_y.item()
if greedy:
return mu_x, mu_y
mean = [mu_x, mu_y]
sigma_x *= np.sqrt(hp.temperature)
sigma_y *= np.sqrt(hp.temperature)
cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y], \\
[rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]]
x = np.random.multivariate_normal(mean, cov, 1)
return x[0][0], x[0][1]
以上是关于解决multivariate_normal中:output parameter (typecode ‘d‘) according to the casting rule ‘‘same_kind‘‘(代的主要内容,如果未能解决你的问题,请参考以下文章
Numpy之高斯分布 multivariate_normal
numpy.random.multivariate_normal()函数解析
np.random.multivariate_normal方法浅析
关于协方差最小化 scipy.stats.multivariate_normal.logpdf
调用 scipy.stats.multivariate_normal 后,pylab.plot“无法将浮点 NaN 转换为整数”