Electra: 判别还是生成,这是一个选择

Posted 张雨石

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Electra: 判别还是生成,这是一个选择相关的知识,希望对你有一定的参考价值。

最近要开始使用Transformer去做一些事情了,特地把与此相关的知识点记录下来,构建相关的、完整的知识结构体系。

以下是要写的文章,文章大部分都发布在公众号【雨石记】上,欢迎关注公众号获取最新文章。

动机

对于Bert的改进,大家可能注意到我的句式都是Bert很好很强大,但是……,今天这篇也不例外,只是改进的方向有点出乎我的意料。

Bert很好很强大,但是它的训练太低效,为什么呢?我们来回顾一下,在训练Bert的时候,在输入上,把15%的词语给替换成Mask,然后这其中有80%是Mask,有10%是替换成其他词语,最后剩下10%保持原来的词语。

可以看到,Bert的训练中,每次相当于只有15%的输入上是会有loss的,而其他位置是没有的,这就导致了每一步的训练并没有被完全利用上,导致了训练速度慢。

于是,就有了Electra,Electra是Efficiently Learning an Encoder that Classifies Token Replacement Accurately的缩写。从名字中可以看出,相对于Bert的去预测Mask的正确值,Electra则是去预测Token是不是被替换了。那么具体是如何做的呢?

判别器 & 生成器

如下图所示,首先会训练一个生成器来生成假样本,然后Electra去判断每个token是不是被替换了。

大家看到这张图,可能会想到,这不是对抗生成网络咩?其实不是的,不是的奥秘就在损失函数上。

对抗生成网络传送门在此处开启

再来仔细的看一下算法流程,首先,输入经过随机选择设置为[MASK],然后输入给Generator,Generator负责把[MASK]变成替换过的词。

但此时Generator并不像对抗神经网络那样需要等Discriminator中传回来的梯度,而是像Bert一样那样去尝试预测正确的词语,从而计算损失。这就是Electra不是GAN的根本原因。

因此,极端情况下如果Generator的预测准确率是100%,那么Discriminator就学习不到什么了,因为所有的token都是正确词语。但所幸,Generator一般是个小模型,所以效果达不到这么高,同时,Generator刚开始就要和Discriminator联合训练,所以刚开始也不会达到这么高。

Discriminator则是去预测每个位置上的词语是不是被替换过。Discriminator是训练完之后我们得到的预训练模型,Generator在训练完之后就没有用了。

Electra另外一点和对抗生成网络不同的是,如果Generator生成的是和原始输入一样的token,那么这个token会被当做是没有替换,而在对抗生成网络中所有来源于生成器的数据都是fake数据。

用公式来解释上述过程,如下:

损失函数如下:

所以,最后的损失函数如下:

注意到GAN的损失函数是minG maxDV(D, G),跟这个损失函数大有不同。

实验效果

终极目标就是能在计算量等同的情况下,超过同等体量的模型效果。

其他尝试的设置

当然,还有很多其他的设置:

  • Generator是多余的模型部分,在最后不会被用到,但在训练时是要占用计算量的,所以为了保证同等计算量,就需要让Discriminator比其他的Bert变种要小。为了节省这一开支:
    • Generator要尽可能的小
    • embedding层可以和Discriminator共享
  • 真正的像GAN一样去做训练
  • Two-stage训练方法: 先训练Generator,然后再去训练Discriminator。

实验结果如下,可以看到,Discriminator的宽度为768,Generator宽度为256时效果最好。同时,GAN和Two-stage训练都不如Electra。

不同模型尺寸的对比

实验还分别比较了小模型和大模型。可以看到,无论是大模型还是小模型,Electra都可以超过Bert。


消融实验

为了验证Electra的提升到底是哪里来的,做了一些Bert和Electra中间设置的一些模型的实验,包括:

  • Electra 15%: 跟Electra很像,不过loss只来源于15%的位置,这点是仿照Bert。
  • Replace MLM: 跟Bert很像,不过要替换的词语不再是[MASK],而是Generator生成的词语。
  • All-Tokens MLM: 和上面的Replace MLM很像,但更进一步,预测所有的位置而不是仅预测被Mask的位置。

结果如下,可以看到,All-tokens MLM提升最大,但还可以看到,Electra相对于Bert,不仅仅在训练速度上有提升,在最终的结果上也有提升。

如下图所示,Electra可以达到Bert打不到的高度。

思考

勤思考,多提问是Engineer的良好品德。

  • 对抗生成网络在图像上应用很广,在文本上却遇到了很多问题,都遇到了什么问题?为什么?

关注公众号【雨石记】,答案会在后续的文章中。

参考文献

  • [1]. Clark, Kevin, et al. “Electra: Pre-training text encoders as discriminators rather than generators.” arXiv preprint arXiv:2003.10555 (2020).

以上是关于Electra: 判别还是生成,这是一个选择的主要内容,如果未能解决你的问题,请参考以下文章

Bert不完全手册4. 绕不开的MASK?XLNET & ELECTRA

使用生成对抗网络(GAN)生成手写字

判别模型生成模型与朴素贝叶斯方法

卷积生成对抗网络(DCGAN)---生成手写数字

卷积生成对抗网络(DCGAN)---生成手写数字

机器学习算法 之DCGAN