你如何在 Tensorflow 中解码 one-hot 标签?
Posted
技术标签:
【中文标题】你如何在 Tensorflow 中解码 one-hot 标签?【英文标题】:How do you decode one-hot labels in Tensorflow? 【发布时间】:2017-05-14 22:42:41 【问题描述】:一直在寻找,但似乎找不到任何关于如何从 TensorFlow 中的 one-hot 值解码或转换回单个整数的示例。
我使用了tf.one_hot
并且能够训练我的模型,但是对于如何在分类后理解标签有点困惑。我的数据是通过我创建的TFRecords
文件输入的。我想过在文件中存储一个文本标签,但无法让它工作。似乎TFRecords
无法存储文本字符串,或者我弄错了。
【问题讨论】:
嘿@Matt,答案能解决你的问题吗? 【参考方案1】:您可以使用tf.argmax
找出矩阵中最大元素的索引。由于您的一个热向量将是一维的,并且只有一个 1
和其他 0
s,因此假设您正在处理单个向量,这将起作用。
index = tf.argmax(one_hot_vector, axis=0)
对于更标准的batch_size * num_classes
矩阵,使用axis=1
得到大小为batch_size * 1
的结果。
【讨论】:
【参考方案2】:由于 one-hot 编码通常只是一个具有batch_size
行和num_classes
列的矩阵,并且每一行都为零,并且有一个与所选类对应的非零,您可以使用tf.argmax()
来恢复整数标签向量:
BATCH_SIZE = 3
NUM_CLASSES = 4
one_hot_encoded = tf.constant([[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, 1]])
# Compute the argmax across the columns.
decoded = tf.argmax(one_hot_encoded, axis=1)
# ...
print sess.run(decoded) # ==> array([1, 0, 3])
【讨论】:
OP 似乎只使用了一个向量,因为他提到他想要一个来自单热值的 单个整数【参考方案3】:data = np.array([1, 5, 3, 8])
print(data)
def encode(data):
print('Shape of data (BEFORE encode): %s' % str(data.shape))
encoded = to_categorical(data)
print('Shape of data (AFTER encode): %s\n' % str(encoded.shape))
return encoded
encoded_data = encode(data)
print(encoded_data)
def decode(datum):
return np.argmax(datum)
decoded_Y = []
print("****************************************")
for i in range(encoded_data.shape[0]):
datum = encoded_data[i]
print('index: %d' % i)
print('encoded datum: %s' % datum)
decoded_datum = decode(encoded_data[i])
print('decoded datum: %s' % decoded_datum)
decoded_Y.append(decoded_datum)
print("****************************************")
print(decoded_Y)
【讨论】:
【参考方案4】:
tf.argmax
已折旧(因此,此页面上答案中的所有链接都是 404),现在应该使用tf.math.argmax
.
用法:
import tensorflow as tf
a = [1, 10, 26.9, 2.8, 166.32, 62.3]
b = tf.math.argmax(input = a)
c = tf.keras.backend.eval(b)
# c = 4
# here a[4] = 166.32 which is the largest element of a across axis 0
注意:您也可以使用numpy 执行此操作。
【讨论】:
以上是关于你如何在 Tensorflow 中解码 one-hot 标签?的主要内容,如果未能解决你的问题,请参考以下文章
使用 Tensorflow 数据集解码 RLE(运行长度编码)掩码