基于stable-baseline3 强化学习DQN的lunar lander的稳定控制
Posted Colin_Fang
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了基于stable-baseline3 强化学习DQN的lunar lander的稳定控制相关的知识,希望对你有一定的参考价值。
基于stable-baseline3 强化学习DQN的lunar lander的稳定控制
依赖包
鉴于不同版本的gym与stable-baselines3会产生冲突,在成功的基础上记录:
gym == 0.21.0
stable-baselines3 == 1.6.2
安装代码:
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple gym==0.21.0
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple stable-baselines3[extra]==1.6.2
lunar lander随机初始化action
import gym
# Create environment
env = gym.make("LunarLander-v2")
eposides = 10
for eq in range(eposides):
obs = env.reset()
done = False
rewards = 0
while not done:
action = env.action_space.sample()
obs, reward, done, info = env.step(action)
env.render()
rewards += reward
print(rewards)
随机初始化,视频链接:lunar_lander_random
基于stable-baseline中DQN的实现
模型训练
import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
# Create environment
env = gym.make("LunarLander-v2")
# Instantiate the agent
model = DQN("MlpPolicy", env, verbose=1)
# Train the agent and display a progress bar
model.learn(total_timesteps=int(2e5), progress_bar=True)
# Save the agent
model.save("dqn_lunar")
这里已经将训练好的模型给保存为dqn_lunar.zip
模型测试
直接读取模型训练结果,进行测试
import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
# Create environment
env = gym.make("LunarLander-v2")
model = DQN.load("dqn_lunar", env=env)
# 测试接口
mean_reward, std_reward = evaluate_policy(
model,
model.get_env(),
deterministic=True,
render=True,
n_eval_episodes=10)
print(mean_reward)
自己写测试模块
import gym
from stable_baselines3 import DQN
# Create environment
env = gym.make("LunarLander-v2")
# Instantiate the agent
model = DQN("MlpPolicy", env, verbose=1)
model = DQN.load("dqn_lunar", env=env)
eposides = 10
for eq in range(eposides):
obs = env.reset()
done = False
rewards = 0
while not done:
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
rewards += reward
print(rewards)
测试结果:lunar_lander_DQN
网络架构优化
根据上述视频可以看出,在默认的DQN网络及参数,还不能使飞行器稳定停在月球上,将学习率改为5e-4,网络参数改为256,训练次数改为2500,000次,训练代码如下:
import gym
from stable_baselines3 import DQN
# Create environment
env = gym.make("LunarLander-v2")
model = DQN(
"MlpPolicy",
env,
verbose=1,
learning_rate=5e-4,
policy_kwargs='net_arch':[256,256])
model.learn(
total_timesteps=int(2.5e6),
progress_bar=True)
model.save("dqn_Net256_lunar_2500K")
模型测试代码如下:
import gym
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
# Create environment
env = gym.make("LunarLander-v2")
model = DQN.load("dqn_Net256_lunar_2500K", env=env)
mean_reward, std_reward = evaluate_policy(
model,
model.get_env(),
deterministic=True,
render=True,
n_eval_episodes=10)
print(mean_reward)
测试视频:lunar_lander_256_2500K
由视频可以看出,月球车每次都能稳定停留在月球表面。
附录
有问题可以直接查官方文档
stable-baseline3: 手册
gym: 手册
以上是关于基于stable-baseline3 强化学习DQN的lunar lander的稳定控制的主要内容,如果未能解决你的问题,请参考以下文章