Unity强化学习之ML-Agents的使用
Posted 微笑小星
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Unity强化学习之ML-Agents的使用相关的知识,希望对你有一定的参考价值。
Github下载链接:https://github.com/Unity-Technologies/ml-agents
参考视频:年轻人的第一个游戏AI:Unity强化学习工具MLAgents全流程实例教程、unity人工智能工具mlagents最新版教程
ML-Agents是游戏引擎Unity3D中的一个插件,也就是说,这个软件的主业是用来开发游戏的,实际上,它也是市面上用得最多的游戏引擎之一。而在几年前随着人工智能的兴起,强化学习算法的不断改进,使得越来越多的强化学习环境被开发出来,例如总所周知的OpenAI的Gym,同时还有许多实验室都采用的星际争霸2环境来进行多智能体强化学习的研究。那么,我们自然想到,可不可以开发属于自己的强化学习环境来实现自己的算法,实际上,作为一款备受欢迎的游戏引擎,Unity3D很早就有了这么一个想法。
详情见论文:Unity: A General Platform for Intelligent Agents
作为一个对游戏开发和强化学习都非常感兴趣的人,自然也了解到了这款插件,使得能够自己创造一个独一无二,与众不同,又充满智慧的游戏AI成为可能。
在Github的官方包中有更为详细的英文指导,如果不想翻译,也可以跟着我下面的步骤走。
文件说明
ML-Agents工具包包含几个部分:
- Unity包com.unity.ml-agents包含了集成到Unity 项目中的 Unity C# SDK。,可以帮助你使用 ML-Agents 的Demo。
- Unity包ml-agents.extensions依赖于com.unity.ml-agents,包含的是实验性组件,还未成为com.unity.ml-agents的一部分。
- 三个Python包:mlagents包含机器学习算法可以让你训练智能体,大多数ML-Agents的用户值需要直接安装这个文件。mlagents_envs包含了Python的API可以让其与Unity场景进行交互,这使得Python机器学习算法和Unity场景间的管理变得便利,mlagents依赖于mlagents_envs。gym_unity提供了Unity场景支持OpenAI Gym接口的封装。
- Project文件夹包含了几个示例环境,展示了ML-Agents的几个特点,可以帮助你快速上手。
需要的环境
- 安装高于Unity2019.4的版本
- 安装高于Python3.6.1的版本
- 克隆Github仓库(可选)
- 安装com.unity.ml-agents这个Unity包
- 安装com.unity.ml-agents.extensions这个Unity包(可选)
- 安装mlagents这个Python包
环境配置
在正式开始我们的项目之前是比较痛苦的配置环境的环节了。
-
首先通过上面的链接从Github上下载ML-Agents的官方包。**切记路径中不能有中文!**否则无法正常训练!
-
安装anaconda。
-
在Anaconda Prompt中创建一个专门用于ML-Agents的Python环境。
# 查看所有环境 conda-env list # 安装新环境 conda create -n ml-agents python=3.6 # 激活新环境 activete ml-agent
-
解压从Github上下载的文件,分别cd到ml-agents和ml-agents-envs两个目录,执行以下命令安装两个包:
pip install .
-
执行以下命令看看安装是否成功
mlagents-learn --help
如果安装不成功,出现Error,那么可能是因为新配的环境没有装Pytorch,此时要去官网复制需要版本的pip命令进行安装。注意在安装pytorch的时候直接使用官方的命令下载速度极慢(甚至多次中断),因此我们需要配置清华源:
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ conda config --set show_channel_urls yes#下载时显示文件来源
并且在使用官方的命令时去掉结尾的-c pytorch,因为这个指令是强制在官网下载,这样下载速度会有大幅提升。
-
下一步打开Unity编辑器,创建一个3D工程,点击windows菜单下的Package Manager,点击左上角的加号,选择Add package from disk,然后弹出的窗口中打开com.unity.ml-agents中的package.json文件,即可把包添加到Unity中,com.unity.ml-agents.extensions同理。安装完之后,就可以在下方文件中的Packages看到ML-Agents以及ML-Agent Extensions。
小试牛刀
首先把官方包下的Project用Unity编辑器打开,进入到ML-Agents -->Examples目录下,里面的全部都是ML-Agents各种实现的示例,包含了ML-Agents的主要功能展示,我们打开第一个项目3DBall,点击场景Scenes中的第一个:
现在可以看到画面了,这是一个相当简单的初级项目,我们只需要训练一个智能体顶着小球不让它落下即可。
我们打开Anaconda Prompt,切换到我们创建的环境,cd到官方包的目录,执行命令:
mlagents-learn config/ppo/3DBall.yaml --run-id=3DBallTest --force
这条命令的意思是采用的配置文件是3DBall.yaml,以3DBallTest这个名字来进行训练,训练数据也将保存在官方包下results目录下的同名文件夹,–force是强制执行,这会覆盖上一次的数据,如果没有写–force而存在同名文件夹,则训练无法执行。
执行命令后,它会弹出以下界面,然后你可以在Unity中点击运行了:
如果前面的步骤没有问题,就可以看到Unity中的画面是加速运行的,并且控制台逐渐输出以下信息:
[INFO] Connected to Unity environment with package version 2.1.0-exp.1 and communication version 1.5.0
[INFO] Connected new brain: 3DBall?team=0
[INFO] Hyperparameters for behavior name 3DBall:
trainer_type: ppo
hyperparameters:
batch_size: 64
buffer_size: 12000
learning_rate: 0.0003
beta: 0.001
epsilon: 0.2
lambd: 0.99
num_epoch: 3
learning_rate_schedule: linear
beta_schedule: linear
epsilon_schedule: linear
network_settings:
normalize: True
hidden_units: 128
num_layers: 2
vis_encode_type: simple
memory: None
goal_conditioning_type: hyper
reward_signals:
extrinsic:
gamma: 0.99
strength: 1.0
network_settings:
normalize: False
hidden_units: 128
num_layers: 2
vis_encode_type: simple
memory: None
goal_conditioning_type: hyper
init_path: None
keep_checkpoints: 5
checkpoint_interval: 500000
max_steps: 500000
time_horizon: 1000
summary_freq: 12000
threaded: False
self_play: None
behavioral_cloning: None
[INFO] Listening on port 5004. Start training by pressing the Play button in the Unity Editor.
[INFO] Connected to Unity environment with package version 2.1.0-exp.1 and communication version 1.5.0
[INFO] Connected new brain: 3DBall?team=0
[INFO] 3DBall. Step: 12000. Time Elapsed: 152.628 s. Mean Reward: 1.182. Std of Reward: 0.671. Training.
[INFO] 3DBall. Step: 24000. Time Elapsed: 194.462 s. Mean Reward: 1.430. Std of Reward: 0.893. Training.
当结束Unity的运行时,模型会自动保存到官方包下results下对应的文件夹,找到onnx后缀的文件,这是训练好的神经网络模型,导进项目中后,拖到Behavior Parameters组件的Model参数中,点击运行就可以查看实际的运行效果啦!
具体实现
首先在布置好环境之后,我们需要在智能体下挂载以下脚本:
为了驱动智能体,我们需要在智能体下挂上Behavior Parameters组件来调节各种参数,然后我们可以在里面设置组件参数:
其中我们需要修改的有Space Size,代表输入的维度,通常位置的输入是三维的,旋转涉及四元数所以是四维的,速度和角速度都是三维的。这样我们就可以根据需要观察的智能体的参数来计算输入的维度。Action是输出,其中的Continuous Action是输出的连续动作,Descrete Branch是离散动作,我们根据需求填写数量。
还有一个必要的组件是Decision Requester,这个组件提供了方便快捷的方式触发智能体决策过程。可以调节采取决策的步数,和不采取决策时是否执行动作。
最后一个必要组件就是需要我们自己写的Agent脚本了。这个脚本必须继承Agent类。默认只需要设置MaxStep一个参数,设置为0代表无限。下面详细讲解其中函数的用法。
ML-Agents提供了若干个方法供我们实现。
Initialize方法,初始化环境,获取组件信息,设置参数在这里完成。
CollectObservations方法,这个方法会收集当前游戏的各种环境,包括智能体的位置,速度等信息,ML-Agents会把这些信息自动生成Tensor,进行计算。这里相当于设置神经网络的输入,如果是摄像机输入而不是向量输入的情况下此函数什么都不用做。
然后是OnActionReceived方法,实现的是整个游戏中一个Step中的操作,接收神经网络的输出,使其转换为智能体的动作,设置奖励函数,并且判断游戏是否结束。
OnEpisodeBegin方法,每次游戏结束后,重开一轮需要做的处理,比如重置位置信息等。
如果我们想自己操作智能体,还需要定义一个Heuristic方法,这样游戏就会采集玩家的输出信息,可以学习玩家的思维,大大促进训练教程。
CollectDiscreteActionMasks方法,在特殊情况下屏蔽某些不需要的AI操作(如地图边界阻止)。
接下来看看3D小球游戏的智能体代码:
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Random = UnityEngine.Random;
类的成员变量:
[Header("Specific to Ball3D")]
public GameObject ball;
[Tooltip("Whether to use vector observation. This option should be checked " +
"in 3DBall scene, and unchecked in Visual3DBall scene. ")]
public bool useVecObs;
Rigidbody m_BallRb;
EnvironmentParameters m_ResetParams;
Initialize方法:
public override void Initialize()
{
m_BallRb = ball.GetComponent<Rigidbody>();
m_ResetParams = Academy.Instance.EnvironmentParameters;
SetResetParameters();
}
CollectObservations方法:
public override void CollectObservations(VectorSensor sensor)
{
if (useVecObs)
{
sensor.AddObservation(gameObject.transform.rotation.z);
sensor.AddObservation(gameObject.transform.rotation.x);
sensor.AddObservation(ball.transform.position - gameObject.transform.position);
sensor.AddObservation(m_BallRb.velocity);
}
}
在hard版本的3DBall中,采用了另一种方式代替CollectObservations方法(输入维度为5,忽略小球速度):
// 标记该变量可观察,观察状态的帧数为9
[Observable(numStackedObservations: 9)]
Vector2 Rotation
{
get
{
return new Vector2(gameObject.transform.rotation.z, gameObject.transform.rotation.x);
}
}
[Observable(numStackedObservations: 9)]
Vector3 PositionDelta
{
get
{
return ball.transform.position - gameObject.transform.position;
}
}
第三种替代CollectObservations是加上相机脚本即Camera Sensor脚本,再把相机的预制体拖入其中,设置好参数,相机自己就会获取画面作为输入,但会大大增加训练时间。
OnActionReceived方法:
public override void OnActionReceived(ActionBuffers actionBuffers)
{
var actionZ = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[1], -1f, 1f);
if ((gameObject.transform.rotation.z < 0.25f && actionZ > 0f) ||
(gameObject.transform.rotation.z > -0.25f && actionZ < 0f))
{
gameObject.transform.Rotate(new Vector3(0, 0, 1), actionZ);
}
if ((gameObject.transform.rotation.x < 0.25f && actionX > 0f) ||
(gameObject.transform.rotation.x > -0.25f && actionX < 0f))
{
gameObject.transform.Rotate(new Vector3(1, 0, 0), actionX);
}
if ((ball.transform.position.y - gameObject.transform.position.y) < -2f ||
Mathf.Abs(ball.transform.position.x - gameObject.transform.position.x) > 3f ||
Mathf.Abs(ball.transform.position.z - gameObject.transform.position.z) > 3f)
{
SetReward(-1f);
EndEpisode();
}
else
{
SetReward(0.1f);
}
}
OnEpisodeBegin方法:
public override void OnEpisodeBegin()
{
gameObject.transform.rotation = new Quaternion(0f, 0f, 0f, 0f);
gameObject.transform.Rotate(new Vector3(1, 0, 0), Random.Range(-10f, 10f));
gameObject.transform.Rotate(new Vector3(0, 0, 1), Random.Range(-10f, 10f));
m_BallRb.velocity = new Vector3(0f, 0f, 0f);
ball.transform.position = new Vector3(Random.Range(-1.5f, 1.5f), 4f, Random.Range(-1.5f, 1.5f))
+ gameObject.transform.position;
//Reset the parameters when the Agent is reset.
SetResetParameters();
}
Heuristic方法:
public override void Heuristic(in ActionBuffers actionsOut)
{
var continuousActionsOut = actionsOut.ContinuousActions;
continuousActionsOut[0] = -Input.GetAxis("Horizontal");
continuousActionsOut[1] = Input.GetAxis("Vertical");
}
其他方法:
public void SetBall()
{
//Set the attributes of the ball by fetching the information from the academy
m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
var scale = m_ResetParams.GetWithDefault("scale", 1.0f);
ball.transform.localScale = new Vector3(scale, scale, scale);
}
public void SetResetParameters()
{
SetBall();
}
到这里智能体的脚本就写完了。接下来我们需要设置算法和参数。进入到config文件夹下,里面存放着算法的配置,对于单智能体强化学习,Unity官方提供了两种算法,PPO和SAC,这里我们使用PPO算法。进入PPO文件夹中,打开3DBall.yaml,通过修改其中的参数就能改变训练的配置,这需要对相应的强化学习算法有一定的了解。具体可以参考配置文档:https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Training-Configuration-File.md
behaviors:
3DBall:
trainer_type: ppo
hyperparameters:
batch_size: 64
buffer_size: 12000
learning_rate: 0.0003
beta: 0.001
epsilon: 0.2
lambd: 0.99
num_epoch: 3
learning_rate_schedule: linear
network_settings:
normalize: true
hidden_units: 128
num_layers: 2
vis_encode_type: simple
reward_signals:
extrinsic:
gamma: 0.99
strength: 1.0
keep_checkpoints: 5
max_steps: 500000
time_horizon: 1000
summary_freq: 12000
其中在Behavior Parameters中,Behavior Name中的名字必须要和第二行的那个名字一致,如果想设置不同的智能体使用不同的配置同时进行训练,只需要在下面加上不同名字的配置,然后再相应的智能体的Behavior Name中使用那个名字即可。
训练时文件路径要写对:
mlagents-learn config/ppo/3DBall.yaml --run-id=3DBallTest --force
如果想继续上次的训练:
mlagents-learn config/ppo/3DBall.yaml --run-id=3DBallTest --rusume
TensorBoard的使用
在上面的控制台环境下输入下列命令(训练过程中可以另外开一个Anaconda Prompt):
tensorboard --logdir .\\results\\ --port 6006
在浏览器中的网址栏输入localhost:6006,就可以看到tensorboard的界面了。这样就能把数据进行可视化。奖励和Loss的变化一清二楚。
加速训练
在编辑器中的训练是需要消耗非常多性能的,因此我们需要先把它打包成exe文件,具体操作是File–>Build Settings–>Build。
打包好后,把配置文件yaml文件也放入文件夹中。cd到该文件夹中,输入以下命令:
mlagents-learn 配置文件名.yaml --run-id=自己随意起名 --env=执行文件名.exe --num-envs=9 --force
就可以开启9个窗口同时训练。如果不显示图形界面可以快一大截,只需在命令后面加上–no-graphics即可。
总结
本文中介绍了Unity插件ML-Agents的安装和使用方法,并且跑通了ML-Agents的第一个入门级示例。现在算是对ML-Agents的整体框架有了大概的了解,后面还有更加复杂功能的实现,自创环境,自创算法等着我们去探索,我会另外开文章对每个要点进行详细的解读。
以上是关于Unity强化学习之ML-Agents的使用的主要内容,如果未能解决你的问题,请参考以下文章