Q-learning简明实例Java代码实现

Posted coshaho

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Q-learning简明实例Java代码实现相关的知识,希望对你有一定的参考价值。

在《Q-learning简明实例》中我们介绍了Q-learning算法的简单例子,从中我们可以总结出Q-learning算法的基本思想

本次选择的经验得分 = 本次选择的反馈得分 + 本次选择后场景的历史最佳经验得分

其中反馈得分是单个步骤的价值分值(固定的分值),经验得分是完成目标的学习分值(动态的分值)。

简明实例的Java实现如下

package com.coshaho.learn.qlearning;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;

/**
 * 
 * QLearning.java Create on 2017年9月4日 下午10:08:49    
 *    
 * 类功能说明:   QLearning简明例子实现
 *
 * Copyright: Copyright(c) 2013 
 * Company: COSHAHO
 * @Version 1.0
 * @Author coshaho
 */
public class QLearning 
{
    FeedbackMatrix R = new FeedbackMatrix();
    
    ExperienceMatrix Q = new ExperienceMatrix();
    
    public static void main(String[] args)
    {
        QLearning ql = new QLearning();
        
        for(int i = 0; i < 500; i++)
        {
            Random random = new Random();
            int x = random.nextInt(100) % 6;
            
            System.out.println("第" + i + "次学习, 初始房间是" + x);
            ql.learn(x);
            System.out.println();
        }
    }
    
    public void learn(int x)
    {
        do
        {
            // 随机选择一个联通的房间进入
            int y =  chooseRandomRY(x);
            
            // 获取以进入的房间为起始点的历史最佳得分
            int qy = getMaxQY(y);
            
            // 计算此次移动的得分
            int value = calculateNewQ(x, y, qy);
            Q.set(x, y, value);
            x = y;
        }
        // 走出房间则学习结束
        while(5 != x);
        
        Q.print();
    }
    
    public int chooseRandomRY(int x)
    {
        int[] qRow = R.getRow(x);
        List<Integer> yValues = new ArrayList<Integer>();
        for(int i = 0; i < qRow.length; i++)
        {
            if(qRow[i] >= 0)
            {
                yValues.add(i);
            }
        }

        Random random = new Random();
        int i = random.nextInt(yValues.size()) % yValues.size();
        return yValues.get(i);
    }
    
    public int getMaxQY(int x)
    {
        int[] qRow = Q.getRow(x);
        int length = qRow.length;
        List<YAndValue> yValues = new ArrayList<YAndValue>();
        for(int i = 0; i < length; i++)
        {
            YAndValue yv = new YAndValue(i, qRow[i]);
            yValues.add(yv);
        }
        
        Collections.sort(yValues);
        int num = 1;
        int value = yValues.get(0).getValue();
        for(int i = 1; i < length; i++)
        {
            if(yValues.get(i).getValue() == value)
            {
                num = i + 1;
            }
            else
            {
                break;
            }
        }
        
        Random random = new Random();
        int i = random.nextInt(num) % num;
        return yValues.get(i).getY();
    }
    
    // Q(x,y) = R(x,y) + 0.8 * max(Q(y,i))
    public int calculateNewQ(int x, int y, int qy)
    {
        return (int) (R.get(x, y) + 0.8 * Q.get(y, qy));
    }
    
    public static class YAndValue implements Comparable<YAndValue>
    {
        int y;
        int value;
        
        public int getY() {
            return y;
        }
        public void setY(int y) {
            this.y = y;
        }
        public int getValue() {
            return value;
        }
        public void setValue(int value) {
            this.value = value;
        }
        public YAndValue(int y, int value)
        {
            this.y = y;
            this.value = value;
        }
        public int compareTo(YAndValue o) 
        {
            return o.getValue() - this.value;
        }
    }
}

package com.coshaho.learn.qlearning;

/**
 * 
 * FeedbackMatrix.java Create on 2017年9月4日 下午9:52:41    
 *    
 * 类功能说明:   反馈矩阵
 *
 * Copyright: Copyright(c) 2013 
 * Company: COSHAHO
 * @Version 1.0
 * @Author coshaho
 */
public class FeedbackMatrix 
{
    public int get(int x, int y)
    {
        return R[x][y];
    }
    
    public int[] getRow(int x)
    {
        return R[x];
    }
    
