怎样克服神经网络训练中argmax的不可导性?

Posted 人工智能博士

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了怎样克服神经网络训练中argmax的不可导性?相关的知识,希望对你有一定的参考价值。

点上方人工智能算法与Python大数据获取更多干货

在右上方 ··· 设为星标 ★,第一时间获取资源

仅做学术分享,如有侵权,联系删除

转载于 :作者|Zhenyue Qin、hoooz、OwlLite

来源 | 知乎问答

地址 | https://www.zhihu.com/question/460500204

原问题

最近在使用torch做nlp的风格转换,当我利用gan进行学习时,发现seq2seq的输出是(batch size,max length,vocab length)形状的tensor,最后一维表示经过softmax的词典里各单词的出现概率。

根据gan的原理,我想将生成器的输出作为输入得到判决器输出(batch size,labels nums)形状的tensor。然后与标准labels做交叉熵得到 loss。但是判决器的输入tensor应该是(batch size,max length)形状。

这里如果将seq2seq的输出out进行out.argmax(-1)处理的话会导致loss.backward()无法在网络内产生梯度。想请教下各位大神在这里有没有什么解决的好方法。

01

回答一:作者-Zhenyue Qin

有个东西叫strainght through Gumbel (estimator), 可以看一下~

大概思想就是: 假设输入的向量是v, 那么我们用softmax得到softmax(v). 这样, 最大值那个地方就会变得很靠近1, 其他地方就会变得很靠近0. 然后, 我们计算argmax(v), 接着可以得到一个常数c = argmax(v) - softmax(v). 我们这时, 可以用softmax(v) + c来作为argmax(v)的结果. 这个东西的好处是, 我们的softmax(v) + c是有反向传播的能力的. 换句话说, 我们用softmax(v)的梯度来作为反向传播.

如果有哪点没说清楚, 欢迎评论. 谢谢.

P.S. 感谢吕纯川和Towser对于原回答的指正.

02

回答二:作者-hoooz

方案1:加入stop gradient operation, 请参考VQVAE以及对应的pytorch实现 [1][2]

一句话解释: 正向传播就和往常一样,反向传播时,将梯度从不可导那个点copy到 不可导点的前面的最近一个可导点。

(请看红线右端点的梯度,跳过中间的字典模块,直达红线的左端点)

~

问题来了

1/梯度链条怎么隔断不让他经过字典模块?pytorch有个 detach(), 可以隔断梯度,梯度就不会进入 不可导区域 引发编译器报错

2/梯度怎么复制?举个最简单例子

quantize = input + (quantize - input).detach()
# 正向传播和往常一样,
# 反向传播时,detach()这部分梯度为0,quantize和input的梯度相同,
# 即实现将quantize复制给input
# quantize即红线右端点,input即红线左端点

参考:

  • [1]. Neural Discrete Representation Learning

  • [2]. https://github.com/rosinality/v

03

回答三:作者-OwlLite

可以对argmax/argmin 这种不可导的操作直接忽视,也就是锁定:

class ArgMax(torch.autograd.Function):
	@staticmethod
	def forward(ctx, input):
        idx = torch.argmax(input, 1)
        output = torch.zeros_like(input)
        output.scatter_(1, idx, 1)
        return output
	
	@staticmethod
	def backward(ctx, grad_output):
        return grad_output

---------♥---------

声明:本内容来源网络,版权属于原作者

图片来源网络,不代表本公众号立场。如有侵权,联系删除

AI博士私人微信,还有少量空位

如何画出漂亮的深度学习模型图?

如何画出漂亮的神经网络图?

一文读懂深度学习中的各种卷积

点个在看支持一下吧

以上是关于怎样克服神经网络训练中argmax的不可导性?的主要内容,如果未能解决你的问题,请参考以下文章

怎样克服神经网络训练中argmax的不可导性?

绕过tf.argmax是不可区分的

深度学习领域的数据增强

SwiftUI 误导性错误:“Int”不可转换为“CGFloat”

Tensorflow text_generation 教程中对有状态 GRU 的误导性训练数据洗牌

在神经网络中提取知识:学习用较小的模型学得更好