我对 Connect Four 的评估函数和 Alpha-beta 修剪的实现不够智能

Posted

技术标签:

【中文标题】我对 Connect Four 的评估函数和 Alpha-beta 修剪的实现不够智能【英文标题】:My implementation of the evaluation function and Alpha-beta pruning for Connect Four is not smart enough 【发布时间】:2019-11-05 08:33:33 【问题描述】:

我正在尝试正确实现 Connect Four 游戏 AI,但我的 AI 行为很愚蠢:

它不会阻止可能导致AI失败的对面玩家模式, 它不会采取可能导致 AI 获胜的动作。

我的项目包含以下两个 GitHub 存储库:

    GameAI, ConnectFour,

GameAI 包含:

SortingAlphaBetaPruningGameEngine

package net.coderodde.zerosum.ai.impl;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import net.coderodde.zerosum.ai.EvaluatorFunction;
import net.coderodde.zerosum.ai.GameEngine;
import net.coderodde.zerosum.ai.State;

/**
 * This class implements the 
 * <a href="https://en.wikipedia.org/wiki/Minimax">Minimax</a> algorithm for 
 * zero-sum two-player games.
 * 
 * @param <S> the game state type.
 * @param <P> the player color type.
 * @author Rodion "rodde" Efremov
 * @version 1.6 (May 26, 2019)
 */