    private static int[][] R = new int[6][6];
    static 
    {
        R[0][0] = -1;
        R[0][1] = -1;
        R[0][2] = -1;
        R[0][3] = -1;
        R[0][4] = 0;
        R[0][5] = -1;
        
        R[1][0] = -1;
        R[1][1] = -1;
        R[1][2] = -1;
        R[1][3] = 0;
        R[1][4] = -1;
        R[1][5] = 100;
        
        R[2][0] = -1;
        R[2][1] = -1;
        R[2][2] = -1;
        R[2][3] = 0;
        R[2][4] = -1;
        R[2][5] = -1;
        
        R[3][0] = -1;
        R[3][1] = 0;
        R[3][2] = 0;
        R[3][3] = -1;
        R[3][4] = 0;
        R[3][5] = -1;
        
        R[4][0] = 0;
        R[4][1] = -1;
        R[4][2] = -1;
        R[4][3] = 0;
        R[4][4] = -1;
        R[4][5] = 100;
        
        R[5][0] = -1;
        R[5][1] = 0;
        R[5][2] = -1;
        R[5][3] = -1;
        R[5][4] = 0;
        R[5][5] = 100;
    }
}

package com.coshaho.learn.qlearning;

/**
 * 
 * ExperienceMatrix.java Create on 2017年9月4日 下午10:03:08    
 *    
 * 类功能说明:   经验矩阵
 *
 * Copyright: Copyright(c) 2013 
 * Company: COSHAHO
 * @Version 1.0
 * @Author coshaho
 */
public class ExperienceMatrix 
{
    public int get(int x, int y)
    {
        return Q[x][y];
    }
    
    public int[] getRow(int x)
    {
        return Q[x];
    }
    
    public void set(int x, int y, int value)
    {
        Q[x][y] = value;
    }
    
    public void print()
    {
        for(int i = 0; i < 6; i++)
        {
            for(int j = 0; j < 6; j++)
            {
                String s = Q[i][j] + "  ";
                if(Q[i][j] < 10)
                {
                    s = s + "  ";
                }
                else if(Q[i][j] < 100)
                {
                    s = s + " ";
                }
                System.out.print(s);
            }
            System.out.println();
        }
    }
    
    private static int[][] Q = new int[6][6];
    static
    {
        Q[0][0] = 0;
        Q[0][1] = 0;
        Q[0][2] = 0;
        Q[0][3] = 0;
        Q[0][4] = 0;
        Q[0][5] = 0;
        
        Q[1][0] = 0;
        Q[1][1] = 0;
        Q[1][2] = 0;
        Q[1][3] = 0;
        Q[1][4] = 0;
        Q[1][5] = 0;
        
        Q[2][0] = 0;
        Q[2][1] = 0;
        Q[2][2] = 0;
        Q[2][3] = 0;
        Q[2][4] = 0;
        Q[2][5] = 0;
        
        Q[3][0] = 0;
        Q[3][1] = 0;
        Q[3][2] = 0;
        Q[3][3] = 0;
        Q[3][4] = 0;
        Q[3][5] = 0;
        
        Q[4][0] = 0;
        Q[4][1] = 0;
        Q[4][2] = 0;
        Q[4][3] = 0;
        Q[4][4] = 0;
        Q[4][5] = 0;
        
        Q[5][0] = 0;
        Q[5][1] = 0;
        Q[5][2] = 0;
        Q[5][3] = 0;
        Q[5][4] = 0;
        Q[5][5] = 0;
    }
}

经过500次计算得到如下结果

第499次学习, 初始房间是1
0    0    0    0    396  0    
0    0    0    316  0    496  
0    0    0    316  0    0    
0    396  252  0    396  0    
316  0    0    316  0    496  
0    396  0    0    396  496  

此时,我们从任意一个房间进入,每次选取最高分值步骤移动,总可以找到最短的逃离路径。

以上是关于Q-learning简明实例Java代码实现的主要内容,如果未能解决你的问题,请参考以下文章

A Painless Q-learning Tutorial (一个 Q-learning 算法的简明教程)

一个 Q-learning 算法的简明教程

强化学习 5 —— SARSA 和 Q-Learning算法代码实现

Adaboost算法原理分析和实例+代码(简明易懂)

用java给html文件添加必要的控制html代码片

强化学习 Q-learning 实例详解