Pysyft学习笔记二:伪分布式模型训练的实现

Posted 一只特立独行的猫

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pysyft学习笔记二:伪分布式模型训练的实现相关的知识,希望对你有一定的参考价值。

不熟悉send和get机制的小伙伴可以看一下我的上一篇博客:Pysyft学习笔记一:dome思路,然后再看这篇博客,效果会更好哦。

导入基本库

import torch 
#分布式训练
import syft as sy

#用于搭建神经网络
from torch import nn
#用于构造优化器
from torch import optim

#在torch上使用hook技术
hook = sy.TorchHook(torch)

创建客户机

id可以理解为客户机的名称

Bob = sy.VirtualWorker(hook,id='Bob')
Alice = sy.VirtualWorker(hook,id='Alice')

建立测试集与验证集

自己想了几个数据集,但是训练了都不收敛,所以还是采用参考博客的数据集了。

data = torch.tensor([[0,1],[0,1], [1,0], [1,1.]],requires_grad=True)
targe = torch.tensor([[0],[0],[1], [1.]],requires_grad=True)

建立全连接模型

模型建立好以后,因为采用分布式训练的方法,所以需要将数据传输给客户端。当然,在实际使用用这一步肯定不需要,因为对整个架构而言,数据只对拥有者可见。所以本地一定已经存储好了数据。

#全连接模型,输入是二维,输出是一维,可以理解为读两个数,输出一个数
model = nn.Linear(2,1)

data_Bob_ptr = data[:2].send(Bob)
targe_Bob_ptr = targe[:2].send(Bob)
data_Alice_ptr = data[2:].send(Alice)
targe_Alice_ptr = data[2:].send(Alice)

#定义训练模块,后面会对每个客户进行计算
datasets = [(data_Bob_ptr,targe_Bob_ptr),(data_Alice_ptr,targe_Alice_ptr)]

定义训练模块

def train():
    #定义优化器
    opt = optim.SGD(params=model.parameters(),lr=0.1)
    for epoch in range(10):
        #分发-收集50次模型,相当于迭代训练50次
        for data, targe in datasets:
            model.send(data.location)
            #梯度清0,如果没有清零,则会进行累加,导致结果错误
            opt.zero_grad()
            pred = model(data)
            #sum((pred - targe)**2)
            loss = ((pred - targe)**2).sum()
            #反向传播梯度计算
            loss.backward()
            #更新参数
            opt.step()
            model.get()
            print(loss.get().data)

唯一与传统训练不同的是,这里涉及到了send和get函数,即需要先将模型传输给客户机,然后再从客户机这里回收模型。因为模拟的时候非并行计算,为了简化表述,这里不涉及到FL中核心的加密算法与分布式模型合并算法,并且训练也是串行训练。并行训练参考我的下一篇博客。

训练效果

train()
print(model(torch.tensor([[0,1],[0,1], [1,0], [1,1.]])).data)

数据随便想的,模型也比较简单,所以训练的效果一般

参考博客https://zhuanlan.zhihu.com/p/391114733

以上是关于Pysyft学习笔记二:伪分布式模型训练的实现的主要内容,如果未能解决你的问题,请参考以下文章

Pysyft学习笔记三:分布式模型实现(Fed_avg算法整合模型)

Pysyft学习笔记三:分布式模型实现(Fed_avg算法整合模型)

Pysyft学习笔记四:MINIST数据集下的联邦学习(并行训练与非并行训练)

Pysyft学习笔记四:MINIST数据集下的联邦学习(并行训练与非并行训练)

.NET 云原生架构师训练营(基于 OP Storming 和 Actor 的大型分布式架构二)--学习笔记...

Pytorch模型训练实用教程学习笔记:二模型的构建