CartPole-v1 50行python实现

Posted pysnow530

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了CartPole-v1 50行python实现相关的知识,希望对你有一定的参考价值。

CartPole-v1 50行python实现

背景

很久没有写文章了,github上维护的博客 https://blog.dong.black/ 上篇还是今年5月15号的,距离现在也有一个多月了。

之前在优达学过一小段时间的机器学习,感觉挺有意思,最近又看到了gym,想动手再尝试一把。

CartPole-v1是gym中比较经(jian)典(dian)的题目,号称机器学习中的 hello world,比较适合我这种小白。趁着周末的闲功夫,求解一下。

题目

官方其实已经给出解释了:

Reinforcement learning Q-learning approach to OpenAI Gym’s CartPole environment.

这本质上是一个Q-learning问题,但是作为强化学习的 hello world,其实也有很多其它的解法。

作者尝试过使用DQN解,但是收敛速度和稳定性差强人意。个人电脑吱吱转,算法却死活不收敛。

罢了,使用线性模型蒙一下吧。

线性模型

基本思路,是使用单个神经元。这里也不反向传导了,直接在当前空间随机探索,然后查看效果。

激活函数就是根据结果符号输出action,可以简单理解为 int(input > 0)

基本的过程如下:

  1. 随机选取 weights + bias
  2. 生成随机步长 delta_weights
  3. 计算更新后 weights 可以得到的回报 rewards
  4. 如果 rewards 相较之前增大了,应用 delta_weights;减小了,反向应用 delta_weights

算法简单粗暴,但是针对这个简单粗暴的题目,效果挺好。

代码

代码量只有50行,也没有比较复杂的逻辑,所以这里就直接贴出来了。

import gym
import numpy as np
import matplotlib.pyplot as plt


def predict(state, weight):
    return int(np.dot(state, weight) >= 0.0)


def run_once(env, weight):
    state = env.reset()
    rewards = 0
    while True:
        # env.render()
        rewards += 1
        action = predict(state[1:], weight)
        state, reward, done, info = env.step(action)
        if done:
            return rewards


def fix_weight(weight, delta_weight, this_rewards, last_rewards):
    return weight + delta_weight * (1 if this_rewards > last_rewards else -1)


def main():
    episodes = 100
    env = gym.make('CartPole-v1')
    weight = np.random.random(3)
    last_rewards = None
    reward_list = []

    for i in range(episodes):
        delta_weight = np.random.random(weight.shape) * 2 - 1

        this_rewards = run_once(env, weight + delta_weight)
        if last_rewards is not None and this_rewards != last_rewards:
            weight = fix_weight(weight, delta_weight, this_rewards, last_rewards)
        last_rewards = this_rewards
        reward_list.append(last_rewards)
        print(i, last_rewards, weight)

    plt.plot(reward_list)
    plt.show()


if __name__ == "__main__":
    main()

在这里插入图片描述

过程中的优化及问题

该模型中没有太多机器学习的核心理论。不过过程中,有借鉴一些机器学习的理念,都比较浅显,这里不做过多解释。

总结下过程中优化的几个点:

  1. 这个题目实际上是个二值问题,而且两个值是对称的,bias实际上没有意义,去掉
  2. 输入是4个参数,但是实际上位置信息很难用来做优化,对总体回报影响太小,去掉

问题:

这个算法大部分时候效果可以,但是存在以下两个问题:

  1. 某些情况下来回震荡,不能收敛
  2. 大部分回合500,但是后期仍会出现回报较低的情况

嗯,这两个问题挺典型,值得做进一步研究。

以上是关于CartPole-v1 50行python实现的主要内容,如果未能解决你的问题,请参考以下文章

CartPole-v1 50行python实现

python学习---50行代码实现图片转字符画1

动漫头像一张50元?1行Python代码实现,别再去交智商税了

50行python代码实现个代理server(你懂的)

Cartpole-v0 的 PyTorch PPO 实现陷入局部最优

50 行Python代码,轻松实现中英文翻译