A*算法在Unity中的实现

Posted PortiaTheGazer

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了A*算法在Unity中的实现相关的知识,希望对你有一定的参考价值。


一、A*算法是什么

   A星算法是一种搜索策略,是一种启发式图搜索策略。不同于深度优先搜索或广度优先搜索等盲目搜索策略,它能够利用与问题有关的启发信息进行搜索。和迪杰斯特拉算法类似,它们之所以是启发式的,是因为融入了人们既嗤之以鼻又甘之如饴的思想:“贪心”。
   为什么说是“贪心”的呢?是因为每次扩展节点的时候都尽可能的选择路径最短的节点,而Dijkstra算法更看重的是已扩展的节点到起点的路径最短,而A星算法兼顾已扩展节点到起点的路径最短和到终点的路径最短,所以说A星算法应该可以说更高级一点。
   如何记录已扩展节点到起点的路径和已扩展节点到终点的路径呢?A星算法的每个节点通过G(n)和H(n)分别记录到起点的花销和到终点的最佳路径的估计代价,两者的和F(n)便是这个节点的估价函数——估价函数也就是该节点到终点的最小代价的估计值,是我们评判某一个节点优越与否的唯一参考标准。
   A星算法一定就是最好的吗?我们定义H*(n)是某个节点到终点的最优路径的代价,即真实的代价,而非估计值,且有H(n)≤H*(n)恒成立。我觉得某人写的A星算法的程序好与不好,主要看H(n)跟H*(n)是否足够接近,如果H(n)=H*(n)的话,我们就不会扩展任何无关的节点,那么这个算法就绝对是最好的。但我们接下来的算法使用节点到终点的对角距离当作H(n)的,所以是会扩展一些不必要的节点的。但瘦死的骆驼比马大,肯定比宽度优先搜索和深度优先搜索还是要快的。

二、为什么要在Unity中用A*

   如果只是用Vector2.MoveTowards,或者只是用transform.position+=方向向量×速度×Time.deltaTime的话,那么这种怪的AI就有点太蠢了,一不会绕开障碍,二只会往主角脸上突。如果一种两种怪的AI是这样那么还可以,如果所有怪的AI都是这种单调乏味的移动方式,那么玩家就会感到疲劳。

这两只怪只能隔着障碍喷我

   想象一下,如果有怪物绕开了障碍物跑到你旁边给你来个背刺,是不是游戏难度一下就加大了?各位高级玩家是不是立马就兴奋起来了♂?
   当然,A星算法在游戏中的应用远远不止这些,欢迎大家来补充。

三、代码实现

   废话说了很多,还是直接上代码吧。我这个程序呢是参考了一个YouTube博主的视频的,大家如果感兴趣的话可以去看那个博主的视频学习。贴一下网址:
https://www.youtube.com/watch?v=alU04hvz6L4

1.创建节点类

public class Node

    private Grid<Node> grid;
    public int gridX;
    public int gridY;
    public int gCost;
    public int hCost;
    public bool isBarrier;
    public Node cameFromNode;
    public int FCost  get  return gCost + hCost;  
    public Node(Grid<Node> _grid,int x,int y)
    
        this.grid = _grid;
        this.gridX = x;
        this.gridY = y;
        isBarrier = false;
    
    public void SetIsBarrier(bool _isBarrier)
    
        this.isBarrier = _isBarrier;
        grid.TriggerGridObjectChanged(gridX, gridY);
    


   解释一下:每个节点有我们之前说的估价函数h(n)节点与起点的实际代价g(n),F(n)直接设一个getter返回它俩的和就好了。
   gridX和gridY是该节点在网格中的位置,我们节点的位置并不是用的World Position,而是一个非负的整型,就像下图这样,左下角的坐标是[0,0],往右x加一,往上y加一,以此类推。

   cameFrom节点很重要,也就是它的父节点,通过这个节点一直向上回溯才能找到我们最终要走的路。
   isBarrier这个布尔值用来记录该节点是不是有障碍物,有障碍物的话直接放到closed表里就不管它了。不是的话再考虑放到open表里扩展。

2.创建网格类

