Module API
Posted 三年一梦
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Module API相关的知识,希望对你有一定的参考价值。
module或简写为mod,提供一个用于执行Symbol算的中高级接口,可理解为module是执行Symbol定义好的程序的机器。
module.Module接受Symbol作为输入:
data = mx.sym.Variable(\'data\') fc1 = mx.sym.FullyConnected(data, name=\'fc1\', num_hidden=128) act1 = mx.sym.Activation(fc1, name=\'relu1\', act_type="relu") fc2 = mx.sym.FullyConnected(act1, name=\'fc2\', num_hidden=10) out = mx.sym.SoftmaxOutput(fc2, name = \'softmax\') mod = mx.mod.Module(out) # create a module by given a Symbol 根据symbol建立module
关于module的一套训练流程见这里。本节的目的是选择一些常用module API,包括一些重要的属性和方法做个分析。
module包提供了以下几个module:最主要的还是第二个。
显然BaseModule是其他所有module的基类, 基类提供以下方法:
1. 初始化空间:
2. 参数操作:
3. 训练预测
4. 前反向传播
5. 参数更新
6. 输入输出
7. 其他
以上这些方法是所有类共有的,而Module类自己还有下面的内置方法:
这么多方法选择一些常用重要的方法进行介绍。
module表示计算组件。可把module看作是一台计算机器。模块可以执行向前和向后传递并更新模型中的参数。
一个module有几个状态:
- 初始状态-initial state:内存尚未分配,因此模块尚未准备好进行计算。
- 绑定-binded:输入、输出和参数的形状都是已知的,内存已分配,模块已准备好进行计算。
- 参数已初始化:对于具有参数的模块,在初始化参数之前执行计算可能会导致未定义的输出。
- 优化器已安装:优化器可以安装到module。在此之后,在计算梯度(前向后)之后,可以根据优化器更新模块的参数。
before binding:为了使module产生交互,它必须能够在初始状态(bind之前)已知以下信息:
- 数据名称data_names:表示所需输入数据名称的字符串类型列表。
- 输出名称output_names:表示所需输出名称的字符串类型列表。
after binding:绑定后,module应能提供以下更丰富的信息:
- 状态信息:
binded:bool,指示是否已分配计算所需的内存缓冲区。
for_training:模块是否绑定进行训练。
params_initialized:bool,指示此模块的参数是否已初始化。
optimizer_initialized:bool,指示是否定义并初始化了优化器。
inputs_need_grad:bool,指示是否需要相输入数据的梯度。在实现模块组合时可能很有用。
- 输入输出信息:
data_shapes:(name、shape)的列表。理论上,由于内存是分配的,可以直接提供数据数组。但在数据并行的情况下,数据数组的形状可能与从外部世界看的不同。
label_shapes:(name、shape)的列表。如果模块不需要标签(如顶部不包含loss函数),或者模块未绑定以进行训练,则此值可能为[]。
outpu_shapes:(name、shape)的列表。
- 参数信息:
get_params():返回一个(arg_params,aux_params)的元组。每一个都是一个name到NDArray的映射。由于NDArray总是使用CPU,用于计算的实际参数可能存在于其他设备(GPU)上,此函数将检索最新参数。
set_params(arg_params,aux_params):为执行计算的设备分配参数。
init_params(…):一个更灵活的接口来分配或初始化参数。
- setup:
bind():为计算准备环境。
init_optimizer():安装用于参数更新的优化器。
prepare():根据当前数据批准备模块。
- 计算:
forward(data_batch):前向操作。
backward(out_grads=None):反向操作。
update():根据安装的优化器更新参数。
get_outputs():获取上一个前向操作的输出。
get_input_grads():获取与上一个后向操作中计算的输入的梯度。
update_metric(metric,labels,pre_sliced=False):更新之前前向传播结果的性能度量,就是metric。
当这些中间层API被正确实现时,以下高级API将自动可用于模块:
- fit:在数据集上训练模块参数。
- predict:在数据集上运行预测并收集输出。
- score:在数据集上运行预测并评估性能
1. score
(eval_data, eval_metric, num_batch=None, batch_end_callback=None, score_end_callback=None, reset=True, epoch=0, sparse_row_id_fn=None)[source]
这个方法用于预测eval_data的结果,并根据eval_metric提供的指标进行评估
- eval_data(DataIte类型):要运行预测的评估数据。
- eval_metric(EvalMetric或EvalMetrics列表):要使用的评估度量。
- num_batch(int):要运行的批数。默认为“无”,表示在DataIter完成之前运行。
- batch_end_callback:(函数)也可以是函数列表。
- reset(bool):默认为True。指示在开始计算之前是否应重置计算数据。
- epoch(int):默认为0。为了兼容性,这将传递给回调(如果有的话)。在训练期间,这将与训练epoch数相对应。
- sparse_row_id_fn(回调函数):函数将数据批作为输入并返回str->NDArray的dict。结果dict用于从kvstore中提取行稀疏参数,其中str键是参数的名称,值是要提取的参数的行id。
一个例子:
# An example of using score for prediction. # Evaluate accuracy on val_dataiter metric = mx.metric.Accuracy() mod.score(val_dataiter, metric) mod.score(val_dataiter, [\'mse\', \'acc\'])
2. predict
(eval_data, num_batch=None, merge_batches=True, reset=True, always_output_list=False, sparse_row_id_fn=None)
这方法主要是得到eval_data的测试结果。
# An example of using `predict` for prediction. # Predict on the first 10 batches of val_dataiter mod.predict(eval_data=val_dataiter, num_batch=10)
3. fit
(train_data, eval_data=None, eval_metric=\'acc\', epoch_end_callback=None, batch_end_callback=None, kvstore=\'local\', optimizer=\'sgd\', optimizer_params=((\'learning_rate\', 0.01), ), eval_end_callback=None, eval_batch_end_callback=None, initializer=, arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None)[source]
这个是最重要的方法,用于训练网络。
- train_data(DataIter型):训练集数据迭代器。
- eval_data(DataIter型):如果不是None,则将用作验证集,并评估每个epoch之后的性能。
- eval_metric(str或EvalMetric):默认为“准确率”。用于在训练期间显示的性能度量。其他可能的预定义度量是:“ce”(交叉熵)、“f1”、“mae”、“mse”、“rmse”、“top_k_accurity”。
- epoch_end_callback(函数或函数列表):将使用当前epoch、symbol、arg_params和aux_params调用每个回调。
- batch_end_callback(函数或函数列表):将使用BatchEndParam调用每个回调。
- kvstore(str或kvstore):默认为“local”。
- optimizer(str或optimizer):默认为“sgd”。
- optimizer_params(dict)–默认为((“learninf_rate”,0.01),)。优化器构造函数的参数。
- eval_end_callback(函数或函数列表):这些函数将在每次完整评估结束时调用,度量值覆盖整个eval集。
- eval_batch_end_callback(函数或函数列表):在评估期间,将在每个batch之后调用这些函数。
- initializer(initializer):调用初始值设定项以在尚未初始化模块参数时对其进行初始化。
- arg_params(dict)–默认值为None,如果不是None,则应该是来自经过训练的模型的现有参数或从checkpoint(以前保存的模型)加载的参数。在这种情况下,这里的值将用于初始化模块参数,除非它们已经由用户通过调用init_params或fit进行了初始化。arg_params的优先级高于初始值设定项。
- aux_params(dict)-默认为无。类似于arg_params,除了用于辅助状态。
- allow_missing(bool)-默认为False。指示当arg_params和aux_params不是None时是否允许缺少参数。如果这是真的,那么丢失的参数将通过初始化器初始化。
- force_rebind(bool)-默认为False。是否强制重新绑定已绑定的执行器。
- force_init(bool)-默认为False。指示是否强制初始化,即使参数已初始化。
- begin_epoch(int)-默认为0。通常,如果从epoch N上一个训练阶段保存的checkpoint恢复,则该值应为N+1。
- num_epoch(int)–训练的epoch数。
- sparse_row_id_fn(回调函数)–函数将数据批作为输入并返回str->NDArray的dict。结果dict用于从kvstore中提取行稀疏参数,其中str键是参数的名称,值是要提取的参数的行id。
# An example of using fit for training. # Assume training dataIter and validation dataIter are ready # Assume loading a previously checkpointed model sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3) mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer=\'sgd\', optimizer_params={\'learning_rate\':0.01, \'momentum\': 0.9}, arg_params=arg_params, aux_params=aux_params, eval_metric=\'acc\', num_epoch=10, begin_epoch=3)
4. get_params()
返回一对字典。类型分别是arg_params和aux_params。每个字典都是从参数映射到NDArray:
# An example of getting module parameters. print mod.get_params()
5. init_params和set_params
init_params
(initializer=, arg_params=None, aux_params=None, allow_missing=False, force_init=False, allow_extra=False)
set_params
(arg_params, aux_params, allow_missing=False, force_init=True, allow_extra=False)[source]
前者初始化参数和辅助参数的状态,后者给参数和辅助参数赋值。
# An example of initializing module parameters. mod.init_params() # An example of setting module parameters. sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, n_epoch_load) mod.set_params(arg_params=arg_params, aux_params=aux_params)
6. load_params和save_params
载入和保存参数
# An example of saving module parameters. mod.save_params(\'myfile\') # An example of loading module parameters. mod.load_params(\'myfile\')
7. forawrd(data_batch, is_train=None)和backwar(out_grads=None)
data_batch是DataBatch类型。
import mxnet as mx
from collections import namedtuple
Batch = namedtuple(\'Batch\', [\'data\'])
data = mx.sym.Variable(\'data\')
out = data * 2
mod = mx.mod.Module(symbol=out, label_names=None)
mod.bind(data_shapes=[(\'data\', (1, 10))])
mod.init_params()
data1 = [mx.nd.ones((1, 10))]
mod.forward(Batch(data1))
print mod.get_outputs()[0].asnumpy()
# Forward with data batch of different shape
data2 = [mx.nd.ones((3, 5))]
mod.forward(Batch(data2))
print mod.get_outputs()[0].asnumpy()
# An example of backward computation. mod.backward() print mod.get_input_grads()[0].asnumpy() ]
上面在前向和反向传播时分别用到了get_outputs(merge_multi_context=True)和get_input_grads(merge_multi_context=True)两个函数:
前者得到前向计算后的输出,后者得到反向传播后关于输入的梯度。
8. init_optimizer
(kvstore=\'local\', optimizer=\'sgd\', optimizer_params=((\'learning_rate\', 0.01), ), force_init=False)
指定了这个还需指定fit里面那个optimizer吗???
# An example of initializing optimizer. mod.init_optimizer(optimizer=\'sgd\', optimizer_params=((\'learning_rate\', 0.005),))
9. update()和update_metric(eval_metric, labels,pre_sliced=False)
update根据已配置的优化器和上一个前向-反向批处理中计算的梯度更新参数。
# An example of updating module parameters. mod.init_optimizer(kvstore=\'local\', optimizer=\'sgd\', optimizer_params=((\'learning_rate\', 0.01), )) mod.backward() mod.update() print mod.get_params()[0][\'fc3_weight\'].asnumpy() ]
update_metric对上次前向计算的输出求值并累积评估metric。
# An example of updating evaluation metric. mod.forward(data_batch) mod.update_metric(metric, data_batch.label)
10. bind
(data_shapes, label_shapes=None, for_training=True, inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req=\'write\')
非常重要,不用fit这个高级api的话,就需要bind来搞起训练。
# An example of binding symbols. mod.bind(data_shapes=[(\'data\', (1, 10, 10))]) # Assume train_iter is already created. mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
以上是关于Module API的主要内容,如果未能解决你的问题,请参考以下文章
onActivityResult 未在 Android API 23 的片段上调用
JMeter:逻辑控制器_模块控制器(Module Controller)
Android 插件化VirtualApp 源码分析 ( 目前的 API 现状 | 安装应用源码分析 | 安装按钮执行的操作 | 返回到 HomeActivity 执行的操作 )(代码片段
jmeter的Include Controller控件和Test Fragment控件和Module Controller控件