如何为语言翻译重新训练序列到序列神经网络模型?
Posted
技术标签:
【中文标题】如何为语言翻译重新训练序列到序列神经网络模型?【英文标题】:How to Retrain Sequence to Sequence Neural Network model for Language Translation? 【发布时间】:2019-07-13 18:56:43 【问题描述】:我已经训练了一个 seq2seq tensorflow 模型,用于将句子从英语翻译成西班牙语。我为 615 700 步训练了一个模型,并成功保存了模型检查点。我的英语和西班牙语句子的训练数据量都是 200 000。我想重新训练这个模型,从 615 700 步中得到 10K 新数据句子。为此,我正在使用序列到序列 tensoflow 模型。如何从最后一个检查点开始重新训练模型? Here 是我用于翻译的链接。
我的 train 文件夹中有 3 种类型的文件:
.index
.meta
.data
and checkpoint file.
我的新训练数据集文件是europarl_train.es-en.en
和europarl_train.es-en.es
,分别用于英语和西班牙语句子。
我编写了一个代码来加载我的模型 .meta 文件和权重
import data_utils
import seq2seq_model
import translate
import tensorflow as tf
with tf.Session() as sess:
saver = tf.train.import_meta_graph('/home/i9/L-T_Model_Training/16_NOV_MODEL/train/translate.ckpt-615700.meta')
saver.restore(sess,tf.train.latest_checkpoint('/home/i9/L-T_Model_Training/16_NOV_MODEL/train/.'))
我怎样才能开始为这个数据集保留?
【问题讨论】:
【参考方案1】:保存
根据TensorFlow version 2 doc,您可以使用tf.train.Checkpoint
和tf.train.CheckpointManager
类来保存您的模型。
考虑以下示例:
checkpoint_dir = './training_checkpoints' # custom directory
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(model=model) # your model variable name
manager = tf.train.CheckpointManager(checkpoint=checkpoint, directory=checkpoint_dir, max_to_keep=3) # max_to_keep means how much of last checkpoints number you like to keep
现在,如果您想保存模型,请输入:manager.save()
加载
再次定义 checkpoint 和 checkpointManager 并运行此代码:
if manager.latest_checkpoint:
checkpoint.restore((manager.latest_checkpoint)).assert_consumed()
print("Restored from ".format(manager.latest_checkpoint))
如果遇到类似 (AssertionError: Unresolved object in checkpoint (root)) 的错误,请将 assert_consumed
替换为 expect_partial
。 (去这里看看区别:link)
模型已从检查点加载。 现在您可以加载数据并修复形状并继续训练您的模型。
【讨论】:
以上是关于如何为语言翻译重新训练序列到序列神经网络模型?的主要内容,如果未能解决你的问题,请参考以下文章