联邦学习代码解读,超详细
Posted 一只揪°
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了联邦学习代码解读,超详细相关的知识,希望对你有一定的参考价值。
参考文献:
[1]Brendan McMahan, H., Moore, E., Ramage, D., Hampson, S., and Agüera y Arcas, B., “Communication-Efficient Learning of Deep Networks from Decentralized Data”, arXiv e-prints, 2016.
参考代码:
https://github.com/AshwinRJ/Federated-Learning-PyTorch
用Pytorch开发项目的时候,常常将项目代码分为数据处理模块、模型构建模块与训练控制模块。
联邦学习伪代码
主函数federated_main.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm
import torch
from tensorboardX import SummaryWriter
from options import args_parser
from update import LocalUpdate, test_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
from utils import get_dataset, average_weights, exp_details
if __name__ == '__main__':
start_time = time.time()
# define paths
path_project = os.path.abspath('..')
logger = SummaryWriter('../logs')
args = args_parser()#命令行输入
exp_details(args) #展示参数细节
if args.gpu: #是否使用gpu
torch.cuda.set_device(args.gpu)
device = 'cuda' if args.gpu else 'cpu'
# 加载数据集和用户组
train_dataset, test_dataset, user_groups = get_dataset(args)
# 建立模型
if args.model == 'cnn':
# Convolutional neural netork
if args.dataset == 'mnist':
global_model = CNNMnist(args=args)
elif args.dataset == 'fmnist':
global_model = CNNFashion_Mnist(args=args)
elif args.dataset == 'cifar':
global_model = CNNCifar(args=args)
elif args.model == 'mlp':
# 多层感知器
img_size = train_dataset[0][0].shape
len_in = 1
for x in img_size:
len_in *= x
global_model = MLP(dim_in=len_in, dim_hidden=64,
dim_out=args.num_classes)
else:
exit('Error: unrecognized model')
#设置模型进行训练,并且将其传输给设备
global_model.to(device)
global_model.train()
print(global_model)
# 复制权重
global_weights = global_model.state_dict()
# 训练
train_loss, train_accuracy = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
for epoch in tqdm(range(args.epochs)): #Tqdm 是一个快速,可扩展的Python进度条
local_weights, local_losses = [], []
print(f'\\n | Global Training Round : epoch+1 |\\n')
global_model.train()
m = max(int(args.frac * args.num_users), 1)
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
for idx in idxs_users:
local_model = LocalUpdate(args=args, dataset=train_dataset,
idxs=user_groups[idx], logger=logger)
w, loss = local_model.update_weights(
model=copy.deepcopy(global_model), global_round=epoch)
local_weights.append(copy.deepcopy(w))
local_losses.append(copy.deepcopy(loss))
# 更新全局权重
global_weights = average_weights(local_weights)
#更新全局模型的权重
global_model.load_state_dict(global_weights)
loss_avg = sum(local_losses) / len(local_losses)
train_loss.append(loss_avg)
#通过计算所有用户在每一个回合中的平均训练精度进行计算
list_acc, list_loss = [], []
global_model.eval()
for c in range(args.num_users):
local_model = LocalUpdate(args=args, dataset=train_dataset,
idxs=user_groups[idx], logger=logger)
acc, loss = local_model.inference(model=global_model)
list_acc.append(acc)
list_loss.append(loss)
train_accuracy.append(sum(list_acc)/len(list_acc))
# 在每'i'轮之后打印出全局训练损失
if (epoch+1) % print_every == 0:
print(f' \\nAvg Training Stats after epoch+1 global rounds:')
print(f'Training Loss : np.mean(np.array(train_loss))')
print('Train Accuracy: :.2f% \\n'.format(100*train_accuracy[-1]))
# 在训练后测试验证
test_acc, test_loss = test_inference(args, global_model, test_dataset)
print(f' \\n Results after args.epochs global rounds of training:')
print("|---- Avg Train Accuracy: :.2f%".format(100*train_accuracy[-1]))
print("|---- Test Accuracy: :.2f%".format(100*test_acc))
# 保存目标训练损失和训练精度
file_name = '../save/objects/___C[]_iid[]_E[]_B[].pkl'.\\
format(args.dataset, args.model, args.epochs, args.frac, args.iid,
args.local_ep, args.local_bs)
with open(file_name, 'wb') as f:
pickle.dump([train_loss, train_accuracy], f)
print('\\n Total Run Time: 0:0.4f'.format(time.time()-start_time))
#画图
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
#绘制损失曲线
plt.figure()
plt.title('训练损失 vs 通信回合数')
plt.plot(range(len(train_loss)), train_loss, color='r')
plt.ylabel('训练损失')
plt.xlabel('通信回合数')
plt.savefig('../save/fed____C[]_iid[]_E[]_B[]_loss.png'.
format(args.dataset, args.model, args.epochs, args.frac,
args.iid, args.local_ep, args.local_bs))
# Plot 平均准确性 vs 通信回合数
plt.figure()
plt.title('平均准确性 vs 通信回合数')
plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
plt.ylabel('平均准确性')
plt.xlabel('通信回合数')
plt.savefig('../save/fed____C[]_iid[]_E[]_B[]_acc.png'.
format(args.dataset, args.model, args.epochs, args.frac,
args.iid, args.local_ep, args.local_bs))
参数设置option.py
在代码文件夹option.py
的args_parser()
当中,其中定义了包括
全局回合数、用户数量K、用户选取比例C、本地训回合数E、本地批量大小B、学习速率、SGD的动量为0.5(???)
模型参数:模型、核数量、核大小、通道数、归一化、过滤器数量、最大池化
以及一些其他参数和默认值
参数更新update.py
数据处理模块的主要任务:构建数据集。为方便深度学习项目构建数据集,Pytorch为我们提供了Dataset类。
构建数据集class DatasetSplit(Dataset)
先来看看Dataset类的官方解释:Dataset可以是任何东西,但它始终包含一个__len__函数(通过Python中的标准函数len调用)和一个用来索引到内容中的__getitem__函数。PyTorch官方中文文档
以下参考PyTorch如何构建数据集呢?
构建数据集前:明确需要哪些输入数据、训练时需要哪些数据
比如:现有1元和100元图像样本,分别放在两个文件中。我们的输入为图像数据,除了图像数据,还需要与图像数据相对应的类别标签,来计算损失loss。
明确了需要构建什么数据后,下一步就是通过继承Pytorch的dataset类来编写自己的dataset类。
定义了类DatasetSplit(Dataset)
重构了Pytorch的类Dataset
class DatasetSplit(Dataset): #使用dataset重构
"An abstract Dataset class warpped around Pytorch Dataset class."
def __init__(self, dataset, idx):
self.dataset = dataset
self.idx = [int(i) for i in idx]
def __len__(self):
return len(self.idx)
def __getitem__(self, item):
image, label = self.dataset[self.idx[item]]
return torch.tensor(image), torch.tensor(label) #torch.tensor() #转换为张量形式,且会拷贝data
上面代码中,重写了__len__(self)
方法。比较简单,返回数据列表长度,即数据集的样本数量。
在__getitem__(self, item)
方法中,通过dataset
读取图像数据,最后返回下标为item的图像数据和标签的张量。
这里返回哪些数据主要是由训练代码中需要哪些数据来决定。也就是说,我们根据训练代码需要什么数据来重写__getitem__(self, index)
方法并返回相应的数据。
本地更新模型构建模块calss LocalUpdate(object)
初始化 __init__(self, args, dataset, idxs, logger)
:
def __init__(self, args, dataset, idxs, logger):
self.args = args
self.logger = logger
self.trainloader, self.validloader, self.testloader = self.train_val_test(
dataset, list(idxs) ) #根据train_val_test划分训练、验证和测试数据集
self.device = 'cuda' if args.gpu else 'cpu' #若args.gpu为true则在cuda上运行程序
# Default criterion set to NLL loss function
self.criterion = nn.NLLLoss().to(self.device) #交叉熵损失函数,用于描述系统的混乱程度,值越小,与真实样本越接近
数据集以及索引划分train_val_test(self, dataset, idxs)
输入:数据集、索引
输出:给定数据集的训练、验证、测试的记录器 用户索引
def train_val_test(self, dataset, idxs):
"""
return train, validation and test datalodaers for a given dataset and user indexes
"""
#split indexes for train, validation, and test(80,10,10)
idxs_train = idxs[:int(0.8*len(idxs))]
idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
idxs_test = idxs[int(0.9 * len(idxs)):]
trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
batch_size=self.args.local_bs, shuffle=True)
validloader = DataLoader(DatasetSplit(dataset, idxs_val),
batch_size=int(len(idxs_val) / 10), shuffle=False)
testloader = DataLoader(DatasetSplit(dataset, idxs_test),
batch_size=int(len(idxs_test) / 10), shuffle=False)
return trainloader, validloader, testloader
本地权重更新upadate_weights(self, model, global_round)
输入:模型、全局更新回合数
输出:更新后的权重 、损失平均值
损失函数使用方法
1 optimizer = optim.SGD(model.parameters())
2 fot epoch in range(num_epoches):
3 train_loss=0
4 for step,(seq, label) in enumerate(data_loader):
5 # 损失函数
6 loss = criterion(model(seq), label.to(device))
7 # 将梯度清零
8 opimizer.zero_grad()
9 # 损失函数backward进行反向传播梯度的计算
10 loss.backward()
11 train_loss += loss.item()
12 # 使用优化器的step函数来更新参数
13 optimizer.step()
def upadate_weights(self, model, global_round):
#Set mode to train model
model.train()
epochs_loss=[]
#set optimizer for the local updates
if self.args.optimizer == 'sgd':
optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
momentum=0.5) #使用SGD作为优化器
elif self.args.optimizer == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr,
weight_decay=1e-4) #使用Adam作为优化器
for iter in range(self.args.local_ep):
batch_loss = []
for batch_idx, (images, labels) in enumerate(self.trainloader):
images, labels = images.to(self.device), labels.to(self.device)
model.zero_grad()
log_probs = model(images)
loss = self.criterion(log_probs, labels)
loss.backward()
optimizer.strep()
if self.args.verbose and (batch_idx %10 == 0):
print('| Global Round : | Local Epoch : | [/ (:.0f%)]\\tLoss: :.6f'.format(
global_round, iter, batch_idx * len(images),
len(self.trainloader.dataset),
100. * batch_idx / len(self.trainloader), loss.item()))
self.logger.add_scalar('loss', loss.item())
batch_loss.append(loss.item())
epochs_loss.append(sum(batch_loss) / len(batch_loss))
return model.state_dict(), sum(epochs_loss) / len(epochs_loss)
计算准确值以及损失值inference(self, model)
def inference(self, model):
" return the inference accuracy and loss"
model.eval() #不改变权值样本训练
loss, total, correct = 0.0, 0.0, 0.0
for batch_idx, (images, labels) in enumerate(self.testloader):
images, labels = images.to(self.device), labels.to(self.device)
#inference
outputs = model(images)
batch_loss = self.criterion(outputs, labels)
loss += batch_loss.item()
#prediction
_, pred_labels = torch.max(outputs, 1) #返回输入tensor中所有元素的最大值
pred_labels = pred_labels.view(-1) #view函数的作用为重构张量的维度,相当于numpy中resize()的功能
correct += torch.sum(torch.eq(pred_labels, labels)).item()
total += len(labels)
accuracy = correct/total
return accuracy, loss
应用集(获取数据集、权重取平均、展示细节)utils.py
获取数据集get_dataset(args)
输入:命令行参数
输出:用于训练和测试的数据集和用户组,其中键是索引,值是每个用户的相应数据
def get_dataset(args):
"返回训练和测试数据集和用户组,用户用户组是字典,其中键是索引,值是每个用户的相应数据。"
if args.dataset == 'cifar':
data_dir = '../data/cifar/'
apply_tramsform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#transforms.Compose()把多个步骤融合到一起
#ToTensor()能够把灰度范围从0-255变换到0-1之间
#而后面的transform.Normalize()则把0-1变换到(-1,1)
train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
transform=apply_tramsform)
test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
transform=apply_tramsform)
#在用户中采集训练数据
if args.iid:
#从Mnist中采集IID用户数据
user_groups = cifar_iid(train_dataset, args.num_users)
else:
#从Mnist中采集Non-IID用户数据
if args.unequal:
#每个用户选择不平等划分
raise NotImplementedError()
else:
#每个用户选择不平等划分
user_groups = cifar_noniid(train_dataset, args.num_users)
elif args.dataset == 'mnist' or SENet代码复现+超详细注释(PyTorch)