numpy笔记整理 multivariate_normal(多元正态分布采样)
Posted UQI-LIUWJ
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了numpy笔记整理 multivariate_normal(多元正态分布采样)相关的知识,希望对你有一定的参考价值。
1 基本用法
np.random.multivariate_normal(
mean,
cov,
size=None,
check_valid=None,
tol=None)
根据均值和协方差矩阵的情况生成一个多元正态分布矩阵
2 参数说明
mean | mean是多维分布的均值,维度为1 |
cov | 协方差矩阵 注意:协方差矩阵必须是对称的且需为半正定矩阵; |
size | 指定生成的正态分布矩阵的维度 eg:若size=(1, 1, 2),则输出的矩阵的shape即形状为 1X1X2XN(N为mean的长度) |
check_valid | 这个参数用于决定当cov即协方差矩阵不是半正定矩阵时程序的处理方式。 它一共有三个值:warn,raise以及ignore。
|
3 举例说明
import numpy as np;
mean = (1, 2)
#均值向量
alpha=[[1. , 0. ],
[0. , 0.5]]
#精度矩阵,协方差矩阵的倒数
cov=np.linalg.inv(alpha)
#协方差矩阵
x = np.random.multivariate_normal(mean, cov)
x
#array([1.52666855, 1.77005564])
x1 = np.random.multivariate_normal(mean, cov,size=(3,2))
x1
#一个3*3*N, 即3*2*2的矩阵,每一行表示一个样本
'''
array([[[ 0.34648623, -0.45031679],
[-0.04902216, 0.76346656]],
[[ 3.45690489, 1.88220564],
[-0.20572284, 1.36394544]],
[[ 0.62486475, -0.56346348],
[ 1.53072488, 3.5600083 ]]])
'''
4 “手动”采样
看到一种别的思路,从精度矩阵(协方差矩阵的倒数)出发,来进行采样的
数据还是这几个:
import numpy as np;
mean = (1, 2)
#均值向量
alpha=[[1. , 0. ],
[0. , 0.5]]
#精度矩阵,协方差矩阵的倒数
cov=np.linalg.inv(alpha)
#协方差矩阵
大致思路是:
我们以一维高斯分布N(μ,σ^2)为例,对于满足这个分布的x,我可以通过这种方式进行归一化:。那么相反地,我们从N(0,1)出发的x',可以通过以下方式转换成满足N(μ,σ^2)的分布:
那么多维的运算就是,我们找到协方差矩阵Σ,然后对他进行cholesky分解(线性代数笔记: Cholesky分解_UQI-LIUWJ的博客-CSDN博客),得到L和L^T,然后用L*X+μ 来采样(这里的X满足N(0,I),+μ操作可能是一个广播操作)
【我们这里使用精度矩阵α来实现的,所以进行cholesky分解前/后,需要进行矩阵求逆操作】
import scipy
src=np.random.normal(size=(2,5))
#从N(0,1)里面采样,每一列是一个样本
print(src)
'''
[[-1.73043484 0.97310463 -1.26852045 0.18902608 0.46363832]
[ 1.17343672 -1.10486709 -1.19307965 0.72384929 -0.17661455]]
'''
L_upp=scipy.linalg.cholesky(alpha,check_finite = False)
print(L_upp)
#将精度矩阵转换成L*L^T的形式(其中L是下三角矩阵)
'''
[[1. 0. ]
[0. 0.70710678]]
'''
x3=scipy.linalg.solve_triangular(
L_upp,
src,
lower=False,
check_finite = False)
#这里直接用np.linalg.inv(L_upp) @ src 也可以
#找x3,使得L_upp*x3=src
print(x3)
'''
[[-1.73043484 0.97310463 -1.26852045 0.18902608 0.46363832]
[ 1.65949013 -1.56251802 -1.68726942 1.02367749 -0.24977069]]
'''
x3=x3.T+mean
#加上均值
x3
'''
array([[-0.73043484, 3.65949013],
[ 1.97310463, 0.43748198],
[-0.26852045, 0.31273058],
[ 1.18902608, 3.02367749],
[ 1.46363832, 1.75022931]])
'''
以上是关于numpy笔记整理 multivariate_normal(多元正态分布采样)的主要内容,如果未能解决你的问题,请参考以下文章