Stockham FFT 的更高基数(或更好)公式

Posted

技术标签:

【中文标题】Stockham FFT 的更高基数(或更好)公式【英文标题】:Higher radix (or better) formulation for Stockham FFT 【发布时间】:2013-05-01 15:13:38 【问题描述】:

背景

我已经使用 OpenCL 实现了 Microsoft Research 的 this 算法,用于基数 2 FFT(Stockham 自动排序)。

我在内核中使用浮点纹理(256 列 X N 行)进行输入和输出,因为我需要在非整数点进行采样,我认为最好将其委托给纹理采样硬件。请注意,我的 FFT 始终是 256 点序列(我的纹理中的每一行)。此时,我的 N 为 16384 或 32768,具体取决于我使用的 GPU 和允许的最大 2D 纹理大小。

我还需要一次执行 4 个实值序列的 FFT,因此内核将 FFT(a, b, c, d) 执行为 FFT(a + ib, c + id),我可以从中提取稍后使用 O(n) 算法输出 4 个复杂序列。如果有人愿意,我可以详细说明这一点 - 但我不认为它属于这个问题的范围。

内核源代码

const sampler_t fftSampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST;

__kernel void FFT_Stockham(read_only image2d_t input, write_only image2d_t output, int fftSize, int size)

    int x = get_global_id(0);
    int y = get_global_id(1);
    int b = floor(x / convert_float(fftSize)) * (fftSize / 2);
    int offset = x % (fftSize / 2);
    int x0 = b + offset;
    int x1 = x0 + (size / 2);

    float4 val0 = read_imagef(input, fftSampler, (int2)(x0, y));
    float4 val1 = read_imagef(input, fftSampler, (int2)(x1, y));

    float angle = -6.283185f * (convert_float(x) / convert_float(fftSize));

    // TODO: Convert the two calculations below into lookups from a __constant buffer
    float tA = native_cos(angle);
    float tB = native_sin(angle);

    float4 coeffs1 = (float4)(tA, tB, tA, tB);
    float4 coeffs2 = (float4)(-tB, tA, -tB, tA);
    float4 result = val0 + coeffs1 * val1.xxzz + coeffs2 * val1.yyww;

    write_imagef(output, (int2)(x, y), result);

主机代码简单地调用这个内核 log2(256) 次,乒乓球输入和输出纹理。

注意:我尝试删除 native_cosnative_sin 以查看这是否会影响时间,但它似乎并没有改变太多。无论如何,这不是我要寻找的因素。

访问模式 知道我可能受内存带宽限制,这是我的 radix-2 FFT 的内存访问模式(每行)。

X0 - 要组合的元素 1(读取) X1 - 要组合的元素 2(读取) X - 要写入的元素(写入)

问题

所以我的问题是 - 有人可以帮助我/为我指出该算法的更高基数公式吗?我问是因为大多数 FFT 都针对大型案例和单实/复值序列进行了优化。他们的内核生成器也非常依赖于大小写,当我试图弄乱他们的内部时很快就会崩溃。

还有其他选择比简单地使用 radix-8 或 16 内核更好吗?

我的一些限制是 - 我必须使用 OpenCL(没有 cuFFT)。我也不能为此目的使用 ACML 的clAmdFft。还可以谈谈 CPU 优化(这个内核在 CPU 上很糟糕)——但让它在 GPU 上以更少的迭代运行是我的主要用例。

提前感谢您阅读所有这些并尝试提供帮助!

【问题讨论】:

【参考方案1】:

我尝试了几个版本,但在 CPU 和 GPU 上性能最好的版本是针对我的特定情况的 radix-16 内核。

这里是内核供参考。它取自 Eric Bainville 的(最优秀的)website,并注明出处。

// #define M_PI 3.14159265358979f

//Global size is x.Length/2, Scale = 1 for direct, 1/N to inverse (iFFT)
__kernel void ConjugateAndScale(__global float4* x, const float Scale)

   int i = get_global_id(0);

   float temp = Scale;
   float4 t = (float4)(temp, -temp, temp, -temp);

   x[i] *= t;



