李沐动手学深度学习V2-Encoder-Decoder编码器和解码器架构
Posted cv_lhp
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了李沐动手学深度学习V2-Encoder-Decoder编码器和解码器架构相关的知识,希望对你有一定的参考价值。
一. encoder-decoder编码器和解码器架构
1. 介绍
机器翻译是序列转换模型的一个核心问题, 其输入和输出都是长度可变的序列。 为了处理这种类型的输入和输出, 可以设计一个包含两个主要组件的架构: 第一个组件是一个编码器(encoder): 它接受一个长度可变的序列作为输入, 并将其转换为具有固定形状的编码状态。 第二个组件是解码器(decoder): 它将固定形状的编码状态映射到长度可变的序列。 这被称为编码器-解码器(encoder-decoder)架构,如下图所示。
以英语到法语的机器翻译为例: 给定一个英文的输入序列:“They”、“are”、“watching”、“.”。 首先,这种“编码器-解码器”架构将长度可变的输入序列编码成一个“状态”, 然后对该状态进行解码, 一个词元接着一个词元地生成翻译后的序列作为输出: “Ils”、“regordent”、“.”,下面实现编码器和解码器的接口。
2. 编码器
在编码器接口中,指定长度可变的序列作为编码器的输入X。 任何继承这个Encoder 基类的模型将完成代码实现。
from torch import nn
"""编码器-解码器架构的基本编码器接口"""
class Encoder(nn.Module):
def __init__(self):
super(Encoder,self).__init__()
def forward(self,X,*args):
raise NotImplementedError
3. 解码器
在解码器接口中,新增一个init_state()函数, 用于将编码器的输出(enc_outputs)转换为编码后的状态。 注意,此步骤可能需要额外的输入,例如:输入序列的有效长度。 为了逐个地生成长度可变的词元序列, 解码器在每个时间步都会将输入 (例如:在前一时间步生成的词元或者使用label序列词元)和编码后的状态输入到网络中,从而得到当前时间步的输出词元。
"""编码器-解码器架构的基本解码器接口"""
class Decoder(nn.Module):
def __init__(self):
super(Decoder,self).__init__()
def init_state(self,enc_outputs,*args):
raise NotImplementedError
def forward(self,X,state):
raise NotImplementedError
4. 合并编码器和解码器
“编码器-解码器”架构包含了一个编码器和一个解码器, 并且还拥有可选的额外的参数。 在前向传播中,编码器的输出用于生成编码状态, 这个状态又被解码器作为其输入的一部分。
"""编码器-解码器架构的基类"""
class EncoderDecoder(nn.Module):
def __init__(self,encoder,decoder):
super(EncoderDecoder,self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self,enc_X,dec_X,*args):
enc_outputs = self.encoder(enc_X,*args)
dec_state = self.decoder.init_state(enc_outputs,*args)
return self.decoder(dec_X,dec_state)
5. 小结
- “编码器-解码器”架构可以将长度可变的序列作为输入和输出,因此适用于机器翻译等序列转换问题。
- 编码器将长度可变的序列作为输入,并将其转换为具有固定形状的编码状态,这个状态又被解码器作为其输入的一部分。
- 解码器将具有固定形状的编码状态和根据在前一时间步生成的预测词元或者使用label序列当前时间步的词元,从而使网络生成具有长度可变的序列结果。
6. 全部代码
from torch import nn
"""编码器-解码器架构的基本编码器接口"""
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
def forward(self, X, *args):
raise NotImplementedError
"""编码器-解码器架构的基本解码器接口"""
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
def init_state(self, enc_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError
"""编码器-解码器架构的基类"""
class EncoderDecoder(nn.Module):
def __init__(self, encoder, decoder):
super(EncoderDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_outputs, *args)
return self.decoder(dec_X, dec_state)
6. 相关链接
机器翻译第一篇:李沐动手学深度学习V2-机器翻译和数据集
机器翻译第二篇:李沐动手学深度学习V2-Encoder-Decoder编码器和解码器架构
机器翻译第三篇:李沐动手学深度学习V2-seq2seq和代码实现
机器翻译第四篇:李沐动手学深度学习V2-基于注意力机制的seq2seq
跟李沐导师:动手学深度学习!
Datawhale学习
预告:三月学习计划,内容:深度学习入门
二月学习需求收集
李沐动手学深度学习热度排名第二。根据读者的学习建议,Datawhale团队联系了李沐老师,将组织动手学深度学习课程的学习。
关于动手学深度学习
《动手学深度学习》这本书由李沐等人主导编写,介绍了深度学习从模型构造到模型训练的方方面面,以及在计算机视觉和自然语言处理中的应用。
它最大的特色在于,不仅阐述了算法原理,还提供了实际可运行的代码。
更令人暖心的是,这本书不要求读者有任何深度学习或者机器学习的背景知识,书中会从头开始解释每一个概念。读者只需了解基础的数学和编程,如基础的线性代数、微分和概率,以及基础的Python编程,就可以愉快地开始啃这本书了。
目前已经上线了最新PyTorch版本。
提前进群学习
关注公众号,回复 “三月” 进学习群
以上是关于李沐动手学深度学习V2-Encoder-Decoder编码器和解码器架构的主要内容,如果未能解决你的问题,请参考以下文章