public final class SortingAlphaBetaPruningGameEngine
        <S extends State<S>, P extends Enum<P>> 
        extends GameEngine<S, P> 

    /**
     * Stores the terminal node or a node at the depth zero with the best value
     * so far, which belongs to the maximizing player moves.
     */
    private S bestTerminalMaximizingState;

    /**
     * Stores the value of @code bestTerminalMaximizingState.
     */
    private double bestTerminalMaximizingStateValue;

    /**
     * Stores the terminal node or a node at the depth zero with the best value
     * so far, which belongs to the minimizing player moves.
     */
    private S bestTerminalMinimizingState;

    /**
     * Stores the value of @code bestTerminalMinimizingState.
     */
    private double bestTerminalMinimizingStateValue;

    /**
     * Indicates whether we are computing a next ply for the minimizing player 
     * or not. If not, we are computing a next ply for the maximizing player.
     */
    private boolean makingPlyForMinimizingPlayer;

    /**
     * Maps each visited state to its parent state.
     */
    private final Map<S, S> parents = new HashMap<>();

    /**
     * Constructs this minimax game engine.
     * @param evaluatorFunction the evaluator function.
     * @param depth the search depth.
     */
    public SortingAlphaBetaPruningGameEngine(
            EvaluatorFunction<S> evaluatorFunction,
            int depth) 
        super(evaluatorFunction, depth, Integer.MAX_VALUE);
    

    /**
     * @inheritDoc 
     */
    @Override
    public S makePly(S state, 
                     P minimizingPlayer,
                     P maximizingPlayer,
                     P initialPlayer) 
        // Reset the best known values:
        bestTerminalMaximizingStateValue = Double.NEGATIVE_INFINITY;
        bestTerminalMinimizingStateValue = Double.POSITIVE_INFINITY;
        makingPlyForMinimizingPlayer = initialPlayer != minimizingPlayer;

        // Do the game tree search:
        makePlyImpl(state,
                    depth,
                    Double.NEGATIVE_INFINITY, // intial alpha
                    Double.POSITIVE_INFINITY, // intial beta
                    minimizingPlayer,
                    maximizingPlayer,
                    initialPlayer);

        // Find the next game state starting from 'state':
        S returnState =
                inferBestState(
                        initialPlayer == minimizingPlayer ? 
                                bestTerminalMinimizingState : 
                                bestTerminalMaximizingState);

        // Release the resources:
        parents.clear();
        bestTerminalMaximizingState = null;
        bestTerminalMinimizingState = null;
        // We are done with a single move:
        return returnState;
    

    private S inferBestState(S bestTerminalState) 
        List<S> statePath = new ArrayList<>();
        S state = bestTerminalState;

        while (state != null) 
            statePath.add(state);
            state = parents.get(state);
        

        if (statePath.size() == 1) 
            // The root node is terminal. Return null:
            return null;
        

        // Return the second upmost state:
        Collections.<S>reverse(statePath);
        return statePath.get(1);
    

    /**
     * Performs a single step down the game tree branch.
     * 
     * @param state the starting state.
     * @param depth the maximum depth of the game tree.
     * @param minimizingPlayer the minimizing player.
     * @param maximizingPlayer the maximizing player.
     * @param currentPlayer the current player.
     * @return the value of the best ply.
     */
    private double makePlyImpl(S state,
                               int depth,
                               double alpha,
                               double beta,
                               P minimizingPlayer,
                               P maximizingPlayer,
                               P currentPlayer) 
        if (depth == 0 || state.isTerminal()) 
            double value = evaluatorFunction.evaluate(state);

            if (!makingPlyForMinimizingPlayer) 
                if (bestTerminalMinimizingStateValue > value) 
                    bestTerminalMinimizingStateValue = value;
                    bestTerminalMinimizingState = state;
                
             else 
                if (bestTerminalMaximizingStateValue < value) 
                    bestTerminalMaximizingStateValue = value;
                    bestTerminalMaximizingState = state;
                
            

            return value;
        

        if (currentPlayer == maximizingPlayer) 
            double value = Double.NEGATIVE_INFINITY;
            List<S> children = state.children();
            children.sort((S a, S b) -> 
                double valueA = super.evaluatorFunction.evaluate(a);
                double valueB = super.evaluatorFunction.evaluate(b);
                return Double.compare(valueB, valueA);
            );

            for (S child : children) 
                value = Math.max(
                        value, 
                        makePlyImpl(child, 
                                    depth - 1, 
                                    alpha,
                                    beta,
                                    minimizingPlayer, 
                                    maximizingPlayer, 
                                    minimizingPlayer));

                parents.put(child, state);
                alpha = Math.max(alpha, value);

                if (alpha >= beta) 
                    break;
                
            

            return value;
         else 
            // Here, 'initialPlayer == minimizingPlayer'.
            double value = Double.POSITIVE_INFINITY;
            List<S> children = state.children();
            children.sort((S a, S b) -> 
                double valueA = super.evaluatorFunction.evaluate(a);
                double valueB = super.evaluatorFunction.evaluate(b);
                return Double.compare(valueA, valueB);
            );

            for (S child : children) 
                value = Math.min(
                        value,
                        makePlyImpl(child, 
                                    depth - 1,
                                    alpha,
                                    beta,
                                    minimizingPlayer, 
                                    maximizingPlayer, 
                                    maximizingPlayer));

                parents.put(child, state);
                beta = Math.min(beta, value);

                if (alpha >= beta) 
                    break;
                
            

            return value;
        
    

我有两个来自网络/我的头脑的评估功能。第一个(见下文)找到所有长度为 2、3 和 4 的模式,并将它们的出现次数乘以有利于它们中较长的常数。似乎没有工作。另一个维护一个整数矩阵;每个整数表示可能占用该整数槽的模式数。也没用。

BruteForceConnectFourStateEvaluatorFunction

package net.coderodde.games.connect.four.impl;

import net.coderodde.games.connect.four.ConnectFourState;
import net.coderodde.games.connect.four.PlayerColor;
import net.coderodde.zerosum.ai.EvaluatorFunction;

/**
 * This class implements the default Connect Four state evaluator. The white 
 * player wants to maximize, the red player wants to minimize.
 * 
 * @author Rodion "rodde" Efremov
 * @version 1.6 (May 24, 2019)
 */
