在 theano 中保存和重置多层网络的参数

Posted

技术标签:

【中文标题】在 theano 中保存和重置多层网络的参数【英文标题】:Save and reset parameters of multilayer networks in theano 【发布时间】:2016-09-02 02:41:18 【问题描述】:

我们可以使用six.moves.cPickle在python中保存和加载对象。

我使用以下代码保存并重置了 LeNet 的参数。

# save model
# params = layer3.params + layer2.params + layer1.params + layer0.params
import six.moves.cPickle as pickle
f = file('best_cnnmodel.save', 'wb')
pickle.dump(params, f, protocol=pickle.HIGHEST_PROTOCOL)
f.close()

# reset parameters
model_file = file('best_cnnmodel.save', 'rb')
params = pickle.load(model_file)
model_file.close()
layer3.W.set_value(params[0].get_value())
layer3.b.set_value(params[1].get_value())
layer2.W.set_value(params[2].get_value())
layer2.b.set_value(params[3].get_value())
layer1.W.set_value(params[4].get_value())
layer1.b.set_value(params[5].get_value())
layer0.W.set_value(params[6].get_value())
layer0.b.set_value(params[7].get_value())

对于 LeNet,代码似乎没问题。但它并不优雅。对于深度网络,我无法使用此代码保存模型。在这种情况下我该怎么办?

【问题讨论】:

【参考方案1】:

可以考虑使用json格式。它易于阅读且易于使用。

这是一个例子:

准备数据

import json


data = 
    'L1' :  'W': layer1.W, 'b': layer1.b ,
    'L2' :  'W': layer2.W, 'b': layer2.b ,
    'L3' :  'W': layer3.W, 'b': layer3.b ,

json_data = json.dumps(data)

json_data 看起来像这样:

"L2": "b": 2, "W": 17, "L3": "b": 2, "W": 10, "L1": "b": 2, "W": 1

解压数据

params = json.loads(json_data)

for k, v in params.items():
    level = int(k[1:])
    # assume you save the layer in an array, but you can use 
    # different way to store and reference the layers
    layer = layers[level]
    layer.W = v['W']
    layer.b = v['b']

【讨论】:

以上是关于在 theano 中保存和重置多层网络的参数的主要内容,如果未能解决你的问题,请参考以下文章

pandas读取csv数据header参数指定作为列索引的行索引列表形成复合(多层)列索引使用reset_index函数把行索引重置为列数据(level参数设置将原行索引中的指定层转化为列数据)

pandas读取csv数据参数指定作为行索引的数据列索引列表形成复合(多层)行索引使用reset_index函数把行索引重置为列数据(level参数设置原行索引的层列表指定需要转化为数据列的层)

VS怎么重置设置开发环境

如何在 Keras 中重置状态变量?

pandas读取csv数据index_col参数指定作为行索引的数据列索引列表形成复合(多层)行索引使用reset_index函数把行索引重置为列数据(原来的行索名称转化为列索引的最外层)

pandas读取csv数据index_col参数指定作为行索引的数据列索引列表形成复合(多层)行索引使用reset_index函数把行索引重置为列数据(原来的行索名称转化为列索引的最外层)