tensorflow中关于vgg16的项目

Posted 徐长卿学数据分析

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow中关于vgg16的项目相关的知识,希望对你有一定的参考价值。

转载请注明链接:http://www.cnblogs.com/SSSR/p/5630534.html

tflearn中的例子训练vgg16项目:https://github.com/tflearn/tflearn/blob/master/examples/images/vgg_network.py 尚未测试成功。

下面的项目是使用别人已经训练好的模型进行预测,测试效果非常好。

github:https://github.com/ry/tensorflow-vgg16 此项目已经测试成功,效果非常好,

如果在Ubuntu中的terminal中运行出现问题,可以参照以下部分解决(解决skimage读取图片的问题)。

#coding:utf-8


import skimage
import skimage.io
import skimage.transform
a=skimage.io.imread(\'cat.jpg\')
import PIL
import numpy as np
import tensorflow as tf
synset = [l.strip() for l in open(\'/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/synset.txt\').readlines()]

def load_image(path):
  # load image
  img = skimage.io.imread(path)
  #img1=PIL.Image.open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/pic/pig.jpg")
  #img=np.array(PIL.Image.open(path))
  #imgx=np.array(img)  
  #print type(imgx),imgx.shape
  img = img/ 255.0
  assert (0 <= img).all() and (img <= 1.0).all()
  #print "Original Image Shape: ", img.shape
  # we crop image from center
  short_edge = min(img.shape[:2])
  yy = int((img.shape[0] - short_edge) / 2)
  xx = int((img.shape[1] - short_edge) / 2)
  crop_img = img[yy : yy + short_edge, xx : xx + short_edge]
  # resize to 224, 224
  resized_img = skimage.transform.resize(crop_img, (224, 224))
  return resized_img
  
# returns the top1 string
def print_prob(prob):
  #print prob
  print "prob shape", prob.shape
  pred = np.argsort(prob)[::-1]
  # Get top1 label
  top1 = synset[pred[0]]
  #print "Top1: ", top1
  # Get top5 label
  top5 = [synset[pred[i]] for i in range(5)]
  #print "Top5: ", top5
  return top1

print u\'加载模型文件\'
with open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/vgg16.tfmodel", mode=\'rb\') as f:
  fileContent = f.read()
  
print u\'创建图\'
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)

images = tf.placeholder("float", [None, 224, 224, 3])

tf.import_graph_def(graph_def, input_map={ "images": images })
print "graph loaded from disk"

graph = tf.get_default_graph()
print u\'加载图片\'
#img=np.array(PIL.Image.open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/pic/pig.jpg"))
#cat = load_image(path)
print u\'进入sess执行\'

sess=tf.Session()
result=[]
for i in [\'cat.jpg\',\'airplane.jpg\',\'zebra.jpg\',\'pig.jpg\',\'12.jpg\',\'23.jpg\']:
  img=load_image(\'pic/\'+i)
  init = tf.initialize_all_variables()
  sess.run(init)
  print "variables initialized"
  batch = img.reshape((1, 224, 224, 3))
  assert batch.shape == (1, 224, 224, 3)
  feed_dict = { images: batch }
  print u\'开始执行\'
  prob_tensor = graph.get_tensor_by_name("import/prob:0")
  prob = sess.run(prob_tensor, feed_dict=feed_dict)
  print u\'输出结果\'
  #print_prob(prob[0])
  result.append(print_prob(prob[0]))


print result
sess.close()


\'\'\'
with tf.Session() as sess:
  init = tf.initialize_all_variables()
  sess.run(init)
  print "variables initialized"
  batch = cat.reshape((1, 224, 224, 3))
  assert batch.shape == (1, 224, 224, 3)
  feed_dict = { images: batch }
  print u\'开始执行\'
  prob_tensor = graph.get_tensor_by_name("import/prob:0")
  prob = sess.run(prob_tensor, feed_dict=feed_dict)

print u\'输出结果\'
print_prob(prob[0])

\'\'\'

  

以上是关于tensorflow中关于vgg16的项目的主要内容,如果未能解决你的问题,请参考以下文章

tensorflow VGG16 网络精度和损失不会改变

VGG-16复现

用于 VGG19 模型参数的 Tensorflow Float16

神经网络学习小记录61——Tensorflow2 搭建常见分类网络平台(VGG16MobileNetResNet50)

『TensorFlow』项目资源分享

当在 tensorflow 1.14 中使用混合精度训练时,张量对象在 keras vgg16 中没有属性“is_initialized”