功能代码分享模型参数非严格性迁移

Posted MarToony|名角

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了功能代码分享模型参数非严格性迁移相关的知识,希望对你有一定的参考价值。

能够将E模型参数迁移到D模型中,即使两者的键不同。

def load_fortune_model(params, model):
    from tqdm import tqdm

    checkpoint = torch.load(params['pretrained'], map_location='cpu')
    try:
        D_dict = model.module.state_dict()
    except AttributeError:
        D_dict = model.state_dict()

    E_dict = checkpoint["state_dict"]

    logger.info("D模型中的参数层数目为:{}".format(len(D_dict.keys())))
    logger.info("E模型中的参数层数目为:{}".format(len(E_dict.keys())))

    logger.info("D中涉及到FC的参数层的名字:{}".format([k for k, v in D_dict.items() if "fc" in k]))
    logger.info("E中涉及到FC的参数层的名字:{}".format([k for k, v in E_dict.items() if "fc" in k]))

    D_count = [0, 0, 0]  # 记录E中键在D中的参数层数目;与replace之后,在D中的参数曾数目;未知情况。

    replace_dict = []
    # 判断并记录A模型中可被借鉴的键/网络层参数。
    for k, v in tqdm(list(E_dict.items())):
        if k in D_dict:
            D_count[0] += 1
        # 情形一:如果 A模型中的键不存在于 B模型中,同时,该键的另一种形式存在于 B模型中,则该键将被B模型使用;
        elif k not in D_dict and k.replace('module.', '') in D_dict:
            # logger.info(colored('=> Load after remove .net: {}'.format(k), "blue"))
            if "fc" not in k.replace('module.', ''):
                replace_dict.append((k, k.replace('module.', '')))
                # 不是 D模型要新创建键,而是E模型要新创建键;
            else:
                del E_dict[k]
                logger.warning("成功捕捉FC层,同时未被启用,计数正常。")

            D_count[1] += 1
        else:
            D_count[2] += 1

    E_count = [0, 0, 0]
    for k, v in tqdm(list(D_dict.items())):
        if k in E_dict:
            E_count[0] += 1
        # 情形二:如果B模型中的键 不存在于A模型中,但是该键的另一种形式存在于A模型中,则该键将被B模型使用;
        elif k not in E_dict and k.replace('module.', '') in E_dict:
            # logger.info(colored('=> Load after adding .net: {}'.format(k), "blue"))
            replace_dict.append((k.replace('module.', ''), k))  # 元组中前者属于A模型,后者属于B模型;
            E_count[1] += 1
        else:
            E_count[2] += 1

    logger.info("两模型参数文件的相融情况:{}————————{}".format(D_count, E_count))

    for k, k_new in tqdm(replace_dict):
        E_dict[k_new] = E_dict.pop(k)

    keys1 = set(list(E_dict.keys()))
    keys2 = set(list(D_dict.keys()))
    set_diff = (keys1 - keys2) | (keys2 - keys1)

    logger.error('#### Notice: keys that failed to load: {}'.format(len(set_diff)))

    D_dict.update(E_dict)
    logger.info("D中涉及到FC的参数层的名字:{}".format([k for k, v in D_dict.items() if "fc" in k]))
    logger.info("E中涉及到FC的参数层的名字:{}".format([k for k, v in E_dict.items() if "fc" in k]))

    model.load_state_dict(D_dict)

    # 原论文中
    # checkpoint = torch.load(params['pretrained'], map_location='cpu')
    # try:
    #     model_dict = model.module.state_dict()
    # except AttributeError:
    #     model_dict = model.state_dict()
    # pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items() if k in model_dict and 'fc' not in k}
    # # pretrained和revocer的区别就在于全连接层的权重被重新训练!
    # logger.info("load pretrained model {}".format(params['pretrained']))
    # model_dict.update(pretrained_dict)
    # model.load_state_dict(model_dict)

以上是关于功能代码分享模型参数非严格性迁移的主要内容,如果未能解决你的问题,请参考以下文章

严格模型和非严格模式的区别

9.13面经

CPNtools协议建模安全分析---实例变迁标记

PyTorch 迁移学习 (Transfer Learning) 代码详解

PyTorch 迁移学习 (Transfer Learning) 代码详解

如何从 Angular.js 迁移到 Vue.js?