ML-Agents案例之“排序算法超硬核版”
Posted 微笑小星
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了ML-Agents案例之“排序算法超硬核版”相关的知识,希望对你有一定的参考价值。
本案例源自ML-Agents官方的示例,Github地址:https://github.com/Unity-Technologies/ml-agents,本文是详细的配套讲解。
本文基于我前面发的两篇文章,需要对ML-Agents有一定的了解,详情请见:Unity强化学习之ML-Agents的使用、ML-Agents命令及配置大全。
我前面的相关文章有:
环境说明
如图所示,智能体在一个圆形的房间中,墙壁上会随机出现带有数字的方块,智能体需要按照数字从小到大与方块进行碰撞,碰撞过的方块会变成绿色,分数+1,一旦碰撞顺序不对,游戏结束,分数-1。
这个案例的挑战是,我们不会告诉智能体怎么排序是对的,智能体需要在环境中试错,从而自己学习到这种从小到大排序,碰撞对应方块的行为模式,同时墙壁上出现的数字方块的个数是不定的,也就是说每个episode我们都需要接收不同个数的输入,这应该怎么处理呢?
状态输入:这里用到了一个新的传感器Buffer Sensor。
这个传感器的作用是可以接收个数变化的状态输入。我们需要每次传入一个向量,这个向量我们可以用数组listObservation表示。通过 m_BufferSensor.AppendObservation(listObservation)传入到BufferSensor中,而BufferSensor可以接收无数个这样的向量输入,但是每个向量的维度必须相同。也就是说即使我们输入的向量个数每次都不同,我们还是能训练网络还是产生我们所期望的输出,具体是怎么实现的项目代码中没有,集成在了ML-Agents包中,根据我的经验,应该用了Self-attention这种网络的结构,这样就能接收不同个数向量的输入了。
除了传给BufferSensor的输入之外,还传入了四维的向量,分别是智能体位置到场地中心的向量在x轴和z轴上的分量,智能体前进方向在x轴和z轴的分量。
动作输出:输出三个离散值,每个离散值包含0-2三个数,第一个离散值决定了前进后退,第二个离散值决定了左移右移,第三个离散值决定了左转右转。
代码讲解
智能体下挂载的脚本除去万年不变的Decesion Requester,Model Overrider,Behavior Parameters,以及刚刚说明的Buffer Sensor,就只剩下智能体的只有文件SorterAgent.cs了:
头文件:
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Random = UnityEngine.Random;
定义变量:
// 默认数字方块的最大数量,可在编辑器中滑动调节,调节范围为1 - 20
[Range(1, 20)]
public int DefaultMaxNumTiles;
// 方块数字的最大值
private const int k_HighestTileValue = 20;
// 生成方块的数量
int m_NumberOfTilesToSpawn;
// 方块的最大数量
int m_MaxNumberOfTiles;
// 刚体
Rigidbody m_AgentRb;
// BufferSensorComponent 是一个传感器,允许观察不同数量的输入
BufferSensorComponent m_BufferSensor;
// 数字方块的列表
public List<NumberTile> NumberTilesList = new List<NumberTile>();
// 出现在场景中的方块列表
private List<NumberTile> CurrentlyVisibleTilesList = new List<NumberTile>();
// 已经被接触过的方块列表
private List<Transform> AlreadyTouchedList = new List<Transform>();
private List<int> m_UsedPositionsList = new List<int>();
// 初始位置
private Vector3 m_StartingPos;
// 整个场景
GameObject m_Area;
// 环境参数,可以从配置文件中获取
EnvironmentParameters m_ResetParams;
// 下一个想要碰撞的数字方块的索引
private int m_NextExpectedTileIndex;
初始化方法Initialize():
public override void Initialize()
// 获取父物体
m_Area = transform.parent.gameObject;
// 获取方块的最大数量
m_MaxNumberOfTiles = k_HighestTileValue;
// 从配置文件中获取环境参数
m_ResetParams = Academy.Instance.EnvironmentParameters;
// 获取传感器脚本
m_BufferSensor = GetComponent<BufferSensorComponent>();
// 获取刚体
m_AgentRb = GetComponent<Rigidbody>();
// 起始位置
m_StartingPos = transform.position;
状态输入方法:
public override void CollectObservations(VectorSensor sensor)
// 获取智能体到场地中心的x轴和z轴上的距离
sensor.AddObservation((transform.position.x - m_Area.transform.position.x) / 20f);
sensor.AddObservation((transform.position.z - m_Area.transform.position.z) / 20f);
// 获取智能体前进方向的x轴和z轴的值
sensor.AddObservation(transform.forward.x);
sensor.AddObservation(transform.forward.z);
foreach (var item in CurrentlyVisibleTilesList)
// 定义一个数组,存放一系列观察值,数组长度为数字方块最大数量 + 3,默认初始化全部为0
float[] listObservation = new float[k_HighestTileValue + 3];
// 获取方块的数字,设置对应的one-hot向量
listObservation[item.NumberValue] = 1.0f;
// 获取方块的坐标(子物体坐标才是真实坐标的,transform本身的位置保持在场景中央,方便旋转)
var tileTransform = item.transform.GetChild(1);
// 输入数字方块和智能体的x分量和z分量
listObservation[k_HighestTileValue] = (tileTransform.position.x - transform.position.x) / 20f;
listObservation[k_HighestTileValue + 1] = (tileTransform.position.z - transform.position.z) / 20f;
// 该方块是否已经被碰撞过
listObservation[k_HighestTileValue + 2] = item.IsVisited ? 1.0f : 0.0f;
// 把数组添加到Buffer Sensor中(不直接输入到网络的原因是需要添加的数组个数个数是变化的)
m_BufferSensor.AppendObservation(listObservation);
动作输出方法OnActionReceived:
public override void OnActionReceived(ActionBuffers actionBuffers)
// 移动智能体
MoveAgent(actionBuffers.DiscreteActions);
// 时间惩罚,激励智能体越快完成越好
AddReward(-1f / MaxStep);
public void MoveAgent(ActionSegment<int> act)
var dirToGo = Vector3.zero;
var rotateDir = Vector3.zero;
// 获取神经网络三个离散输出
var forwardAxis = act[0];
var rightAxis = act[1];
var rotateAxis = act[2];
// 第一个离散输出决定了前进后退
switch (forwardAxis)
case 1:
dirToGo = transform.forward * 1f;
break;
case 2:
dirToGo = transform.forward * -1f;
break;
// 第二个离散输出决定了左移右移
switch (rightAxis)
case 1:
dirToGo = transform.right * 1f;
break;
case 2:
dirToGo = transform.right * -1f;
break;
// 第三个离散输出决定了左转右转
switch (rotateAxis)
case 1:
rotateDir = transform.up * -1f;
break;
case 2:
rotateDir = transform.up * 1f;
break;
// 执行动作
transform.Rotate(rotateDir, Time.deltaTime * 200f);
m_AgentRb.AddForce(dirToGo * 2, ForceMode.VelocityChange);
每一个episode(回合)开始时执行的方法OnEpisodeBegin:
public override void OnEpisodeBegin()
// 从配置文件中获取方块的数量,没有的话设为DefaultMaxNumTiles
m_MaxNumberOfTiles = (int)m_ResetParams.GetWithDefault("num_tiles", DefaultMaxNumTiles);
// 随机生成方块的数量
m_NumberOfTilesToSpawn = Random.Range(1, m_MaxNumberOfTiles + 1);
// 选择将要生成的对应的方块并加入列表中
SelectTilesToShow();
// 生成方块及调整位置
SetTilePositions();
transform.position = m_StartingPos;
m_AgentRb.velocity = Vector3.zero;
m_AgentRb.angularVelocity = Vector3.zero;
void SelectTilesToShow()
// 清除两个列表
CurrentlyVisibleTilesList.Clear();
AlreadyTouchedList.Clear();
// 共生成nunLeft个方块
int numLeft = m_NumberOfTilesToSpawn;
while (numLeft > 0)
// 在范围内取随机数生成对应方块
int rndInt = Random.Range(0, k_HighestTileValue);
var tmp = NumberTilesList[rndInt];
// 如果对应的方块列表中没有才进行添加
if (!CurrentlyVisibleTilesList.Contains(tmp))
CurrentlyVisibleTilesList.Add(tmp);
numLeft--;
// 给方块列表列表按照数字升序进行排序
CurrentlyVisibleTilesList.Sort((x, y) => x.NumberValue.CompareTo(y.NumberValue));
m_NextExpectedTileIndex = 0;
void SetTilePositions()
// 清空列表
m_UsedPositionsList.Clear();
// 重置所有方块的状态,ResetTile方法可以在数字方块的脚本中看到
foreach (var item in NumberTilesList)
item.ResetTile();
item.gameObject.SetActive(false);
foreach (var item in CurrentlyVisibleTilesList)
bool posChosen = false;
// rndPosIndx决定了我们方块的旋转角度(即在圆形场地的哪里)
int rndPosIndx = 0;
while (!posChosen)
rndPosIndx = Random.Range(0, k_HighestTileValue);
// 这个旋转角度是否被选了,没被选就加入列表中
if (!m_UsedPositionsList.Contains(rndPosIndx))
m_UsedPositionsList.Add(rndPosIndx);
posChosen = true;
// 执行方块角度的旋转并激活物体
item.transform.localRotation = Quaternion.Euler(0, rndPosIndx * (360f / k_HighestTileValue), 0);
item.gameObject.SetActive(true);
当与别的物体开始发生碰撞执行方法OnCollisionEnter:
private void OnCollisionEnter(Collision col)
// 只检测和数字方块的碰撞
if (!col.gameObject.CompareTag("tile"))
return;
// 如果方块已经碰撞过,也排除在碰撞对象之外
if (AlreadyTouchedList.Contains(col.transform))
return;
// 如果碰撞的顺序错误,奖励-1,结束游戏
if (col.transform.parent != CurrentlyVisibleTilesList[m_NextExpectedTileIndex].transform)
AddReward(-1);
EndEpisode();
// 碰撞到正确的方块的情况
else
// 奖励+1
AddReward(1);
// 改变方块的材质
var tile = col.gameObject.GetComponentInParent<NumberTile>();
tile.VisitTile();
// 索引+1
m_NextExpectedTileIndex++;
// 把方块加入到已接触列表中
AlreadyTouchedList.Add(col.transform);
// 如果完成了所有的任务,游戏结束
if (m_NextExpectedTileIndex == m_NumberOfTilesToSpawn)
EndEpisode();
当智能体没有模型,人想手动录制示例时可以采用Heuristic方法:
public override void Heuristic(in ActionBuffers actionsOut)
var discreteActionsOut = actionsOut.DiscreteActions;
//forward
if (Input.GetKey(KeyCode.W))
discreteActionsOut[0] = 1;
if (Input.GetKey(KeyCode.S))
discreteActionsOut[0] = 2;
//rotate
if (Input.GetKey(KeyCode.A))
discreteActionsOut[2] = 1;
if (Input.GetKey(KeyCode.D))
discreteActionsOut[2] = 2;
//right
if (Input.GetKey(KeyCode.E))
discreteActionsOut[1] = 1;
if (Input.GetKey(KeyCode.Q))
discreteActionsOut[1] = 2;
挂载在数字方块上的脚本NumberTile.cs:
using UnityEngine;
public class NumberTile : MonoBehaviour
// 方块上的数字
public int NumberValue;
// 默认材质和成功时转换用的材质
public Material DefaultMaterial;
public Material SuccessMaterial;
// 是否已经碰撞过
private bool m_Visited;
// 渲染,用于转换材质
private MeshRenderer m_Renderer;
public bool IsVisited
get return m_Visited;
// 用于转换材质的方法
public void VisitTile()
m_Renderer.sharedMaterial = SuccessMaterial;
m_Visited = true;
// 重置方块的方法,材质还原,m_Visited状态还原
public void ResetTile()
if (m_Renderer is null)
m_Renderer = GetComponentInChildren<MeshRenderer>();
m_Renderer.sharedMaterial = DefaultMaterial;
m_Visited = false;
配置文件
behaviors:
Sorter:
trainer_type: ppo
hyperparameters:
batch_size: 512
buffer_size: 40960
learning_rate: 0.0003
beta: 0.005
epsilon: 0.2
lambd: 0.95
num_epoch: 3
learning_rate_schedule: constant
network_settings:
normalize: False
hidden_units: 128
num_layers: 2
vis_encode_type: simple
reward_signals:
extrinsic:
gamma: 0.99
strength: 1.0
keep_checkpoints: 5
max_steps: 5000000
time_horizon: 256
summary_freq: 10000
environment_parameters:
num_tiles:
curriculum:
- name: Lesson0 # The '-' is important as this is a list
completion_criteria:
measure: progress
behavior: Sorter
signal_smoothing: true
min_lesson_length: 100
threshold: 0.3
value: 2.0
- name: Lesson1
completion_criteria:
measure: progress
behavior: Sorter
signal_smoothing: true
min_lesson_length: 100
threshold: 0.4
value: 4.0
- name: Lesson2
completion_criteria:
measure: progress
behavior: Sorter
signal_smoothing: true
min_lesson_length: 100
threshold: 0.45
value: 6.0
- name: Lesson3
completion_criteria:
measure: progress
behavior: Sorter
signal_smoothing: true
min_lesson_length: 100
threshold: 0.5
value: 8.0
- name: Lesson4
completion_criteria:
measure: progress
behavior: Sorter
signal_smoothing: true
min_lesson_length: 100
threshold: 0.55
value: 10.0
- name: Lesson5
completion_criteria:
measure: progress
behavior: Sorter
signal_smoothing: true
min_lesson_length: 100
threshold: 0.以上是关于ML-Agents案例之“排序算法超硬核版”的主要内容,如果未能解决你的问题,请参考以下文章