1.5神经网络可视化显示(matplotlib)

Posted Mr.Zhao

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了1.5神经网络可视化显示(matplotlib)相关的知识,希望对你有一定的参考价值。

神经网络训练+可视化显示

#添加隐层的神经网络结构+可视化显示
import tensorflow as tf

def add_layer(inputs,in_size,out_size,activation_function=None):
    #定义权重--随机生成inside和outsize的矩阵
    Weights=tf.Variable(tf.random_normal([in_size,out_size]))
    #不是矩阵,而是类似列表
    biaes=tf.Variable(tf.zeros([1,out_size])+0.1)
    Wx_plus_b=tf.matmul(inputs,Weights)+biaes
    if activation_function is  None:
        outputs=Wx_plus_b
    else:
        outputs=activation_function(Wx_plus_b)
    return outputs

import numpy as np
x_data=np.linspace(-1,1,300)[:,np.newaxis] #300行数据
noise=np.random.normal(0,0.05,x_data.shape)
y_data=np.square(x_data)-0.5+noise
#None指定sample个数,这里不限定--输出属性为1
xs=tf.placeholder(tf.float32,[None,1])  #这里需要指定tf.float32,
ys=tf.placeholder(tf.float32,[None,1])

#建造第一层layer
#输入层(1)
l1=add_layer(xs,1,10,activation_function=tf.nn.relu)
#隐层(10)
prediction=add_layer(l1,10,1,activation_function=None)
#输出层(1)
#预测prediction
loss=tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),
                   reduction_indices=[1])) #平方误差
train_step=tf.train.GradientDescentOptimizer(0.1).minimize(loss)

init=tf.initialize_all_variables()
sess=tf.Session()
#直到执行run才执行上述操作
sess.run(init)


import matplotlib.pyplot as plt
fig=plt.figure()
ax=fig.add_subplot(111)
ax.scatter(x_data,y_data)
plt.ion() #图像会连续显示
#plt.show()不会终止整个函数。在2.x时候使用plt.show(block=False)
plt.show()


for i in range(1000):
    #这里假定指定所有的x_data来指定运算结果
    sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
    if i%50:
        # print (sess.run(loss,feed_dict={xs:x_data,ys:y_data}))
        try:
            #忽略第一次的错误
            ax.lines.remove(lines[0]) #在图片中去掉lines的第1条线段,不然线会混乱
        except Exception:
            prediction_value=sess.run(prediction,feed_dict={xs:x_data})
            lines=ax.plot(x_data,prediction_value,\'r-\',lw=5)
            # ax.lines.remove(lines[0]) 移动上上面,先移除第一条线
            plt.pause(0.2) #每次暂停0.2s

显示:

 

以上是关于1.5神经网络可视化显示(matplotlib)的主要内容,如果未能解决你的问题,请参考以下文章

matplotlib可视化连接成对数据点的线图只显示线条不显示数据点(Paired Line Plot with Matplotlib but without points)

Python可视化库matplotlib(超详细)

Py修行路 Matplotlib 绘图及可视化模块

Python使用matplotlib可视化柱状图坐标轴标签的符号(-)显示为了方框□□设置rcParams参数配置解决

python使用matplotlib可视化雷达图(polar函数可视化雷达图极坐标图通过径向方向来显示数据之间的关系)

python数据可视化之matplotlib.pyplot绘图时图片显示不全的解决方法(图文并茂版!!!)