Tensorflow - 断言失败:[预测必须在 [0, 1]]
Posted
技术标签:
【中文标题】Tensorflow - 断言失败:[预测必须在 [0, 1]]【英文标题】:Tensorflow - assertion failed: [predictions must be in [0, 1]] 【发布时间】:2019-03-23 11:44:49 【问题描述】:我正在使用 Tensorflow 的 Estimator
API,我遇到了以下问题。我想检查除准确性之外的 f1 分数,当我在训练后评估时完全没有问题,当我测试时它要求我已经标准化的标准化值。
这是我的网络模型(第一部分省略):
#### architecture omitted #####
predictions =
"classes": tf.argmax(input=logits, axis=1),
"probabilities": tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.cast(labels, tf.float32), logits=tf.cast(logits, tf.float32), name="sigmoid_tensor")
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=labels, logits=logits)
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
#optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.96)
train_op = optimizer.minimize(
loss=loss,
global_step=tf.train.get_global_step())
logging_hook = tf.train.LoggingTensorHook("loss" : loss, every_n_iter=10)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, training_hooks = [logging_hook])
eval_metric_ops =
"accuracy": tf.metrics.accuracy(
labels=tf.argmax(input=labels, axis=1),
predictions=predictions["classes"]),
"f1 score" : tf.contrib.metrics.f1_score(
labels = tf.argmax(input=labels, axis=1),
predictions = tf.cast(predictions["classes"],tf.float32)/tf.norm(tf.cast(predictions["classes"], tf.float32)))
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
这是我正在使用的训练脚本,我在评估时没有任何问题(这里没有错误)。
classifier = tf.estimator.Estimator(model_fn=instrument_recognition_model, model_dir=saved_model_path)
train_input_fn = tf.estimator.inputs.numpy_input_fn(x=X_train, y=y_train, batch_size=16, num_epochs=30, shuffle=True)
classifier.train(input_fn=train_input_fn)
# Evaluate results on the training set
eval_input_fn = tf.estimator.inputs.numpy_input_fn(x=X_eval,y=y_eval,num_epochs=1,shuffle=False)
eval_results = classifier.evaluate(input_fn=eval_input_fn)
print(eval_results)
这是我的测试脚本,这里程序失败:
classifier = tf.estimator.Estimator(model_fn=instrument_recognition_model, model_dir=saved_model_path)
# Keep only certain samples
key_indices = [np.where(instruments == x)[0][0] for x in keys]
example_indices = np.array([])
for ind in key_indices:
tmp = np.argwhere(labels[:,ind] == True).flatten()
example_indices = np.union1d(example_indices, tmp).astype(np.int32)
features = features[example_indices].astype(np.float32)
example_indices = [[x for i in key_indices] for x in example_indices]
labels = labels[example_indices, key_indices].astype(np.int)
# Evaluate results on the test set
print(features)
print(labels)
eval_input_fn = tf.estimator.inputs.numpy_input_fn(x=features, y=labels, batch_size=1, num_epochs=1, shuffle=False)
eval_results = classifier.evaluate(input_fn=eval_input_fn)
print(eval_results)
我真的不知道出了什么问题,因为我遵循了相同的评估和测试流程。没有 f1 指标(仅准确度)一切正常,当我添加 f1 指标时,它在测试脚本中失败。
错误片段如下:
### trace error omitted ###
File "../models.py", line 207, in instrument_recognition_model
predictions = tf.cast(predictions["classes"],tf.float32)/tf.norm(tf.cast(predictions["classes"], tf.float32)))
### trace error ommitted ###
TheInvalidArgumentError (see above for traceback): assertion failed: [predictions must be in [0, 1]] [Condition x <= y did not hold element-wise:x (div:0) = ] [nan] [y (f1/Cast_1/x:0) = ] [1]
[[Node: f1/assert_less_equal/Assert/AssertGuard/Assert = Assert[T=[DT_STRING, DT_STRING, DT_FLOAT, DT_STRING, DT_FLOAT], summarize=3, _device="/job:localhost/replica:0/task:0/device:CPU:0"](f1/assert_less_equal/Assert/AssertGuard/Assert/Switch, f1/assert_less_equal/Assert/AssertGuard/Assert/data_0, f1/assert_less_equal/Assert/AssertGuard/Assert/data_1, f1/assert_less_equal/Assert/AssertGuard/Assert/Switch_1, f1/assert_less_equal/Assert/AssertGuard/Assert/data_3, f1/assert_less_equal/Assert/AssertGuard/Assert/Switch_2)]]
提前谢谢你
【问题讨论】:
【参考方案1】:因为您的数据可能是多类的。如果您查看 Tensorflow here 的 f1-score 的官方文档,您可以看到它仅用于二进制分类。 如果你真的想要 f1-score,你可以做类似this 的事情。
【讨论】:
以上是关于Tensorflow - 断言失败:[预测必须在 [0, 1]]的主要内容,如果未能解决你的问题,请参考以下文章
错误:静态断言失败:std::thread 参数在转换为右值后必须是可调用的
我在查询 API 时收到“断言失败:您必须在传递给 `push` 的哈希中包含一个 `id`”
FIRESTORE (7.14.3) 内部断言失败:值必须未定义或 Uint8Array
如何在 Visual Studio 2017 中禁用作为调试错误的失败断言?
InvalidArgumentError:断言失败:[Condition x == y did not hold element-wise:]