我对 Connect Four 的评估函数和 Alpha-beta 修剪的实现不够智能
Posted
技术标签:
【中文标题】我对 Connect Four 的评估函数和 Alpha-beta 修剪的实现不够智能【英文标题】:My implementation of the evaluation function and Alpha-beta pruning for Connect Four is not smart enough 【发布时间】:2019-11-05 08:33:33 【问题描述】:我正在尝试正确实现 Connect Four 游戏 AI,但我的 AI 行为很愚蠢:
它不会阻止可能导致AI失败的对面玩家模式, 它不会采取可能导致 AI 获胜的动作。我的项目包含以下两个 GitHub 存储库:
-
GameAI,
ConnectFour,
GameAI 包含:
SortingAlphaBetaPruningGameEngine
package net.coderodde.zerosum.ai.impl;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import net.coderodde.zerosum.ai.EvaluatorFunction;
import net.coderodde.zerosum.ai.GameEngine;
import net.coderodde.zerosum.ai.State;
/**
* This class implements the
* <a href="https://en.wikipedia.org/wiki/Minimax">Minimax</a> algorithm for
* zero-sum two-player games.
*
* @param <S> the game state type.
* @param <P> the player color type.
* @author Rodion "rodde" Efremov
* @version 1.6 (May 26, 2019)
*/
public final class SortingAlphaBetaPruningGameEngine
<S extends State<S>, P extends Enum<P>>
extends GameEngine<S, P>
/**
* Stores the terminal node or a node at the depth zero with the best value
* so far, which belongs to the maximizing player moves.
*/
private S bestTerminalMaximizingState;
/**
* Stores the value of @code bestTerminalMaximizingState.
*/
private double bestTerminalMaximizingStateValue;
/**
* Stores the terminal node or a node at the depth zero with the best value
* so far, which belongs to the minimizing player moves.
*/
private S bestTerminalMinimizingState;
/**
* Stores the value of @code bestTerminalMinimizingState.
*/
private double bestTerminalMinimizingStateValue;
/**
* Indicates whether we are computing a next ply for the minimizing player
* or not. If not, we are computing a next ply for the maximizing player.
*/
private boolean makingPlyForMinimizingPlayer;
/**
* Maps each visited state to its parent state.
*/
private final Map<S, S> parents = new HashMap<>();
/**
* Constructs this minimax game engine.
* @param evaluatorFunction the evaluator function.
* @param depth the search depth.
*/
public SortingAlphaBetaPruningGameEngine(
EvaluatorFunction<S> evaluatorFunction,
int depth)
super(evaluatorFunction, depth, Integer.MAX_VALUE);
/**
* @inheritDoc
*/
@Override
public S makePly(S state,
P minimizingPlayer,
P maximizingPlayer,
P initialPlayer)
// Reset the best known values:
bestTerminalMaximizingStateValue = Double.NEGATIVE_INFINITY;
bestTerminalMinimizingStateValue = Double.POSITIVE_INFINITY;
makingPlyForMinimizingPlayer = initialPlayer != minimizingPlayer;
// Do the game tree search:
makePlyImpl(state,
depth,
Double.NEGATIVE_INFINITY, // intial alpha
Double.POSITIVE_INFINITY, // intial beta
minimizingPlayer,
maximizingPlayer,
initialPlayer);
// Find the next game state starting from 'state':
S returnState =
inferBestState(
initialPlayer == minimizingPlayer ?
bestTerminalMinimizingState :
bestTerminalMaximizingState);
// Release the resources:
parents.clear();
bestTerminalMaximizingState = null;
bestTerminalMinimizingState = null;
// We are done with a single move:
return returnState;
private S inferBestState(S bestTerminalState)
List<S> statePath = new ArrayList<>();
S state = bestTerminalState;
while (state != null)
statePath.add(state);
state = parents.get(state);
if (statePath.size() == 1)
// The root node is terminal. Return null:
return null;
// Return the second upmost state:
Collections.<S>reverse(statePath);
return statePath.get(1);
/**
* Performs a single step down the game tree branch.
*
* @param state the starting state.
* @param depth the maximum depth of the game tree.
* @param minimizingPlayer the minimizing player.
* @param maximizingPlayer the maximizing player.
* @param currentPlayer the current player.
* @return the value of the best ply.
*/
private double makePlyImpl(S state,
int depth,
double alpha,
double beta,
P minimizingPlayer,
P maximizingPlayer,
P currentPlayer)
if (depth == 0 || state.isTerminal())
double value = evaluatorFunction.evaluate(state);
if (!makingPlyForMinimizingPlayer)
if (bestTerminalMinimizingStateValue > value)
bestTerminalMinimizingStateValue = value;
bestTerminalMinimizingState = state;
else
if (bestTerminalMaximizingStateValue < value)
bestTerminalMaximizingStateValue = value;
bestTerminalMaximizingState = state;
return value;
if (currentPlayer == maximizingPlayer)
double value = Double.NEGATIVE_INFINITY;
List<S> children = state.children();
children.sort((S a, S b) ->
double valueA = super.evaluatorFunction.evaluate(a);
double valueB = super.evaluatorFunction.evaluate(b);
return Double.compare(valueB, valueA);
);
for (S child : children)
value = Math.max(
value,
makePlyImpl(child,
depth - 1,
alpha,
beta,
minimizingPlayer,
maximizingPlayer,
minimizingPlayer));
parents.put(child, state);
alpha = Math.max(alpha, value);
if (alpha >= beta)
break;
return value;
else
// Here, 'initialPlayer == minimizingPlayer'.
double value = Double.POSITIVE_INFINITY;
List<S> children = state.children();
children.sort((S a, S b) ->
double valueA = super.evaluatorFunction.evaluate(a);
double valueB = super.evaluatorFunction.evaluate(b);
return Double.compare(valueA, valueB);
);
for (S child : children)
value = Math.min(
value,
makePlyImpl(child,
depth - 1,
alpha,
beta,
minimizingPlayer,
maximizingPlayer,
maximizingPlayer));
parents.put(child, state);
beta = Math.min(beta, value);
if (alpha >= beta)
break;
return value;
我有两个来自网络/我的头脑的评估功能。第一个(见下文)找到所有长度为 2、3 和 4 的模式,并将它们的出现次数乘以有利于它们中较长的常数。似乎没有工作。另一个维护一个整数矩阵;每个整数表示可能占用该整数槽的模式数。也没用。
BruteForceConnectFourStateEvaluatorFunction
package net.coderodde.games.connect.four.impl;
import net.coderodde.games.connect.four.ConnectFourState;
import net.coderodde.games.connect.four.PlayerColor;
import net.coderodde.zerosum.ai.EvaluatorFunction;
/**
* This class implements the default Connect Four state evaluator. The white
* player wants to maximize, the red player wants to minimize.
*
* @author Rodion "rodde" Efremov
* @version 1.6 (May 24, 2019)
*/
public final class BruteForceConnectFourStateEvaluatorFunction
implements EvaluatorFunction<ConnectFourState>
private static final double POSITIVE_WIN_VALUE = 1e9;
private static final double NEGATIVE_WIN_VALUE = -1e9;
private static final double POSITIVE_CLOSE_TO_WIN_VALUE = 1e6;
private static final double NEGATIVE_CLOSE_TO_WIN_VALUE = -1e6;
private static final double BASE_VALUE = 1e1;
/**
* The weight matrix. Maps each position to its weight. We need this in
* order to
*/
private final double[][] weightMatrix;
/**
* The winning length.
*/
private final int winningLength;
/**
* Constructs the default heuristic function for Connect Four game states.
*
* @param width the game board width.
* @param height the game board height.
* @param maxWeight the maximum weight in the weight matrix.
* @param winningPatternLength the winning pattern length.
*/
public BruteForceConnectFourStateEvaluatorFunction(final int width,
final int height,
final double maxWeight,
final int winningPatternLength)
this.weightMatrix = getWeightMatrix(width, height, maxWeight);
this.winningLength = winningPatternLength;
/**
* Evaluates the given input @code state and returns the estimate.
* @param state the state to estimate.
* @return the estimate.
*/
@Override
public double evaluate(ConnectFourState state)
PlayerColor winnerPlayerColor = state.checkVictory();
if (winnerPlayerColor == PlayerColor.MAXIMIZING_PLAYER)
return POSITIVE_WIN_VALUE - state.getDepth();
if (winnerPlayerColor == PlayerColor.MINIMIZING_PLAYER)
return NEGATIVE_WIN_VALUE + state.getDepth();
// 'minimizingPatternCounts[i]' gives the number of patterns of
// length 'i':
int[] minimizingPatternCounts = new int[state.getWinningLength() + 1];
int[] maximizingPatternCounts = new int[minimizingPatternCounts.length];
// Do not consider patterns of length one!
for (int targetLength = 2;
targetLength <= winningLength;
targetLength++)
int count = findMinimizingPatternCount(state, targetLength);
if (count == 0)
// Once here, it is not possible to find patterns of larger
// length than targetLength:
break;
minimizingPatternCounts[targetLength] = count;
for (int targetLength = 2;
targetLength <= state.getWinningLength();
targetLength++)
int count = findMaximizingPatternCount(state, targetLength);
if (count == 0)
// Once here, it is not possible to find patterns of larger
// length than targetLength:
break;
maximizingPatternCounts[targetLength] = count;
double score = computeBaseScore(minimizingPatternCounts,
maximizingPatternCounts);
score += computeAlmostFullPatternScores(state, winningLength);
return score + getWeights(weightMatrix, state);
private static final double
computeAlmostFullPatternScores(ConnectFourState state,
int winningLength)
final int targetLength = winningLength - 2;
double score = 0.0;
for (int y = state.getHeight() - 1; y >= 0; y--)
loop:
for (int x = 0; x < state.getWidth() - targetLength; x++)
if (state.readCell(x, y) == null)
// Try to find 'targetLength' marks:
PlayerColor targetPlayerColor = state.readCell(x + 1, y);
if (targetPlayerColor == null)
continue loop;
int currentLength = 1;
for (int xx = x + 1; xx < state.getWidth() - 1; xx++)
if (state.readCell(xx, y) == targetPlayerColor)
currentLength++;
if (currentLength == targetLength)
if (state.getPlayerColor() ==
PlayerColor.MINIMIZING_PLAYER)
score += NEGATIVE_CLOSE_TO_WIN_VALUE;
else
score += POSITIVE_CLOSE_TO_WIN_VALUE;
continue loop;
return score;
return score;
/**
* Finds the number of red patterns of length @code targetLength.
* @param state the target state.
* @param targetLength the length of the pattern to find.
* @return the number of red patterns of length @code targetLength.
*/
private static final int findMinimizingPatternCount(ConnectFourState state,
int targetLength)
return findPatternCount(state,
targetLength,
PlayerColor.MINIMIZING_PLAYER);
/**
* Finds the number of white patterns of length @code targetLength.
* @param state the target state.
* @param targetLength the length of the pattern to find.
* @return the number of white patterns of length @code targetLength.
*/
private static final int findMaximizingPatternCount(ConnectFourState state,
int targetLength)
return findPatternCount(state,
targetLength,
PlayerColor.MAXIMIZING_PLAYER);
/**
* Implements the target pattern counting function for both the player
* colors.
* @param state the state to search.
* @param targetLength the length of the patterns to count.
* @param playerColor the target player color.
* @return the number of patterns of length @code targetLength and color
* @code playerColor.
*/
private static final int findPatternCount(ConnectFourState state,
int targetLength,
PlayerColor playerColor)
int count = 0;
count += findHorizontalPatternCount(state,
targetLength,
playerColor);
count += findVerticalPatternCount(state,
targetLength,
playerColor);
count += findAscendingDiagonalPatternCount(state,
targetLength,
playerColor);
count += findDescendingDiagonalPatternCount(state,
targetLength,
playerColor);
return count;
/**
* Scans the input state for diagonal <b>descending</b> patterns and
* returns the number of such patterns.
* @param state the target state.
* @param patternLength the target pattern length.
* @param playerColor the target player color.
* @return the number of patterns.
*/
private static final int
findDescendingDiagonalPatternCount(ConnectFourState state,
int patternLength,
PlayerColor playerColor)
int patternCount = 0;
for (int y = 0; y < state.getWinningLength() - 1; y++)
inner:
for (int x = 0;
x <= state.getWidth() - state.getWinningLength();
x++)
for (int i = 0; i < patternLength; i++)
if (state.readCell(x + i, y + i) != playerColor)
continue inner;
patternCount++;
return patternCount;
/**
* Scans the input state for diagonal <b>ascending</b> patterns and returns
* the number of such patterns.
* @param state the target state.
* @param patternLength the target pattern length.
* @param playerColor the target player color.
* @return the number of patterns.
*/
private static final int
findAscendingDiagonalPatternCount(ConnectFourState state,
int patternLength,
PlayerColor playerColor)
int patternCount = 0;
for (int y = state.getHeight() - 1;
y > state.getHeight() - state.getWinningLength();
y--)
inner:
for (int x = 0;
x <= state.getWidth() - state.getWinningLength();
x++)
for (int i = 0; i < patternLength; i++)
if (state.readCell(x + i, y - i) != playerColor)
continue inner;
patternCount++;
return patternCount;
/**
* Scans the input state for diagonal <b>horizontal</b> patterns and returns
* the number of such patterns.
* @param state the target state.
* @param patternLength the target pattern length.
* @param playerColor the target player color.
* @return the number of patterns.
*/
private static final int findHorizontalPatternCount(
ConnectFourState state,
int patternLength,
PlayerColor playerColor)
int patternCount = 0;
for (int y = state.getHeight() - 1; y >= 0; y--)
inner:
for (int x = 0; x <= state.getWidth() - patternLength; x++)
if (state.readCell(x, y) == null)
continue inner;
for (int i = 0; i < patternLength; i++)
if (state.readCell(x + i, y) != playerColor)
continue inner;
patternCount++;
return patternCount;
/**
* Scans the input state for diagonal <b>vertical</b> patterns and returns
* the number of such patterns.
* @param state the target state.
* @param patternLength the target pattern length.
* @param playerColor the target player color.
* @return the number of patterns.
*/
private static final int findVerticalPatternCount(ConnectFourState state,
int patternLength,
PlayerColor playerColor)
int patternCount = 0;
outer:
for (int x = 0; x < state.getWidth(); x++)
inner:
for (int y = state.getHeight() - 1;
y > state.getHeight() - state.getWinningLength();
y--)
if (state.readCell(x, y) == null)
continue outer;
for (int i = 0; i < patternLength; i++)
if (state.readCell(x, y - i) != playerColor)
continue inner;
patternCount++;
return patternCount;
/**
* Gets the state weight. We use this in order to discourage the positions
* that are close to borders/far away from the center of the game board.
* @param weightMatrix the weighting matrix.
* @param state the state to weight.
* @return the state weight.
*/
private static final double getWeights(final double[][] weightMatrix,
final ConnectFourState state)
double score = 0.0;
outer:
for (int x = 0; x < state.getWidth(); x++)
for (int y = state.getHeight() - 1; y >= 0; y--)
PlayerColor playerColor = state.readCell(x, y);
if (playerColor == null)
continue outer;
if (playerColor == PlayerColor.MINIMIZING_PLAYER)
score -= weightMatrix[y][x];
else
score += weightMatrix[y][x];
return score;
/**
* Computes the base scorer that relies on number of patterns. For example,
* @code redPatternCounts[i] will denote the number of patterns of length
* [@code i.
* @param minimizingPatternCounts the pattern count map for red patterns.
* @param maximizingPatternCounts the pattern count map for white patterns.
* @return the base estimate.
*/
private static final double computeBaseScore(
int[] minimizingPatternCounts,
int[] maximizingPatternCounts)
final int winningLength = minimizingPatternCounts.length - 1;
double value = 0.0;
if (minimizingPatternCounts[winningLength] != 0)
value = NEGATIVE_WIN_VALUE;
if (maximizingPatternCounts[winningLength] != 0)
value = POSITIVE_WIN_VALUE;
for (int length = 2; length < minimizingPatternCounts.length; length++)
int minimizingCount = minimizingPatternCounts[length];
value -= minimizingCount * Math.pow(BASE_VALUE, length);
int maximizingCount = maximizingPatternCounts[length];
value += maximizingCount * Math.pow(BASE_VALUE, length);
return value;
/**
* Computes the weight matrix. The closer the entry in the board is to the
* center of the board, the closer the weight of that position will be to
* @code maxWeight.
*
* @param width the width of the matrix.
* @param height the height of the matrix.
* @param maxWeight the maximum weight. The minimum weight will be always
* 1.0.
* @return the weight matrix.
*/
private static final double[][] getWeightMatrix(final int width,
final int height,
final double maxWeight)
final double[][] weightMatrix = new double[height][width];
for (int y = 0; y < weightMatrix.length; y++)
for (int x = 0; x < weightMatrix[0].length; x++)
int left = x;
int right = weightMatrix[0].length - x - 1;
int top = y;
int bottom = weightMatrix.length - y - 1;
int horizontalDifference = Math.abs(left - right);
int verticalDifference = Math.abs(top - bottom);
weightMatrix[y][x] =
1.0 + (maxWeight - 1.0) /
(horizontalDifference + verticalDifference);
return weightMatrix;
WeightMatrixConnectFourStateEvaluatorFunction
package net.coderodde.games.connect.four.impl;
import net.coderodde.games.connect.four.ConnectFourState;
import net.coderodde.games.connect.four.PlayerColor;
import net.coderodde.zerosum.ai.EvaluatorFunction;
/**
* This evaluation function relies on a weight matrix that reflects how many
* patterns visit each matrix position.
*
* @author Rodion "rodde" Efremov
* @version 1.6 (Jun 19, 2019)
*/
public class WeightMatrixConnectFourStateEvaluatorFunction implements EvaluatorFunction<ConnectFourState>
private final double[][] matrix;
public WeightMatrixConnectFourStateEvaluatorFunction()
this.matrix = new double[][] 3, 4, 5, 7, 5, 4, 3,
4, 6, 8, 10, 8, 6, 4,
5, 8, 11, 13, 11, 8, 5,
5, 8, 11, 13, 11, 8, 5,
4, 6, 8, 10, 8, 6, 4,
3, 4, 5, 7, 5, 4, 3;
@Override
public double evaluate(ConnectFourState state)
PlayerColor winner = state.checkVictory();
if (winner == PlayerColor.MINIMIZING_PLAYER)
return -1e6;
if (winner == PlayerColor.MAXIMIZING_PLAYER)
return 1e6;
double sum = 0.0;
for (int y = 0; y < state.getHeight(); y++)
for (int x = 0; x < state.getWidth(); x++)
if (state.readCell(x, y) == PlayerColor.MAXIMIZING_PLAYER)
sum += matrix[y][x];
else if (state.readCell(x, y) ==
PlayerColor.MINIMIZING_PLAYER)
sum -= matrix[y][x];
return sum;
我完全不知道为什么这两个评估功能都无法提供智能游戏。有什么建议吗?
【问题讨论】:
您是否测试过您的 alpha-beta 引擎是否适用于其他类型的游戏?那么可能您的评估功能就是问题所在。否则引擎本身也可能存在错误。我只是想知道先看哪里。 如果有超过 2 个 bug,而你修复了大约 5 个,甚至不要考虑修复下一个。你所做的只是让它变得更复杂。将文件复制到另一个目录。然后从项目中取出一些东西,直到问题消失。然后你添加除了最后一点之外的所有内容。这应该可以解决问题。 建议:开始编写回归测试。您的 GitHub 项目测试几乎为零,即使是小错误也需要很长时间才能调试。 我会在几个小时后看看这个。 【参考方案1】:在这种情况下的输赢动作不是启发式函数,它们是二元是/否离散答案。对于像connect 4这样的简单游戏,您不应该启发式地对待它们。您测试每一步“这会赢吗?” (如果是这样的话)。如果不是,请测试每一步“这是否会阻止其他玩家在下一步中获胜?” (再次,如果是这样,就这样做)。之后,您应用启发式方法来找到可用的最佳移动。
我怀疑您遇到了诸如“在角落中获胜的举动(3 值)永远不会击败在中间的失败举动(13 值)”之类的问题。
【讨论】:
以上是关于我对 Connect Four 的评估函数和 Alpha-beta 修剪的实现不够智能的主要内容,如果未能解决你的问题,请参考以下文章
错误:未定义不是对象(评估 \'RCTWebSocketManager.connect\')
全面3D钢结构桥梁和评估应用LEAP Bridge Steel CONNECT Edition 16.01.00.05 1CD
如何为数组中的每个项目评估包含 settimeout 的函数(Javascript)