使用keras在损失函数中批量逐元素产品
Posted
技术标签:
【中文标题】使用keras在损失函数中批量逐元素产品【英文标题】:Batch element-wise product in loss function with keras 【发布时间】:2021-02-16 21:40:51 【问题描述】:我正在尝试在 keras 中编写一个自定义损失函数,我需要通过中间层的输出(其形状为 (batch_size, 1))。
我需要的操作只是将 MSE 的每个批次元素加权(乘以)一个因子,即 weight_tensor 中的相应批次元素。
我尝试了以下
def loss_aux_wrapper(weight_tensor):
def loss_aux(y_true, y_pred):
K.print_tensor(weight_tensor, message='weight = ')
_shape = K.shape(y_true)
return K.reshape(K.batch_dot(K.batch_flatten(mse(y_true, y_pred)), weight_tensor, axes=[1,1]),_shape)
return loss_aux
但我明白了
tensorflow.python.framework.errors_impl.InvalidArgumentError: In[0] mismatch In[1] shape: 4096 vs. 1: [32,1,4096] [32,1,1] 0 0 [[node loss/aux_motor_output_loss/MatMul (defined at /code/icub_sensory_enhancement/scripts/models.py:327) ]] [Op:__inference_keras_scratch_graph_6548]
K.print_tensor 没有输出任何东西,我相信是因为它是在编译时调用的?
非常感谢任何建议!
【问题讨论】:
【参考方案1】:为了加权 MSE 损失函数,您可以在调用 mse
函数时使用 sample_weight=
参数。根据docs,
如果
sample_weight
是大小为[batch_size]
的张量,那么总损失 对于批次的每个样本,都由相应的元素重新缩放 在sample_weight
向量中。如果sample_weight
的形状是[batch_size, d0, .. dN-1]
(或者可以广播到这个形状),然后y_pred
的每个损失元素都按对应的值缩放sample_weight
。 (注意ondN-1:所有损失函数都减少一维, 通常是axis=-1
。)
在您的情况下,weight_tensor
的形状为 ( batch size , 1 )
。所以首先我们需要重塑它,比如,
reshaped_weight_tensor = K.reshape( weight_tensor , shape=( batch_size ) )
然后这个张量可以和MeanSquaredError
一起使用,
def loss_aux_wrapper(weight_tensor):
def loss_aux(y_true, y_pred):
reshaped_weight_tensor = K.reshape( weight_tensor , shape=( batch_size ) )
return mse( y_true , y_pred , sample_weight=reshaped_weight_tensor )
return loss_aux
【讨论】:
以上是关于使用keras在损失函数中批量逐元素产品的主要内容,如果未能解决你的问题,请参考以下文章
Keras 中的像素加权损失函数 - TensorFlow 2.0