public final class BruteForceConnectFourStateEvaluatorFunction
        implements EvaluatorFunction<ConnectFourState> 

    private static final double POSITIVE_WIN_VALUE = 1e9;
    private static final double NEGATIVE_WIN_VALUE = -1e9;
    private static final double POSITIVE_CLOSE_TO_WIN_VALUE = 1e6;
    private static final double NEGATIVE_CLOSE_TO_WIN_VALUE = -1e6;
    private static final double BASE_VALUE = 1e1;

    /**
     * The weight matrix. Maps each position to its weight. We need this in 
     * order to 
     */
    private final double[][] weightMatrix;

    /**
     * The winning length.
     */
    private final int winningLength;

    /**
     * Constructs the default heuristic function for Connect Four game states.
     * 
     * @param width the game board width.
     * @param height the game board height.
     * @param maxWeight the maximum weight in the weight matrix.
     * @param winningPatternLength the winning pattern length.
     */
    public BruteForceConnectFourStateEvaluatorFunction(final int width,
                                             final int height,
                                             final double maxWeight,
                                             final int winningPatternLength) 
        this.weightMatrix = getWeightMatrix(width, height, maxWeight);
        this.winningLength = winningPatternLength;
    

    /**
     * Evaluates the given input @code state and returns the estimate.
     * @param state the state to estimate.
     * @return the estimate.
     */
    @Override
    public double evaluate(ConnectFourState state) 
        PlayerColor winnerPlayerColor = state.checkVictory();

        if (winnerPlayerColor == PlayerColor.MAXIMIZING_PLAYER) 
            return POSITIVE_WIN_VALUE - state.getDepth();
        

        if (winnerPlayerColor == PlayerColor.MINIMIZING_PLAYER) 
            return NEGATIVE_WIN_VALUE + state.getDepth();
        

        // 'minimizingPatternCounts[i]' gives the number of patterns of 
        // length 'i':
        int[] minimizingPatternCounts = new int[state.getWinningLength() + 1];
        int[] maximizingPatternCounts = new int[minimizingPatternCounts.length];

        // Do not consider patterns of length one!
        for (int targetLength = 2; 
                targetLength <= winningLength; 
                targetLength++) 
            int count = findMinimizingPatternCount(state, targetLength);

            if (count == 0) 
                // Once here, it is not possible to find patterns of larger 
                // length than targetLength:
                break;
            

            minimizingPatternCounts[targetLength] = count;
        

        for (int targetLength = 2;
                targetLength <= state.getWinningLength();
                targetLength++) 
            int count = findMaximizingPatternCount(state, targetLength);

            if (count == 0) 
                // Once here, it is not possible to find patterns of larger
                // length than targetLength:
                break;
            

            maximizingPatternCounts[targetLength] = count;
        

        double score = computeBaseScore(minimizingPatternCounts, 
                                        maximizingPatternCounts);
        score += computeAlmostFullPatternScores(state, winningLength);
        return score + getWeights(weightMatrix, state);
    

    private static final double 
        computeAlmostFullPatternScores(ConnectFourState state,
                                       int winningLength) 
        final int targetLength = winningLength - 2;
        double score = 0.0;

        for (int y = state.getHeight() - 1; y >= 0; y--) 
            loop:
            for (int x = 0; x < state.getWidth() - targetLength; x++) 
                if (state.readCell(x, y) == null) 
                    // Try to find 'targetLength' marks:
                    PlayerColor targetPlayerColor = state.readCell(x + 1, y);

                    if (targetPlayerColor == null) 
                        continue loop;
                    

                    int currentLength = 1;

                    for (int xx = x + 1; xx < state.getWidth() - 1; xx++) 
                        if (state.readCell(xx, y) == targetPlayerColor) 
                            currentLength++;

                            if (currentLength == targetLength) 
                                if (state.getPlayerColor() ==
                                        PlayerColor.MINIMIZING_PLAYER) 
                                    score += NEGATIVE_CLOSE_TO_WIN_VALUE;
                                 else 
                                    score += POSITIVE_CLOSE_TO_WIN_VALUE;
                                

                                continue loop;
                            
                        
                    
                
            

            return score;
        

        return score;
    

    /**
     * Finds the number of red patterns of length @code targetLength.
     * @param state the target state.
     * @param targetLength the length of the pattern to find.
     * @return the number of red patterns of length @code targetLength.
     */
    private static final int findMinimizingPatternCount(ConnectFourState state,
                                                        int targetLength) 
        return findPatternCount(state, 
                                targetLength, 
                                PlayerColor.MINIMIZING_PLAYER);
    

    /**
     * Finds the number of white patterns of length @code targetLength. 
     * @param state the target state.
     * @param targetLength the length of the pattern to find.
     * @return the number of white patterns of length @code targetLength.
     */
    private static final int findMaximizingPatternCount(ConnectFourState state,
                                                   int targetLength) 
        return findPatternCount(state,
                                targetLength, 
                                PlayerColor.MAXIMIZING_PLAYER);
    

    /**
     * Implements the target pattern counting function for both the player 
     * colors.
     * @param state the state to search.
     * @param targetLength the length of the patterns to count.
     * @param playerColor the target player color.
     * @return the number of patterns of length @code targetLength and color
     * @code playerColor.
     */
    private static final int findPatternCount(ConnectFourState state,
                                              int targetLength,
                                              PlayerColor playerColor) 
        int count = 0;

        count += findHorizontalPatternCount(state, 
                                            targetLength, 
                                            playerColor);

        count += findVerticalPatternCount(state, 
                                          targetLength, 
                                          playerColor);

        count += findAscendingDiagonalPatternCount(state, 
                                                   targetLength,
                                                   playerColor);

        count += findDescendingDiagonalPatternCount(state, 
                                                    targetLength,
                                                    playerColor);
        return count;
    

    /**
     * Scans the input state for diagonal <b>descending</b> patterns and 
     * returns the number of such patterns.
     * @param state the target state.
     * @param patternLength the target pattern length.
     * @param playerColor the target player color.
     * @return the number of patterns.
     */
    private static final int 
        findDescendingDiagonalPatternCount(ConnectFourState state,
                                           int patternLength,
                                           PlayerColor playerColor) 
        int patternCount = 0;

        for (int y = 0; y < state.getWinningLength() - 1; y++) 
            inner:
            for (int x = 0;
                    x <= state.getWidth() - state.getWinningLength(); 
                    x++) 
                for (int i = 0; i < patternLength; i++) 
                    if (state.readCell(x + i, y + i) != playerColor) 
                        continue inner;
                    
                

                patternCount++;
            
        

        return patternCount;
    

    /**
     * Scans the input state for diagonal <b>ascending</b> patterns and returns
     * the number of such patterns.
     * @param state the target state.
     * @param patternLength the target pattern length.
     * @param playerColor the target player color.
     * @return the number of patterns.
     */
    private static final int 
        findAscendingDiagonalPatternCount(ConnectFourState state,
                                          int patternLength,
                                          PlayerColor playerColor) 
        int patternCount = 0;

        for (int y = state.getHeight() - 1;
                y > state.getHeight() - state.getWinningLength();
                y--) 

            inner:
            for (int x = 0; 
                    x <= state.getWidth() - state.getWinningLength();
                    x++) 
                for (int i = 0; i < patternLength; i++) 
                    if (state.readCell(x + i, y - i) != playerColor) 
                        continue inner;
                    
                

                patternCount++;
            
        

        return patternCount;
     

    /**
     * Scans the input state for diagonal <b>horizontal</b> patterns and returns
     * the number of such patterns.
     * @param state the target state.
     * @param patternLength the target pattern length.
     * @param playerColor the target player color.
     * @return the number of patterns.
     */
    private static final int findHorizontalPatternCount(
            ConnectFourState state,
            int patternLength,
            PlayerColor playerColor) 
        int patternCount = 0;

        for (int y = state.getHeight() - 1; y >= 0; y--) 

            inner:
            for (int x = 0; x <= state.getWidth() - patternLength; x++) 
                if (state.readCell(x, y) == null) 
                    continue inner;
                

                for (int i = 0; i < patternLength; i++) 
                    if (state.readCell(x + i, y) != playerColor) 
                        continue inner;
                    
                

                patternCount++;
            
        

        return patternCount;
    

    /**
     * Scans the input state for diagonal <b>vertical</b> patterns and returns
     * the number of such patterns.
     * @param state the target state.
     * @param patternLength the target pattern length.
     * @param playerColor the target player color.
     * @return the number of patterns.
     */
    private static final int findVerticalPatternCount(ConnectFourState state,
                                                      int patternLength,
                                                      PlayerColor playerColor) 
        int patternCount = 0;

        outer:
        for (int x = 0; x < state.getWidth(); x++) 
            inner:
            for (int y = state.getHeight() - 1;
                    y > state.getHeight() - state.getWinningLength(); 
                    y--) 
                if (state.readCell(x, y) == null) 
                    continue outer;
                

                for (int i = 0; i < patternLength; i++) 
                    if (state.readCell(x, y - i) != playerColor) 
                        continue inner;
                    
                

                patternCount++;
            
        

        return patternCount;
    

    /**
     * Gets the state weight. We use this in order to discourage the positions
     * that are close to borders/far away from the center of the game board.
     * @param weightMatrix the weighting matrix.
     * @param state the state to weight.
     * @return the state weight.
     */
    private static final double getWeights(final double[][] weightMatrix,
                                           final ConnectFourState state) 
        double score = 0.0;

        outer:
        for (int x = 0; x < state.getWidth(); x++) 
            for (int y = state.getHeight() - 1; y >= 0; y--) 
                PlayerColor playerColor = state.readCell(x, y);

                if (playerColor == null) 
                    continue outer;
                

                if (playerColor == PlayerColor.MINIMIZING_PLAYER) 
                    score -= weightMatrix[y][x];
                 else 
                    score += weightMatrix[y][x];
                
            
        

        return score;
    

    /**
     * Computes the base scorer that relies on number of patterns. For example,
     * @code redPatternCounts[i] will denote the number of patterns of length 
     * [@code i.
     * @param minimizingPatternCounts the pattern count map for red patterns.
     * @param maximizingPatternCounts the pattern count map for white patterns.
     * @return the base estimate.
     */
    private static final double computeBaseScore(
            int[] minimizingPatternCounts,
            int[] maximizingPatternCounts) 
        final int winningLength = minimizingPatternCounts.length - 1;

        double value = 0.0;

        if (minimizingPatternCounts[winningLength] != 0) 
            value = NEGATIVE_WIN_VALUE;
        

        if (maximizingPatternCounts[winningLength] != 0) 
            value = POSITIVE_WIN_VALUE;
        

        for (int length = 2; length < minimizingPatternCounts.length; length++) 
            int minimizingCount = minimizingPatternCounts[length];
            value -= minimizingCount * Math.pow(BASE_VALUE, length);

            int maximizingCount = maximizingPatternCounts[length];
            value += maximizingCount * Math.pow(BASE_VALUE, length);
        

        return value;
    

    /**
     * Computes the weight matrix. The closer the entry in the board is to the
     * center of the board, the closer the weight of that position will be to
     * @code maxWeight.
     * 
     * @param width the width of the matrix.
     * @param height the height of the matrix.
     * @param maxWeight the maximum weight. The minimum weight will be always
     * 1.0.
     * @return the weight matrix. 
     */
    private static final double[][] getWeightMatrix(final int width,
                                                    final int height,
                                                    final double maxWeight) 
        final double[][] weightMatrix = new double[height][width];

        for (int y = 0; y < weightMatrix.length; y++) 
            for (int x = 0; x < weightMatrix[0].length; x++) 
                int left = x;
                int right = weightMatrix[0].length - x - 1;
                int top = y;
                int bottom = weightMatrix.length - y - 1;
                int horizontalDifference = Math.abs(left - right);
                int verticalDifference = Math.abs(top - bottom);
                weightMatrix[y][x] =
                        1.0 + (maxWeight - 1.0) / 
                              (horizontalDifference + verticalDifference);
            
        

        return weightMatrix;
    

