15cifar10

Posted pengzhonglian

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了15cifar10相关的知识,希望对你有一定的参考价值。

  1 import  tensorflow as tf
  2 from    tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
  3 from     tensorflow import keras
  4 import  os
  5 
  6 os.environ[TF_CPP_MIN_LOG_LEVEL] = 2
  7 
  8 
  9 def preprocess(x, y):
 10     # [0~255] => [-1~1]
 11     x = 2 * tf.cast(x, dtype=tf.float32) / 255. - 1.
 12     y = tf.cast(y, dtype=tf.int32)
 13     return x,y
 14 
 15 
 16 batchsz = 128
 17 # [50k, 32, 32, 3], [10k, 1]
 18 (x, y), (x_val, y_val) = datasets.cifar10.load_data()
 19 y = tf.squeeze(y)
 20 y_val = tf.squeeze(y_val)
 21 y = tf.one_hot(y, depth=10) # [50k, 10]
 22 y_val = tf.one_hot(y_val, depth=10) # [10k, 10]
 23 print(datasets:, x.shape, y.shape, x_val.shape, y_val.shape, x.min(), x.max())
 24 
 25 
 26 train_db = tf.data.Dataset.from_tensor_slices((x,y))
 27 train_db = train_db.map(preprocess).shuffle(10000).batch(batchsz)
 28 test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
 29 test_db = test_db.map(preprocess).batch(batchsz)
 30 
 31 
 32 sample = next(iter(train_db))
 33 print(batch:, sample[0].shape, sample[1].shape)
 34 
 35 
 36 class MyDense(layers.Layer):
 37     # to replace standard layers.Dense()
 38     def __init__(self, inp_dim, outp_dim):
 39         super(MyDense, self).__init__()
 40 
 41         self.kernel = self.add_variable(w, [inp_dim, outp_dim])
 42         # self.bias = self.add_variable(‘b‘, [outp_dim])
 43 
 44     def call(self, inputs, training=None):
 45 
 46         x = inputs @ self.kernel
 47         return x
 48 
 49 class MyNetwork(keras.Model):
 50 
 51     def __init__(self):
 52         super(MyNetwork, self).__init__()
 53 
 54         self.fc1 = MyDense(32*32*3, 256)
 55         self.fc2 = MyDense(256, 128)
 56         self.fc3 = MyDense(128, 64)
 57         self.fc4 = MyDense(64, 32)
 58         self.fc5 = MyDense(32, 10)
 59 
 60 
 61 
 62     def call(self, inputs, training=None):
 63         """
 64 
 65         :param inputs: [b, 32, 32, 3]
 66         :param training:
 67         :return:
 68         """
 69         x = tf.reshape(inputs, [-1, 32*32*3])
 70         # [b, 32*32*3] => [b, 256]
 71         x = self.fc1(x)
 72         x = tf.nn.relu(x)
 73         # [b, 256] => [b, 128]
 74         x = self.fc2(x)
 75         x = tf.nn.relu(x)
 76         # [b, 128] => [b, 64]
 77         x = self.fc3(x)
 78         x = tf.nn.relu(x)
 79         # [b, 64] => [b, 32]
 80         x = self.fc4(x)
 81         x = tf.nn.relu(x)
 82         # [b, 32] => [b, 10]
 83         x = self.fc5(x)
 84 
 85         return x
 86 
 87 
 88 network = MyNetwork()
 89 network.compile(optimizer=optimizers.Adam(lr=1e-3),
 90                 loss=tf.losses.CategoricalCrossentropy(from_logits=True),
 91                 metrics=[accuracy])
 92 network.fit(train_db, epochs=15, validation_data=test_db, validation_freq=1)
 93 
 94 network.evaluate(test_db)
 95 network.save_weights(ckpt/weights.ckpt)
 96 del network
 97 print(saved to ckpt/weights.ckpt)
 98 
 99 
100 network = MyNetwork()
101 network.compile(optimizer=optimizers.Adam(lr=1e-3),
102                 loss=tf.losses.CategoricalCrossentropy(from_logits=True),
103                 metrics=[accuracy])
104 network.load_weights(ckpt/weights.ckpt)
105 print(loaded weights from file.)
106 network.evaluate(test_db)

以上是关于15cifar10的主要内容,如果未能解决你的问题,请参考以下文章

从 cifar-10 数据集加载图像

CIFAR-10 DEMO代码阅读与理解

Cifar-10数据集及Tensorflow代码实现

如何创建类似于 cifar-10 的数据集 [关闭]

CIFAR-10 尺寸误差 Keras

pytorch实现CIFAR10实战