如何在 Tensorflow 上测试自己的图像到 Cifar-10 教程?
Posted
技术标签:
【中文标题】如何在 Tensorflow 上测试自己的图像到 Cifar-10 教程?【英文标题】:How can I test own image to Cifar-10 tutorial on Tensorflow? 【发布时间】:2017-02-12 09:38:39 【问题描述】:我训练了 Tensorflow Cifar10 模型,我想用自己的单张图像(32*32,jpg/png)喂它。
我想查看每个标签的标签和概率作为输出,但我对此有些麻烦..
搜索堆栈溢出后,我发现了一些帖子,this,我修改了 cifar10_eval.py。
但它根本不起作用。
错误信息是:
InvalidArgumentErrorTraceback(最近一次调用最后一次) 在 () ----> 1个评估()
在评估() 86 # 从检查点恢复 87 打印(“ckpt.model_checkpoint_path”,ckpt.model_checkpoint_path) ---> 88 saver.restore(sess, ckpt.model_checkpoint_path) 89 # 假设 model_checkpoint_path 看起来像: 90 # /my-favorite-path/cifar10_train/model.ckpt-0,
/home/huray/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/training/saver.pyc 在恢复(自我,sess,save_path)1127提高 ValueError("使用无效的保存路径 %s 调用恢复" % save_path) 第1128章 -> 1129 self.saver_def.filename_tensor_name: save_path) 1130 1131 @staticmethod
/home/huray/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc 在运行中(self,fetches,feed_dict,options,run_metadata) 380尝试: 第381章 --> 382 run_metadata_ptr) 383 如果运行元数据: 384 proto_data = tf_session.TF_GetBuffer(运行元数据ptr)
/home/huray/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc 在 _run(self、handle、fetches、feed_dict、options、run_metadata) 第653章 654 个结果 = self._do_run(handle, target_list, unique_fetches, --> 655 feed_dict_string,选项,run_metadata) 656 657 # 用户可能多次获取同一个张量,但我们
/home/huray/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc 在 _do_run(self, handle, target_list, fetch_list, feed_dict, options, 运行元数据) 721 如果句柄为无: 第722章 --> 723 目标列表,选项,运行元数据) 724 其他: 第725章
/home/huray/anaconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc 在 _do_call(self, fn, *args) 741,除了 KeyError: 742通过 --> 743 raise type(e)(node_def, op, message) 744 第745章
InvalidArgumentError:Assign 需要两个张量的形状才能匹配。 lhs shape= [18,384] rhs shape= [2304,384] [[Node: save/Assign_5 = 赋值[T=DT_FLOAT, _class=["loc:@local3/weights"], use_locking=true, 验证形状=真, _device="/job:localhost/replica:0/task:0/cpu:0"](local3/weights, save/restore_slice_5)]]
我们将不胜感激任何对 Cifar10 的帮助。
这是目前已实现的代码,但存在编译问题:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import math
import time
import numpy as np
import tensorflow as tf
import cifar10
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('eval_dir', '/tmp/cifar10_eval',
"""Directory where to write event logs.""")
tf.app.flags.DEFINE_string('eval_data', 'test',
"""Either 'test' or 'train_eval'.""")
tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train',
"""Directory where to read model checkpoints.""")
tf.app.flags.DEFINE_integer('eval_interval_secs', 5,
"""How often to run the eval.""")
tf.app.flags.DEFINE_integer('num_examples', 1,
"""Number of examples to run.""")
tf.app.flags.DEFINE_boolean('run_once', False,
"""Whether to run eval only once.""")
def eval_once(saver, summary_writer, top_k_op, summary_op):
"""Run Eval once.
Args:
saver: Saver.
summary_writer: Summary writer.
top_k_op: Top K op.
summary_op: Summary op.
"""
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
# Assuming model_checkpoint_path looks something like:
# /my-favorite-path/cifar10_train/model.ckpt-0,
# extract global_step from it.
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
else:
print('No checkpoint file found')
return
print("Check point : %s" % ckpt.model_checkpoint_path)
# Start the queue runners.
coord = tf.train.Coordinator()
try:
threads = []
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
start=True))
num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
true_count = 0 # Counts the number of correct predictions.
total_sample_count = num_iter * FLAGS.batch_size
step = 0
while step < num_iter and not coord.should_stop():
predictions = sess.run([top_k_op])
true_count += np.sum(predictions)
step += 1
# Compute precision @ 1.
precision = true_count / total_sample_count
print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
summary = tf.Summary()
summary.ParseFromString(sess.run(summary_op))
summary.value.add(tag='Precision @ 1', simple_value=precision)
summary_writer.add_summary(summary, global_step)
except Exception as e: # pylint: disable=broad-except
coord.request_stop(e)
coord.request_stop()
coord.join(threads, stop_grace_period_secs=10)
def evaluate():
"""Eval CIFAR-10 for a number of steps."""
with tf.Graph().as_default() as g:
# Get images and labels for CIFAR-10.
eval_data = FLAGS.eval_data == 'test'
# images, labels = cifar10.inputs(eval_data=eval_data)
# TEST CODE
img_path = "/TEST_IMAGEPATH/image.png"
input_img = tf.image.decode_png(tf.read_file(img_path), channels=3)
casted_image = tf.cast(input_img, tf.float32)
reshaped_image = tf.image.resize_image_with_crop_or_pad(casted_image, 24, 24)
float_image = tf.image.per_image_withening(reshaped_image)
images = tf.expand_dims(reshaped_image, 0)
logits = cifar10.inference(images)
_, top_k_pred = tf.nn.top_k(logits, k=1)
with tf.Session() as sess:
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
print("ckpt.model_checkpoint_path ", ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
else:
print('No checkpoint file found')
return
print("Check point : %s" % ckpt.model_checkpoint_path)
top_indices = sess.run([top_k_pred])
print ("Predicted ", top_indices[0], " for your input image.")
evaluate()
【问题讨论】:
【参考方案1】:视频https://youtu.be/d9mSWqfo0Xw 展示了对单个图像进行分类的示例。
在网络已经通过 python cifar10_train.py 训练后,我们评估 CIFAR-10 数据库的单个图像 deer6.png 和自己的火柴盒照片。 TF教程原源代码最重要的修改如下:
首先需要将这些图像转换为 cifar10_input.py 可以读取的二进制形式。可以使用How to create dataset similar to cifar-10中的代码 sn-p 轻松完成
然后为了读取转换后的图像(称为 input.bin),我们需要修改 cifar10_input.py 中的函数 input():
else:
#filenames = [os.path.join(data_dir, 'test_batch.bin')]
filenames = [os.path.join(data_dir, 'input.bin')]
(data_dir 等于'./')
最后为了得到标签,我们修改了源 cifar10_eval.py 中的函数 eval_once():
#while step < num_iter and not coord.should_stop():
# predictions = sess.run([top_k_op])
print(sess.run(logits[0]))
classification = sess.run(tf.argmax(logits[0], 0))
cifar10classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
print(cifar10classes[classification])
#true_count += np.sum(predictions)
step += 1
# Compute precision @ 1.
precision = true_count / total_sample_count
# print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
当然,您还需要进行一些小的修改。
【讨论】:
以上是关于如何在 Tensorflow 上测试自己的图像到 Cifar-10 教程?的主要内容,如果未能解决你的问题,请参考以下文章
如何利用TensorFlow.js部署简单的AI版「你画我猜」图像识别应用
教程 | 如何利用TensorFlow.js部署简单的AI版「你画我猜」图像识别应用
如何使用队列方法(没有 feed_dict)#tensorflow 在保存的模型上使用测试数据?