// Return a*EXP(-I*PI*1/2) = a*(-I)
float2 mul_p1q2(float2 a)  return (float2)(a.y,-a.x); 

// Return a^2
float2 sqr_1(float2 a)
 return (float2)(a.x*a.x-a.y*a.y,2.0f*a.x*a.y); 

// Return the 2x DFT2 of the four complex numbers in A
// If A=(a,b,c,d) then return (a',b',c',d') where (a',c')=DFT2(a,c)
// and (b',d')=DFT2(b,d).
float8 dft2_4(float8 a)  return (float8)(a.lo+a.hi,a.lo-a.hi); 

// Return the DFT of 4 complex numbers in A
float8 dft4_4(float8 a)

  // 2x DFT2
  float8 x = dft2_4(a);
  // Shuffle, twiddle, and 2x DFT2
  return dft2_4((float8)(x.lo.lo,x.hi.lo,x.lo.hi,mul_p1q2(x.hi.hi)));


// Complex product, multiply vectors of complex numbers

#define MUL_RE(a,b) (a.even*b.even - a.odd*b.odd)
#define MUL_IM(a,b) (a.even*b.odd + a.odd*b.even)

float2 mul_1(float2 a, float2 b)
 float2 x; x.even = MUL_RE(a,b); x.odd = MUL_IM(a,b); return x; 
float4 mul_1_F4(float4 a, float4 b)
 float4 x; x.even = MUL_RE(a,b); x.odd = MUL_IM(a,b); return x; 


float4 mul_2(float4 a, float4 b)
 float4 x; x.even = MUL_RE(a,b); x.odd = MUL_IM(a,b); return x; 

// Return the DFT2 of the two complex numbers in vector A
float4 dft2_2(float4 a)  return (float4)(a.lo+a.hi,a.lo-a.hi); 

// Return cos(alpha)+I*sin(alpha)  (3 variants)
float2 exp_alpha_1(float alpha)

  float cs,sn;
  // sn = sincos(alpha,&cs);  // sincos
  //cs = native_cos(alpha); sn = native_sin(alpha);  // native sin+cos
  cs = cos(alpha); sn = sin(alpha); // sin+cos
  return (float2)(cs,sn);

// Return cos(alpha)+I*sin(alpha)  (3 variants)
float4 exp_alpha_1_F4(float alpha)

  float cs,sn;
  // sn = sincos(alpha,&cs);  // sincos
  // cs = native_cos(alpha); sn = native_sin(alpha);  // native sin+cos
  cs = cos(alpha); sn = sin(alpha); // sin+cos
  return (float4)(cs,sn,cs,sn);



// mul_p*q*(a) returns a*EXP(-I*PI*P/Q)
#define mul_p0q1(a) (a)

#define mul_p0q2 mul_p0q1
//float2  mul_p1q2(float2 a)  return (float2)(a.y,-a.x); 

__constant float SQRT_1_2 = 0.707106781186548; // cos(Pi/4)
#define mul_p0q4 mul_p0q2
float2  mul_p1q4(float2 a)  return (float2)(SQRT_1_2)*(float2)(a.x+a.y,-a.x+a.y); 
#define mul_p2q4 mul_p1q2
float2  mul_p3q4(float2 a)  return (float2)(SQRT_1_2)*(float2)(-a.x+a.y,-a.x-a.y); 

__constant float COS_8 = 0.923879532511287; // cos(Pi/8)
__constant float SIN_8 = 0.382683432365089; // sin(Pi/8)
#define mul_p0q8 mul_p0q4
float2  mul_p1q8(float2 a)  return mul_1((float2)(COS_8,-SIN_8),a); 
#define mul_p2q8 mul_p1q4
float2  mul_p3q8(float2 a)  return mul_1((float2)(SIN_8,-COS_8),a); 
#define mul_p4q8 mul_p2q4
float2  mul_p5q8(float2 a)  return mul_1((float2)(-SIN_8,-COS_8),a); 
#define mul_p6q8 mul_p3q4
float2  mul_p7q8(float2 a)  return mul_1((float2)(-COS_8,-SIN_8),a); 

// Compute in-place DFT2 and twiddle
#define DFT2_TWIDDLE(a,b,t)  float2 tmp = t(a-b); a += b; b = tmp; 

