Deeplearning4j 实战 (13-2):基于Embedding+CNN的文本分类实现

Posted wangongxi

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Deeplearning4j 实战 (13-2):基于Embedding+CNN的文本分类实现相关的知识,希望对你有一定的参考价值。

Deeplearning4j 实战 (13-2):基于Embedding+CNN的文本分类实现

Eclipse Deeplearning4j GitChat课程:Deeplearning4j 快速入门_专栏
Eclipse Deeplearning4j 系列博客:万宫玺的专栏_wangongxi_CSDN博客
Eclipse Deeplearning4j Github:https://github.com/eclipse/deeplearning4j
Eclipse Deeplearning4j 社区:https://community.konduit.ai/

之前的博客中,我们使用TextCNN对中文进行情感分析,其中词嵌入使用的是预训练模型来初始化词向量。在训练过程中,我们并不更新词向量,只更新卷积和池化层以及最后全连接层的模型参数,总的参数量控制在10W浮点数左右,属于非常轻巧的模型,也适合线上实时调用。这篇博客主要探讨下端到端的建模,也就是将词向量的构建融合到整体模型当中,在更新卷积层等参数的同时也更新词向量的分布。相对于单独构建词向量模型的做法,端到端的建模会更加直观一些,也不用去调研开源的预训练模型或者自己构建一个预训练模型。但必须指出的是,模型的参数量会急剧上升,训练阶段也必定会导致计算量和存储的上升,至于线上serving阶段的时效性理论上并不会有太多改变,因为词嵌入模块仅仅提供了类似字典的lookUp的功能。

1. 模型结构和data flow的分析

1.1 模型结构

整体模型结构和之前介绍TextCNN的博客中的结构类似,不同点在于增加了Embedding层以及Reshape层。我们先给出具体的code:

/*Embedding+CNN的端到端模型*/
private ComputationGraph getModel(final int vectorSize, final int numFeatureMap, final int corpusLenLimit, final int vocabSize, final int batchSize) 
	ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
			.weightInit(WeightInit.XAVIER)
			.updater(new Adam(0.01))
			.convolutionMode(ConvolutionMode.Same)
			.graphBuilder()
			.addInputs("input")
			.addLayer("embedding", new EmbeddingSequenceLayer.Builder()
										.nIn(vocabSize).nOut(vectorSize).build(), "input")
			.addVertex("reshape", new ReshapeVertex('c', new int[] -1, 1, vectorSize, corpusLenLimit, new int[] -1,  1, 1, corpusLenLimit), "embedding")
			.addLayer("2-gram",new ConvolutionLayer.Builder().kernelSize(vectorSize, 2).stride(vectorSize, 1).nIn(1)
				.nOut(numFeatureMap).activation(Activation.LEAKYRELU).build(),"reshape")
			.addLayer("3-gram",
						new ConvolutionLayer.Builder().kernelSize(vectorSize, 3).stride(vectorSize, 1).nIn(1)
								.nOut(numFeatureMap).activation(Activation.LEAKYRELU).build(),"reshape")
				.addLayer("4-gram",
						new ConvolutionLayer.Builder().kernelSize(vectorSize, 4).stride(vectorSize, 1).nIn(1)
								.nOut(numFeatureMap).build(),"reshape")
				.addVertex("merge", new MergeVertex(), "2-gram", "3-gram", "4-gram")
				.addLayer("globalPool",
						new GlobalPoolingLayer.Builder()
							.poolingType(PoolingType.MAX).dropOut(0.5).build(), "merge")
				.addLayer("out",
						new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
								.activation(Activation.SOFTMAX).nOut(2).build(),
						"globalPool")
				.setOutputs("out")
				.setInputTypes(InputType.recurrent(vocabSize))
				.build();
	ComputationGraph net = new ComputationGraph(config);
	net.init();
	return net;

这里对新增的EmbeddingSequenceLayer和RedshapeVertex进行说明。

