强化学习rllib简明教程 ray
Posted Lejeune
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了强化学习rllib简明教程 ray相关的知识,希望对你有一定的参考价值。
强化学习rllib简明教程 ray
之前说到强化 学习的库,推荐了tianshou,但是tianshou实现的功能还不够多,于是转向rllib,个人还是很期待tianshou的发展。
回到rllib,rllib是基于ray的一个工具(不知道这么说是不是合适),ray和rllib的关系就像,mllib之于spark,ray是个分布式的计算框架。
官网,文档。进入官网,可以看到,蚂蚁金服也在使用这个框架,大厂使用,不过本人只是为了快速实现一个强化学习的实验。
不过其文档存在着一些问题,比如:官方案例运行出错,文档长久未更新等。给我这种为了快速完成强化学习的菜鸟造成了一定的困难,本人亲自采坑,通过阅读源码等方式,把一些坑给踩了,现在记录在此,为后来的人少踩点坑。本文主要介绍rllib的一些基本功能和使用,在基于文档的基础上进行一些问题的解决和修补。最多会带一些tune的案例,对其他功能有需要的,请自行看文档采坑。
私以为一个完善的强化学习库,应该完成以下功能:
-
训练参数
-
训练结果
-
测试过程 (强化学习有时候是为了找到最优方案或是ai认为期望最好的方案,所以在我们需要获得测试的过程)
-
模型存储
-
模型读取
-
自定义环境
-
结果复现seed
主要是以上7个功能,为了快速入门与简化过程,接下来会根据新的顺序来对以上七个功能进行实现。
ray版本1.2.0
1. 自定义环境
首先一个自定义环境必须继承自gym. Env,并实现reset和step方法,其他方法可实现可不实现,具体可以参照gym的标准,我这里是根据tianshou的标准去写,但是和tianshou不同的是,在方法__init__中,必须带第二个参数,用于传递envconfig,否则会报错。
在这里我实现了一个简单的游戏,用于简化之后的实验,规则为
长度为10的线段,每次只能左右移动,节点标为0…9,起点为0,终点为9,超过100步则死亡-100。 到达9则胜利+100
myenv1.py
import random
import gym
import gym.spaces
import numpy as np
import traceback
import pprint
class GridEnv1(gym.Env):
'''
长度为10的线段,每次只能左右移动,节点标为0..9,
起点为0,终点为9,超过100步则死亡-100
到达9则胜利+100
'''
def __init__(self,env_config):
self.action_space=gym.spaces.Discrete(2)
self.observation_space=gym.spaces.Box(np.array([0]),np.array([9]))
self.reset()
def reset(self):
'''
:return: state
'''
self.observation = [0]
#self.reward = 10
self.done=False
self.step_num=0
return [0]
def step(self, action)->tuple:
'''
:param action:
:return: tuple ->[observation,reward,done,info]
'''
#pprint.pprint(traceback.extract_stack())
if action==0:
action=-1
self.observation[0]+=action
self.step_num+=1
reward=-1.0
if self.step_num>100 or self.observation[0]<0:
reward=-100.0
self.done=True
#print('last %d action %d now %d' % (self.observation[0] - action, action, self.observation[0]))
return self.observation,reward,self.done,
if self.observation[0]==9:
reward=100.0
self.done=True
#print('last %d action %d now %d'%(self.observation[0]-action,action,self.observation[0]))
return self.observation,reward,self.done,
def render(self, mode='human'):
pass
2.训练参数
众所周知,深度学习又被称作炼丹,超参数很多,rllib的实验有两种启动方法,一种是rllib的底层api进行组合调用,另一种是tune.run进行调用。以dqn为例
rllib-api
import ray
from ray.rllib.agents.dqn import DQNTrainer
from myenv1 import GridEnv1
ray.init()
trainer=DQNTrainer(
env=GridEnv1,
config='framework': 'tfe',
)
for i in range(10):
trainer.train()
其中train方法调用一次即为训练一个世代,这是底层api,无法快速控制结束条件等其他参数,所以官方更推荐tune.run。
framework参数代表你要用什么框架,
tf:tensorflow,tfe: TensorFlow eager, torch: PyTorch。
其中tfe是工程模式,即刻计算张量,如果是tf,则会在构建图完成之后才计算,调试解阶段tfe可以看到过程。tf速度更快。
在过去,可以设置config中的config[“eager”] = True,完成模式的更改,现在这个设置已被弃用,想用工程模式的请使用framework
tune.run
from ray import tune
import ray
from ray.rllib.agents.dqn import DQNTrainer
from myenv1 import GridEnv1
ray.init()
t=tune.run(
DQNTrainer,#此处可以用字符串,请自行进入文档查阅对应字符串
config=
'env':GridEnv1,
,
stop=
'episode_reward_max':91
)
tune会自动生成报告,并以stop为结束条件,上面为当一个世代的最大得分超过91时,停止训练。同时tune可以进行超参数寻优,但这不是本篇的主要内容。
上面是开始训练的两种方法,那config中有什么可以设置呢,config中的设置主要来源于两个地方,一个是基本的默认设置,另一个是根据你选定的trainer的默认设置,比如dqn就有一些其他算法没有的设置。第一种的设置如下
通用配置,算法配置请在算法列表中自行查找。
3.训练结果
如果是为了查看每一个世代的训练情况按照以下操作即可
rllib-API
t=trainer.train()
print(t)
tune.run
运行tune.run之后会自行打印结果
同时有时候还会有获得过程中最优值的需求
这样的需求则需要调用回调类,回调类必须继承自DefaultCallbacks。因为需要传入一个类,所以我自行完成了一下动态类,供大家参考,主要是回调与全局锁变量。代码
record.py
import ray
@ray.remote
class BestRecord:
def __init__(self):
self.bestVal = 0
self.bestAction = []
self.poolAction =
# eps_id:list[action]
self.poolVal =
# eps_id:reward
def add(self, sample):
'''
sample key
obs new_obs actions rewards dones
agent_index eps_id unroll_id weights
'''
for index,item in enumerate(sample['eps_id']):
if not item in self.poolVal:
self.poolVal[item]=0
self.poolAction[item]=[]
self.poolAction[item].append(sample['actions'][index])
self.poolVal[item]+=sample['rewards'][index]
if self.poolVal[item]>self.bestVal:
self.bestVal=self.poolVal[item]
self.bestAction=self.poolAction[item]
if sample['dones'][index]:
del self.poolVal[item]
del self.poolAction[item]
def getBest(self):
return self.bestVal,self.bestAction
def getAll(self):
return self.poolVal,self.poolAction
mycallback.py
from typing import Dict
import numpy as np
import ray
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
def initDefaultCallbacks(logPrint=False,isRecord=False):
#MyCallbacks = type('MyCallbacks', (DefaultCallbacks))
class MyCallbacks(DefaultCallbacks):
pass
if logPrint:
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy],
episode: MultiAgentEpisode, env_index: int, **kwargs):
print("episode (env-idx=) started.".format(
episode.episode_id, env_index))
episode.user_data["pole_angles"] = []
episode.hist_data["pole_angles"] = []
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
episode: MultiAgentEpisode, env_index: int, **kwargs):
# pole_angle = abs(episode.last_observation_for()[2])
# #print(episode.last_observation_for())
# #返回最后一次观察
# raw_angle = abs(episode.last_raw_obs_for()[2])
# #print(episode.last_raw_obs_for())
# #返回指定代理的最后一个未预处理的对象存储服务
# assert pole_angle == raw_angle
# episode.user_data["pole_angles"].append(pole_angle)
# #print(episode.)
# print('episode are running')
# print(episode.last_observation_for())
assert episode.last_observation_for() == episode.last_raw_obs_for()
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy], episode: MultiAgentEpisode,
env_index: int, **kwargs):
pole_angle = np.mean(episode.user_data["pole_angles"])
print("episode (env-idx=) ended with length and pole "
"angles ".format(episode.episode_id, env_index, episode.length,
pole_angle))
episode.custom_metrics["pole_angle"] = pole_angle
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
def on_train_result(self, *, trainer, result: dict, **kwargs):
print("trainer.train() result: -> episodes".format(
trainer, result["episodes_this_iter"]))
# you can mutate the result dict to add new fields to return
result["callback_ok"] = True
def on_postprocess_trajectory(
self, *, worker: RolloutWorker, episode: MultiAgentEpisode,
agent_id: str, policy_id: str, policies: Dict[str, Policy],
postprocessed_batch: SampleBatch,
original_batches: Dict[str, SampleBatch], **kwargs):
print("postprocessed steps".format(postprocessed_batch.count))
if "num_batches" not in episode.custom_metrics:
episode.custom_metrics["num_batches"] = 0
episode.custom_metrics["num_batches"] += 1
MyCallbacks.on_train_result=on_train_result
MyCallbacks.on_episode_start=on_episode_start
MyCallbacks.on_episode_step=on_episode_step
MyCallbacks.on_episode_end=on_episode_end
MyCallbacks.on_postprocess_trajectory=on_postprocess_trajectory
if isRecord:
def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch,
**kwargs) -> None:
'''
sample key
obs new_obs actions rewards dones
agent_index eps_id unroll_id weights
:param worker:
:param samples:
:param kwargs:
:return:
'''
# in your envs
record = ray.get_actor("record")
record.add.remote(samples) # async call to increment the global count
MyCallbacks.on_sample_end=on_sample_end
return MyCallbacks
调用,此处用到了一个全局变量的修改
ray.init()
record = BestRecord.options(name="record").remote()
t=tune.run(
DQNTrainer,
config=
'env':GridEnv1,
'callbacks':initDefaultCallbacks(isRecord=True),
,
stop=
'episode_reward_max':91
)
print(ray.get(record.getBest.remote())) # get the latest count
print(ray.get(record.getAll.remote())) # get the latest count
运行结束后,会打印出历史最好的路径,和记录池。
当然还有一种通过tune存储记录点的方法,记录最佳结果,但是无法记录最佳路径,因为存储记录点属于存储模型的那部分,所以此处不介绍。
4.测试过程 自写评估函数
强化学习有时候是为了找到最优方案(上一板块已经实现)或是ai认为期望最好的方案。所谓“ai认为期望最好的方案”,有这样一种需求,我们在进行学习和探索的时候,采用的可能是ε−greedy策略,这个策略是有一定的随机性的,用这个策略进行学习是合理的,以一定的概率进行跳出局部的探索,但是有时候进行评估时,我们希望直接采用已经训练好的ai认为的期望最大值,也就是每一步都采用贪婪策略进行评估。
对于自定义的评估和rllib的自带的评估
rllib-api
自定义评估需要自写评估函数,在trainer训练结束的时候调用一次
'''
this is your trainer training
'''
def eval(trainer):
#your eval function
pass
eval(trainer)
tune.run
也需要自写评估函数,但是调用与配置更加方便,里面有很多选项供选择,可以自定义测试频率,自定义测试函数等,传入config即可。
官方案例已经比较清晰了。需要注意的是,tune.run的测试函数需要传入trainer, eval_workers并返回一个得分。
同时,我也放上我自己需求的用于简单评估的贪婪策略的DQN评估,里面还是有很多细节的。
def eval(trainer):
policy=trainer.get_policy()
logits, _ = policy.model.from_batch("obs": np.array([[0.0],[0.0]]))
logits=policy.model.get_q_value_distributions(logits)[0]
dist=policy.dist_class(logits,policy.model)
print(dist.deterministic_sample())
这是我用贪婪策略进行的最终的策略选择,需要注意的是,dqn和ppo的配置有所不用,ppo的policy.model.from_batch之后,直接就是策略分布,dqn却不是,所以需要再get_q_value_distributions,不同算法实现不同,具体算法需要具体分析。
同时,可以在官网上看到有trainer.compute_action()和policy.dist_class(logits,policy.model).sample()方法,这两个都是获得动作,其中还是有不同,具体 不同之处建议查看源码,本人才疏学浅并且时间不足,没有深刻理解。本人目前能看出来的是,trainer.compute_action,每一次都是重新运行策略进行计算,如果用的是 有随机数的策略,那么多次运行,会出现不同的答案。
policy.dist_class(logits,policy.model).sample(),在policy.dist_class的类初始化时,就已经根据一定的策略选择了动作,在对象生成之后,不论多少次调用sample方法,答案都是一样的。如果需要重新运行策略,则需要重新实例化policy.dist_class。
5.模型存储与读取
rllib以checkpoint检查点进行存储,当然也可以使用tf和pytorch自带的存储,直接checkpoint会比较方便。本文也主要介绍checkpoint。
rllib-api
path=trainer.save()#存储,该方法返回路径
trainer.restore(path)#读取
tune.run
tune.run(
train,
config=config
#checkpoint_at_end=True #结束时存储检查点
#checkpoint_freq=int #几个世代存储一次
# restore=path #载入检查点
)
6. 结果复现
在深度学习中,结果的复现尤为重要,在rllib中没有统一的api实现结果复现,你需要阅读文档,知道所用的库与算法,将其中所有存在随机数的库,都设置种子。简单的说,就是具体问题具体分析,其中以tf为例。
import ray
from ray.rllib.agents.dqn import DQNTrainer
from mycallback import initDefaultCallbacks
from myenv1 import GridEnv1
from record import BestRecord
from ray import tune
import traceback
import pprint
import numpy as np
import random
np.random.seed(1234)
random.seed(1234)
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
tf.random.set_seed(1234)
ray.init()
trainer=DQNTrainer(
env=GridEnv1,
config='framework': 'tfe',
"env": GridEnv1,
"num_workers": 1,
)
for i in range(10):
trainer.train()
policy=trainer.get_policy()
logits, _ = policy.model.from_batch("obs": np.array([[float(i)] for i in range(9)]))
print(logits)
两次结果一样,则复现成功
以上是关于强化学习rllib简明教程 ray的主要内容,如果未能解决你的问题,请参考以下文章
更改 Ray RLlib Training 的 Logdir 而不是 ~/ray_results
如何评估在 rllib (Ray) 中自定义环境中训练的演员?
Ray和hoplite 强化学习基于任务的分布式系统容错高性能的集合通信
Ray和hoplite 强化学习基于任务的分布式系统容错高性能的集合通信