Keras 自定义损失函数

Posted

技术标签:

【中文标题】Keras 自定义损失函数【英文标题】:Keras custom loss-function 【发布时间】:2020-12-19 15:05:38 【问题描述】:

我想实现以下自定义损失函数,参数x 作为最后一层的输出。到目前为止,我将这个函数实现为 Lambda 层,再加上 keras 的 mae 损失,但我不想再这样做了

def GMM_UNC2(self, x):
    tmp = self.create_mr(x) # get mr series
    mr  = k.sum(tmp, axis=1) # sum over time
    tmp = k.square((1/self.T_i) * mr)
    tmp = k.dot(tmp, k.transpose(self.T_i))
    tmp = (1/(self.T * self.N)) * tmp

    f   = self.create_factor(x) # get factor
    std = k.std(f)
    mu  = k.mean(f)
    tmp = tmp + std/mu 

    def loss(y_true, y_pred=tmp):
        return k.abs(y_true-y_pred)

    return loss

self.y_true = np.zeros((1,1))
self.sdf_net = Model(inputs=[self.in_ma, self.in_mi, self.in_re, self.in_si], outputs=w)
self.sdf_net.compile(optimizer=self.optimizer, loss=self.GMM_UNC2(w))
self.sdf_net.fit([self.macro, self.micro, self.R, self.R_sign], self.y_true, epochs=epochs, verbose=1)

代码实际运行,但实际上并未使用tmp 作为损失的输入(我将它乘以某个数字,但损失保持不变)

我做错了什么?

【问题讨论】:

【参考方案1】:

如果您想将GMM_UNC2 函数应用于预测,或者仅应用一次来构建损失,您的问题并不完全清楚。如果是第一个选项,那么所有代码​​都应该在损失内并将其应用于y_pred,例如

def GMM_UNC2(self):

    def loss(y_true, y_pred):
        tmp = self.create_mr(y_pred) # get mr series
        mr  = k.sum(tmp, axis=1) # sum over time
        tmp = k.square((1/self.T_i) * mr)
        tmp = k.dot(tmp, k.transpose(self.T_i))
        tmp = (1/(self.T * self.N)) * tmp
        f   = self.create_factor(x) # get factor
        std = k.std(f)
        mu  = k.mean(f)
        tmp = tmp + std/mu 
        return k.abs(y_true-y_pred)

    return loss

如果是第二个选项,一般来说,在 Python 函数定义中将对象作为默认值传递并不是一个好主意,因为它可以在函数定义中更改。此外,您假设损失的第二个参数有一个名称y_pred,但是当调用它时,它是在没有名称的情况下完成的,作为位置参数。总之,您可以尝试在损失内使用显式比较,例如

    def loss(y_true, y_pred):
        if y_pred is None:
            y_pred = tmp
        return k.abs(y_true - y_pred)

如果你喜欢忽略预测,强行使用tmp,那么你可以忽略loss的y_pred参数,只使用tmp,like

    def loss(y_true, _):
        return k.abs(y_true - tmp)

【讨论】:

以上是关于Keras 自定义损失函数的主要内容,如果未能解决你的问题,请参考以下文章

Keras 中的自定义损失函数(IoU 损失函数)和梯度误差?

Keras 自定义损失函数

Keras 上的自定义损失函数

Keras 中基于输入数据的自定义损失函数

Keras 自定义损失函数 - 生存分析截尾

在 keras 中制作自定义损失函数