Java中的多线程矩阵乘法

Posted

技术标签:

【中文标题】Java中的多线程矩阵乘法【英文标题】:Multithreading matrix multiplication in Java 【发布时间】:2018-07-12 22:47:02 【问题描述】:

我正在尝试构建一个程序,用于使用 a*d 线程将两个矩阵 (A[a,b], B[c,d]) 相乘(用于打印完成后的一个索引的总和)矩阵),为此,我使用了一个“监视器”类,该类将用作在线程之间同步的控制器,“乘数”类表示单个线程和一个主程序类。我的想法是线程将进行计算,当 thread(0,0) 打印他的总和时,他将发出下一个信号。出于某种原因,在打印第一个索引后 - 所有线程都处于等待模式,不会测试我的状况。你能看看我的代码并告诉我我的错误在哪里吗?

监控类:

import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

final class Monitor 
    private Lock lock;
    int index;
    Condition cond; 
    public Monitor () 
        lock = new ReentrantLock();
        cond = lock.newCondition();
        this.index = 0;
    

    public synchronized void finished(int x, double sum) throws InterruptedException 
        lock.lock();
        if(index != x) 
            while(index != x) cond.await();
            System.out.printf("%9.2f ",sum);
            index++;
            lock.unlock();
            cond.signalAll();
          
        else 
            System.out.printf("%9.2f ",sum);
            index++;
            try  lock.unlock(); 
            catch (java.lang.IllegalMonitorStateException e) ;
            try  lock.unlock(); 
            catch (java.lang.IllegalMonitorStateException e) ;
        
        if(index % 5 == 0) System.out.println();
    

乘数:

public class Multiplier extends Thread 
    private int index;
    private double [] vectorOne;
    private double [] vectorTwo;
    private Monitor monitor;
    private double sum;

    //constructor
    public Multiplier(int index, Monitor monitor,double [] vectorOne,double [] vectorTwo) 
        this.index = index;
        this.monitor = monitor;
        this.vectorOne = vectorOne;
        this.vectorTwo = vectorTwo;
    

    public void VecMulti() 
        sum = 0;
        for (int i = 0 ; i < vectorOne.length ; i++) 
            sum += vectorOne[i] * vectorTwo[i];
    

    public double getSum() 
        return sum;
    

    public void run() 
        VecMulti();
        try 
            monitor.finished(index, sum);
         catch (InterruptedException e) 
            e.printStackTrace();
        
    

主类:

public class MatrixMultiTest 
    public static void main(String[] args) 
        Monitor monitor = new Monitor(3*5);
        Matrix A = Matrix.random(3,4);
        Matrix B = Matrix.random(4,5);
        System.out.println("Matrix No1");
        A.show();
        System.out.println();
        System.out.println("Matrix No2");
        B.show();
        System.out.println();
        System.out.println("Multi Matrix");

        for (int i = 0; i < 3; i++)
            for (int j = 0; j < 5; j++) 
                Multiplier myThr = new Multiplier(i*5+j,
                        monitor,A.getRow(i),B.getCol(j));
                myThr.start();
                try 
                    myThr.join();
                 catch (InterruptedException e) 
                //  TODO Auto-generated catch block
                    e.printStackTrace();
                
            
    

【问题讨论】:

while(index != x) 永远不会停止,因为方法参数 x 和 index 永远不会改变,你需要检查一个变量 shared amonst all threads ,这里java是按值复制的。 @zapl 'monitor'上的索引不是可以被认为是与所有线程共享的值吗?只要线程“完成”运行,它就会发生变化。 哦,天哪,index 确实是一个共享变量。我会尝试new Multiplier(i*5+j.. 一个,否则你的索引搞砸了 让我们打个赌:如果您将多线程解决方案与简单的单线程解决方案进行基准测试,那么您的解决方案会慢得多……至少尺寸约为 10*10,正如您介绍的那样操作的巨大开销。 @RalfKleberhoff 也许吧,但这不是重点。我不打算在这项任务上效率。 【参考方案1】:

finished() 方法充满了问题:

第一个问题是synchronized 关键字。它必须被删除。有了这个关键字,如果第一个进入的线程有一个非零索引,程序就会死锁——线程将永远停在等待条件发出,永远不会到来,因为没有其他线程可以进入finished()方法.

第二个错误在于else 块:

else 
    System.out.printf("%9.2f ",sum);
    index++;
    try  lock.unlock(); 
    catch (java.lang.IllegalMonitorStateException e) ;
    try  lock.unlock(); 
    catch (java.lang.IllegalMonitorStateException e) ;

它从不调用cond.signalAll(),所以index=0的线程通过后,其他线程将永远保持停顿。

第三个问题是在if(index != x) .. 块中,cond.signalAll()lock.unlock() 顺序错误:
lock.unlock();
cond.signalAll();

ConditionsignalAll() 方法documentation 声明:

当调用此方法时,实现可能(并且通常确实)要求当前线程持有与此条件关联的锁。实现必须记录此前提条件以及如果未持有锁定所采取的任何操作。通常会抛出IllegalMonitorStateException 之类的异常。

这两行代码必须按顺序切换,否则会抛出IllegalMonitorStateException

该方法的工作版本可能如下所示:

public void finished(int x, double sum) throws InterruptedException 
    try 
        lock.lock();
        while (index != x) 
            cond.await();
        
        System.out.printf("%9.2f ", sum);
        index++;
        cond.signalAll();
     finally 
        lock.unlock();
    
    if (index % 5 == 0) System.out.println();


有趣的是,OP 提供的代码即使在所有这些错误的情况下也能正常工作,但这只是由于 MatrixMultiTest 类中的这段代码:

try 
    myThr.join();
 catch (InterruptedException e) 
//  TODO Auto-generated catch block
    e.printStackTrace();

每个线程都是按顺序创建然后启动和加入的,因此只有一个线程在任何时候都试图获取同步 finished() 方法上的隐式锁,i*5+j 索引值确保线程获取这个按index的顺序隐式锁:0、1、2等。这意味着在方法内部,index总是等于x,并且每个线程都经过else块在finished()中,允许程序完成执行。 cond.await() 实际上从未被调用过。

如果join 块被删除,那么一些值可能会被打印出来,但程序最终会死锁。

【讨论】:

以上是关于Java中的多线程矩阵乘法的主要内容,如果未能解决你的问题,请参考以下文章

numpy/pandas矩阵乘法的多线程?

使用 CUDA 进行矩阵乘法:2D 块与 1D 块

矩阵乘法

使用win32线程的矩阵乘法

python 多线程稀疏矩阵乘法

OpenMP 矩阵向量乘法仅在一个线程上执行