WeightMatrixConnectFourStateEvaluatorFunction

package net.coderodde.games.connect.four.impl;

import net.coderodde.games.connect.four.ConnectFourState;
import net.coderodde.games.connect.four.PlayerColor;
import net.coderodde.zerosum.ai.EvaluatorFunction;

/**
 * This evaluation function relies on a weight matrix that reflects how many
 * patterns visit each matrix position.
 * 
 * @author Rodion "rodde" Efremov
 * @version 1.6 (Jun 19, 2019)
 */
public class WeightMatrixConnectFourStateEvaluatorFunction implements EvaluatorFunction<ConnectFourState> 

    private final double[][] matrix;

    public WeightMatrixConnectFourStateEvaluatorFunction() 
        this.matrix =  new double[][] 3, 4,  5,  7,  5, 4, 3, 
                                       4, 6,  8, 10,  8, 6, 4,
                                       5, 8, 11, 13, 11, 8, 5, 
                                       5, 8, 11, 13, 11, 8, 5,
                                       4, 6,  8, 10,  8, 6, 4,
                                       3, 4,  5,  7,  5, 4, 3;
    

    @Override
    public double evaluate(ConnectFourState state) 
        PlayerColor winner = state.checkVictory();

        if (winner == PlayerColor.MINIMIZING_PLAYER) 
            return -1e6;
        

        if (winner == PlayerColor.MAXIMIZING_PLAYER) 
            return 1e6;
        

        double sum = 0.0;

        for (int y = 0; y < state.getHeight(); y++) 
            for (int x = 0; x < state.getWidth(); x++) 
                if (state.readCell(x, y) == PlayerColor.MAXIMIZING_PLAYER) 
                    sum += matrix[y][x];
                 else if (state.readCell(x, y) ==
                        PlayerColor.MINIMIZING_PLAYER) 
                    sum -= matrix[y][x];
                
            
        

        return sum;
    

