神经网络框架原理
Posted 未来可期-2018
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了神经网络框架原理相关的知识,希望对你有一定的参考价值。
文章目录
想看详细代码请到-> 从零编写一个简单神经网络框架
1.建立计算图
定义变量和使用各种算子的过程就是一个建立一个计算图的过程
import networkx as nx
import random
import numpy
from matplotlib.animation import FuncAnimation
%matplotlib notebook
seed=1
random.seed(seed)
np.random.seed(seed)
sample_graph =
'x':["linear_01"],
'k':["linear_01"],
'b':["linear_01"],
'linear_01':["sigmoid"],
"sigmoid":["linear_02"],
"k2":["linear_02"],
"b2":["linear_02"],
"linear_02":["yhat"],
"yhat":["loss"]
graph = nx.DiGraph(sample_graph)
2.拓扑排序
-
利用拓扑排序获得前馈计算和反馈传播顺序,通过前馈计算计算出loss, 反向传播计算出loss对k和b的偏导
-
所有节点的求导法则
∂ l o s s ∂ s e l f = ∂ l o s s ∂ o u t p u t ∗ ∂ o u t p u t ∂ s e l f \\frac\\partial loss\\partial self=\\frac\\partial loss\\partial output*\\frac\\partial output\\partial self ∂self∂loss=∂output∂loss∗∂self∂output
valid_feedforward_order = list(nx.topological_sort(graph))
3.前馈计算
def forward(i):
color_map = ("red", "green")
pre_color, after_color = color_map
changed = valid_feedforward_order[:i]
color_order = [
after_color if c in changed else pre_color for c in valid_feedforward_order]
if i == len(valid_feedforward_order):
ani.event_source.stop()
nx.draw(graph, layout, node_color=color_order, with_labels=True)
fig = plt.gcf()
nx.draw(graph, layout, with_labels=True)
ani = FuncAnimation(fig, forward, frames=range(len(valid_feedforward_order)),interval=300)
plt.show()
<IPython.core.display.javascript object>
4.反向传播
def backward(i):
color_map = ("red", "green")
pre_color, after_color = color_map
changed = valid_feedforward_order[:i]
color_order = [
after_color if c in changed else pre_color for c in valid_feedforward_order][::-1]
if i == len(valid_feedforward_order):
ani.event_source.stop()
nx.draw(graph, layout, node_color=color_order, with_labels=True)
fig = plt.gcf()
nx.draw(graph, layout, with_labels=True)
ani = FuncAnimation(fig, backward, frames=range(len(valid_feedforward_order)),interval=300)
plt.show()
<IPython.core.display.Javascript object>
FuncAnimation
ani = FuncAnimation(fig,update,frames,init_func,interval=1,blit)
- fig figure对象
- update 不断更新图像的函数,生成新的xdata和ydata
- frames list 不断提供frame给update用于生成xdata和ydata
- init_func 初始函数为init,自定义开始帧
- interval 时间间隔1ms
以上是关于神经网络框架原理的主要内容,如果未能解决你的问题,请参考以下文章