如何在 Tensorflow 2.x Keras 自定义层中使用多个输入?

Posted

技术标签:

【中文标题】如何在 Tensorflow 2.x Keras 自定义层中使用多个输入?【英文标题】:How to use multiple inputs in Tensorflow 2.x Keras Custom Layer? 【发布时间】:2020-09-05 12:37:11 【问题描述】:

我正在尝试在 Tensorflow-Keras 的自定义层中使用多个输入。用法可以是任何东西,现在它被定义为将蒙版与图像相乘。我搜索了 SO,我能找到的唯一答案是 TF 1.x,所以它没有任何好处。

class mul(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # I've added pass because this is the simplest form I can come up with.
        pass
          
    def call(self, inputs):
        # magic happens here and multiplications occur
        return(Z)

【问题讨论】:

【参考方案1】:

编辑:自 TensorFlow v2.3/2.4 起,合约将使用call 方法的输入列表。对于keras(不是tf.keras),我认为下面的答案仍然适用。

在您的类的call 方法中实现多个输入,有两种选择:

列表输入,这里inputs参数应该是一个包含所有输入的列表,这里的好处是它可以是可变大小的。您可以使用= 运算符对列表进行索引或解压缩参数:

  def call(self, inputs):
      Z = inputs[0] * inputs[1]

      #Alternate
      input1, input2 = inputs
      Z = input1 * input2

      return Z

call 方法中的多个输入参数,可以工作,但是在定义层时参数的数量是固定的:

  def call(self, input1, input2):
      Z = input1 * input2

      return Z

您选择哪种方法来实现它取决于您需要固定大小或可变大小的参数数量。当然,每个方法都会改变调用层的方式,要么通过传递参数列表,要么在函数调用中逐个传递参数。

您也可以在第一种方法中使用*args 以允许具有可变数量参数的call 方法,但总体而言,keras 自己的层需要多个输入(如ConcatenateAdd)使用列表实现。

【讨论】:

你必须使用一个列表,而不是多个参数。请参阅此“文档”:github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/python/… 多个输入参数违反了tf.keras.Layer.call() (tensorflow.org/api_docs/python/tf/keras/layers/Layer#call) 的约定,其中明确指出inputs 应该是多个输入张量的列表/元组。【参考方案2】:

这样试试

class mul(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # I've added pass because this is the simplest form I can come up with.
        pass

    def call(self, inputs):
        inp1, inp2 = inputs
        Z = inp1*inp2
        return Z

inp1 = Input((10))
inp2 = Input((10))
x = mul()([inp1,inp2])
x = Dense(1)(x)
model = Model([inp1,inp2],x)
model.summary()

【讨论】:

以上是关于如何在 Tensorflow 2.x Keras 自定义层中使用多个输入?的主要内容,如果未能解决你的问题,请参考以下文章

使用Keras的模型类将Tensorflow 1.x代码迁移到Tensorflow 2.x

图像分类手撕ResNet——复现ResNet(Keras,Tensorflow 2.x)

如何在 Tensorflow 中从 tf.keras 导入 keras?

如何在训练 tensorflow.keras 期间替换损失函数

如何使用 tensorflow 在 keras 中禁用 GPU?

如何让 Keras 在 Anaconda 中使用 Tensorflow 后端?