tensorflow的keras实现搭配dataset 之一

Posted wdmx

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow的keras实现搭配dataset 之一相关的知识,希望对你有一定的参考价值。

tensorflow的keras实现搭配dataset,几种形式都工作!

tensorflow,keras Sequential模式下:

见代码:

from tensorflow import keras as ks
import tensorflow as tf

# Generate dummy data
import numpy as np
x_train = np.random.random((1000, 20))
y_train = ks.utils.to_categorical(np.random.randint(10, size=(1000, 1)), num_classes=10)
x_test = np.random.random((100, 20))
y_test = ks.utils.to_categorical(np.random.randint(10, size=(100, 1)), num_classes=10)


batch_size = 100
steps_per_epoch = int(np.ceil(x_train.shape[0]/batch_size))

train_ds = tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_ds = train_ds.batch(batch_size)   # batch 能给数据集增加批维度
train_ds = train_ds.repeat()

train_it = train_ds.make_one_shot_iterator()
x_train_it, y_train_it = train_it.get_next()


test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.batch(batch_size)
test_ds = test_ds.repeat()

model = ks.models.Sequential()
# Dense(64) is a fully-connected layer with 64 hidden units.
# in the first layer, you must specify the expected input data shape:
# here, 20-dimensional vectors.
model.add(ks.layers.Dense(64, activation=relu, input_dim=20))
model.add(ks.layers.Dropout(0.5))
model.add(ks.layers.Dense(64, activation=relu))
model.add(ks.layers.Dropout(0.5))
model.add(ks.layers.Dense(10, activation=softmax))

sgd = ks.optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss=categorical_crossentropy, optimizer=sgd,   metrics=[accuracy])

# passing the data to the model with the below to style, both work
model.fit(x_train_it, y_train_it, epochs=20, steps_per_epoch=steps_per_epoch)
print("(+("*20,
*4)
model.fit(train_ds, epochs=20, steps_per_epoch=steps_per_epoch)

score = model.evaluate(test_ds, steps=128)
print(score)

 

以上是关于tensorflow的keras实现搭配dataset 之一的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Tensorflow Keras 中标准化我的图像数据

将 TensorFlow 模型转换为 Keras hdf5

人工智能深度学习:如何使用TensorFlow2.0实现文本分类?

人工智能深度学习:如何使用TensorFlow2.0实现文本分类?

Tensorflow——keras model.save() raise NotImplementedError

TensorFlow2 动手训练模型和部署服务