我完全不知道为什么这两个评估功能都无法提供智能游戏。有什么建议吗?

【问题讨论】:

您是否测试过您的 alpha-beta 引擎是否适用于其他类型的游戏?那么可能您的评估功能就是问题所在。否则引擎本身也可能存在错误。我只是想知道先看哪里。 如果有超过 2 个 bug,而你修复了大约 5 个,甚至不要考虑修复下一个。你所做的只是让它变得更复杂。将文件复制到另一个目录。然后从项目中取出一些东西,直到问题消失。然后你添加除了最后一点之外的所有内容。这应该可以解决问题。 建议:开始编写回归测试。您的 GitHub 项目测试几乎为零,即使是小错误也需要很长时间才能调试。 我会在几个小时后看看这个。 【参考方案1】:

在这种情况下的输赢动作不是启发式函数,它们是二元是/否离散答案。对于像connect 4这样的简单游戏,您不应该启发式地对待它们。您测试每一步“这会赢吗?” (如果是这样的话)。如果不是,请测试每一步“这是否会阻止其他玩家在下一步中获胜?” (再次,如果是这样,就这样做)。之后,您应用启发式方法来找到可用的最佳移动。

我怀疑您遇到了诸如“在角落中获胜的举动(3 值)永远不会击败在中间的失败举动(13 值)”之类的问题。

【讨论】:

以上是关于我对 Connect Four 的评估函数和 Alpha-beta 修剪的实现不够智能的主要内容,如果未能解决你的问题,请参考以下文章

错误:未定义不是对象(评估 \'RCTWebSocketManager.connect\')

全面3D钢结构桥梁和评估应用LEAP Bridge Steel CONNECT Edition 16.01.00.05 1CD

R中的标准评估和非标准评估

如何为数组中的每个项目评估包含 settimeout 的函数(Javascript)

如何在 Python 中使用带有 Keras 的 scikit-learn 评估指标函数?

VueJS - Vee-Validate:调用方法/函数来评估结果的自定义规则?