第七周.02.Tree LSTM代码讲解

Posted oldmao_2000

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了第七周.02.Tree LSTM代码讲解相关的知识,希望对你有一定的参考价值。


本文内容整理自深度之眼《GNN核心能力培养计划》
公式输入请参考: 在线Latex公式
之前的论文带读看这里: 第七周.直播.Tree LSTM带读
官网的代码看这里:
https://docs.dgl.ai/tutorials/models/2_small_graph/3_tree-lstm.html

任务和数据集介绍

数据集是斯坦福的Stanford Sentiment Treebank(SST DATASET)
数据集官网有语法树的示例:https://nlp.stanford.edu/sentiment/treebank.html,贴一个,其他自己去官网看:

语法树中,非叶子节点不包含单词(用PAD_WORD表示,没有表征,训练和测试时embedding初始化设置为0,但是非叶节点参与消息汇聚的操作),最后的标签共有5个分类:Very negative, negative, neutral, positive, and very positive

导入数据

为了演示,这里对原数据集进行缩减,用的是tiny模式,该模式下数据集只包含5个句子
所有单词用的独热编码来表示。

from collections import namedtuple

import dgl
from dgl.data.tree import SSTDataset


SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])

# Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is an int value stored in the "x" field.
# The non-leaf nodes have a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field.
trainset = SSTDataset(mode='tiny')  # the "tiny" set has only five trees
tiny_sst = trainset.trees
num_vocabs = trainset.num_vocabs
num_classes = trainset.num_classes

vocab = trainset.vocab # vocabulary dict: key -> id
inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word

a_tree = tiny_sst[0]
for token in a_tree.ndata['x'].tolist():#接地转list
    if token != trainset.PAD_WORD:#判断是否叶子节点
        print(inv_vocab[token], end=" ")#通过id转word后打印

打印结果:
the rock is destined to be the 21st century 's new `` conan ‘’ and that he 's going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .
这句话和上面的语法树图是一一对应的。

Step 1: Batching

题外话:

graphviz的安装有坑啊,哪位解决了评论告诉我一下
先下载安装 Graphviz:http://www.graphviz.org/download/
安装过程记得勾选加入系统变量,然后可以cmd里面测试一下,建立一个dot文件,加入以下代码:

//dot a.dot -Tpng -o a.png  -Gsplines=line  
digraph G {
	a -> b;//边
	b -> c;//边
	subgraph x{
		rank=same;//同一行接下个节点
		b->d;
	}
	subgraph y{
		rank = same;//同一行接下个节点
		d->e;
	}
	subgraph z{
		//rank=same;
		c->e;
	}
}

在dot文件相同目录下运行:

dot a.dot -Tpng -o a.png  -Gsplines=line 

得到结果如下:

表示安装成功
然后安装 PyGraphviz,到这里
https://www.lfd.uci.edu/~gohlke/pythonlibs/#pygraphviz
注意里面的数字对应python的版本,不要下错了,不然安装不了

下载后用pip装之,但是运行报错:

No module named _graphviz

先不管了,反正不画图不影响,先注释掉吧。
我的是win10的系统。
借用一下官网的图:

代码

import networkx as nx
import matplotlib.pyplot as plt

graph = dgl.batch(tiny_sst)
#def plot_tree(g):
#    # this plot requires pygraphviz package
#    pos = nx.nx_agraph.graphviz_layout(g, prog='dot')
#    nx.draw(g, pos, with_labels=False, node_size=10,
#            node_color=[[.5, .5, .5]], arrowsize=4)
#    plt.show()

#plot_tree(graph.to_networkx())

这里的batch是将数据集中的子图的邻接矩阵按对角线进行排列,这样可以把所有子图放到一个大的邻接矩阵里面进行计算。

Step 2: Tree-LSTM cell with message-passing APIs

原文有提出两种Tree-LSTM :
Child-Sum Tree-LSTMs
N-ary Tree-LSTMs
这里的实现主要针对二叉树的语法树,用N-ary Tree-LSTMs来处理。
N-ary Tree-LSTMs中,每一个节点 j j j 包含一个隐层表征 h j h_j hj(公式6)和一个记忆单元 c j c_j cj(公式5),节点 j j j吃两个输入,一个是孩子节点的输入 x j x_j xj以及两个孩子的隐层输入 h j l , 1 ≤ l ≤ N h_{jl}, 1\\leq l\\leq N hjl,1lN (看公式1,这里二叉树N=2)

i j = σ ( W ( i ) x j + ∑ l = 1 N U l ( i ) h j l + b ( i ) ) , ( 1 ) f j k = σ ( W ( f ) x j + ∑ l = 1 N U k l ( f ) h j l + b ( f ) ) , ( 2 ) o j = σ ( W ( o ) x j + ∑ l = 1 N U l ( o ) h j l + b ( o ) ) , ( 3 ) u j = tanh ( W ( u ) x j + ∑ l = 1 N U l ( u ) h j l + b ( u ) ) , ( 4 ) c j = i j ⊙ u j + ∑ l = 1 N f j l ⊙ c j l , ( 5 ) h j = o j ⋅ tanh ( c j ) , ( 6 ) \\begin{aligned}i_j & = \\sigma\\left(W^{(i)}x_j + \\sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\\right), & (1)\\\\ f_{jk} & = \\sigma\\left(W^{(f)}x_j + \\sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \\right), & (2)\\\\ o_j & = \\sigma\\left(W^{(o)}x_j + \\sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \\right), & (3) \\\\ u_j & = \\textrm{tanh}\\left(W^{(u)}x_j + \\sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \\right), & (4)\\\\ c_j & = i_j \\odot u_j + \\sum_{l=1}^{N} f_{jl} \\odot c_{jl}, &(5) \\\\ h_j & = o_j \\cdot \\textrm{tanh}(c_j), &(6) \\\\\\end{aligned} ijfjkojujcjhj=σ(W(i)xj+l=1NUl(i)hjl+b(i)),=σ(W(f)xj+l=1NUkl(f)hjl+b(f)),=σ(W(o)xj+l=1NUl(o)hjl+b(o)),=tanh(W(u)xj+l=1NUl(u)hjl+b(u)),=ijuj+l=1Nfjlcjl,=ojtanh(cj),以上是关于第七周.02.Tree LSTM代码讲解的主要内容,如果未能解决你的问题,请参考以下文章

第七周收获

第七周.01.Message更新讲解+GCN实例

软件工程第七周psp

第七周学习总结

Linux内核分析——第七周学习笔记20135308

201521123105 第七周Java学习总结