模仿学习笔记:生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL
Posted UQI-LIUWJ
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了模仿学习笔记:生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL相关的知识,希望对你有一定的参考价值。
1 GAN (回顾)
GAIL 的设计基于生成判别网络 ( GAN)。这里简单地回顾一下GAN,详细的可见 NTU 课程笔记 7454 GAN_UQI-LIUWJ的博客-CSDN博客 GAN由生成器 (Generator) 和判别器 (Discriminator)组成,它们 各是一个神经网络。 ——>生成器负责生成假的样本 ——>判别器负责判定一个样本是真是假。 我们的目标是希望生成器生成的内容可以“以假乱真”1.1 生成器
生成器 记作 a = G ( s ; θ ) ,其中 θ 是参数。它的输入是向量 s ,向量的每一个元素从均匀分布或标准正态分布 N (0 , 1) 中抽取。生成器的输出是数据(比如图片)x。
1.2 判别器
判别器 记作,其中 ϕ 是参数。 它的输入是图片 x;输出 是介于 0 到 1 之间的概率值,0 表示“假的”, 1 表示“真的”。 判别器的功能是二分类器。1.3 训练生成器
将生成器与判别器相连,固定住判别器的参数,只更新生成器的参数 θ,使得生成的图片 x = G(s; θ) 在判别器的眼里更像真的。
对于任意一个随机生成的向量 s,应该改变 θ,使得判别器的输出尽量接近 1
可以用如下函数作为loss function:
我们希望此时D(x;Φ)越大越好,也就是E(s;θ)越小越好
所以我们用梯度下降来更新生成器的θ
1.4 训练判别器
- 判别器的本质是个二分类器,它的输出值 表示对图片真伪的预测;
- 接近 1 表示“真”,
- 接近 0 表示“假”。
- 从真实数据集中抽取一个样本,记作。
- 再随机生成一个向量 s,用生成器生成
- 训练判别器的目标是改进参数 ϕ,让 更接近 1(真),让更接近 0 (假)。
- ——>也就是说让判别器的分类结果更准确,更好区分真实图片和生成的假图片。
此时的损失函数如下所示
不难发现,判别器越准确,损失函数F越小
所以我们也用梯度下降更新判别器的θ
1.5 整体训练流程
2 生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL
2.1 训练数据
GAIL 的训练数据是被模仿的对象(人类专家)操作智能体得到的轨迹数据集中有 k 条轨迹,把数据集记作:
2.2 生成器
GAIL 的生成器是策略网络 π ( a | s ; θ ) 策略网络的输入是状态 s,输出是一个向量:输出向量 f 的维度是动作空间的大小 A ,它的每个元素对应一个动作,表示执行该动作 的概率。
给定初始状态 s 1 ,并让智能体与环境交互,可以得到一条轨迹:
其中动作是根据策略网络抽样得到的,
下一时刻的状态是环境根据状态转移函数计算出来的
2.3 判别器
GAIL 的判别器记作 D ( s, a ; ϕ ) 判别器的输入是状态 s,输出是一个向量:输出向量 的维度是动作空间的大小 A ,它的每个元素对应一个动作 a ,把一个元素记作:
接近 1 表示 ( s, a ) 为“真”,即动作 a 是人类专家做的。 接近 0 表示 ( s, a ) 为“假”,即动作 a 是策略网络生成的。
2.4 GAIL的训练
2.4.1 训练生成器
设是当前策略网络的参数。用策略网络控制智能体与环境交互,得到一条轨迹: 用判别器评价 的真实情况, 越大,说明 在判别器的眼里越真实。 我们记第t步的回报为:
于是我们的轨迹可以变成
有不同的方法来更新策略网络的参数θ
在GAIL中,使用的是TRPO
强化学习笔记:置信域策略优化 TRPO_UQI-LIUWJ的博客-CSDN博客
即目标函数为
通过解带约束的最大化问题,得到新的参数
2.4.2 训练判别器
训练判别器的目的是让它能区分真的轨迹与生成的轨迹 我们从训练数据中抽样一条轨迹:同时用策略网络控制智能体和环境交互,得到另一条轨迹,记作
注意real和fake轨迹的长度可能不一样
同样地,我们希望尽量趋近于1,尽量趋近于0
于是我们定义损失函数
我们希望损失函数尽量小,也就是说判别器能区分开真假轨迹。可以做梯度下降来更新判别器的参数Φ
2.4.3 整体训练流程
每一轮训练更新一个生成器,更新一次判别器。训练重复以下步骤,直 到收敛。
以上是关于模仿学习笔记:生成判别模仿学习 Generative Adversarial Imitation Learning, GAIL的主要内容,如果未能解决你的问题,请参考以下文章
论文精读:Generative Adversarial Imitation Learning(生成对抗模仿学习)
论文精读:Generative Adversarial Imitation Learning(生成对抗模仿学习)