// T = N/16 = number of threads.
// P is the length of input sub-sequences, 1,16,256,...,N/16.
__kernel void FFT_Radix16(__global const float4 * x, __global float4 * y, int pp)

  int p = pp;
  int t = get_global_size(0); // number of threads
  int i = get_global_id(0); // current thread


//////  y[i] = 2*x[i];
//////  return;

  int k = i & (p-1); // index in input sequence, in 0..P-1
  // Inputs indices are I+0,..,15*T
  x += i;
  // Output indices are J+0,..,15*P, where
  // J is I with four 0 bits inserted at bit log2(P)
  y += ((i-k)<<4) + k;

  // Load
  float4 u[16];
  for (int m=0;m<16;m++) u[m] = x[m*t];

  // Twiddle, twiddling factors are exp(_I*PI*0,..,15*K/4P)
  float alpha = -M_PI*(float)k/(float)(8*p);
  for (int m=1;m<16;m++) u[m] = mul_1_F4(exp_alpha_1_F4(m * alpha), u[m]);

  // 8x in-place DFT2 and twiddle (1)
  DFT2_TWIDDLE(u[0].lo,u[8].lo,mul_p0q8);
  DFT2_TWIDDLE(u[0].hi,u[8].hi,mul_p0q8);

  DFT2_TWIDDLE(u[1].lo,u[9].lo,mul_p1q8);
  DFT2_TWIDDLE(u[1].hi,u[9].hi,mul_p1q8);

  DFT2_TWIDDLE(u[2].lo,u[10].lo,mul_p2q8);
  DFT2_TWIDDLE(u[2].hi,u[10].hi,mul_p2q8);

  DFT2_TWIDDLE(u[3].lo,u[11].lo,mul_p3q8);
  DFT2_TWIDDLE(u[3].hi,u[11].hi,mul_p3q8);

  DFT2_TWIDDLE(u[4].lo,u[12].lo,mul_p4q8);
  DFT2_TWIDDLE(u[4].hi,u[12].hi,mul_p4q8);

  DFT2_TWIDDLE(u[5].lo,u[13].lo,mul_p5q8);
  DFT2_TWIDDLE(u[5].hi,u[13].hi,mul_p5q8);

  DFT2_TWIDDLE(u[6].lo,u[14].lo,mul_p6q8);
  DFT2_TWIDDLE(u[6].hi,u[14].hi,mul_p6q8);

  DFT2_TWIDDLE(u[7].lo,u[15].lo,mul_p7q8);
  DFT2_TWIDDLE(u[7].hi,u[15].hi,mul_p7q8);


  // 8x in-place DFT2 and twiddle (2)
  DFT2_TWIDDLE(u[0].lo,u[4].lo,mul_p0q4);
  DFT2_TWIDDLE(u[0].hi,u[4].hi,mul_p0q4);

  DFT2_TWIDDLE(u[1].lo,u[5].lo,mul_p1q4);
  DFT2_TWIDDLE(u[1].hi,u[5].hi,mul_p1q4);

  DFT2_TWIDDLE(u[2].lo,u[6].lo,mul_p2q4);
  DFT2_TWIDDLE(u[2].hi,u[6].hi,mul_p2q4);

  DFT2_TWIDDLE(u[3].lo,u[7].lo,mul_p3q4);
  DFT2_TWIDDLE(u[3].hi,u[7].hi,mul_p3q4);

  DFT2_TWIDDLE(u[8].lo,u[12].lo,mul_p0q4);
  DFT2_TWIDDLE(u[8].hi,u[12].hi,mul_p0q4);

  DFT2_TWIDDLE(u[9].lo,u[13].lo,mul_p1q4);
  DFT2_TWIDDLE(u[9].hi,u[13].hi,mul_p1q4);

  DFT2_TWIDDLE(u[10].lo,u[14].lo,mul_p2q4);
  DFT2_TWIDDLE(u[10].hi,u[14].hi,mul_p2q4);

  DFT2_TWIDDLE(u[11].lo,u[15].lo,mul_p3q4);
  DFT2_TWIDDLE(u[11].hi,u[15].hi,mul_p3q4);

  // 8x in-place DFT2 and twiddle (3)
  DFT2_TWIDDLE(u[0].lo,u[2].lo,mul_p0q2);
  DFT2_TWIDDLE(u[0].hi,u[2].hi,mul_p0q2);

  DFT2_TWIDDLE(u[1].lo,u[3].lo,mul_p1q2);
  DFT2_TWIDDLE(u[1].hi,u[3].hi,mul_p1q2);

  DFT2_TWIDDLE(u[4].lo,u[6].lo,mul_p0q2);
  DFT2_TWIDDLE(u[4].hi,u[6].hi,mul_p0q2);

  DFT2_TWIDDLE(u[5].lo,u[7].lo,mul_p1q2);
  DFT2_TWIDDLE(u[5].hi,u[7].hi,mul_p1q2);

  DFT2_TWIDDLE(u[8].lo,u[10].lo,mul_p0q2);
  DFT2_TWIDDLE(u[8].hi,u[10].hi,mul_p0q2);

  DFT2_TWIDDLE(u[9].lo,u[11].lo,mul_p1q2);
  DFT2_TWIDDLE(u[9].hi,u[11].hi,mul_p1q2);

  DFT2_TWIDDLE(u[12].lo,u[14].lo,mul_p0q2);
  DFT2_TWIDDLE(u[12].hi,u[14].hi,mul_p0q2);

  DFT2_TWIDDLE(u[13].lo,u[15].lo,mul_p1q2);
  DFT2_TWIDDLE(u[13].hi,u[15].hi,mul_p1q2);

  // 8x DFT2 and store (reverse binary permutation)
  y[0]    = u[0]  + u[1];
  y[p]    = u[8]  + u[9];
  y[2*p]  = u[4]  + u[5];
  y[3*p]  = u[12] + u[13];
  y[4*p]  = u[2]  + u[3];
  y[5*p]  = u[10] + u[11];
  y[6*p]  = u[6]  + u[7];
  y[7*p]  = u[14] + u[15];
  y[8*p]  = u[0]  - u[1];
  y[9*p]  = u[8]  - u[9];
  y[10*p] = u[4]  - u[5];
  y[11*p] = u[12] - u[13];
  y[12*p] = u[2]  - u[3];
  y[13*p] = u[10] - u[11];
  y[14*p] = u[6]  - u[7];
  y[15*p] = u[14] - u[15];

