Tensorflow 静态图的动态收缩
Posted Yan_Joy
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow 静态图的动态收缩相关的知识,希望对你有一定的参考价值。
Tensorflow 的静态图机制给一个动态调整区间和mask的网络带来了不少麻烦。
问题描述
随着训练的进行,扩大区间 r r r的范围,并对区间内的权重进行量化操作。一次训练可能要量化多个区间,量化后权重冻结。
静态图思路
权重冻结
这是一个老问题,之前的文章中也有介绍。解决方法还是:
def entry_stop_gradients(target, mask):
mask = tf.cast(mask, tf.float32)
mask_h = tf.abs(mask-1.)
return tf.stop_gradient(mask * target) + mask_h * target
卷积、全连接层
计算层比较尴尬,他们不能直接对权重进行更改操作。如果基于slim
或keras
库,conv2d
只能进行计算和mask的赋值操作,这也是目前大部分量化的训练方式。其实这个层只是进行了前向的传播,导致网络数据流经此层后,由于反向传播的更新值覆盖了用户想要赋值固定的值,导致无法量化。更大的一个问题在于不同区间的转换,如果一个层仅对一个区间进行量化,那么在完成一个区间量化,想进行下一个区间量化时,就要重新构建网络。
在这类层,一方面可以通过停止量化区间内的反向传播,另一方面对之前已量化的权重进行mask标记。
def call(self, inputs):
S = tf.cond(self.s < 1.0, lambda: tf.assign_add(self.s, self.d_s), lambda: tf.identity(1.0))
mask_b = tf.logical_and(tf.less_equal(self.kernel, self.center+self.max_dis*S),tf.greater_equal(self.kernel, self.center-self.min_dis*S))
mask_f = tf.assign(self.mask_f, tf.logical_or(self.mask_f, mask_b))
stopped = entry_stop_gradients(self.kernel, mask_f) # stop kernel < 0 1停止,0继续
outputs = tf.nn.conv2d(
inputs,
stopped,
strides=[1, self.strides[0], self.strides[1], 1],
data_format='NHWC',
padding=self.padding)
if self.use_bias:
outputs = K.bias_add(
outputs,
self.bias,
data_format=self.data_format)
if self.activation is not None:
return self.activation(outputs)
return outputs
限制器
如果想要进行权重的赋值和更改,还是要依靠限制器。通过限制器对网络进行赋值,可以不利用mask而保存真正的量化值。
kernel_c = tf.where(self.r_mask, K.clip(w, self.b_retrain[0]+1e-6, self.b_retrain[1]-1e-6), w)
S = tf.cond(self.s < 1.0, lambda: tf.assign_add(self.s, self.d_s), lambda: tf.identity(1.0))
mask_b = tf.logical_and(tf.less_equal(kernel_c, self.center+self.max_dis*S),tf.greater_equal(kernel_c, self.center-self.min_dis*S))
kernel_m = tf.where(mask_b , self.center*tf.ones_like(kernel_c), kernel_c)
stopped = entry_stop_gradients(kernel_m, tf.logical_or(mask_b, self.f_mask) if self.f_mask is not None else mask_b)
总结
以上只是这个问题的一个解决思路,实际在训练中出现了准确率无法在后续收缩中保持的情况。还在寻找其问题根源……
以上是关于Tensorflow 静态图的动态收缩的主要内容,如果未能解决你的问题,请参考以下文章
[TensorFlow系列-22]:基本元素与运行机制 - TensorVariableOperationSessionPlaceholderGraph静态与动态数据流图的比较