保存 OCR 模型以从 Keras 读取验证码作者:A_K_Nain
Posted
技术标签:
【中文标题】保存 OCR 模型以从 Keras 读取验证码作者:A_K_Nain【英文标题】:Save the OCR model for reading Captchas from Keras Author: A_K_Nain 【发布时间】:2021-09-10 02:49:23 【问题描述】:我正在研究应用于 colab 中 Kaggle 的 word mnist 数据集的 OCR 模型。我受到来自 ocr 验证码的模型的启发,该模型具有由 A_K_Nain 在站点托管的 Keras 示例中编写的 LSTM 和 CTC 层:https://keras.io/examples/vision/captcha_ocr/ 我想保存模型,但是当我尝试加载它以对看不见的数据进行预测时。我收到未知 CTClayer 的错误。 ctclaer 不是在模型内部而是在模型外部定义的问题,所以当我尝试加载模型时,我会遇到错误。我找到了使用自定义模型的解决方案,但对我没有任何作用。 如何保存托管在以下站点中的模型:https://keras.io/examples/vision/captcha_ocr/
【问题讨论】:
【参考方案1】:CTC 层不用于进行预测,因此您可以像这样在不使用 CTC 层的情况下进行保存:-
saving_model = keras.models.Model(model.get_layer(name="image").input, model.get_layer(name="dense2").output
)
saving_model.summary()
saving_model.save("model_tf")
除此之外,您必须进行一些更改才能使此代码在变量中工作:-
max_length = max([len(label) for label in labels])
outfile = open("max_length",'wb')
pickle.dump(max_length,outfile)
outfile.close()
import string
chars = string.printable
chars = chars[:-5]
characters = [c for c in chars]
这将给出一组定义的字符,这将有助于预测,因此在预测部分你必须做:-
infile = open("max_length",'rb')
max_length = pickle.load(infile)
infile.close()
import string
chars = string.printable
chars = chars[:-5]
characters = [c for c in chars]
# Mapping characters to integers
char_to_num = layers.experimental.preprocessing.StringLookup(
vocabulary=characters, mask_token=None
)
# Mapping integers back to original characters
num_to_char = layers.experimental.preprocessing.StringLookup(
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)
prediction_model = tf.keras.models.load_model('model_tf')
然后继续。
【讨论】:
【参考方案2】:这是我们如何使用作者 A_K_Nain 代码预测新图像的方法。从同一代码中加载相关函数。
test_img_path =['/path/to/test/image/117011.png']
validation_dataset = tf.data.Dataset.from_tensor_slices((test_img_path[0:1], ['']))
validation_dataset = (
validation_dataset.map(
encode_single_sample, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
.batch(batch_size)
.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
)
for batch in validation_dataset.take(1):
#print(batch['image'])
preds = reconstructed_model.predict(batch['image']) # reconstructed_model is saved trained model
pred_texts = decode_batch_predictions(preds)
print(pred_texts)
【讨论】:
以上是关于保存 OCR 模型以从 Keras 读取验证码作者:A_K_Nain的主要内容,如果未能解决你的问题,请参考以下文章