如何使用 Tensorflow 2 在我的自定义优化器上更新可训练变量

Posted

技术标签:

【中文标题】如何使用 Tensorflow 2 在我的自定义优化器上更新可训练变量【英文标题】:How to Update Trainable Variables on My Custom Optimizer Using Tensorflow 2 【发布时间】:2020-12-31 19:59:03 【问题描述】:

我是现在学习卷积神经网络的新手。

所以,我一直在实施 AlexNet,参考了一篇题为“ImageNet Classification with Deep Convolution Neural Networks”的论文,该论文在 anaconda 环境中使用了 Tensorflow 2.3。 但是,在实现自定义优化器时,我感到很沮丧。

我的问题:我必须根据 AlexNet 论文修改优化器。我找不到参考 How to update the variable that uses TensorflowV2,虽然我有谷歌搜索。只有“tf.assign()”使用,在Tensorflow V2中不支持,但是如果我要使用这个功能,我也害怕V1和V2之间的兼容性。

我只知道我必须自定义“_resource_apply_dense”函数来适应我的更新规则。然后,我在那里加载了一些超参数。但我不知道如何更新超参数。

(tf.Variable() 可以作为python变量使用,所以我想它和tf.Variable()....一样?)

感谢所有高级读者^_^

这是代码。

- update rule in AlexNet
 v_(i+1)= momentum * v_i- 0.0005 * learning_rate * w_i - learning_rate * gradient
 w_(i+1) = w_i + v_(i+1)

# where
# w_i = weight
# v_i = velocity
from tensorflow.python.framework import ops
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training import training_ops
from tensorflow.python.util.tf_export import keras_export
import tensorflow as tf

