Federated Meta-Learning with Fast Convergence and Efficient Communication 论文阅读笔记+关键代码解读

Posted 编程龙

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Federated Meta-Learning with Fast Convergence and Efficient Communication 论文阅读笔记+关键代码解读相关的知识,希望对你有一定的参考价值。

论文地址点这里

一. 介绍

联邦学习中数据是非独立同分布的,基于FedAvg算法成功后,作者发现元学习算法MAML应对客户端上数据量较少,数据分布不均的场景提出了FedMeta框架,作为连接元学习方法和联邦学习的桥梁。在元学习中,参数化算法通过元训练过程从大量任务中慢慢学习,在元训练过程中,算法在每个任务中快速训练特定的模型。任务由互不关联的支持集和查询集组成。在支持集上训练特定的模型,然后在查询集上进行测试,测试结果用于更新算法。对于FedMeta来说,算法在服务器上维护并分发给客户端进行训练。训练之后,查询集上的测试结果被上传到服务器进行算法更新。

二. 算法介绍

首先我们定义一下
D S T : s u p p o r t   s e t D_S^T:support\\ set DST:support set
D Q T : q u e r y   s e t D_Q^T:query\\ set DQT:query set
A : 元 学 习 算 法 A:元学习算法 A:
ϕ : 元 学 习 参 数 \\phi:元学习参数 ϕ:
θ T : 模 型 参 数 \\theta_T:模型参数 θT
根据元学习思想,我们首先通过 D S T D_S^T DST训练A上的模型f,经过更新输出模型参数 θ T \\theta_T θT,这一步叫做inner update(内部更新)。之后训练出来的 θ T \\theta_T θT通过我们的query set D Q T D_Q^T DQT进行评估,计算出测试的损失 L D Q T ( θ T ) L_D_Q^T(\\theta_T) LDQT(θT),通过损失我们可以反映出我们的算法 A ϕ A_\\phi Aϕ上的训练能力,最后我们根据这个测试损失去最小化更新我们的参数 ϕ \\phi ϕ,这一步叫outer update(外部更新)。这些过程用数据表达就是:我们的算法 A ϕ A_\\phi Aϕ通过优化下面目标:

min ⁡ ϕ E T [ L D Q T ( θ T ) ] = min ⁡ ϕ E T [ L D Q T ( A ϕ ( D S T ) ) ] \\min_\\phi E_T[L_D_Q^T(\\theta_T)]=\\min_\\phi E_T[L_D_Q^T(A_\\phi (D_S^T))] ϕminET[LDQT(θT)]=ϕminET[LDQT(Aϕ(DST))]

如果以maml来看的话,在一开始我们出事参数 ϕ = θ \\phi=\\theta ϕ=θ,然后通过 D S T D_S^T DST训练更新(几步梯度下降) L D S T ( θ ) = 1 ∣ D S T ∣ ∑ ( x , y ) l ( f θ ( x ) , y ) L_D_S^T(\\theta)=\\frac1|D_S^T|\\sum_(x,y)l(f_\\theta(x),y) LDST(θ)=DST1(x,y)l(fθ(x),y)使得 θ = θ T \\theta = \\theta_T θ=θT,之后,将 f θ T f_\\theta_T fθT D Q T D_Q^T DQT进行测试,获得测试损失函数 L D S T ( θ ) = 1 ∣ D Q T ∣ ∑ ( x ′ , y ′ ) l ( f θ T ( x ′ ) , y ′ ) L_D_S^T(\\theta)=\\frac1|D_Q^T|\\sum_(x',y')l(f_\\theta_T(x'),y') LDST(θ)=DQT1(x,y)l(fθT(x),y)。定义好值周上面的最小化目标就可以改变为:

min ⁡ ϕ E T [ L D Q T ( θ   −   α ∇ L D S T ( θ ) ) ] \\min_\\phi E_T[L_D_Q^T(\\theta\\ -\\ \\alpha\\nabla L_D_S^T(\\theta))] ϕminET[LDQT(θ  αLDST(θ))]

到这里,meta的部分结束,之后就是联邦学习部分。怎么结合起来呢?作者想到每一个客户端在query set测试完之后,获取到测试的损失,同时根据这个损失计算出对应的梯度,将这个梯度传到服务端,服务端平均梯度后,根据这个梯度更新服务端的参数,最后再把参数传回到客户端,也就是客户端进行inner update和outer update(只进行梯度计算),服务端进行outer update(合并梯度更新)。
算法过程如图所示

这里对maml以及meta learning还有不太清楚,以及query set和support set有疑问的可以看我之前的博客点这里

四. 代码讲解

本次算法的github地址点这里,代码中很大一部分是实现客户端服务端的交互,这里就不详细说,重点讲解客户端训练过程和服务端的更新过程。
首先我们来看客户端的训练(对应inner update)

for batch_idx, (x, y) in enumerate(support_data_loader):
    x, y = x.to(self.device), y.to(self.device)
    num_sample = y.size(0)
    pred = self.model(x)
    loss = self.criterion(pred, y)
    # 评估
    correct = self.count_correct(pred, y)
    # 写入相关的记录, 这份 loss 是平均的
    support_loss.append(loss.item())
    support_correct.append(correct)
    support_num_sample.append(num_sample)
    # 计算 loss 关于当前参数的导数, 并更新目前网络的参数(回传到 model)
    loss_sum += loss * num_sample
grads = torch.autograd.grad(loss_sum / sum(support_num_sample), listMeta-learning原来有这么多用途,一文汇总元学习在5个问题中的应用

Meta-Learning: Learning to Learn Fast

元学习Meta-learning与MAML

MAML:Model-Agnostic Meta-Learning

[notes] model-agnostic meta-learning

卡耐基梅隆大学(CMU)元学习和元强化学习课程 | Elements of Meta-Learning