谷歌的inception模型是怎么训练的

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了谷歌的inception模型是怎么训练的相关的知识,希望对你有一定的参考价值。

参考技术A

Inception (GoogLeNet)是Google 2014年发布的Deep Convolutional Neural Network,其它几个流行的CNN网络还有QuocNet、AlexNet、BN-Inception-v2、VGG、ResNet等等。

Inception V3模型源码定义:tensorflow/contrib/slim/python/slim/nets/inception_v3.py

训练大的网络模型很耗资源,幸亏TensorFlow支持分布式:

    把计算任务Distribution到服务器集群

    把计算任务Distribution到多个GPU

    TensorBoard可视化Inception V3模型

    1

    2

    3

    4

    5

    6

    7

    8

    9

    10

    11

    12

    13

    14

    15

    16

    17

    18

    19

    20

    21

    22

    23

    24

    25

    26

    27

    28

    29

    30

    31

    32

    33

    34

    35

    36

    37

    38

    39

    40

       

    import tensorflow as tf

    import os

    import tarfile

    import requests

    inception_pretrain_model_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'

    # 下载inception pretrain模型

    inception_pretrain_model_dir = "inception_pretrain"

    if not os.path.exists(inception_pretrain_model_dir):

    os.makedirs(inception_pretrain_model_dir)

    filename = inception_pretrain_model_url.split('/')[-1]

    filepath = os.path.join(inception_pretrain_model_dir, filename)

    if not os.path.exists(filepath):

    print("开始下载: ", filename)

    r = requests.get(inception_pretrain_model_url, stream=True)

    with open(filepath, 'wb') as f:

    for chunk in r.iter_content(chunk_size=1024):

    if chunk:

    f.write(chunk)

    print("下载完成, 开始解压: ", filename)

    tarfile.open(filepath, 'r:gz').extractall(inception_pretrain_model_dir)

    # TensorBoard log目录

    log_dir = 'inception_log'

    if not os.path.exists(log_dir):

    os.makedirs(log_dir)

    # 加载inception graph

    inception_graph_def_file = os.path.join(inception_pretrain_model_dir, 'classify_image_graph_def.pb')

    with tf.Session() as sess:

    with tf.gfile.FastGFile(inception_graph_def_file, 'rb') as f:

    graph_def = tf.GraphDef()

    graph_def.ParseFromString(f.read())

    tf.import_graph_def(graph_def, name='')

    writer = tf.train.SummaryWriter(log_dir, sess.graph)

    writer.close()

       

    使用TensorBoard查看Graph:

    1

       

    $ tensorboard --logdir inception_log

       

    浏览器访问:http://127.0.0.1:6006

    如要转载,请保持本文完整,并注明作者@斗大的熊猫和本文原始地址: http://blog.topspeedsnail.com/archives/10919

以上是关于谷歌的inception模型是怎么训练的的主要内容,如果未能解决你的问题,请参考以下文章

机器学习谷歌的速成课程

Opencv+TF-Slim实现图像分类及深度特征提取

计算机视觉:用inception-v3模型重新训练自己的数据模型

tensorflow 1.0 学习:用别人训练好的模型来进行图像分类

下载inception v3 google训练好的模型并解压08-3

将预训练的 inception_resnet_v2 与 Tensorflow 结合使用