功能代码分享模型参数非严格性迁移
Posted MarToony|名角
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了功能代码分享模型参数非严格性迁移相关的知识,希望对你有一定的参考价值。
能够将E模型参数迁移到D模型中,即使两者的键不同。
def load_fortune_model(params, model):
from tqdm import tqdm
checkpoint = torch.load(params['pretrained'], map_location='cpu')
try:
D_dict = model.module.state_dict()
except AttributeError:
D_dict = model.state_dict()
E_dict = checkpoint["state_dict"]
logger.info("D模型中的参数层数目为:{}".format(len(D_dict.keys())))
logger.info("E模型中的参数层数目为:{}".format(len(E_dict.keys())))
logger.info("D中涉及到FC的参数层的名字:{}".format([k for k, v in D_dict.items() if "fc" in k]))
logger.info("E中涉及到FC的参数层的名字:{}".format([k for k, v in E_dict.items() if "fc" in k]))
D_count = [0, 0, 0] # 记录E中键在D中的参数层数目;与replace之后,在D中的参数曾数目;未知情况。
replace_dict = []
# 判断并记录A模型中可被借鉴的键/网络层参数。
for k, v in tqdm(list(E_dict.items())):
if k in D_dict:
D_count[0] += 1
# 情形一:如果 A模型中的键不存在于 B模型中,同时,该键的另一种形式存在于 B模型中,则该键将被B模型使用;
elif k not in D_dict and k.replace('module.', '') in D_dict:
# logger.info(colored('=> Load after remove .net: {}'.format(k), "blue"))
if "fc" not in k.replace('module.', ''):
replace_dict.append((k, k.replace('module.', '')))
# 不是 D模型要新创建键,而是E模型要新创建键;
else:
del E_dict[k]
logger.warning("成功捕捉FC层,同时未被启用,计数正常。")
D_count[1] += 1
else:
D_count[2] += 1
E_count = [0, 0, 0]
for k, v in tqdm(list(D_dict.items())):
if k in E_dict:
E_count[0] += 1
# 情形二:如果B模型中的键 不存在于A模型中,但是该键的另一种形式存在于A模型中,则该键将被B模型使用;
elif k not in E_dict and k.replace('module.', '') in E_dict:
# logger.info(colored('=> Load after adding .net: {}'.format(k), "blue"))
replace_dict.append((k.replace('module.', ''), k)) # 元组中前者属于A模型,后者属于B模型;
E_count[1] += 1
else:
E_count[2] += 1
logger.info("两模型参数文件的相融情况:{}————————{}".format(D_count, E_count))
for k, k_new in tqdm(replace_dict):
E_dict[k_new] = E_dict.pop(k)
keys1 = set(list(E_dict.keys()))
keys2 = set(list(D_dict.keys()))
set_diff = (keys1 - keys2) | (keys2 - keys1)
logger.error('#### Notice: keys that failed to load: {}'.format(len(set_diff)))
D_dict.update(E_dict)
logger.info("D中涉及到FC的参数层的名字:{}".format([k for k, v in D_dict.items() if "fc" in k]))
logger.info("E中涉及到FC的参数层的名字:{}".format([k for k, v in E_dict.items() if "fc" in k]))
model.load_state_dict(D_dict)
# 原论文中
# checkpoint = torch.load(params['pretrained'], map_location='cpu')
# try:
# model_dict = model.module.state_dict()
# except AttributeError:
# model_dict = model.state_dict()
# pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items() if k in model_dict and 'fc' not in k}
# # pretrained和revocer的区别就在于全连接层的权重被重新训练!
# logger.info("load pretrained model {}".format(params['pretrained']))
# model_dict.update(pretrained_dict)
# model.load_state_dict(model_dict)
以上是关于功能代码分享模型参数非严格性迁移的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch 迁移学习 (Transfer Learning) 代码详解