Callback API
Posted 三年一梦
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Callback API相关的知识,希望对你有一定的参考价值。
Callback API
用于跟踪epoch期间各种状态的回调函数。主要有6个类:
1. mxnet.callback.
module_checkpoint
(mod, prefix, period=1, save_optimizer_states=False)
参数:
- mod:BaseModule的子类。需要做checkpoint的module
- prefix:字符串,该checkpoint文件的前缀
- period:在做checkpoint之前需要等多少个epoch,默认为1
- save_optimizer_states:布尔型,表明是否保存优化器状态用于继续训练
返回:
- callback:callback函数,可被作为iter_end_callback参数传递到fit函数里。
2. mxnet.callback.
do_checkpoint
(prefix, period=1)
这个callback函数用于每隔几个epoch来保存以下模型checkpoint,每个checkpoint由几个binary files组成:一个模型描述文件和一个参数(权重和偏置)文件。模型描述文件名字为prefix-symbol.json,参数文件名字为prefix-epoch_number.params
参数:
- prefix:同上
- period:整型,可选。几个epoch来保存一次。默认为1
返回:
- callback:一个callback函数,可被作为epoch_end_callback参数传递到fit函数里。
>>> module.fit(iterator, num_epoch=n_epoch, ... epoch_end_callback = mx.callback.do_checkpoint("mymodel", 1))
Start training with [cpu(0)] Epoch[0] Resetting Data Iterator Epoch[0] Time cost=0.100 Saved checkpoint to "mymodel-0001.params" Epoch[1] Resetting Data Iterator Epoch[1] Time cost=0.060 Saved checkpoint to "mymodel-0002.params"
3. mxnet.callback.
log_train_metric
(period, auto_reset=False)
callback函数用于每隔几个周期记录训练打印结果
参数:
- period:整型,打印多少个batch的训练结果
- auto_reset:布尔型,每次打印后重置评估函数
返回:
- callback:callback函数,可被作为iter_epoch_callback参数传递到fit函数里。
4. class mxnet.callback.
Speedometer
(batch_size, frequent=50, auto_reset=True)
周期性的打印训练速度和评价指标
参数:
- batch_size:整型
- frequent:打印频率,默认每50个批量打印一次
- auto_set:同上
例子:
>>> # Print training speed and evaluation metrics every ten batches. Batch size is one. >>> module.fit(iterator, num_epoch=n_epoch, ... batch_end_callback=mx.callback.Speedometer(1, 10))
Epoch[0] Batch [10] Speed: 1910.41 samples/sec Train-accuracy=0.200000 Epoch[0] Batch [20] Speed: 1764.83 samples/sec Train-accuracy=0.400000 Epoch[0] Batch [30] Speed: 1740.59 samples/sec Train-accuracy=0.500000
5. class mxnet.callback.
ProgressBar
(total, length=80)
呈现一个进度条,表明每个epoch内批量的进度。
参数:
- total:每个epoch中所有批量的数目
- length:进度条的最大长度
例子:
>>> progress_bar = mx.callback.ProgressBar(total=2) >>> mod.fit(data, num_epoch=5, batch_end_callback=progress_bar) [========--------] 50.0% [================] 100.0%
6. class mxnet.callback.
LogValidationMetricsCallback
打印出一个epoch之后的评估结果
整体的一个例子:train_mnist.py:用到了第2个和第4个类:
model.fit(train, begin_epoch=args.load_epoch if args.load_epoch else 0, num_epoch=args.num_epochs, eval_data=val, eval_metric=eval_metrics, kvstore=kv, optimizer=args.optimizer, optimizer_params=optimizer_params, initializer=initializer, arg_params=arg_params, aux_params=aux_params, batch_end_callback=[mx.callback.Speedometer(args.batch_size, args.disp_batches)], # 每过多少个batch打印一下 epoch_end_callback=mx.callback.do_checkpoint(args.model_prefix , period=args.save_period), # 每过多少period保存模型 allow_missing=True, monitor=monitor)
以上是关于Callback API的主要内容,如果未能解决你的问题,请参考以下文章
Wordpress REST Api:add_action('rest_api_init', callback) 不调用回调