加速牛顿法求第 n 个根
Posted
技术标签:
【中文标题】加速牛顿法求第 n 个根【英文标题】:Speeding up Newton's Method for finding nth root 【发布时间】:2013-06-12 21:01:34 【问题描述】:让我用一个陈述来断言这个问题;此代码按预期工作,但它的速度非常慢。有没有办法让牛顿法收敛得更快,或者设置一个 __m256 var 等于单个浮点数而不弄乱浮点数数组等?
__m256 nthRoot(__m256 a, int root)
#define aligned __declspec(align(16)) float
// uses the calculation
// n_x+1 = (1/root)*(root * x + a / pow(x,root))
//initial numbers
aligned r[8];
aligned iN[8];
aligned mN[8];
//Function I made to fill arrays
/*
template<class T>
void FillArray(T a[],T b)
int n = sizeof(a)/sizeof(T);
for(int i = 0; i < n; a[i++] = b);
*/
//fills the arrays
FillArray(iN,(1.0f/(float)root));
FillArray(mN,(float)(root-1));
FillArray(r,(float)root);
//loads the arrays into the sse componenets
__m256 R = _mm256_load_ps(r);
__m256 Ni = _mm256_load_ps(iN);
__m256 Nm = _mm256_load_ps(mN);
//sets initaial guess to 1 / (a * root)
__m256 x = _mm256_rcp_ps(_mm256_mul_ps(R,a));
for(int i = 0; i < 20 ; i ++)
__m256 tmpx = x;
for(int k = 0 ; k < root -2 ; k++)
tmpx = _mm256_mul_ps(x,tmpx);
//f over f'
__m256 tar = _mm256_mul_ps(a,_mm256_rcp_ps(tmpx));
//fmac with Ni*X+tar
tar = _mm256_fmadd_ps(Nm,x,tar);
//Multipled by Ni
x = _mm256_mul_ps(Ni,tar);
return x;
编辑#1
__m256 SSEnthRoot(__m256 a, int root)
__m256 R = _mm256_set1_ps((float)root);
__m256 Ni = _mm256_set1_ps((1.0f)/((float)root));
__m256 Nm = _mm256_set1_ps((float)(root -1));
__m256 x = _mm256_mul_ps(a,_mm256_rcp_ps(R));
for(int i = 0; i < 10 ; i ++)
__m256 tmpx = x;
for(int k = 0 ; k < root -2 ; k++)
tmpx = _mm256_mul_ps(x,tmpx);
//f over f'
__m256 tar = _mm256_mul_ps(a,_mm256_rcp_ps(tmpx));
//mult nm x then add tar because my compiler stoped thinking that fmadd is a valid instruction
tar = _mm256_add_ps(_mm256_mul_ps(Nm,x),tar);
//Multiplied by the inverse of power
x = _mm256_mul_ps(Ni,tar);
return x;
任何使牛顿方法收敛更快的提示或指针(不是内存类型)将不胜感激。
在使用 _mm256_rcp_ps() 调用 _mm256_set1_ps() 函数时删除了编辑 #2,因为我已经将所需内容的倒数加载到 R 中
__m256 SSEnthRoot(__m256 a, int root)
__m256 R = _mm256_set1_ps((float)root);
__m256 Ni = _mm256_rcp_ps(R);
__m256 Nm = _mm256_set1_ps((float)(root -1));
__m256 x = _mm256_mul_ps(a,Ni);
for(int i = 0; i < 20 ; i ++)
__m256 tmpx = x;
for(int k = 0 ; k < root -2 ; k++)
tmpx = _mm256_mul_ps(x,tmpx);
//f over f'
__m256 tar = _mm256_mul_ps(a,_mm256_rcp_ps(tmpx));
//fmac with Ni*X+tar
//my compiler believes in fmac again
tar = _mm256_fmadd_ps(Nm,x,tar);
//Multiplied by the inverse of power
x = _mm256_mul_ps(Ni,tar);
return x;
编辑#3
__m256 SSEnthRoot(__m256 a, int root)
__m256 Ni = _mm256_set1_ps(1.0f/(float)root);
__m256 Nm = _mm256_set1_ps((float)(root -1));
__m256 x = _mm256_mul_ps(a,Ni);
for(int i = 0; i < 20 ; i ++)
__m256 tmpx = x;
for(int k = 0 ; k < root -2 ; k++)
tmpx = _mm256_mul_ps(x,tmpx);
__m256 tar = _mm256_mul_ps(a,_mm256_rcp_ps(tmpx));
tar = _mm256_fmadd_ps(Nm,x,tar);
x = _mm256_mul_ps(Ni,tar);
return x;
【问题讨论】:
当您切换到使用_mm256_set1_ps
时,速度提高了多少,需要提高多少?
每个函数的 1000000 只加速了 86 毫秒。改进功能的时间 = 2816。旧功能 2900 磨机的时间。我将 SSEnthRoot 函数的 for 循环迭代次数固定为与未改进的相同。
现在的速度是多少?你的目标是什么?不过,这些数字似乎非常错误,因为每个函数调用需要 3 µs,这太高了(除非 root
是一个很大的数字?)。
我找到了第 29 个根。
您的 pow 函数效率低下。您使用 27 次乘法来计算 x^28。只需 6 次乘法即可完成。我在答案中添加了一个函数,可以更有效地使用 AVX。
【参考方案1】:
您的pow
函数效率低下。
for(int k = 0 ; k < root -2 ; k++)
tmpx = _mm256_mul_ps(x,tmpx);
在您的示例中,您取的是第 29 个根。你需要pow(x, 29-1) = x^28
。目前,您为此使用 27 次乘法,但只需 6 次乘法即可。
x^28 = (x^4)*(x^8)*(x^16)
x^4 = y -> 2 multiplications
x^8 = y*y = z -> 1 multiplication
x^16 = z^2 = w-> 1 multiplications
y*z*w -> 2 multiplications
6 multiplications in total
这是你的代码的改进版本,它在我的系统上的速度大约是我系统的两倍。它使用我创建的一个新的pow_avx_fast
函数,该函数使用 AVX 一次执行 8 个浮点数的 x^n。它确实,例如x^28 是 6 次乘法而不是 27。请进一步了解我的答案。我找到了一个版本,它可以在一定的公差xacc
内找到结果。如果收敛很快,这可能会快得多。
inline __m256 pow_avx_fast(__m256 x, const int n)
//n must be greater than zero
if(n%2 == 0)
return pow_avx_fast(_mm256_mul_ps(x, x), n/2);
else
if(n>1) return _mm256_mul_ps(x,pow_avx_fast(_mm256_mul_ps(x, x), (n-1)/2));
return x;
inline __m256 SSEnthRoot_fast(__m256 a, int root)
// n_x+1 = (1/root)*((root-1) * x + a / pow(x,root-1))
__m256 R = _mm256_set1_ps((float)root);
__m256 Ni = _mm256_rcp_ps(R);
__m256 Nm = _mm256_set1_ps((float)(root -1));
__m256 x = _mm256_mul_ps(a,Ni);
for(int i = 0; i < 20 ; i ++)
__m256 tmpx = pow_avx_fast(x, root-1);
//f over f'
__m256 tar = _mm256_mul_ps(a,_mm256_rcp_ps(tmpx));
//fmac with Ni*X+tar
//tar = _mm256_fmadd_ps(Nm,x,tar);
tar = _mm256_add_ps(_mm256_mul_ps(Nm,x),tar);
//Multiplied by the inverse of power
x = _mm256_mul_ps(Ni,tar);
return x;
有关如何编写高效的pow
函数的更多信息,请参阅这些链接http://en.wikipedia.org/wiki/Addition-chain_exponentiation 和
http://en.wikipedia.org/wiki/Exponentiation_by_squaring
此外,您最初的猜测可能不太好。这是根据您的方法查找第 n 个根的标量代码(但使用数学 pow
函数可能比您的更快)。求解 16 的 4 次根(即 2)大约需要 50 次迭代。对于您使用的 20 次迭代,它返回超过 4000,这与 2.0 相差甚远。因此,您将需要调整您的方法以进行足够的迭代,以确保在一定公差范围内得到合理的答案。
float fx(float a, int n, float x)
return 1.0f/n * ((n-1)*x + a/pow(x, n-1));
float scalar_nthRoot_v2(float a, int root)
//sets initaial guess to 1 / (a * root)
float x = 1.0f/(a*root);
printf("x0 %f\n", x);
for(int i = 0; i<50; i++)
x = fx(a, root, x);
printf("x %f\n", x);
return x;
我从这里得到了牛顿法的公式。 http://en.wikipedia.org/wiki/Nth_root_algorithm
这是您的函数的一个版本,它在xacc
的某个公差范围内给出结果,或者如果在nmax
迭代后没有收敛,则退出。如果收敛发生在少于 20 次迭代中,则此函数可能比您的方法快得多。它要求所有八个浮点数同时收敛。换句话说,如果七个收敛而一个不收敛,那么其他七个必须等待不收敛的那个。这就是 SIMD 的问题(在 GPU 上也是如此),但总的来说它仍然比没有 SIMD 的情况要快。
int get_mask(const __m256 dx, const float xacc)
__m256i mask = _mm256_castps_si256(_mm256_cmp_ps(dx, _mm256_set1_ps(xacc), _CMP_GT_OQ));
return _mm_movemask_epi8(_mm256_castsi256_si128(mask)) + _mm_movemask_epi8(_mm256_extractf128_si256(mask,1));
inline __m256 SSEnthRoot_fast_xacc(const __m256 a, const int root, const int nmax, float xacc)
// n_x+1 = (1/root)*(root * x + a / pow(x,root))
__m256 R = _mm256_set1_ps((float)root);
__m256 Ni = _mm256_rcp_ps(R);
//__m256 Ni = _mm256_set1_ps(1.0f/root);
__m256 Nm = _mm256_set1_ps((float)(root -1));
__m256 x = _mm256_mul_ps(a,Ni);
for(int i = 0; i <nmax ; i ++)
__m256 tmpx = pow_avx_fast(x, root-1);
__m256 tar = _mm256_mul_ps(a,_mm256_rcp_ps(tmpx));
//tar = _mm256_fmadd_ps(Nm,x,tar);
tar = _mm256_add_ps(_mm256_mul_ps(Nm,x),tar);
tmpx = _mm256_mul_ps(Ni,tar);
__m256 dx = _mm256_sub_ps(tmpx,x);
dx = _mm256_max_ps(_mm256_sub_ps(_mm256_setzero_ps(), dx), dx); //fabs(dx)
int cnt = get_mask(dx, xacc);
if(cnt == 0) return x;
x = tmpx;
return x; //at least one value out of eight did not converge by nmax.
这里是 avx 的 pow 函数的更通用版本,它也适用于 n
__m256 pow_avx(__m256 x, const int n)
if(n<0)
return pow_avx(_mm256_rcp_ps(x), -n);
else if(n == 0)
return _mm256_set1_ps(1.0f);
else if(n == 1)
return x;
else if(n%2 ==0)
return pow_avx(_mm256_mul_ps(x, x), n/2);
else
return _mm256_mul_ps(x,pow_avx(_mm256_mul_ps(x, x), (n-1)/2));
其他一些建议
您可以使用查找第 n 个根的 SIMD 数学库。 SIMD math libraries for SSE and AVX
对于英特尔,您可以使用昂贵且封闭源代码的 SVML(英特尔的 OpenCL 驱动程序使用 SVML,因此您可以免费获得它)。对于 AMD,您可以使用免费但封闭源代码的 LIBM。有几个开源 SIMD 数学库,例如 http://software-lisc.fbk.eu/avx_mathfun/ 和 https://bitbucket.org/eschnett/vecmathlib/wiki/Home
【讨论】:
【参考方案2】:要将__m256
向量的所有元素设置为单个值:
__m256 v = _mm256_set1_ps(1.0f);
或在您的特定情况下:
__m256 R = _mm256_set1_ps((float)root);
__m256 Ni = _mm256_set1_ps((1.0f/(float)root));
__m256 Nm = _mm256_set1_ps((float)(root-1));
显然,一旦您进行了此更改,您就可以摆脱 FillArray
的东西。
【讨论】:
【参考方案3】:也许您应该在日志域中执行此操作。
pow(a,1/root) == exp( log(x) /root)
Julien Pommier 有一个sse_mathfun.h,它具有 SSE、SSE2 日志和 exp 函数,但我不能说我特别使用过这些函数。这些技术可以扩展到avx。
【讨论】:
以上是关于加速牛顿法求第 n 个根的主要内容,如果未能解决你的问题,请参考以下文章