EmbeddingSequenceLayer是1.0.0-beta版本新增的Layer模块,主要是对于现有的EmbeddingLayer进行功能扩展,直接支持时序数据以及时序Mask功能。现有的EmbeddingLayer的功能更多的是一张lookUp表,可以批量地查询词向量。虽然可以通过reshape的方式来支持时序数据,但并不直观,因此这里使用EmbeddingSequenceLayer来处理时序数据。

从I/O data flow层面分析,EmbeddingSequenceLayer支持[mb, seq_len]或者[mb, 1, seq_len]格式的输入数据,并输出[mb, embedding_size, seq_len]格式的数据。换句话说,支持对时序数据进行向量化的操作。另外,由于时序数据一般都是变长的,因此需要基于Mask机制来标识序列实际有效的长度和位置。这个在1.2部分的data flow详细分析中会具体展开。

RedshapeVertex顾名思义是对数据的reshape操作,当然也包括Mask部分的reshape。RedshapeVertex层的第一个参数是底层数据存储的order,这里可以不关注。第二和第三个参数分别代表输入数据和Mask数据需要被规整后的shape。需要说明的是,由于mb是动态的,因此用-1来代替。如果从处理图像数据的角度看,EmbeddingSequenceLayer的输出可以被认为是一批灰度图,是height=embedding_size,width=seq_len,depth/channel=1的图像数据,这也是文本包括语音等时序数据可以通过CNN来处理的原因。为了适配后续卷积层处理数据的格式,我们通过ReshapeVertex将原始的时序数据增加一个维度且等于1,从3D变换成4D的张量,这个新增维度的物理含义是图像中的channel或者depth,对于灰度图channel/depth即等于1。

对于ReshapeVertex操作,它的I/O data shape其实是开发人员根据需要指定的,比如上面代码中实现了从[mb, embedding_size, seq_len][mb, 1, embedding_size, seq_len]的转换,目的也是为了适配卷积层的操作。对于Mask的reshape操作同样放到1.2的部分中阐述。

除了这两个部分以外,其余的模块和之前TextCNN的博客中描述的是一致的。如果有需要,可以翻阅前面的博客。我们通过summary接口来打印下模型的详细信息,超参数设置如下。

final int vectorSize = 8;
final int numFeatureMap = 3;
final int corpusLenLimit = 10;
final int vocabSize = 10000;
final int batchSize = 2;

ComputationGraph graph = getModel(vectorSize, numFeatureMap, corpusLenLimit, vocabSize, batchSize);
System.out.println(graph.summary(InputType.recurrent(vocabSize)));

可以得到如下的信息:

==============================================================================================================================================================================================================
VertexName (VertexType)              nIn,nOut   TotalParams   ParamsShape          Vertex Inputs              InputShape                                   OutputShape                                        
==============================================================================================================================================================================================================
input (InputVertex)                  -,-        -             -                    -                          -                                            -                                                  
embedding (EmbeddingSequenceLayer)   10000,8    80,000        W:10000,8          [input]                    InputTypeRecurrent(10000,format=NCW)         InputTypeRecurrent(8,timeSeriesLength=1,format=NCW)
reshape (ReshapeVertex)              -,-        -             -                    [embedding]                -                                            InputTypeConvolutional(h=8,w=100,c=1,NCHW)         
2-gram (ConvolutionLayer)            1,3        51            W:3,1,8,2, b:3   [reshape]                  InputTypeConvolutional(h=8,w=100,c=1,NCHW)   InputTypeConvolutional(h=1,w=100,c=3,NCHW)         
3-gram (ConvolutionLayer)            1,3        75            W:3,1,8,3, b:3   [reshape]                  InputTypeConvolutional(h=8,w=100,c=1,NCHW)   InputTypeConvolutional(h=1,w=100,c=3,NCHW)         
4-gram (ConvolutionLayer)            1,3        99            W:3,1,8,4, b:3   [reshape]                  InputTypeConvolutional(h=8,w=100,c=1,NCHW)   InputTypeConvolutional(h=1,w=100,c=3,NCHW)         
merge (MergeVertex)                  -,-        -             -                    [2-gram, 3-gram, 4-gram]   -                                            InputTypeConvolutional(h=1,w=100,c=9,NCHW)         
globalPool (GlobalPoolingLayer)      -,-        0             -                    [merge]                    InputTypeConvolutional(h=1,w=100,c=9,NCHW)   InputTypeFeedForward(9)                            
out (OutputLayer)                    9,2        20            W:9,2, b:2       [globalPool]               InputTypeFeedForward(9)                      InputTypeFeedForward(2)                            
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
            Total Parameters:  80,245
        Trainable Parameters:  80,245
           Frozen Parameters:  0
