tf.nn.in_top_k:目标超出范围

Posted

技术标签:

【中文标题】tf.nn.in_top_k:目标超出范围【英文标题】:tf.nn.in_top_k: targets out of range 【发布时间】:2016-06-02 09:18:58 【问题描述】:

我从 tensorflow 改编了 cifar10 网络,以解决我自己的分类问题。我已经训练了网络,现在我尝试使用 cifar10_eval.py 评估训练后的模型

top_k_op = tf.nn.in_top_k(logits, labels, 1)

但我收到以下错误。经过进一步调查,目标指数在2,3和4之间变化

tensorflow.python.framework.errors.InvalidArgumentError: targets[3] is out of range

到目前为止,我知道我的标签张量有问题。它是一个 int32-Tensor,其 shape(50,) 如下所示。

labels = Tensor Tensor("batch_processing/Reshape_1:0", shape=(50,), dtype=int32, device=/device:CPU:0)

我的数据集只有 2 个类/标签。也许这可能是问题所在。有谁知道是什么问题?

【问题讨论】:

logits的形状是什么? logits的形状是shape(50,2) 确保标签只包含0和1 感谢您的建议....关于文档,top_k_op 张量具有正确的类型(布尔)和大小(50)。正如你所提到的,我怀疑标签张量包含的数字多于 0 和 1。但目前我正在努力调试标签张量。我看不到张量的值.. 我在回答中总结了 【参考方案1】:

综上所述,函数tf.nn.in_top_k(predictions, targets, k)(见doc)有参数:

预测:形状[batch_size, num_classes],类型float32 目标(正确的标签):形状[batch_size],输入int32或int64

当元素targets[i] 超出predictions[i] 的范围时,函数会引发错误InvalidArgumentError: targets[i] is out of range

例如,有 2 个类 (num_classes=2) 和 targets=[1, 3]。 使用这些目标,您将看到错误 InvalidArgumentError: targets[1] is out of range,因为 targets[1] = 3 超出了只有形状 2 的 predictions[1] 的范围。


要检查您的labels 是否正确,您可以打印它们的最大值:

labels = ...
labels_max = tf.reduce_max(labels)

sess = tf.Session()
print sess.run(labels_max)

如果打印的值优于num_classes,你就有问题了。

【讨论】:

【参考方案2】:

因此,如果您希望以一种热编码之类的方式进行预测,那么您的目标必须是放置一 (1) 个热编码的正确索引。 例如:

bb=tf.nn.in_top_k([[0,1],[1,0],[0,1]]  ,   [1,1,1],1)

将返回:

[真假真]

所以要回答,你必须将你可能的一个热门目标转换为这个索引方法

麻木:

targetsindex = np.argmax(targets, axis=1)

张量:

targetsindex = tf.argmax(targets, axis=0)

【讨论】:

以上是关于tf.nn.in_top_k:目标超出范围的主要内容,如果未能解决你的问题,请参考以下文章

循环向量不起作用::向量下标超出范围[重复]

SwiftUI:关闭视图导致索引超出范围错误

变量超出范围不起作用

linux进不了,提示信号超出范围。

IndexError:在pyspark shell上使用reduceByKey操作时列出索引超出范围

简单python for循环上的列表索引超出范围错误