Pysyft学习笔记二:伪分布式模型训练的实现
Posted 一只特立独行的猫
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pysyft学习笔记二:伪分布式模型训练的实现相关的知识,希望对你有一定的参考价值。
不熟悉send和get机制的小伙伴可以看一下我的上一篇博客:Pysyft学习笔记一:dome思路,然后再看这篇博客,效果会更好哦。
导入基本库
import torch
#分布式训练
import syft as sy
#用于搭建神经网络
from torch import nn
#用于构造优化器
from torch import optim
#在torch上使用hook技术
hook = sy.TorchHook(torch)
创建客户机
id可以理解为客户机的名称
Bob = sy.VirtualWorker(hook,id='Bob')
Alice = sy.VirtualWorker(hook,id='Alice')
建立测试集与验证集
自己想了几个数据集,但是训练了都不收敛,所以还是采用参考博客的数据集了。
data = torch.tensor([[0,1],[0,1], [1,0], [1,1.]],requires_grad=True)
targe = torch.tensor([[0],[0],[1], [1.]],requires_grad=True)
建立全连接模型
模型建立好以后,因为采用分布式训练的方法,所以需要将数据传输给客户端。当然,在实际使用用这一步肯定不需要,因为对整个架构而言,数据只对拥有者可见。所以本地一定已经存储好了数据。
#全连接模型,输入是二维,输出是一维,可以理解为读两个数,输出一个数
model = nn.Linear(2,1)
data_Bob_ptr = data[:2].send(Bob)
targe_Bob_ptr = targe[:2].send(Bob)
data_Alice_ptr = data[2:].send(Alice)
targe_Alice_ptr = data[2:].send(Alice)
#定义训练模块,后面会对每个客户进行计算
datasets = [(data_Bob_ptr,targe_Bob_ptr),(data_Alice_ptr,targe_Alice_ptr)]
定义训练模块
def train():
#定义优化器
opt = optim.SGD(params=model.parameters(),lr=0.1)
for epoch in range(10):
#分发-收集50次模型,相当于迭代训练50次
for data, targe in datasets:
model.send(data.location)
#梯度清0,如果没有清零,则会进行累加,导致结果错误
opt.zero_grad()
pred = model(data)
#sum((pred - targe)**2)
loss = ((pred - targe)**2).sum()
#反向传播梯度计算
loss.backward()
#更新参数
opt.step()
model.get()
print(loss.get().data)
唯一与传统训练不同的是,这里涉及到了send和get函数,即需要先将模型传输给客户机,然后再从客户机这里回收模型。因为模拟的时候非并行计算,为了简化表述,这里不涉及到FL中核心的加密算法与分布式模型合并算法,并且训练也是串行训练。并行训练参考我的下一篇博客。
训练效果
train()
print(model(torch.tensor([[0,1],[0,1], [1,0], [1,1.]])).data)
数据随便想的,模型也比较简单,所以训练的效果一般
以上是关于Pysyft学习笔记二:伪分布式模型训练的实现的主要内容,如果未能解决你的问题,请参考以下文章
Pysyft学习笔记三:分布式模型实现(Fed_avg算法整合模型)
Pysyft学习笔记三:分布式模型实现(Fed_avg算法整合模型)
Pysyft学习笔记四:MINIST数据集下的联邦学习(并行训练与非并行训练)
Pysyft学习笔记四:MINIST数据集下的联邦学习(并行训练与非并行训练)