您必须为 MNIST 数据集的 dtype float 和 shape [?,784] 提供占位符张量“Placeholder”的值
Posted
技术标签:
【中文标题】您必须为 MNIST 数据集的 dtype float 和 shape [?,784] 提供占位符张量“Placeholder”的值【英文标题】:You must feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,784] for MNIST dataset 【发布时间】:2018-05-09 13:57:51 【问题描述】:这是我在 MNIST 数据集上测试以进行量化的示例。我正在使用以下代码测试我的模型:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework import graph_util
from tensorflow.core.framework import graph_pb2
import numpy as np
def test_model(model_file,x_in):
with tf.Session() as sess:
with open(model_file, "rb") as f:
output_graph_def = graph_pb2.GraphDef()
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
x = sess.graph.get_tensor_by_name('Placeholder_1:0')
y = sess.graph.get_tensor_by_name('softmax_cross_entropy_with_logits:0')
new_scores = sess.run(y, feed_dict=x:x_in.test.images)
print((orig_scores - new_scores) < 1e-6)
find_top_pred(orig_scores)
find_top_pred(new_scores)
#print(epoch_x.shape)
mnist = input_data.read_data_sets("/tmp/data/", one_hot = True)
test_model('mnist_cnn1.pb',mnist)
我没有得到我提供错误值的地方。这里我添加了错误代码的完整轨迹。以下是错误:
Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1323, in _do_call
return fn(*args)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1302, in _run_fn
status, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,784]
[[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[?,784], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
在处理上述异常的过程中,又发生了一个异常:
Traceback (most recent call last):
File "tmp.py", line 26, in <module>
test_model('/home/shringa/tensorflowdata/mnist_cnn1.pb',mnist)
File "tmp.py", line 19, in test_model
new_scores = sess.run(y, feed_dict=x:x_in.test.images)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 889, in run
run_metadata_ptr)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1120, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1317, in _do_run
options, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1336, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,784]
[[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[?,784], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
Caused by op 'Placeholder', defined at:
File "tmp.py", line 26, in <module>
test_model('/home/shringa/tensorflowdata/mnist_cnn1.pb',mnist)
File "tmp.py", line 16, in test_model
_ = tf.import_graph_def(output_graph_def, name="")
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/deprecation.py", line 316, in new_func
return func(*args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/importer.py", line 411, in import_graph_def
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3069, in create_op
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1579, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,784]
[[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[?,784], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
如上所示,我正在使用mnist_cnn1.pb
文件来提取我的模型并在 mnist 测试图像上对其进行测试,但它会引发占位符形状错误。
下图是我的 cnn 模型:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot = True)
print(type(mnist));
n_classes = 10
batch_size = 128
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32)
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding= 'SAME')
def maxpool2d(x):
# size of window movement of window
return tf.nn.max_pool(x, ksize =[1,2,2,1], strides= [1,2,2,1], padding = 'SAME')
def convolutional_network_model(x):
weights = 'W_conv1':tf.Variable(tf.random_normal([5,5,1,32])),
'W_conv2':tf.Variable(tf.random_normal([5,5,32,64])),
'W_fc':tf.Variable(tf.random_normal([7*7*64,1024])),
'out':tf.Variable(tf.random_normal([1024, n_classes]))
biases = 'B_conv1':tf.Variable(tf.random_normal([32])),
'B_conv2':tf.Variable(tf.random_normal([64])),
'B_fc':tf.Variable(tf.random_normal([1024])),
'out':tf.Variable(tf.random_normal([n_classes]))
x = tf.reshape(x, shape=[-1,28,28,1])
conv1 = conv2d(x, weights['W_conv1'])
conv1 = maxpool2d(conv1)
conv2 = conv2d(conv1, weights['W_conv2'])
conv2 = maxpool2d(conv2)
fc =tf.reshape(conv2,[-1,7*7*64])
fc = tf.nn.relu(tf.matmul(fc, weights['W_fc'])+ biases['B_fc'])
output = tf.matmul(fc, weights['out']+biases['out'])
return output
def train_neural_network(x):
prediction = convolutional_network_model(x)
# OLD VERSION:
#cost = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(prediction,y) )
# NEW:
cost = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(logits=prediction, labels=y) )
optimizer = tf.train.AdamOptimizer().minimize(cost)
hm_epochs = 25
with tf.Session() as sess:
# OLD:
#sess.run(tf.initialize_all_variables())
# NEW:
sess.run(tf.global_variables_initializer())
for epoch in range(hm_epochs):
epoch_loss = 0
for _ in range(int(mnist.train.num_examples/batch_size)):
epoch_x, epoch_y = mnist.train.next_batch(batch_size)
_, c = sess.run([optimizer, cost], feed_dict=x: epoch_x, y: epoch_y)
epoch_loss += c
print('Epoch', epoch, 'completed out of',hm_epochs,'loss:',epoch_loss)
correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
print('Accuracy:',accuracy.eval(x:mnist.test.images, y:mnist.test.labels))
train_neural_network(x)
通过使用 bazel,我创建了 mnist_cnn1.pb
文件:
python3 tensorflow/tools/quantization/quantize_graph.py --input=/home/shringa/tensorflowdata/mnist_cnn.pb --output=/home/shringa/tensorflowdata/mnist_cnn1.pb --output_node_names=softmax_cross_entropy_with_logits --mode=eightbit
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/home/shringa/tensorflowdata/mnist_cnn1.pb
【问题讨论】:
请包含整个错误回溯。 @Stephen 我已经添加了完整的错误回溯。 你从哪里得到mnist_cnn1.pb
?如果你正在创造它,你是怎么做的?此外,在您致电get_tensor_by_name
时,您如何知道要使用哪些名称?如果这是来自教程,链接到它会很有用。
我已经粘贴了我的 CNN 模型以及我如何生成 PB 文件,从使用上面的代码我可以打包获取 get_tensor_by_name 参数。
你找到解决方案了吗?
【参考方案1】:
原因
问题的原因是您没有为变量/节点命名,因此感到困惑。
当你定义占位符时:
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32)
...x
和 y
被 tensorflow 分配了以下名称:
Tensor("Placeholder:0", shape=(?, 784), dtype=float32) <-- x
Tensor("Placeholder_1:0", dtype=float32) <-- y
因此,在测试时,以下行拉错了节点:
x = sess.graph.get_tensor_by_name('Placeholder_1:0') # this is y!
这就是 tensorflow 抱怨不提供占位符的原因:它需要 x
,而不是 y
。
解决方案
明确说明:
x = tf.placeholder(tf.float32, [None, 784], name='x')
y = tf.placeholder(tf.float32, name='y')
...
x = sess.graph.get_tensor_by_name('x')
我还将为softmax_cross_entropy_with_logits
op 提供名称,以使所有推理节点都易于访问。
【讨论】:
这给了我类似ValueError: The name 'x' refers to an Operation, not a Tensor. Tensor names must be of the form "<op_name>:<output_index>"
的错误。我使用x = graph.get_tensor_by_name('x:0')
修复了它以上是关于您必须为 MNIST 数据集的 dtype float 和 shape [?,784] 提供占位符张量“Placeholder”的值的主要内容,如果未能解决你的问题,请参考以下文章
Tensorflow - 您必须使用 dtype float 为占位符张量“X”提供一个值
Tensorflow 错误:InvalidArgumentError:您必须使用 dtype float 和 shape[?:784]] 为占位符张量“Placeholder”提供一个值