在 keras 回调中使用带有自定义参数的自定义函数

Posted

技术标签:

【中文标题】在 keras 回调中使用带有自定义参数的自定义函数【英文标题】:Use custom function with custom parameters in keras callback 【发布时间】:2019-11-19 01:24:31 【问题描述】:

我正在 keras 中训练一个模型,我想在每个 epoch 之后绘制结果图。我知道 keras 回调提供了“on_epoch_end”函数,如果一个人想在每个 epoch 之后进行一些计算,那么该函数可以被重载,但是我的函数需要一些额外的参数,当给定这些参数时,元类错误会导致代码崩溃。具体如下:

这是我现在的做法,效果很好:-

class NewCallback(Callback):

def on_epoch_end(self, epoch, logs=):  #working fine, printing epoch after each epoch
    print("EPOCH IS: "+str(epoch))


epochs=5
batch_size = 16
model_saved=False
if model_saved:
    vae.load_weights(args.weights)
else:
    # train the autoencoder
    vae.fit(x_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(x_test, None),
           callbacks=[NewCallback()])

但我想要这样的回调函数:-

class NewCallback(Callback,models,data,batch_size):
   def on_epoch_end(self, epoch, logs=):
     print("EPOCH IS: "+str(epoch))
     x=models.predict(data)
     plt.plot(x)
     plt.savefig(epoch+".png")

如果我这样称呼它:

callbacks=[NewCallback(models, data, batch_size=batch_size)]

我收到此错误:

TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases 

我正在寻找一个更简单的解决方案来调用我的函数或解决元类的这个错误,非常感谢任何帮助!

【问题讨论】:

【参考方案1】:

我认为您想做的是定义一个从回调派生的类,并将模型、数据等...作为构造函数参数。所以:

class NewCallback(Callback):
    """ NewCallback descends from Callback
    """
    def __init__(self, models, data, batch_size):
        """ Save params in constructor
        """
        self.models = models

    def on_epoch_end(self, epoch, logs=):
        x = self.models.predict(self.data)

【讨论】:

【参考方案2】:

如果你想对测试数据进行预测,你可以试试这个

class CustomCallback(keras.callbacks.Callback):
    def __init__(self, model, x_test, y_test):
        self.model = model
        self.x_test = x_test
        self.y_test = y_test

    def on_epoch_end(self, epoch, logs=):
        y_pred = self.model.predict(self.x_test, self.y_test)
        print('y predicted: ', y_pred)

你需要在model.fit期间提及回调

model.sequence()
# your model architecture
model.fit(x_train, y_train, epochs=10, 
          callbacks=[CustomCallback(model, x_test, y_test)])

on_epoch_end类似,keras还提供了很多其他方法

on_train_begin, on_train_end, on_epoch_begin, on_epoch_end, on_test_begin,
on_test_end, on_predict_begin, on_predict_end, on_train_batch_begin, on_train_batch_end,
on_test_batch_begin, on_test_batch_end, on_predict_batch_begin,on_predict_batch_end

【讨论】:

以上是关于在 keras 回调中使用带有自定义参数的自定义函数的主要内容,如果未能解决你的问题,请参考以下文章

Keras 中带有附加变量输入的自定义损失/目标函数

在带有多个参数的自定义模板标签的模板中使用“if”

Keras 中基于输入数据的自定义损失函数

图像分割 - Keras 中的自定义损失函数

如何在 keras 自定义回调中访问 tf.data.Dataset?

带有自定义对象的 Keras load_model 无法正常工作