Federated Meta-Learning with Fast Convergence and Efficient Communication 论文阅读笔记+关键代码解读
Posted 编程龙
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Federated Meta-Learning with Fast Convergence and Efficient Communication 论文阅读笔记+关键代码解读相关的知识,希望对你有一定的参考价值。
一. 介绍
联邦学习中数据是非独立同分布的,基于FedAvg算法成功后,作者发现元学习算法MAML应对客户端上数据量较少,数据分布不均的场景提出了FedMeta框架,作为连接元学习方法和联邦学习的桥梁。在元学习中,参数化算法通过元训练过程从大量任务中慢慢学习,在元训练过程中,算法在每个任务中快速训练特定的模型。任务由互不关联的支持集和查询集组成。在支持集上训练特定的模型,然后在查询集上进行测试,测试结果用于更新算法。对于FedMeta来说,算法在服务器上维护并分发给客户端进行训练。训练之后,查询集上的测试结果被上传到服务器进行算法更新。
二. 算法介绍
首先我们定义一下
D
S
T
:
s
u
p
p
o
r
t
s
e
t
D_S^T:support\\ set
DST:support set
D
Q
T
:
q
u
e
r
y
s
e
t
D_Q^T:query\\ set
DQT:query set
A
:
元
学
习
算
法
A:元学习算法
A:元学习算法
ϕ
:
元
学
习
参
数
\\phi:元学习参数
ϕ:元学习参数
θ
T
:
模
型
参
数
\\theta_T:模型参数
θT:模型参数
根据元学习思想,我们首先通过
D
S
T
D_S^T
DST训练A上的模型f,经过更新输出模型参数
θ
T
\\theta_T
θT,这一步叫做inner update(内部更新)。之后训练出来的
θ
T
\\theta_T
θT通过我们的query set
D
Q
T
D_Q^T
DQT进行评估,计算出测试的损失
L
D
Q
T
(
θ
T
)
L_D_Q^T(\\theta_T)
LDQT(θT),通过损失我们可以反映出我们的算法
A
ϕ
A_\\phi
Aϕ上的训练能力,最后我们根据这个测试损失去最小化更新我们的参数
ϕ
\\phi
ϕ,这一步叫outer update(外部更新)。这些过程用数据表达就是:我们的算法
A
ϕ
A_\\phi
Aϕ通过优化下面目标:
min ϕ E T [ L D Q T ( θ T ) ] = min ϕ E T [ L D Q T ( A ϕ ( D S T ) ) ] \\min_\\phi E_T[L_D_Q^T(\\theta_T)]=\\min_\\phi E_T[L_D_Q^T(A_\\phi (D_S^T))] ϕminET[LDQT(θT)]=ϕminET[LDQT(Aϕ(DST))]
如果以maml来看的话,在一开始我们出事参数 ϕ = θ \\phi=\\theta ϕ=θ,然后通过 D S T D_S^T DST训练更新(几步梯度下降) L D S T ( θ ) = 1 ∣ D S T ∣ ∑ ( x , y ) l ( f θ ( x ) , y ) L_D_S^T(\\theta)=\\frac1|D_S^T|\\sum_(x,y)l(f_\\theta(x),y) LDST(θ)=∣DST∣1∑(x,y)l(fθ(x),y)使得 θ = θ T \\theta = \\theta_T θ=θT,之后,将 f θ T f_\\theta_T fθT在 D Q T D_Q^T DQT进行测试,获得测试损失函数 L D S T ( θ ) = 1 ∣ D Q T ∣ ∑ ( x ′ , y ′ ) l ( f θ T ( x ′ ) , y ′ ) L_D_S^T(\\theta)=\\frac1|D_Q^T|\\sum_(x',y')l(f_\\theta_T(x'),y') LDST(θ)=∣DQT∣1∑(x′,y′)l(fθT(x′),y′)。定义好值周上面的最小化目标就可以改变为:
min ϕ E T [ L D Q T ( θ − α ∇ L D S T ( θ ) ) ] \\min_\\phi E_T[L_D_Q^T(\\theta\\ -\\ \\alpha\\nabla L_D_S^T(\\theta))] ϕminET[LDQT(θ − α∇LDST(θ))]。
到这里,meta的部分结束,之后就是联邦学习部分。怎么结合起来呢?作者想到每一个客户端在query set测试完之后,获取到测试的损失,同时根据这个损失计算出对应的梯度,将这个梯度传到服务端,服务端平均梯度后,根据这个梯度更新服务端的参数,最后再把参数传回到客户端,也就是客户端进行inner update和outer update(只进行梯度计算),服务端进行outer update(合并梯度更新)。
算法过程如图所示
这里对maml以及meta learning还有不太清楚,以及query set和support set有疑问的可以看我之前的博客点这里。
四. 代码讲解
本次算法的github地址点这里,代码中很大一部分是实现客户端服务端的交互,这里就不详细说,重点讲解客户端训练过程和服务端的更新过程。
首先我们来看客户端的训练(对应inner update)
for batch_idx, (x, y) in enumerate(support_data_loader):
x, y = x.to(self.device), y.to(self.device)
num_sample = y.size(0)
pred = self.model(x)
loss = self.criterion(pred, y)
# 评估
correct = self.count_correct(pred, y)
# 写入相关的记录, 这份 loss 是平均的
support_loss.append(loss.item())
support_correct.append(correct)
support_num_sample.append(num_sample)
# 计算 loss 关于当前参数的导数, 并更新目前网络的参数(回传到 model)
loss_sum += loss * num_sample
grads = torch.autograd.grad(loss_sum / sum(support_num_sample), listMeta-learning原来有这么多用途,一文汇总元学习在5个问题中的应用
Meta-Learning: Learning to Learn Fast
MAML:Model-Agnostic Meta-Learning