tensorflow中关于vgg16的项目
Posted 徐长卿学数据分析
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了tensorflow中关于vgg16的项目相关的知识,希望对你有一定的参考价值。
转载请注明链接:http://www.cnblogs.com/SSSR/p/5630534.html
tflearn中的例子训练vgg16项目:https://github.com/tflearn/tflearn/blob/master/examples/images/vgg_network.py 尚未测试成功。
下面的项目是使用别人已经训练好的模型进行预测,测试效果非常好。
github:https://github.com/ry/tensorflow-vgg16 此项目已经测试成功,效果非常好,
如果在Ubuntu中的terminal中运行出现问题,可以参照以下部分解决(解决skimage读取图片的问题)。
#coding:utf-8
import skimage
import skimage.io
import skimage.transform
a=skimage.io.imread(\'cat.jpg\')
import PIL
import numpy as np
import tensorflow as tf
synset = [l.strip() for l in open(\'/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/synset.txt\').readlines()]
def load_image(path):
# load image
img = skimage.io.imread(path)
#img1=PIL.Image.open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/pic/pig.jpg")
#img=np.array(PIL.Image.open(path))
#imgx=np.array(img)
#print type(imgx),imgx.shape
img = img/ 255.0
assert (0 <= img).all() and (img <= 1.0).all()
#print "Original Image Shape: ", img.shape
# we crop image from center
short_edge = min(img.shape[:2])
yy = int((img.shape[0] - short_edge) / 2)
xx = int((img.shape[1] - short_edge) / 2)
crop_img = img[yy : yy + short_edge, xx : xx + short_edge]
# resize to 224, 224
resized_img = skimage.transform.resize(crop_img, (224, 224))
return resized_img
# returns the top1 string
def print_prob(prob):
#print prob
print "prob shape", prob.shape
pred = np.argsort(prob)[::-1]
# Get top1 label
top1 = synset[pred[0]]
#print "Top1: ", top1
# Get top5 label
top5 = [synset[pred[i]] for i in range(5)]
#print "Top5: ", top5
return top1
print u\'加载模型文件\'
with open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/vgg16.tfmodel", mode=\'rb\') as f:
fileContent = f.read()
print u\'创建图\'
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
images = tf.placeholder("float", [None, 224, 224, 3])
tf.import_graph_def(graph_def, input_map={ "images": images })
print "graph loaded from disk"
graph = tf.get_default_graph()
print u\'加载图片\'
#img=np.array(PIL.Image.open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/pic/pig.jpg"))
#cat = load_image(path)
print u\'进入sess执行\'
sess=tf.Session()
result=[]
for i in [\'cat.jpg\',\'airplane.jpg\',\'zebra.jpg\',\'pig.jpg\',\'12.jpg\',\'23.jpg\']:
img=load_image(\'pic/\'+i)
init = tf.initialize_all_variables()
sess.run(init)
print "variables initialized"
batch = img.reshape((1, 224, 224, 3))
assert batch.shape == (1, 224, 224, 3)
feed_dict = { images: batch }
print u\'开始执行\'
prob_tensor = graph.get_tensor_by_name("import/prob:0")
prob = sess.run(prob_tensor, feed_dict=feed_dict)
print u\'输出结果\'
#print_prob(prob[0])
result.append(print_prob(prob[0]))
print result
sess.close()
\'\'\'
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
print "variables initialized"
batch = cat.reshape((1, 224, 224, 3))
assert batch.shape == (1, 224, 224, 3)
feed_dict = { images: batch }
print u\'开始执行\'
prob_tensor = graph.get_tensor_by_name("import/prob:0")
prob = sess.run(prob_tensor, feed_dict=feed_dict)
print u\'输出结果\'
print_prob(prob[0])
\'\'\'
以上是关于tensorflow中关于vgg16的项目的主要内容,如果未能解决你的问题,请参考以下文章
用于 VGG19 模型参数的 Tensorflow Float16
神经网络学习小记录61——Tensorflow2 搭建常见分类网络平台(VGG16MobileNetResNet50)
当在 tensorflow 1.14 中使用混合精度训练时,张量对象在 keras vgg16 中没有属性“is_initialized”