请注意,我已修改内核以一次执行 2 个复值序列的 FFT,而不是一个。此外,由于我一次只需要在更大的序列中包含 256 个元素的 FFT,因此我只运行了 2 次该内核,这让我在更大的数组中得到了 256 个长度的 DFT。

这里还有一些相关的主机代码。

var ev = new[]  new Cl.Event() ;
var pEv = new[]  new Cl.Event() ;

int fftSize = 1;
int iter = 0;
int n = distributionSize >> 5;
while (fftSize <= n)

    Cl.SetKernelArg(fftKernel, 0, memA);
    Cl.SetKernelArg(fftKernel, 1, memB);
    Cl.SetKernelArg(fftKernel, 2, fftSize);

    Cl.EnqueueNDRangeKernel(commandQueue, fftKernel, 1, null, globalWorkgroupSize, localWorkgroupSize,
        (uint)(iter == 0 ? 0 : 1),
        iter == 0 ? null : pEv,
        out ev[0]).Check();
    if (iter > 0)
        pEv[0].Dispose();
    Swap(ref ev, ref pEv);

    Swap(ref memA, ref memB); // ping-pong

    fftSize = fftSize << 4;
    iter++;

    Cl.Finish(commandQueue);


Swap(ref memA, ref memB);

希望这对某人有所帮助!

【讨论】:

以上是关于Stockham FFT 的更高基数(或更好)公式的主要内容,如果未能解决你的问题,请参考以下文章

基数 32 FFT 实现

使用 JTransforms 的 FFT:是基数 2 吗?

从向量中最快擦除元素或更好地使用内存(排序基数)

如何在 Scala 中定义一个存在的更高种类的类型

推送/弹出期间类似消息的更高/标准导航栏

从python中的更高目录导入[重复]