LeetCode 847. Shortest Path Visiting All Nodes的强化学习解法

Posted autosoftdev

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了LeetCode 847. Shortest Path Visiting All Nodes的强化学习解法相关的知识,希望对你有一定的参考价值。

这题的本意不是要考机器学习的,而且模型已知情况下,可以直接求解,不需要用MC、TD等方式。使用这个代码,即使得到解,也不能通过本题测试。可以初步练习下调参,比如设置不同的奖励,探索衰减等,看看机器找路的特点。本题是要考遍历和找字串等技能。

代码如下:

import random as rnd

class Env(object):
    def __init__(self):
        self.Connection = []
        self.Visited = []
        self.NodeIndex = -1

    def setConnection(self, con):
        self.Connection = con

    def reset(self):
        self.Visited = []
        for node in self.Connection:
            self.Visited.append(0)
        self.NodeIndex = -1
        return self.NodeIndex

    def stateSpace(self):
        return self.Connection

    def actionSpace(self, node):
        return self.Connection[node]

    def actionSample(self, state):
        Sample = -1
        if -1 == state:
            Sample = rnd.randint(0, len(self.Connection)-1)
        else:
            Sample = self.Connection[state][rnd.randint(0, len(self.Connection[state])-1)]
        return Sample

    def isDone(self):
        is_done = True
        for visited in self.Visited:
            if 0 == visited:
                is_done = False
                break
        return is_done

    def step(self, action):
        self.NodeIndex = action
        reward = -1 * self.Visited[self.NodeIndex]
        self.Visited[self.NodeIndex] += 1

        is_done = self.isDone()
        if is_done:
            reward = 1

        return self.NodeIndex, reward, is_done


class SarsaAgent(object):
    def __init__(self, env:Env):
        self.env = env
        self.Q = {}
        self.E = {}
        self.initAgent()

    def initStateValues(self, randomized = True):
        self.Q, self.E = {}, {}
        self.Q[-1], self.E[-1] = {}, {}
        actionIndex = 0
        for state in self.env.stateSpace():
            default_v = rnd.random() / 10 if randomized is True else 0.0
            self.Q[-1][actionIndex] = default_v
            self.E[-1][actionIndex] = 0.0
            actionIndex += 1

        stateIndex = 0
        for state in self.env.stateSpace():
            self.Q[stateIndex], self.E[stateIndex] = {}, {}
            for action in self.env.actionSpace(stateIndex):
                default_v = random() / 10 if randomized is True else 0.0
                self.Q[stateIndex][action] = default_v
                self.E[stateIndex][action] = 0.0
            stateIndex += 1

    def get(self, QorE, s, a):
        return QorE[s][a]

    def set(self, QorE, s, a, value):
        QorE[s][a] = value

    def resetEValue(self):
        self.E = {}
        self.E[-1] = {}
        actionIndex = 0
        for state in self.env.stateSpace():
            self.E[-1][actionIndex] = 0.0
            actionIndex += 1

        stateIndex = 0
        for state in self.env.stateSpace():
            self.E[stateIndex] = {}
            for action in self.env.actionSpace(stateIndex):
                self.E[stateIndex][action] = 0.0
            stateIndex += 1

    def initAgent(self):
        self.state = self.env.reset()
        self.initStateValues(randomized=False)
    
    # using simple decaying epsilon greedy exploration
    def curPolicy(self, s, episode_num, use_epsilon):
        epsilon = 1.00 / (episode_num+1)
        rand_value = rnd.random()
        action = None

        if use_epsilon and rand_value < epsilon:
            action = self.env.actionSample(s)
        else:
            Q_s = self.Q[s]
            action = max(Q_s, key=Q_s.get)

        return action

    # Agent依据当前策略和状态决定下一步的动作
    def performPolicy(self, s, episode_num, use_epsilon=False):
        return self.curPolicy(s, episode_num, use_epsilon)

    def act(self, a):
        return self.env.step(a)

    # SARSA(λ) learning
    def learning(self, lambda_, gamma, alpha, max_episode_num):
        total_time = 0
        time_in_episode = 0
        num_episode = 1
        while num_episode <= max_episode_num:
            self.state = self.env.reset()
            self.start = self.state
            self.resetEValue()

            s0 = self.state
            a0 = self.performPolicy(s0, num_episode)

            time_in_episode = 0
            is_done = False
            while not is_done:
                s1, r1, is_done = self.act(a0)
                print(a0, end="")
                
                a1= self.performPolicy(s1, num_episode)

                q = self.get(self.Q, s0, a0)
                q_prime = self.get(self.Q, s1, a1)
                delta = r1 + gamma * q_prime - q

                e = self.get(self.E, s0, a0)
                e = e + 1
                self.set(self.E, s0, a0, e) # set E before update E

                state_action_list = list(zip(self.E.keys(),self.E.values()))
                for s, a_es in state_action_list:
                    for a in self.env.actionSpace(s):
                        e_value = a_es[a]
                        old_q = self.get(self.Q, s, a)
                        new_q = old_q + alpha * delta * e_value
                        new_e = gamma * lambda_ * e_value
                        self.set(self.Q, s, a, new_q)
                        self.set(self.E, s, a, new_e)
                
                s0, a0 = s1, a1
                time_in_episode += 1

            print("
Episode {0} takes {1} steps.".format(
                num_episode, time_in_episode))
            total_time += time_in_episode
            num_episode += 1
        return


def main():
    env = Env()
    env.setConnection([[1],[0,2,4],[1,3,4],[2],[1,2]])
    agent = SarsaAgent(env)
    env.reset()

    print("Learning...")
    agent.learning(lambda_ = 0.01,
                   gamma=1.0,
                   alpha=0.1,
                   max_episode_num=50)


if __name__ == "__main__":
    main()

 

以上是关于LeetCode 847. Shortest Path Visiting All Nodes的强化学习解法的主要内容,如果未能解决你的问题,请参考以下文章

leetcode 847. Shortest Path Visiting All Nodes 无向连通图遍历最短路径

LeetCode 934. Shortest Bridge

LeetCode - 581. Shortest Unsorted Continuous Subarray

LeetCode 0214 Shortest Palindrome

leetcode 934. Shortest Bridge

LeetCode 847 访问所有节点的最短路径[BFS] HERODING的LeetCode之路