R Keras 中的自定义损失函数
Posted
技术标签:
【中文标题】R Keras 中的自定义损失函数【英文标题】:Custom Loss Function in R Keras 【发布时间】:2018-12-21 07:32:54 【问题描述】:我想计算加权均方误差,其中权重是数据中的一个向量。我根据堆栈溢出可用的建议编写了自定义代码。
功能如下:
weighted_mse <- function(y_true, y_pred,weights)
# convert tensors to R objects
K <- backend()
y_true <- K$eval(y_true)
y_pred <- K$eval(y_pred)
weights <- K$eval(weights)
# calculate the metric
loss <- sum(weights*((y_true - y_pred)^2))
# convert to tensor
return(K$constant(loss))
但是,我不确定如何将自定义函数传递给编译器。如果有人可以帮助我,那就太好了。谢谢你。
model <- model %>% compile(
loss = 'mse',
optimizer = 'rmsprop',
metrics = 'mse')
问候
【问题讨论】:
【参考方案1】:我没有将 Keras 与 R 一起使用,但是按照 documentation 中的示例,这应该可以工作:
weighted_mse <- function(y_true, y_pred, weights)
K <- backend()
weights <- K$variable(weights)
# calculate the metric
loss <- K$sum(weights * (K$pow(y_true - y_pred, 2)))
loss
metric_weighted_mse <- custom_metric("weighted_mse", function(y_true, y_pred)
weighted_mse(y_true, y_pred, weights)
)
model <- model %>% compile(
loss = 'mse',
optimizer = 'rmsprop',
metrics = metric_weighted_mse)
请注意,我为损失函数使用了一个包装器,因为它有一个额外的参数。此外,损失函数将输入作为张量处理,这就是为什么您应该使用K$variable(weights)
转换权重。
【讨论】:
我在使用您的功能时收到以下错误。 py_call_impl 中的错误(callable,dots$args,dots$keywords):RuntimeError:评估错误:AttributeError:'function' 对象没有属性'eval'。【参考方案2】:你不能在损失函数中eval
。这将破坏图表。
您应该只使用fit
方法的sample_weight
参数:https://keras.rstudio.com/reference/fit.html
##not sure if this is valid R, but
##at some point you will call `fit` for training with `X_train` and `Y_train`,
##so, just add the weights.
history <- model$fit(X_train, Y_train, ..., sample_weight = weights)
就是这样(不要使用自定义损失)。
仅供参考 - 将损失函数传递给 compile
仅适用于采用 y_true
和 y_pred
的函数。 (如果您使用sample_weights
,则不需要)
model <- model %>% compile(
loss = weighted_mse,
optimizer = 'rmsprop',
metrics = 'mse')
但这不起作用,您需要类似于@spadarian 创建的包装器。
此外,保持数据和权重之间的相关性将非常复杂,因为 Keras 会将您的数据分批划分,也因为数据会被打乱。
【讨论】:
好的。谢谢你。那么 sample_weight 在损失函数的计算中使用这些权重呢?那么例如,使用 sample_weight 的 mse 是否等同于加权 mse?我注意到使用 sample_weight 时我的拟合和预测更差,因此我在问。 是的,使用sample_weight
+ mse
与使用weighted_mse
相同。以上是关于R Keras 中的自定义损失函数的主要内容,如果未能解决你的问题,请参考以下文章
Keras 中的自定义损失函数 - 遍历 TensorFlow