DQN 为每个状态(推车杆)预测相同的动作值

Posted

技术标签:

【中文标题】DQN 为每个状态(推车杆)预测相同的动作值【英文标题】:DQN predicts same action value for every state (cart pole) 【发布时间】:2022-01-19 18:41:43 【问题描述】:

我正在尝试实施 DQN。作为热身,我想用一个由两个隐藏层以及输入和输出层组成的 MLP 来解决 CartPole-v0。输入是一个 4 元素数组 [小车位置、小车速度、极角、极角速度],输出是每个动作(左或右)的动作值。我并没有完全实现“使用 DRL 玩 Atari”论文中的 DQN(输入等没有帧堆叠)。我也做了一些非标准的选择,比如把done和目标网络预测的动作值放在体验回放中,但是这些选择应该不会影响学习。

在任何情况下,我都很难让这件事正常工作。无论我训练代理多长时间,它都会不断预测一个动作的值高于另一个动作,例如所有状态 s 的 Q(s, Right)> Q(s, Left)。下面是我的学习代码、我的网络定义以及我从训练中得到的一些结果

class DQN:
    def __init__(self, env, steps_per_episode=200):
        self.env = env
        self.agent_network = MlpPolicy(self.env)
        self.target_network = MlpPolicy(self.env)
        self.target_network.load_state_dict(self.agent_network.state_dict())
        self.target_network.eval()
        self.optimizer = torch.optim.RMSprop(
            self.agent_network.parameters(), lr=0.005, momentum=0.95
        )
        self.replay_memory = ReplayMemory()
        self.gamma = 0.99
        self.steps_per_episode = steps_per_episode
        self.random_policy_stop = 1000
        self.start_learning_time = 1000
        self.batch_size = 32

    def learn(self, episodes):
        time = 0
        for episode in tqdm(range(episodes)):
            state = self.env.reset()
            for step in range(self.steps_per_episode):
                if time < self.random_policy_stop:
                    action = self.env.action_space.sample()
                else:
                    action = select_action(self.env, time, state, self.agent_network)
                new_state, reward, done, _ = self.env.step(action)
                target_value_pred = predict_target_value(
                    new_state, reward, done, self.target_network, self.gamma
                )
                experience = Experience(
                    state, action, reward, new_state, done, target_value_pred
                )
                self.replay_memory.append(experience)
                if time > self.start_learning_time:  # learning step
                    experience_batch = self.replay_memory.sample(self.batch_size)
                    target_preds = extract_value_predictions(experience_batch)
                    agent_preds = agent_batch_preds(
                        experience_batch, self.agent_network
                    )
                    loss = torch.square(agent_preds - target_preds).sum()
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                if time % 1_000 == 0:  # how frequently to update target net
                    self.target_network.load_state_dict(self.agent_network.state_dict())
                    self.target_network.eval()

                state = new_state
                time += 1

                if done:
                    break

def agent_batch_preds(experience_batch: list, agent_network: MlpPolicy):
    """
    Calculate the agent action value estimates using the old states and the
    actual actions that the agent took at that step.
    """
    old_states = extract_old_states(experience_batch)
    actions = extract_actions(experience_batch)
    agent_preds = agent_network(old_states)
    experienced_action_values = agent_preds.index_select(1, actions).diag()
    return experienced_action_values
def extract_actions(experience_batch: list) -> list:
    """
    Extract the list of actions from experience replay batch and torchify
    """
    actions = [exp.action for exp in experience_batch]
    actions = torch.tensor(actions)
    return actions
class MlpPolicy(nn.Module):
    """
    This class implements the MLP which will be used as the Q network. I only
    intend to solve classic control problems with this.
    """

    def __init__(self, env):
        super(MlpPolicy, self).__init__()
        self.env = env
        self.input_dim = self.env.observation_space.shape[0]
        self.output_dim = self.env.action_space.n
        self.fc1 = nn.Linear(self.input_dim, 32)
        self.fc2 = nn.Linear(32, 128)
        self.fc3 = nn.Linear(128, 32)
        self.fc4 = nn.Linear(32, self.output_dim)

    def forward(self, x):
        if type(x) != torch.Tensor:
            x = torch.tensor(x).float()
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

学习成果:

在这里,我看到一个动作总是比其他动作更重要(Q(right, s) > Q(left, s))。同样清楚的是,网络为每个状态预测相同的动作值。

有人知道发生了什么吗?我已经对原始论文进行了很多调试和仔细阅读(也考虑过“规范化”观察空间,即使速度可以是无限的)并且在这一点上可能会遗漏一些明显的东西。如果有用的话,我可以为辅助函数添加更多代码。

【问题讨论】:

【参考方案1】:

网络定义没有问题。事实证明,学习率太高了,将其降低 0.00025(如在介绍 DQN 的原始自然论文中)导致了一个可以解决 CartPole-v0 的代理。

也就是说,学习算法不正确。特别是我使用了错误的目标行动价值预测。请注意,上面列出的算法不使用目标网络的最新版本进行预测。随着训练的进行,这会导致结果不佳,因为代理是基于陈旧的目标数据进行学习的。解决此问题的方法是将(s, a, r, s', done) 放入重放内存中,然后在对小批量进行采样时使用最新版本的目标网络进行目标预测。有关更新的学习循环,请参见下面的代码。

def learn(self, episodes):
        time = 0
        for episode in tqdm(range(episodes)):
            state = self.env.reset()
            for step in range(self.steps_per_episode):
                if time < self.random_policy_stop:
                    action = self.env.action_space.sample()
                else:
                    action = select_action(self.env, time, state, self.agent_network)
                new_state, reward, done, _ = self.env.step(action)
                experience = Experience(state, action, reward, new_state, done)
                self.replay_memory.append(experience)
                if time > self.start_learning_time:  # learning step.
                    experience_batch = self.replay_memory.sample(self.batch_size)
                    target_preds = target_batch_preds(
                        experience_batch, self.target_network, self.gamma
                    )
                    agent_preds = agent_batch_preds(
                        experience_batch, self.agent_network
                    )
                    loss = torch.square(agent_preds - target_preds).sum()
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                if time % 1_000 == 0:  # how frequently to update target net
                    self.target_network.load_state_dict(self.agent_network.state_dict())
                    self.target_network.eval()

                state = new_state
                time += 1
                if done:
                    break

【讨论】:

以上是关于DQN 为每个状态(推车杆)预测相同的动作值的主要内容,如果未能解决你的问题,请参考以下文章

强化学习 DQN pytorch实例

针对连续动作的DQN

第二周作业

DQN笔记:MC & TD

python 推车杆

基于Pytorch的强化学习(DQN)之价值学习