矩阵最优链乘及Java实现

Posted Redo

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了矩阵最优链乘及Java实现相关的知识,希望对你有一定的参考价值。

矩阵最优链乘及Java实现

给一系列矩阵(A_1,A_2,...A_n)进行链乘,找出最优运算顺序

矩阵乘法

(egin{bmatrix} {a_1,a_2}{a_3,a_4}\end{bmatrix}*egin{bmatrix} {b_1,b_2}{b_3,b_4}\end{bmatrix}=egin{bmatrix} {a_1*b_1+a_2*b_3,a_1*b_2+a_2*b_4}{a_3*b_1+a_4*b_3,a_3*b_2+a_4*b_4}\end{bmatrix})

一个p行q列的矩阵乘以一个q行r列的矩阵最终得到一个p行r列的矩阵

新矩阵中第i行第j列的数字等于第一个矩阵的第i行数字和第二个矩阵第j列数字一一对应乘积的和

由此可以看出两矩阵相乘总共标量乘法运算数为pqr,这就是衡量的标准

矩阵乘法满足结合律

((AB)C=A(BC))

运算顺序不同带来的差异

假设有三个矩阵A,B,C相乘

矩阵 A B C
维度 10*2 2*10 10*2

(AB)C代价

(Cost(AB)=10*2*10)

(Cost((AB)C)=10*2*10+10*10*2=400)

A(BC)代价

(Cost(BC)=2*10*2)

(Cost(A(BC)=10*2*2+2*10*2=80)

解决思路

最优方案的子方案也是最优方案

当一个链乘的方案最优,那么它的子方案也是最优

思路分解

n个矩阵链乘可以分解为三部分

  • 0-k矩阵链乘
  • k+1-n矩阵链乘
  • 前两者相乘

n个矩阵就可以分为n-1种分解情况

  • 0-0矩阵链乘,1-n矩阵链乘
  • 0-1矩阵链乘,2-n矩阵链乘
  • ...
  • 0-(n-1)矩阵链乘,(n-1)-n矩阵链乘

如果用m[i,j]表示第i到第j个矩阵相乘的代价

p[i]表示矩阵的维度(p[i]表示第i个矩阵的行数,p[i+1]表示第i个矩阵的列数)

(m[i,j]=egin{cases} 0,i=j\min{m[i,k]+m[k+1,j]+p_i*p_{k+1}*p_{j+1}}\end{cases} )

也就是说要求出m[i,j]必须要先求出m[i,k],m[k+1,j]

也就是说要求出n个矩阵链乘的最小代价,要先求出子链乘的代价

自底向上

我们可以从最小的子链乘开始计算最小代价

计算流程

  1. 所有单个矩阵链乘的最小代价(0)
  2. 所有两个矩阵链乘的最小代价
  3. 所有三个矩阵链乘的最小代价
  4. ...
  5. 所有n个矩阵链乘的最小代价

Java实现

数据结构

  1. m

    • 二维矩阵,存储最小代价和最优分解点
    m(0,2) m(1,2) m(2,2)
    m(0,1) m(1,1) m(2,1)
    m(0,0) m(1,0) m(2,0)

    (m[i,j]=egin{cases} 矩阵i到j链乘的最小代价,ile j矩阵j到i链乘最小代价的分解点 i>j\end{cases})

  2. p

    • p[i]表示矩阵的维度(p[i]表示第i个矩阵的行数,p[i+1]表示第i个矩阵的列数)

代码

/**
 * @Date 2020/4/30
 * @Author Redo
 * @Description 矩阵链乘
 **/
public class MCM {
    //内容矩阵
    private int[][] m;
    //维数数组
    private int[] p;
    private int size;

    /**
     * 以维数为输入的构造
     * @param p
     */
    public MCM(int ...p){
        this.p=p;
        //维数数组元素个数较矩阵个数多1
        this.size=p.length-1;
        m=new int[size][size];
    }

    /**
     * 设置分解点
     * @param i 第i个矩阵
     * @param j 链乘到第j个矩阵
     * @param value 的最优分解点
     */
    private void setSplit(int i, int j, int value){
        m[j][i]=value;
    }

    /**
     * 设置分解点
     * @param i 第i个矩阵
     * @param j 链乘到第j个矩阵
     * @return 最优分解点
     */
    private int getSplit(int i,int j){
        return m[j][i];
    }

    /**
     * 获取i到j的矩阵链乘分解点为k的最小代价
     * @param start 起始位置
     * @param k 分解点
     * @param end 结束位置
     * @return 以k为分解点的最小代价
     */
    private int m(int start, int k,int end){
        return m[start][k]+m[k+1][end]+p[start]*p[k+1]*p[end+1];
    }

    /**
     * 获取i到j的矩阵链乘的最小代价
     * @param start 起始位置
     * @param end 结束位置
     * @return 最小代价
     */
    private int m(int start,int end){
        if(start==end)
            return 0;
        else if(start+1==end){
            //相邻的矩阵链乘设置分解点为第一个矩阵
            setSplit(start,end,start);
            return m(start,start,end);
        }
        else {
            //设置为最大值,找出最小代价
            int temp=Integer.MAX_VALUE;
            for(int k=start;k<end;k++){
                if(temp>m(start,k,end)){
                    temp=m(start,k,end);
                    setSplit(start,end,k);
                }
            }
            return temp;
        }
    }

    /**
     * 计算所有最小代价和分解点
     */
    private void cal(){
        for(int i=1;i<p.length-1;i++){
            for (int j=0;j+i<p.length-1;j++){
                m[j][j+i]=m(j,j+i);
            }
        }
    }

    /**
     * 返回 start到end的矩阵链乘的最优方案
     * @param start 起始点
     * @param end 结束点
     * @return 最优方案
     */
    private String subDisplay(int start,int end){
        if(start==end)
            return "A"+end;
        else
            return "("+subDisplay(start,m[end][start])+""+subDisplay(m[end][start]+1,end)+")";
    }

    /**
     * 返回最优方案
     * @return
     */
    public String display(){
        return subDisplay(0,p.length-2);
    }

    private int getMinCost(int start,int end){
        return m[start][end];
    }

    /**
     * 获取最小代价
     * @return
     */
    public int getMinCost(){
        return m[0][size-1];
    }

    public static void main(String[] args) {
        MCM mcm=new MCM(10,2,10,2);
        mcm.cal();
        System.out.println(mcm.display());
        System.out.println("MinCost:"+mcm.getMinCost());
    }
}

以上是关于矩阵最优链乘及Java实现的主要内容,如果未能解决你的问题,请参考以下文章

最优矩阵链乘

dp-最优矩阵链乘

矩阵链乘(解析表达式)

动态规划—矩阵链乘法

矩阵链乘

矩阵链乘问题