public class Grid<T>

    
    public event EventHandler<OnGridValueChangedEventArgs> OnGridValueChanged;
    public class OnGridValueChangedEventArgs:EventArgs
    
        public int x;
        public int y;
    
    private int width;
    private int height;
    private T[,] gridArray;//创建一个二维数组用来存储网格的每一个节点,大小为网格长度乘以网格宽度
    private float cellSize;
    private Vector3 originPosition;
    public Grid(int _width, int _height,float _cellSize,Vector3 _originPosition,Func<Grid<T>,int,int,T> _createGridObject)
    
        this.width = _width;
        this.height = _height;
        this.cellSize = _cellSize;
        this.originPosition = _originPosition;
        gridArray = new T[this.width, this.height];
        for (int x = 0; x < width; x++)
        
            for (int y = 0; y < height; y++)
            
                gridArray[x, y] = _createGridObject(this,x,y);
                Debug.DrawLine(GetWorldPosition(x, y), GetWorldPosition(x, y + 1));
                Debug.DrawLine(GetWorldPosition(x, y), GetWorldPosition(x + 1, y));
            
        
        Debug.DrawLine(GetWorldPosition(width, 0), GetWorldPosition(width, height));
        Debug.DrawLine(GetWorldPosition(0, height), GetWorldPosition(width, height));
    
    public int GetWidth()
    
        return this.width;
    
    public int GetHeight()
    
        return this.height;
    
    public float GetCellSize()
    
        return this.cellSize;
    
    public T[,] GetGridArray()
    
        return this.gridArray;
    
    public Vector3 GetOriginPosition()
    
        return this.originPosition;
    

    private Vector3 GetWorldPosition(int x,int y)
    
        return new Vector3(x, y) * cellSize+originPosition;
    
    public Vector2 GetXY(Vector3 _worldPosition)
    
        return new Vector2(_worldPosition.x - originPosition.x / cellSize,
            _worldPosition.y - originPosition.y /cellSize);
    
    public void SetValue(int x, int y, T value)
    
        if(x>=0 && y>=0 && x<width && y<height)
        
            gridArray[x, y] = value;
            OnGridValueChanged?.Invoke(this, new OnGridValueChangedEventArgs  x = x, y = y );
        
    
    public void TriggerGridObjectChanged(int x,int y)
    
        OnGridValueChanged?.Invoke(this, new OnGridValueChangedEventArgs  x = x, y = y );
    
    public void SetValue(Vector3 _worldPosition, T value)
    
        int x, y;
        x = Mathf.FloorToInt(GetXY(_worldPosition).x);
        y = Mathf.FloorToInt(GetXY(_worldPosition).y);
        SetValue(x, y, value);
    
    public T GetValue(int x,int y)
    
        if (x >= 0 && y >= 0 && x < width && y < height)
        
            return gridArray[x, y];
        
        else
        
            return default;
        
    
    public T GetValue(Vector3 _worldPosition)
    
        int x, y;
        x= Mathf.FloorToInt(GetXY(_worldPosition).x);
        y = Mathf.FloorToInt(GetXY(_worldPosition).y);
        return GetValue(x, y);
    

  主要就是网格的初始化,World Position和网格坐标的转来转去,以及一些getter和setter。

3.PathFinding核心代码

