tf.keras 模型 多个输入 tf.data.Dataset

Posted li修远

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tf.keras 模型 多个输入 tf.data.Dataset相关的知识,希望对你有一定的参考价值。

import tensorflow as tf 
a = tf.keras.layers.Input(batch_shape=(None,10, 1))
b = tf.keras.layers.Input(batch_shape=(None,1))

fc1 = tf.keras.layers.Dense(16,‘relu‘)(a)
fc2 = tf.keras.layers.Dense(16,‘relu‘)(b)

fc1 = tf.keras.layers.Lambda(lambda x: x[:,0,:])(fc1)
reshape = tf.keras.layers.Lambda(lambda x: tf.reshape(x,(-1, 16)))(fc1)
hidden = tf.keras.layers.concatenate([reshape, fc2],axis=-1)
inputs = [a, b]
outputs = hidden
print(hidden.shape)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

model.compile(optimizer=tf.keras.optimizers.SGD(),
              loss=tf.keras.losses.mean_squared_error)

import numpy as np
data1 = np.random.rand(10, 10, 1)
data2 = np.random.rand(10, 1)
label  = np.random.rand(10, 32)

dataset1 = tf.data.Dataset.from_tensor_slices((data1, data2))
dataset2 = tf.data.Dataset.from_tensor_slices(label)

dataset  = tf.data.Dataset.zip((dataset1, dataset2)).batch(10).repeat()

model.fit(dataset, epochs=5, steps_per_epoch=30)

参考文献
[1] tensorflow使用tf.keras.Mode写模型并使用tf.data.Dataset作为数据输入
[2] Tensorflow keras入门教程
[3] 使用 tf.data 加载 NumPy 数据

以上是关于tf.keras 模型 多个输入 tf.data.Dataset的主要内容,如果未能解决你的问题,请参考以下文章

TensorFlow2 动手训练模型和部署服务

Keras Estimator + tf.data API

如何在Tensorflow中组合feature_columns,model_to_estimator和dataset API

tensorflow2.0新特性

《30天吃掉那只 TensorFlow2.0》五TensorFlow的中阶API

《30天吃掉那只 TensorFlow2.0》五TensorFlow的中阶API