怎样克服神经网络训练中argmax的不可导性?
Posted 人工智能博士
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了怎样克服神经网络训练中argmax的不可导性?相关的知识,希望对你有一定的参考价值。
点上方人工智能算法与Python大数据获取更多干货
在右上方 ··· 设为星标 ★,第一时间获取资源
仅做学术分享,如有侵权,联系删除
转载于 :作者|Zhenyue Qin、hoooz、OwlLite
来源 | 知乎问答
地址 | https://www.zhihu.com/question/460500204
原问题
最近在使用torch做nlp的风格转换,当我利用gan进行学习时,发现seq2seq的输出是(batch size,max length,vocab length)形状的tensor,最后一维表示经过softmax的词典里各单词的出现概率。
根据gan的原理,我想将生成器的输出作为输入得到判决器输出(batch size,labels nums)形状的tensor。然后与标准labels做交叉熵得到 loss。但是判决器的输入tensor应该是(batch size,max length)形状。
这里如果将seq2seq的输出out进行out.argmax(-1)处理的话会导致loss.backward()无法在网络内产生梯度。想请教下各位大神在这里有没有什么解决的好方法。
01
回答一:作者-Zhenyue Qin
有个东西叫strainght through Gumbel (estimator), 可以看一下~
大概思想就是: 假设输入的向量是v, 那么我们用softmax得到softmax(v). 这样, 最大值那个地方就会变得很靠近1, 其他地方就会变得很靠近0. 然后, 我们计算argmax(v), 接着可以得到一个常数c = argmax(v) - softmax(v). 我们这时, 可以用softmax(v) + c来作为argmax(v)的结果. 这个东西的好处是, 我们的softmax(v) + c是有反向传播的能力的. 换句话说, 我们用softmax(v)的梯度来作为反向传播.
如果有哪点没说清楚, 欢迎评论. 谢谢.
P.S. 感谢吕纯川和Towser对于原回答的指正.
02
回答二:作者-hoooz
方案1:加入stop gradient operation, 请参考VQVAE以及对应的pytorch实现 [1][2]
一句话解释: 正向传播就和往常一样,反向传播时,将梯度从不可导那个点copy到 不可导点的前面的最近一个可导点。
(请看红线右端点的梯度,跳过中间的字典模块,直达红线的左端点)
~
问题来了
1/梯度链条怎么隔断不让他经过字典模块?pytorch有个 detach(), 可以隔断梯度,梯度就不会进入 不可导区域 引发编译器报错
2/梯度怎么复制?举个最简单例子
quantize = input + (quantize - input).detach()
# 正向传播和往常一样,
# 反向传播时,detach()这部分梯度为0,quantize和input的梯度相同,
# 即实现将quantize复制给input
# quantize即红线右端点,input即红线左端点
参考:
[1]. Neural Discrete Representation Learning
[2]. https://github.com/rosinality/v
03
回答三:作者-OwlLite
可以对argmax/argmin 这种不可导的操作直接忽视,也就是锁定:
class ArgMax(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
idx = torch.argmax(input, 1)
output = torch.zeros_like(input)
output.scatter_(1, idx, 1)
return output
@staticmethod
def backward(ctx, grad_output):
return grad_output
---------♥---------
声明:本内容来源网络,版权属于原作者
图片来源网络,不代表本公众号立场。如有侵权,联系删除
AI博士私人微信,还有少量空位
点个在看支持一下吧
以上是关于怎样克服神经网络训练中argmax的不可导性?的主要内容,如果未能解决你的问题,请参考以下文章
SwiftUI 误导性错误:“Int”不可转换为“CGFloat”