public class PathFinding

    private const int MOVE_STRAIGHT_COST=10;
    private const int MOVE_DIAGONAL_COST = 14;//本A*算法使用对角距离
    private List<Node> openList;
    private List<Node> closedList;
    public static PathFinding Instance  get; private set; 
    public Grid<Node> Grid  get; set; 
    public Node GetNode(int x, int y)
    
        return Grid.GetValue(x, y);
    
    public PathFinding(string sceneName)
    
        Instance = this;
        Vector3 barrierGridPosition = Vector3.zero;
        switch (sceneName)
        
            case "Hell_Mid":
                Grid = new Grid<Node>(13, 12, 1, new Vector3(-3, -8, 0), (Grid<Node> g, int x, int y) => new Node(g, x, y));
                break;
        
        
    
    public List<Vector3> FindPath(Vector3 _startWorldPosition,Vector3 _endWorldPosition)
    
        Vector2 startPosition=Grid.GetXY(_startWorldPosition);
        Vector2 endPosition=Grid.GetXY(_endWorldPosition);
        List<Node> path = FindPath(Mathf.FloorToInt(startPosition.x), Mathf.FloorToInt(startPosition.y), 
            Mathf.FloorToInt(endPosition.x), Mathf.FloorToInt(endPosition.y));
        if(path==null)
        
            return null;
        else
        
            List<Vector3> worldPath=new List<Vector3> ;
            foreach(Node node in path)
            
                worldPath.Add(Grid.GetOriginPosition()+new Vector3(node.gridX, node.gridY) * Grid.GetCellSize() + new Vector3(1, 1, 0) * Grid.GetCellSize() * .5f);
            
            return worldPath;
        
    
    public List<Node> FindPath(int _startX,int _startY,int _endX,int _endY)
    
        Node startNode = Grid.GetValue(_startX, _startY);//定义起始节点,起始节点将作为Open表中的第一个元素
        Node endNode = Grid.GetValue(_endX, _endY);
        openList = new List<Node>  startNode;
        closedList = new List<Node>();
        #region
        //初始化所有节点,让每个节点的gCost设为无穷大,前一节点设为空值
        for(int x=0;x<Grid.GetWidth();x++)
        
            for(int y=0;y<Grid.GetHeight();y++)
            
                Node node = Grid.GetValue(x,y);
                node.gCost = int.MaxValue;
                node.cameFromNode = null;
            
        
        #endregion
        startNode.gCost = 0;
        startNode.hCost = CaculateDistanceCost(startNode, endNode);
        while(openList.Count>0)
        
            SortList();
            Node currentNode = openList[0];
            if (currentNode==endNode)
            
            	openList.Remove(currentNode);
                closedList.Add(currentNode);
                return CaculatePath(currentNode);
            
            else
            
                openList.Remove(currentNode);
                closedList.Add(currentNode);
                foreach (Node neighbourNode in GetNeighbourList(currentNode))
                
                    if (closedList.Contains(neighbourNode)) continue;
                    if(neighbourNode.isBarrier)
                    
                        closedList.Add(neighbourNode);
                        continue;
                    
                    int tentativeGCost = currentNode.gCost + CaculateDistanceCost(currentNode, neighbourNode);
                    if(tentativeGCost<neighbourNode.gCost)
                    
                        neighbourNode.cameFromNode = currentNode;
                        neighbourNode.gCost = tentativeGCost;
                        neighbourNode.hCost = CaculateDistanceCost(neighbourNode, endNode);
                        if(!openList.Contains(neighbourNode))
                        
                            openList.Add(neighbourNode);
                        
                    
                
            
        
        return null;
    
    private List<Node> GetNeighbourList(Node _currentNode)
    
        List<Node> neighbourList = new List<Node>  ;
        if((_currentNode.gridX-1)>=0)
        
            neighbourList.Add(GetNode(_currentNode.gridX - 1, _currentNode.gridY));//左邻居
            if((_currentNode.gridY-1)>=0)
            
                neighbourList.Add(GetNode(_currentNode.gridX - 1, _currentNode.gridY - 1));//左下邻居
            
            if((_currentNode.gridY+1)<Grid.GetHeight())
            
                neighbourList.Add(GetNode(_currentNode.gridX - 1, _currentNode.gridY + 1));//左上邻居
            
        
        if((_currentNode.gridX+1)<Grid.GetWidth())
        
            neighbourList.Add(GetNode(_currentNode.gridX + 1, _currentNode.gridY));//右邻居
            if((_currentNode.gridY-1)>=0)
            
                neighbourList.Add(GetNode(_currentNode.gridX + 1, _currentNode.gridY - 1));//右下邻居
            
            if((_currentNode.gridY+1)<Grid.GetHeight())
            
                neighbourList.Add(GetNode(_currentNode.gridX + 1, _currentNode.gridY + 1));//右上邻居
            
        
        if((_currentNode.gridY-1)>=0)
        
            neighbourList.以上是关于A*算法在Unity中的实现的主要内容,如果未能解决你的问题,请参考以下文章

Unity中的AI算法和实现1-Waypoint

Unity A星(A Star/A*)寻路算法

Unity中的AI算法和实现2-有限状态机FSM(上)

Cg入门20:Fragment shader - 片段级模型动态变色(实现汽车动态换漆)

unity动画一个片段播放完怎么让它不会到初始状态

Unity 实现A* 寻路算法