==============================================================================================================================================================================================================

从调用summary接口的结果来看,主要的参数量集中在EmbeddingSequenceLayer这层。这也符合上文的有关分析。需要说明的是,这里为了简化参数分析,超参数的设置并不是真正training阶段的超参。这批超参数是为了方便说明网络结构和参数以及1.2部分阐述data flow所准备的,因此诸如featureMap数量等都设置的比较少。下面就结合summary的结果来整体说明下上述结构每一层data shape的变换,包括Mask部分data shape的变化。

1.2 Data Flow描述

这部分内容首先给出的每一层Layer的数据shape变换情况,包括Mask部分的变换,并且做些说明。首先来看EmbeddedSeqLayer这一层。

1.2.1 词嵌入层

  • 定义:
    .addLayer("embedding", new EmbeddingSequenceLayer.Builder() .nIn(vocabSize).nOut(vectorSize).build(), "input")

  • Data I/O Shape:
    input:[mb, seq_len]
    output:[mb, embedding_size, seq_len]

  • Mask I/O Shape:
    input:[mb, seq_len]
    output:[mb, seq_len]

  • 说明:这一层Layer的I/O数据格式比较清晰,是对原始数据(比如一段分词后文本序列)进行向量化,那自然output的部分会扩展出向量长度这一维度。Mask部分的目的是标识实际有效的文本序列,原因上文也提到过。这些部分均会参与forward+backward pass的计算。

  • 例子:I/O tensor + mask tensor

input
Rank: 2, DataType: FLOAT, Offset: 0, Order: c, Shape: [2,10],  Stride: [10,1]
[[    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000]]
mask
Rank: 2, DataType: FLOAT, Offset: 0, Order: c, Shape: [2,10],  Stride: [10,1]
[[    1.0000,    1.0000,         0,         0,         0,         0,         0,         0,         0,         0], 
 [    1.0000,    1.0000,    1.0000,         0,         0,         0,         0,         0,         0,         0]]
embedding
Rank: 3, DataType: FLOAT, Offset: 0, Order: c, Shape: [2,8,10],  Stride: [80,1,8]
[[[   -0.4199,   -0.4199,         0,         0,         0,         0,         0,         0,         0,         0], 
  [    0.2123,    0.2123,         0,         0,         0,         0,         0,         0,         0,         0], 
  [   -0.4690,   -0.4690,         0,         0,         0,         0,         0,         0,         0,         0], 
  [    0.3492,    0.3492,         0,         0,         0,         0,         0,         0,         0,         0], 
  [    0.5633,    0.5633,         0,         0,         0,         0,         0,         0,         0,         0], 
  [    0.2636,    0.2636,         0,         0,         0,         0,         0,         0,         0,         0], 
  [    0.0920,    0.0920,         0,         0,         0,         0,         0,         0,         0,         0], 
  [   -1.1781,   -1.1781,         0,         0,         0,         0,         0,         0,         0,         0]], 

 [[   -0.4199,   -0.4199,   -0.4199,         0,         0,         0,         0,         0,         0,         0], 
  [    0.2123,    0.2123,    0.2123,         0,         0,         0,         0,         0,         0,         0], 
  [   -0.4690,   -0.4690,   -0.4690,         0,         0,         0,         0,         0,         0,         0], 
  [    0.3492,    0.3492,    0.3492,         0,         0,         0,         0,         0,         0,         0], 
  [    0.5633,    0.5633,    0.5633,         0,         0,         0,         0,         0,         0,         0], 
  [    0.2636,    0.2636,    0.2636,         0,         0,         0,         0,         0,         0,         0], 
  [    0.0920,    0.0920,    0.0920,         0,         0,         0,         0,         0,         0,         0], 
  [   -1.1781,   -1.1781,   -1.1781,         0,         0,         0,         0,         0,         0,         0]]]

