CIFAR-10 DEMO代码阅读与理解
Posted aofengdaxia
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了CIFAR-10 DEMO代码阅读与理解相关的知识,希望对你有一定的参考价值。
1、首先打开cifar_train.py 找到最后
if __name__ == '__main__':
tf.app.run()
这个代码是让所有的参数生效类似
tf.app.flags.DEFINE_string()
2、开始执行main()函数
def main(argv = none):
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(Flags.train_dir)
train()
以上代码好理解,就是首先尝试下载cifar10的数据文件,如果培训目录存在,那么删除,创建一个新的培训dir,可以扩展的学习点是:关于tf.gfile的学习 https://blog.csdn.net/pursuit_zhangyu/article/details/80557958
3、研究train()函数
def train():
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step()
with tf.device("/cpu:0"):
images,labels = cifar10.distorted_imputs()
logits = cifar10.inference(images)
loss = cifar10.loss(logits,labels)
train_op = cifar10.train(loss,global_step)
#此处省略记录信息的类
with tf.train.MonitoredTrainingSession(
checkpoint _dir = FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps]=),tf.train.NanTensorHook(loss),_LoggerHook()],
config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
) as mon_sess:
以上是关于CIFAR-10 DEMO代码阅读与理解的主要内容,如果未能解决你的问题,请参考以下文章