class AlexSGD(optimizer_v2.OptimizerV2):

    _HAS_AGGREGATE_GRAD = True

    def __init__(self,
                learning_rate=0.01,
                weight_decay=0.0005,
                momentum=0.9,
                name="AlexSGD",
                **kwargs):
        super(AlexSGD, self).__init__(name, **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
        self._set_hyper("decay", self._initial_decay)
        
        self._is_first = True
        self._set_hyper("vx", 0)
        self._set_hyper("pg", 0)
        self._set_hyper("pv", 0)
        self._weight_decay = False
        if isinstance(weight_decay, ops.Tensor) or callable(weight_decay) or 
            weight_decay > 0:
        self._weight_decay = True
        if isinstance(weight_decay, (int, float)) and (weight_decay < 0 or 
            weight_decay > 1):
        raise ValueError("`weight_decay` must be between [0, 1].")
        self._set_hyper("weight_decay", weight_decay)

        self._momentum = False
        if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0:
        self._momentum = True
        if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1):
        raise ValueError("`momentum` must be between [0, 1].")
        self._set_hyper("momentum", momentum)

    def _create_slots(self, var_list):
        if self._momentum:
        for var in var_list:
            self.add_slot(var, "momentum")
        if self._weight_decay:
        for var in var_list:
            self.add_slot(var, "weight_decay")
        for var in var_list:
        self.add_slot(var, 'pv') # previous variable i.e. weight or bias    
        for var in var_list:
        self.add_slot(var, 'pg') # previous gradient
        for var in var_list:
        self.add_slot(var, 'vx') # update velocity

    def _prepare_local(self, var_device, var_dtype, apply_state):
        super(AlexSGD, self)._prepare_local(var_device, var_dtype, apply_state)
        apply_state[(var_device, var_dtype)]["momentum"] = array_ops.identity(
            self._get_hyper("momentum", var_dtype))
        apply_state[(var_device, var_dtype)]["weight_decay"] = array_ops.identity(
            self._get_hyper("weight_decay", var_dtype))
        apply_state[(var_device, var_dtype)]["vx"] = array_ops.identity(
            self._get_hyper("vx", var_dtype))
        apply_state[(var_device, var_dtype)]["pv"] = array_ops.identity(
            self._get_hyper("pv", var_dtype))
        apply_state[(var_device, var_dtype)]["pg"] = array_ops.identity(
            self._get_hyper("pg", var_dtype))

    # main function
    @tf.function
    def _resource_apply_dense(self, grad, var, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype
        coefficients = ((apply_state or ).get((var_device, var_dtype))
                        or self._fallback_apply_state(var_device, var_dtype))
        momentum_var = self.get_slot(var, "momentum")
        weight_decay_var = self.get_slot(var, "weight_decay")
        vx_var = self.get_slot(var, "vx")
        pv_var = self.get_slot(var, "pv")
        pg_var = self.get_slot(var, "pg")
        lr_t = self._decayed_lr(var_dtype)

        # update rule in AlexNet
        # v_(i+1) = momentum * v_i - 0.0005 * lr * w_i - lr * grad
        # w_(i+1) = w_i + v_(i+1)
        # where
        # w_i = var
        # vx, v_i = velocity (Feel like I need to set this variable as a slot) 
        # lr = learning_rate
        # grad = gradient
        
        # I'm confused why pv, pg variables are declared... 
        # does it replace by var & grad ?  (pv, pg refer from blog)
        # pv = previous var
        # pg = previous gradient

        if self._is_first:
        self._is_first = False
        vx_var = grad
        new_var = var + vx_var
        else:
        vx_var = momentum_var * vx_var - weight_decay_var*lr_t*pv_var- 
        lr_t*pg_var
        new_var = var + vx_var

        print("grad:",grad)
        print("var:",var)
        print("vx_var:",vx_var)
        print("new_var:",new_var)
        print("pv_var:",pv_var)
        print("pg_var:",pg_var)
        
        # TODO: I got stuck how update the variables because tf.assign() function 
        #       is deprecated in Tensorflow V2 
        pg_var = grad
        pv_var = var

        if var == new_var:
        var = new_var
        
        # TODO: In order to update variables, I can't find the equivalent 
        #        "tf.assign" method in TF V2
        # pg_var.assign(grad)
        # vx_var.assign(vx_var)
        # var.assign(new_var)
        
        
        
        """
        # TODO: I referred the below code from Tensorflow official document, and I 
        #       realized the training_ops module is in c++ library, So I thought I 
        #       can't modify it ( Cuz I need to modify an update function of 
        #       velocity 

        # return training_ops.resource_apply_keras_momentum(
        #     var.handle,
        #     momentum_var.handle,
        #     coefficients["lr_t"],
        #     grad,
        #     coefficients["momentum"],
        #     use_locking=self._use_locking,
        #     use_nesterov=self.nesterov)


        # if self._momentum :
        #  momentum_var = self.get_slot(var, "momentum")
        #   return training_ops.resource_apply_keras_momentum(
        #     var.handle,
        #     momentum_var.handle,
        #     coefficients["lr_t"],
        #     grad,
        #     coefficients["momentum"],
        #     use_locking=self._use_locking,
        #     use_nesterov=self.nesterov)
        # else:
        #   return training_ops.resource_apply_gradient_descent(
        #       var.handle, coefficients["lr_t"], grad, 
                use_locking=self._use_locking)
        """


    def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
        raise NotImplementedError

    def get_config(self):
        config = super(AlexSGD, self).get_config()
        config.update(
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
            "decay": self._serialize_hyperparameter("decay"),
            "weight_decay": self._serialize_hyperparameter("weight_decay"),
            "momentum": self._serialize_hyperparameter("momentum"),
        )
        return config

参考:https://www.kdnuggets.com/2018/01/custom-optimizer-tensorflow.html

这是另一个参考,它也使用 tf.assign()

from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.framework import ops
from tensorflow.python.training import optimizer
import tensorflow as tf

class AlexOptimizer(optimizer.Optimizer):
    def __init__(self, learning_rate="learning_rate",alpha="alpha",beta="beta", #weight_decay="weight_decay", use_locking=False, name="AlexOptimizer"):
        super(AlexOptimizer, self).__init__(use_locking, name)
        self._lr = learning_rate
        self._wd = weight_decay
        self._alpha = alpha
        self._beta = beta
        # Tensor versions of the constructor arguments, created in _prepare().
        self._lr_t = None
        self._wd_t = None
        self._alpha_t = None
        self._beta_t = None

    def _prepare(self):
        self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate")
        self._wd_t = ops.convert_to_tensor(self._wd, name="weight_decay")
        self._alpha_t = ops.convert_to_tensor(self._beta, name="alpha_t")
        self._beta_t = ops.convert_to_tensor(self._beta, name="beta_t")

    def _create_slots(self, var_list):
        # Create slots for the first and second moments.
        for v in var_list:
            self._zeros_slot(v, "m", self._name)

    def _apply_dense(self, grad, var):
        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
        wd_t = math_ops.cast(self._wd_t, var.dtype.base_dtype)
        alpha_t = math_ops.cast(self._alpha_t, var.dtype.base_dtype)
        beta_t = math_ops.cast(self._beta_t, var.dtype.base_dtype)

        eps = 1e-7 #cap for moving average
        m = self.get_slot(var, "m")
        m_t = m.assign(tf.maximum(beta_t * m + eps, tf.abs(grad)))

        var_update = state_ops.assign_sub(var, lr_t*grad*tf.exp( tf.log(alpha_t)*tf.sign(grad)*tf.sign(m_t))) 
        # Update 'ref' by subtracting value
        # Create an op that groups multiple operations.
        # When this op finishes, all ops in input have finished
        return control_flow_ops.group(*[var_update, m_t])
    def _apply_sparse(self, grad, var):
        raise NotImplementedError("Sparse gradient updates are not supported.")

我想使用 OptimizerV2 修改代码,我应该更新哪些变量?

(p.s.)“def _resource_apply_dense()”上面的“@tf.function”用法对吗?

另一方面,我的模型在训练过程中不断改组ㅠ_ㅠ(此代码在数据集预处理过程中(tf.data.datsets.shuffle()),即使它在while循环中不存在..........(很抱歉没有发布此代码......所以没关系......)

【问题讨论】:

【参考方案1】:

不需要来自v1tf.assign。在v2 中,assign 是一个Variable class method。

您的第二个示例没有使用tf.assign

m_t = m.assign(tf.maximum(beta_t * m + eps, tf.abs(grad)))

适用于v2。它只是从v2 API 调用类方法。

【讨论】:

以上是关于如何使用 Tensorflow 2 在我的自定义优化器上更新可训练变量的主要内容,如果未能解决你的问题,请参考以下文章

tensorflow lite添加自定义操作

如何在我的自定义主题中包含自定义 js 文件?

如何在我的自定义类型中包含非命名模块 TS 定义

如何在我的自定义导航控制器中隐藏标签栏?

如何在我的自定义应用程序中禁用 flashplayer 上的声音?

如何在我的自定义用户模块中避免此“'tuple' 对象没有属性 'photo'”错误?