Pytorch 复制替换网络中的部分参数,网络参数的定向赋值

Posted 呆呆象呆呆

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch 复制替换网络中的部分参数,网络参数的定向赋值相关的知识,希望对你有一定的参考价值。

主要目标

尝试复制网络中的部分参数到另一个网络中

尝试将网络的部分参数做定向赋值

实验代码

最核心的部分

testnet2.state_dict()["net1.weight"].copy_(testnet.state_dict()["net1.weight"]) 
或者
testnet2.state_dict()["net1.weight"].copy_(testnet.state_dict()["net1.weight"].clone()) 
或者
testnet2.state_dict()["net1.weight"].copy_(testnet.state_dict()["net1.weight"].clone().detach()) 
好像没差区别都不会把梯度带到新的网络中间去

演示代码

可能有点长,主要是为了看清楚会不会把梯度复制过去,或者会不会下去那个浅拷贝一样进行同步的参数更新

import torch 
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torchsummary import summary
import os
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm


# 设置一下数据集   数据集的构成是随机两个整数,形成一个加法的效果 input1 + input2 = label
class TrainDataset(Dataset):
    def __init__(self):
        super(TrainDataset, self).__init__()
        self.data = []
        for i in range(1,1000):
            for j in range(1,1000):
                self.data.append([i,j])
    def __getitem__(self, index):
        input_data = self.data[index]
        label = input_data[0] + input_data[1]
        return torch.Tensor(input_data),torch.Tensor([label])
    def __len__(self):
        return len(self.data)

class TestNet(nn.Module):
    def __init__(self):
        super(TestNet, self).__init__()
        self.net1 = nn.Linear(2,1)
    def forward(self, x):
        x = self.net1(x)
        return x
class TestNet2(nn.Module):
    def __init__(self):
        super(TestNet2, self).__init__()
        self.net1 = nn.Linear(2,1)
    def forward(self, x):
        x = self.net1(x)
        return x


def train():
    traindataset = TrainDataset()
    traindataloader = DataLoader(dataset = traindataset,batch_size=1,shuffle=False)
    testnet = TestNet().cuda()
    testnet2 = TestNet2().cuda()
    myloss = nn.MSELoss().cuda()
    optimizer = optim.SGD(testnet.parameters(), lr=0.001 )


    for epoch in range(100):
        for data,label in traindataloader :
            print("\\n=====迭代开始=====")
            data = data.cuda()
            label = label.cuda()
            output = testnet(data)
            print("输入数据:",data)
            print("输出数据:",output)
            print("标签:",label)
            loss = myloss(output,label)
            optimizer.zero_grad()
            for name, parms in testnet.named_parameters():	
                print('-->name:', name)
                print('-->para:', parms)
                print('-->grad_requirs:',parms.requires_grad)
                print('-->grad_value:',parms.grad)
                print("===")
            loss.backward()
            optimizer.step()
            print("=============更新之后===========")
            for name, parms in testnet.named_parameters():	
                print('-->name:', name)
                print('-->para:', parms)
                print('-->grad_requirs:',parms.requires_grad)
                print('-->grad_value:',parms.grad)
                print("===")
            print(optimizer)

            print("=============改变之前===========")
            for name, parms in testnet2.named_parameters():	
                print('-->name:', name)
                print('-->para:', parms)
                print('-->grad_requirs:',parms.requires_grad)
                print('-->grad_value:',parms.grad)
                print("===")
            testnet2.state_dict()["net1.weight"].copy_(testnet.state_dict()["net1.weight"]) 
            print("=============改变之后===========")
            for name, parms in testnet2.named_parameters():	
                print('-->name:', name)
                print('-->para:', parms)
                print('-->grad_requirs:',parms.requires_grad)
                print('-->grad_value:',parms.grad)
                print("===")
            input("=====迭代结束=====")

if __name__ == '__main__':

    os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(3)
    train()


代码结果

LAST 参考文献

(1条消息) pytorch一种给模型参数赋值的方法。_genous110的博客-CSDN博客_pytorch修改模型参数

以上是关于Pytorch 复制替换网络中的部分参数,网络参数的定向赋值的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch冻结部分层的参数

pytorch固定部分网络参数

[Pytorch]Pytorch 保存模型与加载模型(转)

Pytorch 模型 查看网络参数的梯度以及参数更新是否正确,优化器学习率的分层设置

pytorch中的神经网络子模块(线性模块)——torch.nn.Linear

Pytorch中的自动求导函数backward()所需参数含义