TensorFlow 的混淆矩阵
Posted
技术标签:
【中文标题】TensorFlow 的混淆矩阵【英文标题】:Confusion Matrix with Tensorflow 【发布时间】:2018-09-26 18:02:45 【问题描述】:我在我自己的数据集上使用@kratzert 编写的finetune AlexNet 架构,它工作正常(我从这里得到代码:https://github.com/kratzert/finetune_alexnet_with_tensorflow),我想弄清楚如何从他的代码中构建混淆矩阵。我曾尝试使用tf.confusion_matrix(labels, predictions, num_classes)
来构建混淆矩阵,但我不能。我很困惑标签和预测的值应该是什么,我的意思是,我知道应该是什么,但是每次输入这些值时都会出错。任何人都可以帮助我或查看代码(以上链接)并指导我吗?
我在计算准确度之后在finetune.py中添加了这两行,以使标签和预测作为类的数量。
with tf.name_scope("accuracy"):
correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
**true_class = tf.argmax(y, 1)
predicted_class = tf.argmax(score, 1)**
在保存模型检查点之前,我在会话的最底部添加了tf.confusion_matrix()
for _ in range(val_batches_per_epoch):
img_batch, label_batch = sess.run(next_batch)
acc, cost = sess.run([accuracy, loss], feed_dict=x: img_batch,
y: label_batch,
keep_prob: 1.)
test_acc += acc
test_count += 1
test_acc /= test_count
print(" Validation Accuracy = :.4f -- Validation Loss = :.4f".format(datetime.now(),test_acc, cost))
print(" Saving checkpoint of model...".format(datetime.now()))
**print(sess.run(tf.confusion_matrix(true_class, predicted_class, num_classes)))**
# save checkpoint of the model
checkpoint_name = os.path.join(checkpoint_path,
'model_epoch'+str(epoch+1)+'.ckpt')
save_path = saver.save(sess, checkpoint_name)
print(" Model checkpoint saved at ".format(datetime.now(),
checkpoint_name))
我也尝试过其他地方,但每次都会出错:
Caused by op 'Placeholder_1', defined at:
File "/home/armin/Desktop/Alexnet_DataPipeline/finetune.py", line 85, in <module>
y = tf.placeholder(tf.float32, [batch_size, num_classes])
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/array_ops.py", line 1777, in placeholder
return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 4521, in placeholder
"Placeholder", dtype=dtype, shape=shape, name=name)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3290, in create_op
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1654, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder_1' with dtype float and shape [128,3]
任何帮助将不胜感激,谢谢。
【问题讨论】:
你能贴出你的代码和错误(重要的部分,不是完整的代码)吗? 我添加了部分代码和我添加的行来计算混淆矩阵和我的错误 【参考方案1】:您指的是一段相当长的代码,并且您没有指定将混淆矩阵行放在哪里。
根据经验,混淆矩阵最常见的问题是tf.confusion_matrix()
要求标签和预测都作为类的数量,而不是单热向量。换句话说,标签和预测应该是数字5
的形式,而不是[0,0,0,0,0,1,0,0,0,0]。
在您引用的代码中,y
是 one-hot 格式。网络的输出,score
是一个向量,给出了每个类的概率。这也不是必需的格式。你需要做类似的事情
true_class = tf.argmax( y, 1 )
predicted_class = tf.argmax( score, 1 )
并使用带有混淆矩阵的那些
tf.confusion_matrix( true_class, predicted_class, num_classes )
(基本上,如果您查看finetune.py 的第 123 行,这两个元素都用于确定准确性,但它们没有保存在单独的张量中。)
如果您想保持所有批次的混淆矩阵的总和,您只需将它们相加 - 因为矩阵的每个单元格都会计算属于该类别的示例数量,因此逐元素添加会产生混淆整个集合的矩阵:
cm_running_total = None
cm_nupmy_array = sess.run(tf.confusion_matrix(true_class, predicted_class, num_classes), feed_dict=x: img_batch, y: label_batch, keep_prob: 1. )
if cm_running_total is None:
cm_running_total = cm_numpy_array
else:
cm_running_total += cm_numpy_array
【讨论】:
谢谢,我完全按照你之前说的做了,在这里写下了我的问题。我编辑了我的帖子,请看一下 任何想法,我的错误在哪里? 抱歉有段时间了。该错误表明您需要将feed_dict
提供给您计算混淆矩阵的sess.run()
。所以这样做:print(sess.run(tf.confusion_matrix(true_class, predicted_class, num_classes), feed_dict=x: img_batch, y: label_batch, keep_prob: 1.) ))
。或者,您可以在训练循环中包含每个批次的混淆矩阵,例如 acc, cost, conf_m = sess.run([accuracy, loss, tf.confusion_matrix( true_class, predicted_class, num_classes )], feed_dict=x: img_batch, y: label_batch, keep_prob: 1.)
。
非常感谢,你很棒。现在,我通过在打印验证准确性后添加此行 print(sess.run(tf.confusion_matrix(true_class, predicted_class, num_classes), feed_dict=x: img_batch, y: label_batch, keep_prob: 1.) )
来为每个时代创建混淆矩阵。我的意思是,我在每个时期都有一个 128(批量大小)的混淆矩阵。我没有整个模型的混淆矩阵。我想我必须找到累积它们的方法并获得包含所有内容的最终混淆矩阵。你能帮我解决这个问题吗?
在我的答案中添加了运行总计的建议代码。以上是关于TensorFlow 的混淆矩阵的主要内容,如果未能解决你的问题,请参考以下文章
如何使用 Tensorflow 创建预测标签和真实标签的混淆矩阵?