golang蒙特卡洛树算法实现五子棋AI

Posted janbar

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了golang蒙特卡洛树算法实现五子棋AI相关的知识,希望对你有一定的参考价值。

已经实现蒙特卡洛树算法的通用逻辑,只需要对应结构体实现相关接口就可以直接使用该算法。

优化算法主要优化GetActions生成下一步动作,要尽可能少,去掉无意义的动作。

以及优化ActionPolicy从众多动作挑选比较优秀的动作。对应五子棋就是执行该动作后当前局面评分最高。

package main

import (
	"fmt"
	"math"
	"math/rand"
	"strings"
	"time"
)

func main() 
	var (
		board = NewQuZiQi(15)
		x, y  int
	)

	board.Print()
	for board.IsTerminal() == 0 
		board = Search(time.Second*10, board).(*WuZiQi)

		board.Print()
		if board.IsTerminal() == 1 
			fmt.Println("电脑赢了")
			return
		

		for 
			fmt.Print("轮到您执棋,请输入坐标: ")
			_, _ = fmt.Scanln(&x, &y)
			x--
			y--
			if x < 0 || y < 0 || x >= board.size || y >= board.size 
				fmt.Println("您输入的数据超出棋盘范围")
			 else if board.board[x][y] > 0 
				fmt.Println("该位置已有棋子")
			 else 
				board.board[x][y] = 2
				board.player = 1 // 下一步该电脑下
				break
			
		

		board.Print()
		if board.IsTerminal() == 2 
			fmt.Println("你赢了")
			return
		
	


// WuZiQi 五子棋游戏
type WuZiQi struct 
	size   int     // 棋盘大小
	board  [][]int // 棋盘状态
	player int     // 1: 电脑落子,2: 玩家落子


func NewQuZiQi(size int) *WuZiQi 
	w := &WuZiQi
		size:   size,
		board:  make([][]int, size),
		player: 1,
	
	for i := 0; i < size; i++ 
		w.board[i] = make([]int, size)
	
	size /= 2
	// 默认中间落一个棋子
	// 0: 表示没有落子,1: 表示电脑,2: 表示玩家
	w.board[size][size] = 2
	return w