简单说明下这个例子。模型的输入是[2, 10]的tensor/matrix,代表batch=2的时序数据,且为了方便我们将元素值都固定1.0(当然这同实际情况不相符,实际应用中序列中每个元素对应一个词)。这里先不考虑padding的情况,当然如果需要padding,在遵循约定的前提下,通过padding zero即可。接着说明mask的情况,可以比较直观得看到,是一个batch=2的multi-hot的tensor。这个tensor中的第一个序列代表前两个元素是有效的,第二个序列代表前三个元素是有效的,序列长度等于input tensor的长度。最后看下embedding的输出tensor,从打印出的信息也可以看到,是个[2, 8, 10]的tensor,其中dim=1就是新增的代表词向量长度的维度。由于mask tensor的作用,我们可以看到无效的embedding部分都用0来占位了。有效元素的位置和mask tensor本身元素的位置是一致的。

1.2.2 Tensor Reshape层

  • 定义:
    .addVertex("reshape", new ReshapeVertex('c', new int[] -1, 1, vectorSize, corpusLenLimit, new int[] -1, 1, 1, corpusLenLimit), "embedding")

  • Data I/O Shape:
    input: [mb, embedding_size, seq_len]
    output:[mb, 1, embedding_size, seq_len]

  • Mask I/O Shape:
    input: [mb, seq_len]
    output:[mb, 1, 1, seq_len]

  • 说明:ReshapeVertex对上一层输出的tensor进行reshape操作。由于reshape操作并不改变总的元素数量,更多的时候是对了适配不同Layer或者Op的操作,因此从上面给出的shape可以看出data和mask tensor为了适配后续卷积层的操作,将维度都扩展到了4D,另外由于mini-batch size是不定的,还是用-1来代替。

  • 例子:data output tensor

reshape
Rank: 4, DataType: FLOAT, Offset: 0, Order: c, Shape: [2,1,8,10],  Stride: [80,80,10,1]
[[[[    0.1742,    0.1742,         0,         0,         0,         0,         0,         0,         0,         0], 
   [   -0.2452,   -0.2452,         0,         0,         0,         0,         0,         0,         0,         0], 
   [    0.1370,    0.1370,         0,         0,         0,         0,         0,         0,         0,         0], 
   [   -0.1828,   -0.1828,         0,         0,         0,         0,         0,         0,         0,         0], 
   [    0.8138,    0.8138,         0,         0,         0,         0,         0,         0,         0,         0], 
   [   -0.1781,   -0.1781,         0,         0,         0,         0,         0,         0,         0,         0], 
   [    0.3427,    0.3427,         0,         0,         0,         0,         0,         0,         0,         0], 
   [   -0.4277,   -0.4277,         0,         0,         0,         0,         0,         0,         0,         0]]], 


 [[[    0.1742,    0.1742,    0.1742,         0,         0,         0,         0以上是关于Deeplearning4j 实战 (13-2):基于Embedding+CNN的文本分类实现的主要内容,如果未能解决你的问题,请参考以下文章

Deeplearning4j 实战:Deeplearning4j 手写体数字识别Spark实现

DeepLearning4j 实战——手写体数字识别GPU实现与性能比较

deeplearning4j 和 Maven 的错误

deeplearning4j学习一

第02课:DeepLearning4j 开发的基本流程

第02课:DeepLearning4j 开发的基本流程