使用 one-hot 代码的 TensorFlow 混淆矩阵
Posted
技术标签:
【中文标题】使用 one-hot 代码的 TensorFlow 混淆矩阵【英文标题】:Tensorflow confusion matrix using one-hot code 【发布时间】:2018-03-30 08:25:30 【问题描述】:我使用 RNN 进行多类分类,这是我的 RNN 主要代码:
def RNN(x, weights, biases):
x = tf.unstack(x, input_size, 1)
lstm_cell = rnn.BasicLSTMCell(num_unit, forget_bias=1.0, state_is_tuple=True)
stacked_lstm = rnn.MultiRNNCell([lstm_cell]*lstm_size, state_is_tuple=True)
outputs, states = tf.nn.static_rnn(stacked_lstm, x, dtype=tf.float32)
return tf.matmul(outputs[-1], weights) + biases
logits = RNN(X, weights, biases)
prediction = tf.nn.softmax(logits)
cost =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(cost)
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
我必须将所有输入分类为 6 个类,每个类都由 one-hot 代码标签组成,如下所示:
happy = [1, 0, 0, 0, 0, 0]
angry = [0, 1, 0, 0, 0, 0]
neutral = [0, 0, 1, 0, 0, 0]
excited = [0, 0, 0, 1, 0, 0]
embarrassed = [0, 0, 0, 0, 1, 0]
sad = [0, 0, 0, 0, 0, 1]
问题是我无法使用tf.confusion_matrix()
函数打印混淆矩阵。
有没有办法使用这些标签打印混淆矩阵?
如果不是,如何仅在需要打印混淆矩阵时将 one-hot 代码转换为整数索引?
【问题讨论】:
【参考方案1】:
您不能使用 one-hot 向量作为 labels
和 predictions
的输入参数来生成混淆矩阵。您必须直接为其提供包含标签的一维张量。
要将您的一个热向量转换为普通标签,请使用argmax 函数:
label = tf.argmax(one_hot_tensor, axis = 1)
之后你可以像这样打印你的confusion_matrix
:
import tensorflow as tf
num_classes = 2
prediction_arr = tf.constant([1, 1, 1, 1, 0, 0, 0, 0, 1, 1])
labels_arr = tf.constant([0, 1, 1, 1, 1, 1, 1, 1, 0, 0])
confusion_matrix = tf.confusion_matrix(labels_arr, prediction_arr, num_classes)
with tf.Session() as sess:
print(confusion_matrix.eval())
输出:
[[0 3]
[4 3]]
【讨论】:
以上是关于使用 one-hot 代码的 TensorFlow 混淆矩阵的主要内容,如果未能解决你的问题,请参考以下文章
你如何在 Tensorflow 中解码 one-hot 标签?
TensorFlow实现one-hot编码TensorFlow2入门手册
如何从 PNG 为 Tensorflow 2 中的每个像素分类创建 One-hot 编码矩阵