GAN 拟合高斯分布数据Pytorch实现
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了GAN 拟合高斯分布数据Pytorch实现相关的知识,希望对你有一定的参考价值。
参考技术A GAN本身是一种生成式模型,所以在数据生成上用的是最普遍的,最常见的是图片生成,常用的有DCGAN WGAN,BEGAN。目前比较有意思的应用就是GAN用在图像风格迁移,图像降噪修复,图像超分辨率了,都有比较好的结果。目前也有研究者将GAN用在对抗性攻击上,具体就是训练GAN生成对抗文本,有针对或者无针对的欺骗分类器或者检测系统等等,但是目前没有见到很典范的文章。好吧,笔者有一个项目和对抗性攻击有关,所以要学习一下GAN。GANs组成:生成器和判别器。结构如图1所示
针对问题: 给定一批样本,训练一个系统能够生成类似的新样本
核心思想:博弈论中的纳什均衡,
判别器D 的目的是判断数据来自生成器还是训练集,
生成器G 的目的是学习真实数据的分布,使得生成的数据更接近真实数据,
两者不断学习优化最后得到纳什平衡点。
D( x) 表示真实数据的概率分布,
G( z) 表示输入噪声z 产生的生成数据的概率分布
训练目标:G( Z)在判别器上的分布D( G( Z) ) 更接近真实数据在判别器上的分布D( X)
接下来就来实现我们的例子把,目标是把标准正态分布的数据,通过训练的GAN网络之后,得到的数据x_fake能尽量拟合均值为3方差为1的高斯分布N(3,1)的数据。
可以看出生成器其实就是简单的全连接网络,当然CNN,RNN等网络都是适合GAN的,根据需要选择。
可以看出判别器其实也是简单的全连接网络,当然CNN,RNN等网络都是适合GAN的,根据需要选择。
在这里想说的是对于判别器和生成器的训练是分开的,训练判别器的时候固定生成器,训练生成器的时候固定判别器,如此循环。本例子中先训练三次判别器,接着训练一次生成器。
为了便于理解具体训练过程,图2 、图3展示了判别器和生成器训练时的数据流向,具体就不展开了,参考注释。
画图函数敬上
然后调用main()函数就好了
红色是目标分布,蓝色是生成分布,还是有一定效果的额。
感受到是在调参了,请教我如何学习生成(xie)对抗(lun)网络(wen)。
python 拟合曲线并求参
需要对数据进行函数拟合,首先画一下二维散点图,目测一下大概的分布,
所谓正态分布,就是高斯分布,正态曲线是一种特殊的高斯曲线。
python的scipy.optimize包里的curve_fit函数来拟合曲线,当然还可以拟合很多类型的曲线。scipy.optimize提供了函数最小值(标量或多维)、曲线拟合和寻找等式的根的有用算法。
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import math
#单个高斯模型,如果曲线有多个波峰,可以分段拟合
def func(x, a,u, sig):
return a*np.exp(-(x - u) ** 2 / (2 * sig ** 2)) / (sig * math.sqrt(2 * math.pi))
#混合高斯模型,多个高斯函数相加
def func3(x, a1, a2, a3, m1, m2, m3, s1, s2, s3):
return a1 * np.exp(-((x - m1) / s1) ** 2) + a2 * np.exp(-((x - m2) / s2) ** 2) + a3 * np.exp(-((x - m3) / s3) ** 2)
#正弦函数拟合
#def fmax(x,a,b,c): # return a*np.sin(x*np.pi/6+b)+c #fita,fitb=optimize.curve_fit(fmax,x,ymax,[1,1,1])
#非线性最小二乘法拟合
#def func(x, a, b,c):
# return a*np.sqrt(x)*(b*np.square(x)+c)
#用3次多项式拟合,可推广到n次多项式,数学上可以证明,任意函数都可以表示为多项式形式
#f1 = np.polyfit(x, y, 3)
#p1 = np.poly1d(f1)
#yvals = p1(x) #拟合y值
#也可使用yvals=np.polyval(f1, x)
拟合,并对参数进行限制,bounds里面代表参数上下限,p0是初始范围,默认是[1,1,1]
x=np.arange(1,206,1)
num = []<-自己的y值
numhunt = []<-自己的y值
y = np.array(num)
yhunt = np.array(numhunt)
popt, pcov = curve_fit(func3, x, y)
popthunt, pcovhunt = curve_fit(func, x, yhunt,p0=[2,2,2])
ahunt = popthunt[0]
uhunt = popthunt[1]
sighunt = popthunt[2]
a1 = popt[0]
u1 = popt[1]
sig1 = popt[2]
a2 = popt[3]
u2 = popt[4]
sig2 = popt[5]
a3 = popt[6]
u3 = popt[7]
sig3 = popt[8]
yvals = func3(x,a1,u1,sig1,a2,u2,sig2,a3,u3,sig3) #拟合y值
yhuntvals = func(x,ahunt,uhunt,sighunt) #拟合y值
print(u‘系数ahunt:‘, ahunt)
print(u‘系数uhunt:‘, uhunt)
print(u‘系数sighunt:‘, sighunt)
#绘图
plot1 = plt.plot(x, y, ‘s‘,label=‘insect original values‘)
plot2 = plt.plot(x, yvals, ‘r‘,label=‘insect polyfit values‘)
plot3 = plt.plot(x, yhunt, ‘s‘,label=‘predator original values‘)
plot4 = plt.plot(x, yhuntvals, ‘g‘,label=‘predator polyfit values‘)
plt.xlabel(‘date‘)
plt.ylabel(‘Nightly catches log10(N+1)‘)
plt.legend(loc=4) #指定legend的位置右下角
plt.title(‘insect/predator‘)
plt.show()
下图是单个和多个高斯拟合图像
下图是多项式拟合图像
图例的位置可以自定义设置
lower left
upper center
lower right
center
upper left
center left
upper right
lower center
best
center right
right
以上是关于GAN 拟合高斯分布数据Pytorch实现的主要内容,如果未能解决你的问题,请参考以下文章