利用cublasHgemm来实现cublasHgemv

Posted thisjiang

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了利用cublasHgemm来实现cublasHgemv相关的知识,希望对你有一定的参考价值。

前几天做half量化时发现cublas竟然没有提供half版本的矩阵-向量乘,也就是half版本的cublasHgemv。自己写一个又太麻烦,重点是精度和耗时不一定比cublas提供的要好,不过cublas提供了half版本的矩阵-矩阵乘函数cublasHgemm,只要维度没啥问题,用cublasHgemm实现cublasHgemv,既方便又好用。

废话不多说,直接上。


前置准备

对于矩阵A和向量V,我们要计算(y=alpha * A * V + eta * y),其中矩阵A的维度为(m*n),向量V的长度为(n),二维表示就是(n*1)(alpha)(eta)都是标量,所以y的维度就是(m*1)。由于是用gemm实现,因此还有个ld参数,矩阵A的ldm为m,向量V的ldv为1(这里暂定,待会儿解释),最后值得注意的是结果向量(y)的ldy应该是m而不是1.

对于half矩阵-向量乘,这里我们假设AV(y)都是half类型(不然就用不了cublasHgemm,只能尝试用cublasGemmEx来实现了),当然(alpha)(eta)也都得是half类型数值(__float2half)。

不考虑转置,接下来直接上代码:

half版本cublasHgemv

cublasStatus_t cublasHgemv(cublasHandle_t handle, cublasOperation_t trans,
                           int m, int n,
                           const half           *alpha,
                           const half           *A, int ldm,//由于cublas库是列优先存储,因此ldm常为m,ldv常为1(暂定),ldy常为m
                           const half           *V, int ldv,
                           const half           *beta,
                           half  *y, int ldy
){
    return cublasHgemm(handle, trans, trans, m, 1, n, alpha, A, ldm, V, ldv, beta, y, ldy); 
}

对比cublas库中已有的float版本的cublasSgemv:

cublasSgemv

声明

cublasStatus_t cublasSgemv(cublasHandle_t handle, cublasOperation_t trans,
                           int m, int n,
                           const float           *alpha,
                           const float           *A, int lda,
                           const float           *V, int incv,
                           const float           *beta,
                           float           *y, int incy)

调用

status = cublasSgemv(handle, trans, m, n, alpha, A, ldm, V, incv, beta, y, incy);

关于ld参数

注意,相比较于原生的cublasSgemv,自实现版本不是用的incvincy参数而是使用的ldvldy参数,这主要是因为我自己在使用时遇到的大部分情况是这个向量只是矩阵的某一行而不,此时ldv参数应该设置为该矩阵的行数而不是1。当然,由于无论是多少维的张量,在计算机中都是以一维连续空间存放的,因此ldv和incv,ldy和incy大部分情况下都相同。

以上是关于利用cublasHgemm来实现cublasHgemv的主要内容,如果未能解决你的问题,请参考以下文章

利用反射来实现动态代理

C#利用反射来判断对象是否包含某个属性的实现方法

C#利用反射来判断对象是否包含某个属性的实现方法

独家 | 利用OpenCV和深度学习来实现人类活动识别(附链接)

利用Knockout来实现可重复利用的组件Component

Windows中利用共享内存来实现不同进程间的通信