第七周.02.Tree LSTM代码讲解
Posted oldmao_2000
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了第七周.02.Tree LSTM代码讲解相关的知识,希望对你有一定的参考价值。
文章目录
本文内容整理自深度之眼《GNN核心能力培养计划》
公式输入请参考: 在线Latex公式
之前的论文带读看这里: 第七周.直播.Tree LSTM带读
官网的代码看这里:
https://docs.dgl.ai/tutorials/models/2_small_graph/3_tree-lstm.html
任务和数据集介绍
数据集是斯坦福的Stanford Sentiment Treebank(SST DATASET)
数据集官网有语法树的示例:https://nlp.stanford.edu/sentiment/treebank.html,贴一个,其他自己去官网看:
语法树中,非叶子节点不包含单词(用PAD_WORD表示,没有表征,训练和测试时embedding初始化设置为0,但是非叶节点参与消息汇聚的操作),最后的标签共有5个分类:Very negative, negative, neutral, positive, and very positive
导入数据
为了演示,这里对原数据集进行缩减,用的是tiny模式,该模式下数据集只包含5个句子。
所有单词用的独热编码来表示。
from collections import namedtuple
import dgl
from dgl.data.tree import SSTDataset
SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
# Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is an int value stored in the "x" field.
# The non-leaf nodes have a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field.
trainset = SSTDataset(mode='tiny') # the "tiny" set has only five trees
tiny_sst = trainset.trees
num_vocabs = trainset.num_vocabs
num_classes = trainset.num_classes
vocab = trainset.vocab # vocabulary dict: key -> id
inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word
a_tree = tiny_sst[0]
for token in a_tree.ndata['x'].tolist():#接地转list
if token != trainset.PAD_WORD:#判断是否叶子节点
print(inv_vocab[token], end=" ")#通过id转word后打印
打印结果:
the rock is destined to be the 21st century 's new `` conan ‘’ and that he 's going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .
这句话和上面的语法树图是一一对应的。
Step 1: Batching
题外话:
graphviz的安装有坑啊,哪位解决了评论告诉我一下
先下载安装 Graphviz:http://www.graphviz.org/download/
安装过程记得勾选加入系统变量,然后可以cmd里面测试一下,建立一个dot文件,加入以下代码:
//dot a.dot -Tpng -o a.png -Gsplines=line
digraph G {
a -> b;//边
b -> c;//边
subgraph x{
rank=same;//同一行接下个节点
b->d;
}
subgraph y{
rank = same;//同一行接下个节点
d->e;
}
subgraph z{
//rank=same;
c->e;
}
}
在dot文件相同目录下运行:
dot a.dot -Tpng -o a.png -Gsplines=line
得到结果如下:
表示安装成功
然后安装 PyGraphviz,到这里
https://www.lfd.uci.edu/~gohlke/pythonlibs/#pygraphviz
注意里面的数字对应python的版本,不要下错了,不然安装不了
下载后用pip装之,但是运行报错:
No module named _graphviz
先不管了,反正不画图不影响,先注释掉吧。
我的是win10的系统。
借用一下官网的图:
代码
import networkx as nx
import matplotlib.pyplot as plt
graph = dgl.batch(tiny_sst)
#def plot_tree(g):
# # this plot requires pygraphviz package
# pos = nx.nx_agraph.graphviz_layout(g, prog='dot')
# nx.draw(g, pos, with_labels=False, node_size=10,
# node_color=[[.5, .5, .5]], arrowsize=4)
# plt.show()
#plot_tree(graph.to_networkx())
这里的batch是将数据集中的子图的邻接矩阵按对角线进行排列,这样可以把所有子图放到一个大的邻接矩阵里面进行计算。
Step 2: Tree-LSTM cell with message-passing APIs
原文有提出两种Tree-LSTM :
Child-Sum Tree-LSTMs
N-ary Tree-LSTMs
这里的实现主要针对二叉树的语法树,用N-ary Tree-LSTMs来处理。
在N-ary Tree-LSTMs中,每一个节点
j
j
j 包含一个隐层表征
h
j
h_j
hj(公式6)和一个记忆单元
c
j
c_j
cj(公式5),节点
j
j
j吃两个输入,一个是孩子节点的输入
x
j
x_j
xj以及两个孩子的隐层输入
h
j
l
,
1
≤
l
≤
N
h_{jl}, 1\\leq l\\leq N
hjl,1≤l≤N (看公式1,这里二叉树N=2)
i j = σ ( W ( i ) x j + ∑ l = 1 N U l ( i ) h j l + b ( i ) ) , ( 1 ) f j k = σ ( W ( f ) x j + ∑ l = 1 N U k l ( f ) h j l + b ( f ) ) , ( 2 ) o j = σ ( W ( o ) x j + ∑ l = 1 N U l ( o ) h j l + b ( o ) ) , ( 3 ) u j = tanh ( W ( u ) x j + ∑ l = 1 N U l ( u ) h j l + b ( u ) ) , ( 4 ) c j = i j ⊙ u j + ∑ l = 1 N f j l ⊙ c j l , ( 5 ) h j = o j ⋅ tanh ( c j ) , ( 6 ) \\begin{aligned}i_j & = \\sigma\\left(W^{(i)}x_j + \\sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\\right), & (1)\\\\ f_{jk} & = \\sigma\\left(W^{(f)}x_j + \\sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \\right), & (2)\\\\ o_j & = \\sigma\\left(W^{(o)}x_j + \\sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \\right), & (3) \\\\ u_j & = \\textrm{tanh}\\left(W^{(u)}x_j + \\sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \\right), & (4)\\\\ c_j & = i_j \\odot u_j + \\sum_{l=1}^{N} f_{jl} \\odot c_{jl}, &(5) \\\\ h_j & = o_j \\cdot \\textrm{tanh}(c_j), &(6) \\\\\\end{aligned} ijfjkojujcjhj=σ(W(i)xj+l=1∑NUl(i)hjl+b(i)),=σ(W(f)xj+l=1∑NUkl(f)hjl+b(f)),=σ(W(o)xj+l=1∑NUl(o)hjl+b(o)),=tanh(W(u)xj+l=1∑NUl(u)hjl+b(u)),=ij⊙uj+l=1∑Nfjl⊙cjl,=oj⋅tanh(cj),以上是关于第七周.02.Tree LSTM代码讲解的主要内容,如果未能解决你的问题,请参考以下文章