func (w *WuZiQi) Print() 
	var (
		str strings.Builder
		num = func(n int) 
			a, b := n/10, n%10
			if a > 0 
				str.WriteByte(byte(a + \'0\'))
			 else 
				str.WriteByte(\' \') // 1位数前面加空格
			
			str.WriteByte(byte(b + \'0\'))
		
	)
	str.WriteString("   ")
	for i := 1; i <= w.size; i++ 
		str.WriteByte(\' \')
		num(i)
	
	str.WriteByte(\'\\n\')
	for i := 0; i < w.size; i++ 
		str.WriteString("   ")
		for j := 0; j < w.size; j++ 
			str.WriteString(" __")
		

		str.WriteByte(\'\\n\')
		num(i + 1)
		str.WriteByte(\' \')

		for j := 0; j < w.size; j++ 
			str.WriteByte(\'|\')
			switch w.board[i][j] 
			case 0:
				str.WriteByte(\' \')
			case 1:
				str.WriteByte(\'O\')
			case 2:
				str.WriteByte(\'X\')
			
			str.WriteByte(\' \')
		
		str.WriteString("|\\n")
	
	str.WriteString("   ")
	for i := 0; i < w.size; i++ 
		str.WriteString(" __")
	
	fmt.Println(str.String())


func (w *WuZiQi) IsTerminal() int 
	full := -1 // 没有空位且都没赢
	for i := 0; i < w.size; i++ 
		for j := 0; j < w.size; j++ 
			if wc := w.board[i][j]; wc == 0 
				full = 0 // 还有空位,没结束
			 else 
				// 向右
				cnt, x, y := 1, 0, j+1
				for ; y < w.size && w.board[i][y] == wc; y++ 
					cnt++
				
				if cnt >= 5 
					return wc
				
				// 向下
				cnt, x = 1, i+1
				for ; x < w.size && w.board[x][j] == wc; x++ 
					cnt++
				
				if cnt >= 5 
					return wc
				
				// 向右下
				cnt, x, y = 1, i+1, j+1
				for ; x < w.size && y < w.size && w.board[x][y] == wc; x, y = x+1, y+1 
					cnt++
				
				if cnt >= 5 
					return wc
				
				// 向左下
				cnt, x, y = 1, i+1, j-1
				for ; x < w.size && y >= 0 && w.board[x][y] == wc; x, y = x+1, y-1 
					cnt++
				
				if cnt >= 5 
					return wc
				
			
		
	
	return full


func (w *WuZiQi) Result(state int) float64 
	switch state 
	case -1:
		return 0 // 都没赢且没空位
	case 1:
		return -1 // 电脑赢了
	case 2:
		return +1 // 玩家赢了
	default:
		return 0 // 都没赢且有空位
	


func (w *WuZiQi) GetActions() (res []any) 
	// todo 敌方上一步落子附近才是最优搜索范围
	//  某个落子必胜,则直接落子,如果某个落子让对手所有落子都必败则直接落子
	//    因此后续动作进一步缩小范围
	//  可以使用hash判断棋盘状态

	m := map[[2]int]struct // 用于去重
	for i := 0; i < w.size; i++ 
		for j := 0; j < w.size; j++ 
			if w.board[i][j] == 0 || w.board[i][j] == w.player 
				continue // 跳过空位和己方棋子
			

			x0, x1, y0, y1 := i-2, i+2, j-2, j+2
			for ii := x0; ii < x1; ii++ 
				for jj := y0; jj < y1; jj++ 
					if ii >= 0 && jj >= 0 && ii < w.size && jj < w.size &&
						w.board[ii][jj] == 0 

						p := [2]intii, jj
						_, ok := m[p]
						if !ok 
							// 在棋子周围2格范围的空位加到结果中
							// 超过2格的空位落子的意义不大
							res = append(res, p)
							m[p] = struct
						
					
				
			
		
	
	return


func (w *WuZiQi) ActionPolicy(action []any) any 
	// 目前随机选一个动作,应该是好方案先选出来
	return action[rand.Intn(len(action))]


func (w *WuZiQi) Action(action any) TreeState 
	wn := &WuZiQi
		size:   w.size,
		board:  make([][]int, w.size),
		player: 3 - w.player, // 切换电脑和玩家
	
	for i := 0; i < w.size; i++ 
		wn.board[i] = make([]int, w.size)
		for j := 0; j < w.size; j++ 
			wn.board[i][j] = w.board[i][j]
		
	

	ac := action.([2]int) // 在该位置落子
	wn.board[ac[0]][ac[1]] = w.player
	return wn


// MonteCarloTree 下面是算法部分
// 你的对象只需要提供TreeState所有接口,就可以直接使用
// https://github.com/int8/monte-carlo-tree-search
// https://blog.csdn.net/masterhero666/article/details/126325506
type (
	TreeState interface 
		IsTerminal() int        // 0: 未结束,其他为自定义状态
		Result(int) float64     // 计算分数,传入IsTerminal结果
		GetActions() []any      // 获取所有合法动作, todo 考虑获取不到动作时如何处理
		ActionPolicy([]any) any // 按策略挑选一个动作
		Action(any) TreeState   // 执行动作生成子节点
	

	McTreeNode struct 
		parent         *McTreeNode
		children       []*McTreeNode
		score          float64
		visitCount     float64
		untriedActions []any
		nodeState      TreeState
	
)

func Search(simulate any, state TreeState, discount ...float64) TreeState 
	var (
		root = &McTreeNodenodeState: state
		leaf *McTreeNode
		dp   = 1.4 // 折扣参数默认值
	)
	if len(discount) > 0 
		dp = discount[0]
	

	var loop func() bool
	switch s := simulate.(type) 
	case int:
		loop = func() bool 
			s-- // 模拟指定次数后退出
			return s >= 0
		
	case time.Duration:
		ts := time.Now().Add(s) // 超过指定时间后退出
		loop = func() bool  return time.Now().Before(ts) 
	case func() bool:
		loop = s // 或者由外部指定模拟结束方案
	default:
		panic(simulate)
	

	for loop() 
		leaf = root.treePolicy(dp)

		result, curState := 0, leaf.nodeState
		for 
			if result = curState.IsTerminal(); result != 0 
				break // 结束状态
			

			// 根据该节点状态生成所有合法动作
			all := curState.GetActions()
			// 按照某种策略选出1个动作,不同于expand的顺序取出
			one := curState.ActionPolicy(all)
			// 执行该动作,重复该过程,直到结束
			curState = curState.Action(one)
		

		// 根据结束状态计算结果,将该结果反向传播
		leaf.backPropagate(curState.Result(result))
	
	return root.chooseBestChild(dp).nodeState // 选择最优子节点


func (cur *McTreeNode) chooseBestChild(c float64) *McTreeNode 
	var (
		idx        = 0
		maxValue   = -math.MaxFloat64
		childValue float64
	)
	for i, child := range cur.children 
		childValue = (child.score / child.visitCount) +
			c*math.Sqrt(math.Log(cur.visitCount)/child.visitCount)
		if childValue > maxValue 
			maxValue = childValue
			idx = i // 选择分值最高的子节点
		
	
	return cur.children[idx]


func (cur *McTreeNode) backPropagate(result float64) 
	nodeCursor := cur
	for nodeCursor.parent != nil 
		nodeCursor.score += result
		nodeCursor.visitCount++ // 反向传播,增加访问次数,更新分数
		nodeCursor = nodeCursor.parent
	
	nodeCursor.visitCount++


func (cur *McTreeNode) expand() *McTreeNode 
	res := cur.untriedActions[0] // 返回1个未经尝试动作
	cur.untriedActions = cur.untriedActions[1:]

	child := &McTreeNode
		parent:    cur, // 当前节点按顺序弹出1个动作,执行动作生成子节点
		nodeState: cur.nodeState.Action(res),
	
	cur.children = append(cur.children, child)
	return child


func (cur *McTreeNode) treePolicy(discountParamC float64) *McTreeNode 
	nodeCursor := cur // 一直循环直到结束
	for nodeCursor.nodeState.IsTerminal() == 0 
		if nodeCursor.untriedActions == nil 
			// 只会初始化1次,找出该节点所有动作
			nodeCursor.untriedActions = nodeCursor.nodeState.GetActions()
		
		if len(nodeCursor.untriedActions) > 0 
			return nodeCursor.expand() // 存在未处理动作则添加子节点
		
		// 处理完动作,选择最好子节点继续往下处理
		nodeCursor = nodeCursor.chooseBestChild(discountParamC)
	
	return nodeCursor

五子棋对弈——MCTS学习

初识AlphaZero

AlphaZero能够基于强化学习实现较高技巧的棋类博弈,我看过nb网友实现的基于MCTS的五子棋模型后,惊叹不已!特此记录一下其中训练的一些方法和技巧。

MCTS

MCTS是指蒙特卡洛搜索树。

蒙特卡洛搜索树没听过的话,想必你是知道蒙特卡罗模拟的。这个模拟过程就是暴力的按照概率去操作所有过程,最后得出一个统计的结果。举一个很简单的例子,比如你要计算圆周率(pi),那么可以画一个正方形和一个内切圆。用两个面积之比可以得到圆周率的值,于是我们进行蒙特卡洛模拟,具体过程是在正方形内撒点,在每个区域内点数均匀的情况下,我们可以认为一个区域内的点数正比于面积,那么我们通过统计点数之比就可以近似得到面积之比。

而MCTS与模拟有一些区别,分为四个部分:SELECTION,EXPANSION,SIMULATION,BACK_PROPAGATION。

关于MCTS的详细内容可以参考这篇文章

UCB

树上的上限置信区间算法是一个能很好权衡探索与利用的算法。

[UCT(v) = frac{Q(v)}{N(v)} + c sqrt{frac{2ln N(u)}{N(v)}} ]

式中(Q)是赢的次数,(N)是这个点经过次数,(u)(v)节点的父亲节点。通过调节系数(c)我们也能改变对exploration和exploitation的倾向。

SELECTION

第一步,从当前根节点选择一个子节点,作为下一次的根。提供一个判断标准,我们算出每个叶子节点的分数,选择最高的一个吧?

但是直接选择最高的一个其实是有问题的。因为如果每次都从最高的开始选,可能存在一些效果更好的选择但是我们从来没有探索过,所以我们一般采用(UCB)来作为评估手段。

EXPANSION

在第一步中,我们始终在向下进行选择,然而一定会到达一种状态

  1. 游戏结束
  2. 有一个节点没有探索过

对于上面第二种结果,我们就要应用我们的EXPANSION了。扩展这个没有儿子节点的node的所有后续可能局面。

SIMULATION

在上一步的基础上,我们按照游戏规则模拟整局游戏,直到游戏结束,这一步是比较简单的。

BACK_PROPAGATION

得到游戏结果,包括打分和赢家,我们对这一条树上路径进行往回更新,更新祖先节点的分数和行动概率,以改良结果。


值得一提的是,上面讲到的是传统MCTS,我们还实现一个基于深度模型预测的MCTS,这个东西比上面所提到的有些进步,它的EXPANSION决策不再是随机的,而是按照Model给出的预测结果进行选择的,最后的结果也将会影响Model的参数。

Policy Value Net

网络结构

接下来说一说训练代码的模型结构。
公共的三层全卷积网络,然后分成Policy和Value两个输出端。

  • policy端,4个1X1的filter进行滤波,接全连接层,然后做softmax得到落子概率。

  • value端,filter后接全连接层,tanh后得到([-1,1])的评分。

输入描述

输入为4个(width imes height)的矩阵,前两个表示两个玩家的落子位置,第三个是对手上一次落子位置,第四个表示是否先手(原文中用了四个,但我认为前两个矩阵的顺序完全可以决定当前玩家是谁,比如规定第一个矩阵表示当前玩家的落子位置。如果知道原因的大佬希望评论区留言)。

train目标

输入是局面(state),网络输出的落子概率和最终评分分别是(p)(v),MCTS模拟结果的落子概率和评分分别是(pi)(z),我们的目标就是网络输出和MCTS的结果尽量相同,这样模型预测的结果尽量代替上千次MCTS的模拟,MCTS又在模型基础上模拟出更多的训练数据。

定义损失函数为

[l = (z-v)^2 - pi ^T log p + c || heta||^2 ]

前两项分别是下一步概率和评分的损失,最后一项是防止过拟合。

self play

训练时,两个玩家分别是纯粹MCTS策略玩家和有模型优化的MCTS_AlphaZero策略。他们在下棋的时候,不会直接落子,而是自己和自己self-play若干局,这样就在当前局面中构造了一个蒙特卡洛搜索树。MCTS的EXPANSION策略上面讲过了,这里说一下后者是怎么做的。

为了达到exporation的效果,我们的(UCT)还不够。self-play时,我们的“树玩家”并不是严格执行着某一个落子结果,而是按照概率随机进行的,而且在原有move probability的基础上,还加了一个迪利克雷分布的噪声,

[P(s, a) = (1 - varepsilon)p_a + varepsilon eta_a ]

有助于探索更多局面。文中作者找到的一个比较好的参数为

[varepsilon = 0.25, eta_a sim Dir(0.3) ]

AlphaZero玩家还会不停地给Policy Value Net返回训练需要的((state, p, v))结果,但需要注意的是,每次计算的时候,考虑的都是当前玩家的最优策略,注意Game Theory。

Policy Value Net拿到数据之后,去梯度下降,用最新的模型去和纯MCTS的玩家博弈,看看平均胜率,如果比历史的模型胜率更高,那么就更新我们的best_model

Tricks

  • 每次训练得到的数据不要直接扔进去训练,我们做一个操作:由于五子棋游戏本身的各种旋转、对称局面等价性,我们对同一种数据做旋转和对称,那么这些state的结果也是一样的。

参考

https://zhuanlan.zhihu.com/p/32089487
https://www.cnblogs.com/yifdu25/p/8303462.html

以上是关于golang蒙特卡洛树算法实现五子棋AI的主要内容,如果未能解决你的问题,请参考以下文章

[程序设计]-基于人工智能博弈树,极大极小(Minimax)搜索算法并使用Alpha-Beta剪枝算法优化实现的可人机博弈的AI智能五子棋游戏。

五子棋AI算法第二篇-极大极小值搜索算法

蒙特卡洛树搜索:井字游戏的实现

人机ai五子棋 ——五子棋AI算法之Java实现

AI五子棋第二篇-运用极大极小值算法书写AI三子棋,可拓展到五子棋(建议收藏)

蒙特卡洛树搜索 UCT 实现