如何在 tf.keras 自定义损失函数中触发 python 函数?

Posted

技术标签:

【中文标题】如何在 tf.keras 自定义损失函数中触发 python 函数?【英文标题】:How to trigger a python function inside a tf.keras custom loss function? 【发布时间】:2019-10-23 18:21:47 【问题描述】:

在我的自定义损失函数中,我需要调用一个纯 python 函数,传入计算的 TD 错误和一些索引。该函数不需要返回任何东西或被区分。这是我要调用的函数:

def update_priorities(self, traces_idxs, td_errors):
    """Updates the priorities of the traces with specified indexes."""
    self.priorities[traces_idxs] = td_errors + eps

我已经尝试使用tf.py_function 调用包装函数,但只有在它嵌入到图形中时才会调用它,即如果它具有输入和输出并且使用了输出。因此,我尝试通过一些张量而不对它们执行任何操作,现在该函数被调用。这是我的整个自定义损失函数:

def masked_q_loss(data, y_pred):
    """Computes the MSE between the Q-values of the actions that were taken and the cumulative
    discounted rewards obtained after taking those actions. Updates trace priorities.
    """
    action_batch, target_qvals, traces_idxs = data[:,0], data[:,1], data[:,2]
    seq = tf.cast(tf.range(0, tf.shape(action_batch)[0]), tf.int32)
    action_idxs = tf.transpose(tf.stack([seq, tf.cast(action_batch, tf.int32)]))
    qvals = tf.gather_nd(y_pred, action_idxs)

    def update_priorities(_qvals, _target_qvals, _traces_idxs):
        """Computes the TD error and updates memory priorities."""
        td_error = _target_qvals - _qvals
        _traces_idxs = tf.cast(_traces_idxs, tf.int32)
        mem.update_priorities(_traces_idxs, td_error)
        return _qvals

    qvals = tf.py_function(func=update_priorities, inp=[qvals, target_qvals, traces_idxs], Tout=[tf.float32])
    return tf.keras.losses.mse(qvals, target_qvals)

但是由于调用mem.update_priorities(_traces_idxs, td_error),我收到以下错误

ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

我不需要为update_priorities 计算梯度,我只想在图形计算中的特定点调用它并忘记它。我该怎么做?

【问题讨论】:

【参考方案1】:

在包装函数内部的张量上使用.numpy() 解决了这个问题:

def update_priorities(_qvals, _target_qvals, _traces_idxs):
    """Computes the TD error and updates memory priorities."""
    td_error = np.abs((_target_qvals - _qvals).numpy())
    _traces_idxs = (tf.cast(_traces_idxs, tf.int32)).numpy()
    mem.update_priorities(_traces_idxs, td_error)
    return _qvals

【讨论】:

以上是关于如何在 tf.keras 自定义损失函数中触发 python 函数?的主要内容,如果未能解决你的问题,请参考以下文章

如何在具有使用@tf.keras.utils.register_keras_serializable 注册的自定义函数的 Tensorflow Serving 中提供模型?

在 tf.keras 的 model.fit 中,有没有办法将每个样本分批传递 n 次?

在 TF2/Keras 中正确实现 Autoencoder MSE 损失函数

tf.keras中度量和损失MSE之间的差异[重复]

在TF 2.0中使用tf.keras,如何定义依赖于学习阶段的自定义层?

Tensorflow2深度学习基础和tf.keras