自定义指标仅适用于急切执行
Posted
技术标签:
【中文标题】自定义指标仅适用于急切执行【英文标题】:Custom metrics work with eager execution only 【发布时间】:2021-12-23 10:49:36 【问题描述】:根据我对earlier 提出的问题的回答,我正在尝试使自定义指标word_accuracy
和char_accuracy
在张量流中与CRNN-CTC 模型implementation 一起使用。运行以下几行后,它在链接中运行良好:
import tensorflow as tf
tf.config.run_functions_eagerly(True)
这里是CTC自定义层以及精度计算函数:
def calculate_accuracy(y_true, y_pred, metric, unknown_placeholder):
y_pred = tf.stack(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)
unknown_indices = tf.where(y_pred == -1)
y_pred = tf.tensor_scatter_nd_update(
y_pred,
unknown_indices,
tf.cast(tf.ones(unknown_indices.shape[0]) * unknown_placeholder, tf.int64),
)
if metric == 'word':
return tf.where(tf.reduce_all(y_true == y_pred, 1)).shape[0] / y_true.shape[0]
if metric == 'char':
return tf.where(y_true == y_pred).shape[0] / tf.reduce_prod(y_true.shape)
return 0
class CTCLayer(Layer):
def __init__(self, max_label_length, unknown_placeholder, **kwargs):
super().__init__(**kwargs)
self.max_label_length = max_label_length
self.unknown_placeholder = unknown_placeholder
def call(self, *args):
y_true, y_pred = args
batch_length = tf.cast(tf.shape(y_true)[0], dtype='int64')
input_length = tf.cast(tf.shape(y_pred)[1], dtype='int64')
label_length = tf.cast(tf.shape(y_true)[1], dtype='int64')
input_length = input_length * tf.ones(shape=(batch_length, 1), dtype='int64')
label_length = label_length * tf.ones(shape=(batch_length, 1), dtype='int64')
loss = tf.keras.backend.ctc_batch_cost(
y_true, y_pred, input_length, label_length
)
if y_true.shape[1] is not None: # this is to prevent an error at model creation
predictions = decode_batch_predictions(y_pred, self.max_label_length)
self.add_metric(
calculate_accuracy(
y_true, predictions, 'word', self.unknown_placeholder
),
'word_accuracy',
)
self.add_metric(
calculate_accuracy(
y_true, predictions, 'char', self.unknown_placeholder
),
'char_accuracy',
)
self.add_loss(loss)
return y_pred
if y_true.shape[1] is not None
块旨在防止在创建模型时发生错误,因为传递的是占位符而不是实际张量。如果 if 语句不存在会发生以下情况(无论是否急切执行,我仍然会遇到相同的错误)
3 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
697 except Exception as e: # pylint:disable=broad-except
698 if hasattr(e, 'ag_error_metadata'):
--> 699 raise e.ag_error_metadata.to_exception(e)
700 else:
701 raise
ValueError: Exception encountered when calling layer "ctc_loss" (type CTCLayer).
in user code:
File "<ipython-input-6-fabf4ec5a640>", line 67, in call *
predictions = decode_batch_predictions(y_pred, self.max_label_length)
File "<ipython-input-6-fabf4ec5a640>", line 23, in decode_batch_predictions *
results = tf.keras.backend.ctc_decode(
File "/usr/local/lib/python3.7/dist-packages/keras/backend.py", line 6436, in ctc_decode
inputs=y_pred, sequence_length=input_length)
ValueError: Shape must be rank 1 but is rank 0 for 'node ctc_loss/CTCGreedyDecoder = CTCGreedyDecoder[T=DT_FLOAT, blank_index=-1, merge_repeated=true](ctc_loss/Log_1, ctc_loss/Cast_9)' with input shapes: [31,?,20], [].
Call arguments received:
• args=('tf.Tensor(shape=(None, None), dtype=float32)', 'tf.Tensor(shape=(None, 31, 20), dtype=float32)')
注意:在图表执行中,标签的形状总是(None, None)
,因此添加指标的 if 块下的代码永远不会执行。要使指标起作用,只需运行我包含的笔记本而不进行修改,稍后再对其进行修改以重现错误。
以下是启用 Eager Execution 时应该看到的内容:
/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py:4527: UserWarning: Even though the `tf.config.experimental_run_functions_eagerly` option is set, this option does not apply to tf.data functions. To force eager execution of tf.data functions, please use `tf.data.experimental.enable_debug_mode()`.
"Even though the `tf.config.experimental_run_functions_eagerly` "
Epoch 1/100
59/Unknown - 42s 177ms/step - loss: 18.1605 - word_accuracy: 0.0000e+00 - char_accuracy: 2.1186e-04
Epoch 00001: val_loss improved from inf to 17.36043, saving model to 1k_captcha.tf
59/59 [==============================] - 44s 213ms/step - loss: 18.1605 - word_accuracy: 0.0000e+00 - char_accuracy: 2.1186e-04 - val_loss: 17.3604 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0000e+00
Epoch 2/100
59/59 [==============================] - ETA: 0s - loss: 16.1261 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0021
Epoch 00002: val_loss improved from 17.36043 to 16.20875, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 210ms/step - loss: 16.1261 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0021 - val_loss: 16.2087 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0000e+00
Epoch 3/100
59/59 [==============================] - ETA: 0s - loss: 15.8597 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0110
Epoch 00003: val_loss improved from 16.20875 to 16.11712, saving model to 1k_captcha.tf
59/59 [==============================] - 12s 204ms/step - loss: 15.8597 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0110 - val_loss: 16.1171 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0071
Epoch 4/100
59/59 [==============================] - ETA: 0s - loss: 15.3741 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0184
Epoch 00004: val_loss did not improve from 16.11712
59/59 [==============================] - 12s 207ms/step - loss: 15.3741 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0184 - val_loss: 16.6811 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0143
Epoch 5/100
59/59 [==============================] - ETA: 0s - loss: 14.9846 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0225
Epoch 00005: val_loss improved from 16.11712 to 15.23923, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 214ms/step - loss: 14.9846 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0225 - val_loss: 15.2392 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0268
Epoch 6/100
59/59 [==============================] - ETA: 0s - loss: 14.4598 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0258
Epoch 00006: val_loss did not improve from 15.23923
59/59 [==============================] - 12s 207ms/step - loss: 14.4598 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0258 - val_loss: 18.6373 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0089
Epoch 7/100
59/59 [==============================] - ETA: 0s - loss: 13.8650 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0335
Epoch 00007: val_loss improved from 15.23923 to 14.37547, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 215ms/step - loss: 13.8650 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0335 - val_loss: 14.3755 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0393
Epoch 8/100
59/59 [==============================] - ETA: 0s - loss: 13.1221 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0422
Epoch 00008: val_loss did not improve from 14.37547
59/59 [==============================] - 13s 208ms/step - loss: 13.1221 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0422 - val_loss: 14.4376 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0393
Epoch 9/100
59/59 [==============================] - ETA: 0s - loss: 12.2508 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0780
Epoch 00009: val_loss did not improve from 14.37547
59/59 [==============================] - 13s 211ms/step - loss: 12.2508 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0780 - val_loss: 14.8398 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.0500
Epoch 10/100
59/59 [==============================] - ETA: 0s - loss: 11.0290 - word_accuracy: 0.0000e+00 - char_accuracy: 0.1460
Epoch 00010: val_loss did not improve from 14.37547
59/59 [==============================] - 13s 215ms/step - loss: 11.0290 - word_accuracy: 0.0000e+00 - char_accuracy: 0.1460 - val_loss: 14.4219 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.1054
Epoch 11/100
59/59 [==============================] - ETA: 0s - loss: 9.8587 - word_accuracy: 0.0011 - char_accuracy: 0.2004
Epoch 00011: val_loss improved from 14.37547 to 10.11944, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 212ms/step - loss: 9.8587 - word_accuracy: 0.0011 - char_accuracy: 0.2004 - val_loss: 10.1194 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.1750
Epoch 12/100
59/59 [==============================] - ETA: 0s - loss: 8.6827 - word_accuracy: 0.0032 - char_accuracy: 0.2388
Epoch 00012: val_loss did not improve from 10.11944
59/59 [==============================] - 13s 216ms/step - loss: 8.6827 - word_accuracy: 0.0032 - char_accuracy: 0.2388 - val_loss: 10.3900 - val_word_accuracy: 0.0089 - val_char_accuracy: 0.1714
Epoch 13/100
59/59 [==============================] - ETA: 0s - loss: 7.4976 - word_accuracy: 0.0127 - char_accuracy: 0.3047
Epoch 00013: val_loss improved from 10.11944 to 8.38430, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 215ms/step - loss: 7.4976 - word_accuracy: 0.0127 - char_accuracy: 0.3047 - val_loss: 8.3843 - val_word_accuracy: 0.0179 - val_char_accuracy: 0.2714
Epoch 14/100
59/59 [==============================] - ETA: 0s - loss: 6.6434 - word_accuracy: 0.0508 - char_accuracy: 0.3519
Epoch 00014: val_loss did not improve from 8.38430
59/59 [==============================] - 13s 217ms/step - loss: 6.6434 - word_accuracy: 0.0508 - char_accuracy: 0.3519 - val_loss: 9.5689 - val_word_accuracy: 0.0000e+00 - val_char_accuracy: 0.2571
Epoch 15/100
59/59 [==============================] - ETA: 0s - loss: 5.3200 - word_accuracy: 0.1398 - char_accuracy: 0.4271
Epoch 00015: val_loss improved from 8.38430 to 6.74445, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 214ms/step - loss: 5.3200 - word_accuracy: 0.1398 - char_accuracy: 0.4271 - val_loss: 6.7445 - val_word_accuracy: 0.0804 - val_char_accuracy: 0.3482
Epoch 16/100
59/59 [==============================] - ETA: 0s - loss: 4.4252 - word_accuracy: 0.2108 - char_accuracy: 0.4799
Epoch 00016: val_loss improved from 6.74445 to 5.40682, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 222ms/step - loss: 4.4252 - word_accuracy: 0.2108 - char_accuracy: 0.4799 - val_loss: 5.4068 - val_word_accuracy: 0.1161 - val_char_accuracy: 0.4446
Epoch 17/100
59/59 [==============================] - ETA: 0s - loss: 3.8119 - word_accuracy: 0.2691 - char_accuracy: 0.5206
Epoch 00017: val_loss improved from 5.40682 to 4.76755, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 220ms/step - loss: 3.8119 - word_accuracy: 0.2691 - char_accuracy: 0.5206 - val_loss: 4.7676 - val_word_accuracy: 0.1964 - val_char_accuracy: 0.4929
Epoch 18/100
59/59 [==============================] - ETA: 0s - loss: 3.1290 - word_accuracy: 0.3379 - char_accuracy: 0.5712
Epoch 00018: val_loss improved from 4.76755 to 4.45828, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 221ms/step - loss: 3.1290 - word_accuracy: 0.3379 - char_accuracy: 0.5712 - val_loss: 4.4583 - val_word_accuracy: 0.2768 - val_char_accuracy: 0.5375
Epoch 19/100
59/59 [==============================] - ETA: 0s - loss: 2.6048 - word_accuracy: 0.4163 - char_accuracy: 0.6267
Epoch 00019: val_loss improved from 4.45828 to 4.13174, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 222ms/step - loss: 2.6048 - word_accuracy: 0.4163 - char_accuracy: 0.6267 - val_loss: 4.1317 - val_word_accuracy: 0.2054 - val_char_accuracy: 0.5143
Epoch 20/100
59/59 [==============================] - ETA: 0s - loss: 2.1555 - word_accuracy: 0.5117 - char_accuracy: 0.6979
Epoch 00020: val_loss improved from 4.13174 to 3.35257, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 223ms/step - loss: 2.1555 - word_accuracy: 0.5117 - char_accuracy: 0.6979 - val_loss: 3.3526 - val_word_accuracy: 0.3482 - val_char_accuracy: 0.5518
Epoch 21/100
59/59 [==============================] - ETA: 0s - loss: 1.8185 - word_accuracy: 0.5604 - char_accuracy: 0.7284
Epoch 00021: val_loss did not improve from 3.35257
59/59 [==============================] - 13s 223ms/step - loss: 1.8185 - word_accuracy: 0.5604 - char_accuracy: 0.7284 - val_loss: 3.5486 - val_word_accuracy: 0.3304 - val_char_accuracy: 0.5500
Epoch 22/100
59/59 [==============================] - ETA: 0s - loss: 1.4279 - word_accuracy: 0.6578 - char_accuracy: 0.8021
Epoch 00022: val_loss improved from 3.35257 to 2.97987, saving model to 1k_captcha.tf
59/59 [==============================] - 14s 229ms/step - loss: 1.4279 - word_accuracy: 0.6578 - char_accuracy: 0.8021 - val_loss: 2.9799 - val_word_accuracy: 0.3750 - val_char_accuracy: 0.6679
Epoch 23/100
59/59 [==============================] - ETA: 0s - loss: 1.1666 - word_accuracy: 0.7278 - char_accuracy: 0.8417
Epoch 00023: val_loss did not improve from 2.97987
59/59 [==============================] - 13s 224ms/step - loss: 1.1666 - word_accuracy: 0.7278 - char_accuracy: 0.8417 - val_loss: 5.2543 - val_word_accuracy: 0.1429 - val_char_accuracy: 0.4768
Epoch 24/100
59/59 [==============================] - ETA: 0s - loss: 1.0938 - word_accuracy: 0.7511 - char_accuracy: 0.8576
Epoch 00024: val_loss improved from 2.97987 to 2.72415, saving model to 1k_captcha.tf
59/59 [==============================] - 14s 226ms/step - loss: 1.0938 - word_accuracy: 0.7511 - char_accuracy: 0.8576 - val_loss: 2.7242 - val_word_accuracy: 0.4911 - val_char_accuracy: 0.7250
Epoch 25/100
59/59 [==============================] - ETA: 0s - loss: 0.8378 - word_accuracy: 0.7977 - char_accuracy: 0.8837
Epoch 00025: val_loss improved from 2.72415 to 2.47315, saving model to 1k_captcha.tf
59/59 [==============================] - 13s 223ms/step - loss: 0.8378 - word_accuracy: 0.7977 - char_accuracy: 0.8837 - val_loss: 2.4731 - val_word_accuracy: 0.4554 - val_char_accuracy: 0.6964
Epoch 26/100
59/59 [==============================] - ETA: 0s - loss: 0.6497 - word_accuracy: 0.8633 - char_accuracy: 0.9195
Epoch 00026: val_loss improved from 2.47315 to 2.10521, saving model to 1k_captcha.tf
59/59 [==============================] - 14s 227ms/step - loss: 0.6497 - word_accuracy: 0.8633 - char_accuracy: 0.9195 - val_loss: 2.1052 - val_word_accuracy: 0.4821 - val_char_accuracy: 0.6929
Epoch 27/100
59/59 [==============================] - ETA: 0s - loss: 0.4810 - word_accuracy: 0.9153 - char_accuracy: 0.9528
Epoch 00027: val_loss did not improve from 2.10521
59/59 [==============================] - 14s 226ms/step - loss: 0.4810 - word_accuracy: 0.9153 - char_accuracy: 0.9528 - val_loss: 2.5292 - val_word_accuracy: 0.4375 - val_char_accuracy: 0.7054
Epoch 28/100
59/59 [==============================] - ETA: 0s - loss: 0.4621 - word_accuracy: 0.9121 - char_accuracy: 0.9500
Epoch 00028: val_loss did not improve from 2.10521
59/59 [==============================] - 14s 224ms/step - loss: 0.4621 - word_accuracy: 0.9121 - char_accuracy: 0.9500 - val_loss: 2.1713 - val_word_accuracy: 0.4821 - val_char_accuracy: 0.7268
要重现问题,如果您之前运行过 notebook,您可能需要重新启动运行时,然后尝试在不急切执行的情况下运行,并且指标将永远不会显示。如果要重现错误,请注释掉 if y_true.shape[1] is not None
行并将 if 块与其余代码合并。我需要在提供的笔记本中进行哪些修改才能使指标按照演示的方式工作,而无需使用 Eager Execution?
【问题讨论】:
【参考方案1】:您可能不喜欢这种解决方案,但您可以尝试更改您的 calculate_accuracy
和 decode_batch_predictions
函数,以便它们只使用 tf
操作:
def decode_batch_predictions(predictions, max_label_length, char_lookup=None, increment=0):
input_length = tf.cast(tf.ones(tf.shape(predictions)[0]), dtype=tf.int32) * tf.cast(tf.shape(predictions)[1], dtype=tf.int32)
results = tf.keras.backend.ctc_decode(
predictions, input_length=input_length, greedy=True
)[0][0][:, :max_label_length] + increment
if char_lookup: # For inference
output = []
for result in results:
result = tf.strings.reduce_join(char_lookup(result)).numpy().decode('utf-8')
output.append(result)
return output
else: # For training
output = tf.TensorArray(tf.int64, size=0, dynamic_size=True)
for result in results:
output = output.write(output.size(), result)
return output.stack()
def calculate_accuracy(y_true, y_pred, metric, unknown_placeholder):
y_pred = tf.stack(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)
unknown_indices = tf.where(y_pred == -1)
y_pred = tf.tensor_scatter_nd_update(
y_pred,
unknown_indices,
tf.cast(tf.ones(tf.shape(unknown_indices)[0]) * unknown_placeholder, tf.int64),
)
if metric == 'word':
return tf.shape(tf.where(tf.reduce_all(y_true == y_pred, 1)))[0] / tf.shape(y_true)[0]
if metric == 'char':
return tf.shape(tf.where(y_true == y_pred))[0] / tf.reduce_prod(tf.shape(y_true))
return 0
Writing example: 936/1040 [90.0 %] to e7fe398b-da12-4176-a91c-84a8ca076937-train.tfrecord
Writing example: 1040/1040 [100.0 %] to e7fe398b-da12-4176-a91c-84a8ca076937-valid.tfrecord
Epoch 1/100
59/Unknown - 107s 470ms/step - loss: 18.2176 - word_accuracy: 0.0000e+00 - char_accuracy: 0.0015
Epoch 00001: val_loss improved from inf to 16.23781, saving model to 1k_captcha.tf
这样您就不必使用tf.config.run_functions_eagerly(True)
或if y_true.shape[1] is not None
。
【讨论】:
我最终将calculate_accuracy
转换为numpy 等效项并使用tf.numpy_function
调用它,这也解决了这个问题而无需急于执行。我认为您的解决方案可能会稍微快一些,所以我会尝试一下,谢谢。以上是关于自定义指标仅适用于急切执行的主要内容,如果未能解决你的问题,请参考以下文章
FxCop 中的自定义规则仅适用于由特定类型的方法调用的方法