[-1,1] 上反正切的最佳机器优化多项式极小极大近似?

Posted

技术标签:

【中文标题】[-1,1] 上反正切的最佳机器优化多项式极小极大近似?【英文标题】:Best machine-optimized polynomial minimax approximation to arctangent on [-1,1]? 【发布时间】:2014-11-01 20:22:01 【问题描述】:

为了以合理的精度简单高效地实现快速数学函数,多项式极小极大近似通常是首选方法。极小极大近似值通常使用 Remez 算法的变体生成。各种广泛使用的工具,例如 Maple 和 Mathematica,都为此提供了内置功能。生成的系数通常使用高精度算术计算。众所周知,简单地将这些系数四舍五入到机器精度会导致最终实现的精度欠佳。

相反,人们会搜索密切相关的系数集,这些系数集可以精确地表示为机器数,以生成机器优化的近似值。两篇相关论文是:

Nicolas Brisebarre、Jean-Michel Muller 和 Arnaud Tisserand,“计算机高效多项式近似”,ACM 数学软件交易,卷。 32,第 2 期,2006 年 6 月,第 236-256 页。

Nicolas Brisebarre 和 Sylvain Chevillard,“高效多项式 L∞-近似”,第 18 届 IEEE 计算机算术研讨会 (ARITH-18),蒙彼利埃(法国),2007 年 6 月,第 169-176 页。

后一篇论文中 LLL 算法的实现可作为Sollya tool 的fpminimax() 命令获得。据我了解,为生成机器优化近似值而提出的所有算法都是基于启发式的,因此通常不知道最佳近似值可以达到什么精度。我不清楚用于评估近似值的 FMA(融合乘加)是否会影响该问题的答案。在我看来,它应该是天真的。

我目前正在研究 [-1,1] 上的反正切的简单多项式逼近,它使用霍纳方案和 FMA 在 IEEE-754 单精度算术中进行评估。请参阅下面 C99 代码中的函数 atan_poly()。由于目前无法访问 Linux 机器,我没有使用 Sollya 来生成这些系数,而是使用了我自己的启发式算法,可以粗略地描述为最陡的体面退火和模拟退火的混合(以避免陷入局部最小值) .我的机器优化多项式的最大误差非常接近 1 ulp,但理想情况下,我希望最大 ulp 误差低于 1 ulp。

我知道我可以更改我的计算以提高准确性,例如通过使用表示为超过单精度精度的前导系数,但我希望保持代码原样(即,尽可能简单尽可能)仅调整系数以提供最准确的结果。

“经过验证的”最优系数集将是理想的,欢迎提供相关文献的指针。我进行了文献检索,但找不到任何论文能够在 Sollya 的 fpminimax() 之外有意义地推进现有技术,也没有任何论文研究 FMA(如果有的话)在本期中的作用。

// max ulp err = 1.03143
float atan_poly (float a)

    float r, s;
    s = a * a;
    r =              0x1.7ed1ccp-9f;
    r = fmaf (r, s, -0x1.0c2c08p-6f);
    r = fmaf (r, s,  0x1.61fdd0p-5f);
    r = fmaf (r, s, -0x1.3556b2p-4f);
    r = fmaf (r, s,  0x1.b4e128p-4f);
    r = fmaf (r, s, -0x1.230ad2p-3f);
    r = fmaf (r, s,  0x1.9978ecp-3f);
    r = fmaf (r, s, -0x1.5554dcp-2f);
    r = r * s;
    r = fmaf (r, a, a);
    return r;


// max ulp err = 1.52637
float my_atanf (float a)

    float r, t;
    t = fabsf (a);
    r = t;
    if (t > 1.0f) 
        r = 1.0f / r;
    
    r = atan_poly (r);
    if (t > 1.0f) 
        r = fmaf (0x1.ddcb02p-1f, 0x1.aee9d6p+0f, -r); // pi/2 - r
    
    r = copysignf (r, a);
    return r;

【问题讨论】:

我喜欢你对 π/2 - r 的计算。你是如何选择这两个因素的?这些值是否经典地被认为是 π 的近似值的好因素,还是您自己通过某种详尽的搜索找到了它们? (PS:抱歉,我无法回答实际问题) @Pascal Cuoq 太糟糕了,我非常希望你能有一些额外的或更好的见解。这个问题多年来一直困扰着我,无论是一般情况还是反正切的具体情况。至于分解 pi/2,我只是从 pi/2 的平方根开始进行了蛮力搜索,并以 1-ulp 的步长递增了一个因子。对于float,只需几秒钟即可达到“最佳”分解。 请注意,您引用的两篇文章根本不需要提及 FMA,因为它们在该级别上不起作用。它们产生实数多项式,其系数恰好是浮点值(并且,作为实数多项式,是目标函数的良好近似)。换句话说,它们考虑了系数必须表示为浮点常数的事实,但没有考虑到运算将是浮点运算的事实。将 FMA 添加到等式时没有任何变化,因为首先忽略了操作错误。 我故意不查看基于表格的方法,例如Gal's accurate table method。 FLOPS 的增长速度超过了内存吞吐量,应用 SIMD 矢量化时表访问很困难,并且缓存访问需要比 FMA 更多的能量(据我了解大约是 10 倍)。 @Nominal Animal: 对于发布的代码错误大于 1 ulp 发生 0.85 ulp = 1.02852 @ -9.31304693e-001 -0x1.dcd3f8p-1 是最大的正误差,ulp = -1.03143 @ -9.84267354e-001 -0x1.f7f1e4p-1 是最大的负误差。 【参考方案1】:

以下函数是arctan[0, 1] 上的忠实实现:

float atan_poly (float a) 
  float s = a * a, u = fmaf(a, -a, 0x1.fde90cp-1f);
  float r1 =               0x1.74dfb6p-9f;
  float r2 = fmaf (r1, u,  0x1.3a1c7cp-8f);
  float r3 = fmaf (r2, s, -0x1.7f24b6p-7f);
  float r4 = fmaf (r3, u, -0x1.eb3900p-7f);
  float r5 = fmaf (r4, s,  0x1.1ab95ap-5f);
  float r6 = fmaf (r5, u,  0x1.80e87cp-5f);
  float r7 = fmaf (r6, s, -0x1.e71aa4p-4f);
  float r8 = fmaf (r7, u, -0x1.b81b44p-3f);
  float r9 = r8 * s;
  float r10 = fmaf (r9, a, a);
  return r10;

如果函数 atan_poly 未能忠实地在 [1e-16, 1] 上进行四舍五入,则以下测试工具将中止,否则打印“成功”:

int checkit(float f) 
  double d = atan(f);
  float d1 = d, d2 = d;
  if (d1 < d) d2 = nextafterf(d1, 1.0/0.0);
  else d1 = nextafterf(d1, -1.0/0.0);
  float p = atan_poly(f);
  if (p != d1 && p != d2) return 0;
  return 1;


int main() 
  for (float f = 1; f > 1e-16; f = nextafterf(f, -1.0/0.0)) 
    if (!checkit(f)) abort();
  
  printf("success\n");
  exit(0);


在每次乘法中使用s 的问题是多项式的系数不会迅速衰减。接近 1 的输入会导致大量取消几乎相等的数字,这意味着您正在尝试找到一组系数,以便计算结束时的累积舍入非常接近 arctan 的残差。

常量0x1.fde90cp-1f 是一个接近1 的数字,(arctan(sqrt(x)) - x) / x^3 非常接近最近的浮点数。也就是说,它是一个用于计算u 的常数,因此三次系数几乎完全确定。 (对于这个程序,三次系数必须是-0x1.b81b44p-3f-0x1.b81b42p-3f。)

su 的交替乘法可以将ri 中的舍入误差对ri+2 的影响减少至多1/4,因为s*u &lt; 1/4 无论a 是什么.这为选择五阶及以上的系数提供了相当大的余地。


我借助两个程序找到了系数:

一个程序插入一堆测试点,写下一个线性不等式系统,并计算该不等式系统的系数界限。请注意,给定a,可以计算出r8 的范围,从而得到一个如实舍入的结果。为了得到 线性 不等式,我假设 r8real-number 算术中将被计算为 floats su 中的多项式;线性不等式将这个实数r8 限制在某个区间内。我使用 Parma Polyhedra Library 来处理这些约束系统。 另一个程序随机测试特定范围内的系数集,首先插入一组测试点,然后按降序插入从11e-8 的所有floats,并检查atan_poly 是否产生忠实的舍入的atan((double)x)。如果某些x 失败,它会打印出x 以及失败的原因。

为了获得系数,我破解了第一个程序来修复c3,为每个测试点计算r7 的界限,然后获得高阶系数的界限。然后我破解了它来修复c3c5 并获得高阶系数的界限。我一直这样做,直到除了三个最高阶系数,c13c15c17 之外的所有系数。

我在第二个程序中增加了一组测试点,直到它停止打印任何内容或打印出“成功”。我需要很少的测试点来拒绝几乎所有错误的多项式——我在程序中计算了 85 个测试点。


在这里,我展示了我选择系数的一些工作。为了得到一个忠实四舍五入的arctan 作为我的初始测试点集,假设r1r8 是在实际算术中评估的(并且以某种方式不愉快地四舍五入,但我不记得了)但是r9r10float 算术中进行评估,我需要:

-0x1.b81b456625f15p-3 <= c3 <= -0x1.b81b416e22329p-3
-0x1.e71d48d9c2ca4p-4 <= c5 <= -0x1.e71783472f5d1p-4
0x1.80e063cb210f9p-5 <= c7 <= 0x1.80ed6efa0a369p-5
0x1.1a3925ea0c5a9p-5 <= c9 <= 0x1.1b3783f148ed8p-5
-0x1.ec6032f293143p-7 <= c11 <= -0x1.e928025d508p-7
-0x1.8c06e851e2255p-7 <= c13 <= -0x1.732b2d4677028p-7
0x1.2aff33d629371p-8 <= c15 <= 0x1.41e9bc01ae472p-8
0x1.1e22f3192fd1dp-9 <= c17 <= 0x1.d851520a087c2p-9

取 c3 = -0x1.b81b44p-3,假设 r8 也在 float 算术中求值:

-0x1.e71df05b5ad56p-4 <= c5 <= -0x1.e7175823ce2a4p-4
0x1.80df529dd8b18p-5 <= c7 <= 0x1.80f00e8da7f58p-5
0x1.1a283503e1a97p-5 <= c9 <= 0x1.1b5ca5beeeefep-5
-0x1.ed2c7cd87f889p-7 <= c11 <= -0x1.e8c17789776cdp-7
-0x1.90759e6defc62p-7 <= c13 <= -0x1.7045e66924732p-7
0x1.27eb51edf324p-8 <= c15 <= 0x1.47cda0bb1f365p-8
0x1.f6c6b51c50b54p-10 <= c17 <= 0x1.003a00ace9a79p-8

取c5 = -0x1.e71aa4p-4,假设r7是在float算术中完成的:

0x1.80e3dcc972cb3p-5 <= c7 <= 0x1.80ed1cf56977fp-5
0x1.1aa005ff6a6f4p-5 <= c9 <= 0x1.1afce9904742p-5
-0x1.ec7cf2464a893p-7 <= c11 <= -0x1.e9d6f7039db61p-7
-0x1.8a2304daefa26p-7 <= c13 <= -0x1.7a2456ddec8b2p-7
0x1.2e7b48f595544p-8 <= c15 <= 0x1.44437896b7049p-8
0x1.396f76c06de2ep-9 <= c17 <= 0x1.e3bedf4ed606dp-9

取c7 = 0x1.80e87cp-5,假设r6是在float算术中完成的:

0x1.1aa86d25bb64fp-5 <= c9 <= 0x1.1aca48cd5caabp-5
-0x1.eb6311f6c29dcp-7 <= c11 <= -0x1.eaedb032dfc0cp-7
-0x1.81438f115cbbp-7 <= c13 <= -0x1.7c9a106629f06p-7
0x1.36d433f81a012p-8 <= c15 <= 0x1.3babb57bb55bap-8
0x1.5cb14e1d4247dp-9 <= c17 <= 0x1.84f1151303aedp-9

取c9 = 0x1.1ab95ap-5,假设r5是在float算术中完成的:

-0x1.eb51a3b03781dp-7 <= c11 <= -0x1.eb21431536e0dp-7
-0x1.7fcd84700f7cfp-7 <= c13 <= -0x1.7ee38ee4beb65p-7
0x1.390fa00abaaabp-8 <= c15 <= 0x1.3b100a7f5d3cep-8
0x1.6ff147e1fdeb4p-9 <= c17 <= 0x1.7ebfed3ab5f9bp-9

我为c11 选择了一个接近范围中间的点,并随机选择了c13c15c17


编辑:我现在已经自动化了这个过程。以下函数也是arctan[0, 1]上的忠实实现:

float c5 = 0x1.997a72p-3;
float c7 = -0x1.23176cp-3;
float c9 = 0x1.b523c8p-4;
float c11 = -0x1.358ff8p-4;
float c13 = 0x1.61c5c2p-5;
float c15 = -0x1.0b16e2p-6;
float c17 = 0x1.7b422p-9;

float juffa_poly (float a) 
  float s = a * a;
  float r1 =              c17;
  float r2 = fmaf (r1, s, c15);
  float r3 = fmaf (r2, s, c13);
  float r4 = fmaf (r3, s, c11);
  float r5 = fmaf (r4, s, c9);
  float r6 = fmaf (r5, s, c7);
  float r7 = fmaf (r6, s, c5);
  float r8 = fmaf (r7, s, -0x1.5554dap-2f);
  float r9 = r8 * s;
  float r10 = fmaf (r9, a, a);
  return r10;

令我惊讶的是,这段代码竟然存在。对于这些附近的系数,您可以得到r10 和在几 ulps 数量级的实数算术中评估的多项式值之间的距离的界限,这要归功于当s 接近@987654390 时该多项式的缓慢收敛@。我曾预计舍入误差会以一种基本上“无法控制”的方式表现,只需通过调整系数即可。

【讨论】:

感谢您为缩小每个系数的范围提供一些指导。不幸的是,您的结果提供了一个非常大的搜索空间。在改进了我最初的启发式方法并计划在接下来的 24 小时内发布各种答案后,我取得了一些进展。 @njuffa:期待您的回答!我希望这种变化能给搜索空间提供更可行的解决方案;对高阶项的更广泛限制似乎意味着高阶项的低阶位可用于对抗剩余的舍入异常。 非常有趣。使用编辑后的系数将 ulp 与 mpfr 进行比较,我获得了 1.3 的最大 ulp。您认为您可以运行您的程序并优化以下 Float64 系数吗:pastebin.com/Zdyk4wA6 这些系数为 Float64 提供了 1.3 的 ulp,但我认为我们可以做得更好。 @musm:64 位浮点数是不同的球赛。 32 位浮点数空间的详尽枚举是该算法中重复完成的一个步骤。【参考方案2】:

我思考了我在 cmets 中收到的各种想法,并根据这些反馈进行了一些实验。最后,我认为精炼的启发式搜索是最好的方法。我现在已经设法将atanf_poly() 的最大错误减少到 1.01036 ulps,只有三个参数超出了我声明的 1 ulp 错误界限的目标:

ulp = -1.00829 @ |a| =  9.80738342e-001 0x1.f62356p-1 (3f7b11ab)
ulp = -1.01036 @ |a| =  9.87551928e-001 0x1.f9a068p-1 (3f7cd034)
ulp =  1.00050 @ |a| =  9.99375939e-001 0x1.ffae34p-1 (3f7fd71a)

根据生成改进近似值的方式,不能保证这是最佳近似值;这里没有科学突破。由于当前解决方案的 ulp 误差尚未完全平衡,并且由于继续搜索继续提供更好的近似值(尽管时间间隔呈指数增长),我的 猜测 是可以实现 1 ulp 误差范围,但与此同时,我们似乎已经非常接近最佳机器优化近似值了。

新近似值的更好质量是精细搜索过程的结果。我观察到多项式中所有最大的 ulp 误差都接近于单位,比如说在 [0.75,1.0] 中是保守的。这允许对最大误差小于某个界限(例如 1.08 ulps)的有趣系数集进行快速扫描。然后,我可以详细而详尽地测试锚定在该点的启发式选择的超锥内的所有系数集。第二步搜索最小 ulp 误差作为主要目标,将正确舍入结果的最大百分比作为次要目标。通过在我的 CPU 的所有四个内核上使用这个两步过程,我能够显着加快搜索过程:到目前为止,我已经能够检查大约 221 个系数集。

基于所有“接近”解决方案中每个系数的范围,我现在估计这个近似问题的总有用搜索空间是 >= 224 个系数集,而不是更乐观的220我之前扔了。对于非常有耐心或拥有大量计算能力的人来说,这似乎是一个可行的问题。

我的更新代码如下:

// max ulp err = 1.01036
float atanf_poly (float a)

    float r, s;
    s = a * a;
    r =              0x1.7ed22cp-9f;
    r = fmaf (r, s, -0x1.0c2c2ep-6f);
    r = fmaf (r, s,  0x1.61fdf6p-5f);
    r = fmaf (r, s, -0x1.3556b4p-4f);
    r = fmaf (r, s,  0x1.b4e12ep-4f);
    r = fmaf (r, s, -0x1.230ae0p-3f);
    r = fmaf (r, s,  0x1.9978eep-3f);
    r = fmaf (r, s, -0x1.5554dap-2f);
    r = r * s;
    r = fmaf (r, a, a);
    return r;


// max ulp err = 1.51871
float my_atanf (float a)

    float r, t;
    t = fabsf (a);
    r = t;
    if (t > 1.0f) 
        r = 1.0f / r;
    
    r = atanf_poly (r);
    if (t > 1.0f) 
        r = fmaf (0x1.ddcb02p-1f, 0x1.aee9d6p+0f, -r); // pi/2 - r
    
    r = copysignf (r, a);
    return r;


更新(两年半后重新审视这个问题)

以 T. Myklebust 的 draft publication 为起点,我发现 [-1,1] 上的反正切近似值具有最小误差,最大误差为 0.94528 ulp。

/* Based on: Tor Myklebust, "Computing accurate Horner form approximations 
   to special functions in finite precision arithmetic", arXiv:1508.03211,
   August 2015. maximum ulp err = 0.94528
*/
float atanf_poly (float a)

    float r, s;
    s = a * a;                        
    r =              0x1.6d2086p-9f;  //  2.78569828e-3
    r = fmaf (r, s, -0x1.03f2ecp-6f); // -1.58660226e-2
    r = fmaf (r, s,  0x1.5beebap-5f); //  4.24722321e-2
    r = fmaf (r, s, -0x1.33194ep-4f); // -7.49753043e-2
    r = fmaf (r, s,  0x1.b403a8p-4f); //  1.06448799e-1
    r = fmaf (r, s, -0x1.22f5c2p-3f); // -1.42070308e-1
    r = fmaf (r, s,  0x1.997748p-3f); //  1.99934542e-1
    r = fmaf (r, s, -0x1.5554d8p-2f); // -3.33331466e-1
    r = r * s;
    r = fmaf (r, a, a);
    return r;

【讨论】:

不错!我使用 AVX(通过&lt;immintrin.h&gt;)编写了一些 C,它发现了 ulps 中的最大/最小误差,以及正确舍入的浮点结果的数量,每个参数的速率为 7.5 到 8 个周期(在 i5-4200U 上测量)。在 [0.75f,1.0f] 范围内,这应该意味着每核每 GHz 每秒大约 30 组系数,使用大约 96 兆字节的预计算表(使用 atan() 预计算在 1 秒内)。我的代码还说明了这些系数的 -1.00829..1.01257 ulps 错误,正确舍入了 3534268 (84.263%)。如果您认为它有用/相关,我很乐意将其清理并在此处发布该代码。 @NominalAnimal 我不确定我是否理解您的评论:大多数参数集都可以提前被拒绝,并且永远不需要计算整个范围的最大误差(只要停止比你已经拥有的更大)。但无论如何,我会对你的代码非常感兴趣。不用担心清理它,只要能编译,就有用! @PascalCuoq:如果我理解正确,njuffa 使用f(C17,C15,C13,..,C5,C3) = maximum_error_in_ULPs 来定义一个表面(或者实际上是两个,一个用于下面的最大误差,一个用于上面的最大误差)。最佳解决方案是该表面的最低点,您可以使用例如共轭梯度或其他方法来找到它。使用 AVX2,我的函数可以在大约 3000 万个 CPU 周期内计算该表面上的每个点(如果仅考虑 x=[0.75,1]),仅使用一个内核。我希望它会加快 njuffa 的方法。我会发布我的代码。 :) @Nominal Anmimal 这将使拥有最新硬件的人有机会加入搜索。我目前在一个没有支持 FMA 的硬件的平台上,这使得搜索速度相当慢。再加上我的参考atan() 可能也不是最快的。 @njuffa:您使用的是双精度多项式引用arctan() 函数吗?我知道库三角函数不可靠,但由于测试表明我可以复制你的结果,所以我没有费心去寻找更好的东西。我想知道,对于多项式参考反正切,偏导数的差异是否。最大误差 x 处的每个系数都会产生误差表面梯度——或者至少是一个足够好的近似值?【参考方案3】:

这不是问题的答案,但太长了,无法放入评论:

您的问题是关于系数 C3、C5、...、C17 在反正切的多项式近似中的最佳选择,其中您将 C1 固定为 1,将 C2、C4、...、C16 固定为 0。

您的问题的标题说您正在寻找 [-1, 1] 上的近似值,将偶数系数固定为 0 的一个很好的理由是,近似值恰好是一个奇函数是充分且必要的。通过仅在 [0, 1] 上应用多项式近似,您问题中的代码与标题“矛盾”。

如果您使用 Remez 算法将系数 C2、C3、...、C8 查找为 [0, 1] 上反正切的多项式近似值,您最终可能会得到类似于以下值的结果:

#include <stdio.h>
#include <math.h>

float atan_poly (float a)

  float r, s;
  s = a;
  //  s = a * a;

  r =             -3.3507930064626076153585890630056286726807491543578e-2;
  r = fmaf (r, s, 1.3859776280052980081098065189344699108643282883702e-1);
  r = fmaf (r, s, -1.8186361916440430105127602496688553126414578766147e-1);
  r = fmaf (r, s, -1.4583047494913656326643327729704639191810926020847e-2);
  r = fmaf (r, s, 2.1335202878219865228365738728594741358740655881373e-1);
  r = fmaf (r, s, -3.6801711826027841250774413728610805847547996647342e-3);
  r = fmaf (r, s, -3.3289852243978319173749528028057608377028846413080e-1);
  r = fmaf (r, s, -1.8631479933914856903459844359251948006605218562283e-5);
  r = fmaf (r, s, 1.2917291732886065585264586294461539492689296797761e-7);

  r = fmaf (r, a, a);
  return r;


int main() 
  for (float x = 0.0f; x < 1.0f; x+=0.1f)
    printf("x: %f\n%a\n%a\n\n", x, atan_poly(x), atan(x));

这与您问题中的代码具有大致相同的复杂性 - 乘法的数量相似。看看这个多项式,没有理由特别想将任何系数固定为 0。如果我们想在 [-1, 1] 上逼近奇函数而不固定偶数系数,它们会自动变得非常小并且受制于吸收,然后我们希望将它们固定为 0,但是对于 [0, 1] 上的这种近似,它们没有,所以我们不必固定它们。

它可能比您问题中的奇数多项式更好或更差。事实证明情况更糟(见下文)。然而,LolRemez 0.2(问题底部的代码)的这种快速而肮脏的应用似乎足以引发系数选择的问题。我特别想知道,如果您将此答案中的系数与用于获取问题中的系数相同的“最陡的体面和模拟退火的混合”优化步骤进行处理,会发生什么。

所以,总结一下这个作为答案发布的评论,您确定您正在寻找最佳系数 C3、C5、...、C17?在我看来,您正在寻找能够产生反正切近似值的最佳单精度浮点运算序列,并且该近似值不必是 17 次奇多项式的霍纳形式。

x: 0.000000 0x0p+0 0x0p+0 x: 0.100000 0x1.983e2cp-4 0x1.983e28938f9ecp-4 x: 0.200000 0x1.94442p-3 0x1.94441ff1e8882p-3 x: 0.300000 0x1.2a73a6p-2 0x1.2a73a71dcec16p-2 x: 0.400000 0x1.85a37ap-2 0x1.85a3770ebe7aep-2 x: 0.500000 0x1.dac67p-2 0x1.dac670561bb5p-2 x: 0.600000 0x1.14b1dcp-1 0x1.14b1ddf627649p-1 x: 0.700000 0x1.38b116p-1 0x1.38b113eaa384ep-1 x: 0.800000 0x1.5977a8p-1 0x1.5977a686e0ffbp-1 x: 0.900000 0x1.773388p-1 0x1.77338c44f8faep-1

这是我链接到 LolRemez 0.2 的代码,以优化 [0, 1] 上反正切的 9 次多项式近似的相对精度:

#include "lol/math/real.h"
#include "lol/math/remez.h"

using lol::real;
using lol::RemezSolver;

real f(real const &y)

  return (atan(y) - y) / y;


real g(real const &y)

  return re (atan(y) / y);


int main(int argc, char **argv)

  RemezSolver<8, real> solver;
  solver.Run("1e-1000", 1.0, f, g, 50);
  return 0;

【讨论】:

启发式地,使用s = a*a 中的多项式而不是a 中的多项式应该确保在更多输入上更快地收敛,从而减少不利舍入错误的可能性。这可能重要,也可能不重要。 @tmyklebu 我最终查看其他非空系数集的一个原因恰恰是,在反正切的情况下,只有奇数系数似乎没有很好地收敛到在实际功能的 1ULP 范围内。我一直不明白为什么人们更喜欢 Padé 逼近,即使定义区间不包括奇点(因为它在这里不包括),但现在我看到反正切不想被多项式逼近。 我的代码使用 [0,1] 的主要近似区间这一事实是 my_atanf() 实现的产物。如果没有初始的fabsf(),数值上不会发生任何变化,因为在atan_poly() 中,变量s 总是正数,但是它会使符号位处理稍微复杂化。就机器效率多项式而言,将偶数系数固定为零具有直观意义,正如 tmyklebu 指出的那样,这种方法也得到了文献的支持。如果不固定导致在相同数量的操作中得到更准确的结果,那是完全可以接受的。【参考方案4】:

这不是一个答案,而是一个扩展评论。

最近的 Intel CPU 和一些未来的 AMD CPU 具有 AVX2。在 Linux 中,在 /proc/cpuinfo 中查找 avx2 标志以查看您的 CPU 是否支持这些。

AVX2 是一个扩展,它允许我们使用 256 位向量(例如,八个单精度数或四个双精度数)来构造和计算,而不仅仅是标量。它包括 FMA3 支持,这意味着此类向量的融合乘加。简而言之,AVX2 允许我们并行评估八个多项式,几乎与我们使用标量运算评估单个多项式的时间相同。

函数error8() 使用x 的预定义值分析一组系数,与atan(x) 的预计算值进行比较,并返回ULP 中的错误(分别低于和高于所需结果),以及与所需浮点值完全匹配的结果数。这些对于简单地测试一组系数是否比目前最知名的一组更好,而不需要这些,但允许不同的策略来测试哪些系数。 (基本上,ULP 中的最大误差形成一个表面,我们试图找到该表面上的最低点;知道每个点的表面“高度”可以让我们对前进的方向做出有根据的猜测 - - 如何改变系数。)

使用了四个预先计算的表:known_x 用于参数,known_f 用于正确舍入的单精度结果,known_a 用于双精度“准确”值(我只是希望库atan() 对此已经足够精确了——但不应该在没有检查的情况下依赖它!),known_m 可以将双精度差异缩放到 ULP。给定所需的参数范围,precalculate() 函数将使用库 atan() 函数预先计算这些参数。 (它还依赖于 IEEE-754 浮点格式,浮点和整数字节顺序相同,但在运行此代码的 CPU 上确实如此。)

请注意,known_xknown_fknown_a 数组可以存储在二进制文件中; known_m 的内容从 known_a 派生而来。使用库atan() 而不验证它不是一个好主意——但是因为我的匹配njuffa 的结果,我没有费心去寻找更好的参考atan()

为简单起见,下面是示例程序形式的代码:

#define _POSIX_C_SOURCE 200809L
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <immintrin.h>
#include <math.h>
#include <errno.h>

/** poly8() - Compute eight polynomials in parallel.
 * @x - the arguments
 * @c - the coefficients.
 *
 * The first coefficients are for degree 17, the second
 * for degree 15, and so on, down to degree 3.
 *
 * The compiler should vectorize the expression using vfmaddXXXps
 * given an AVX2-capable CPU; for example, Intel Haswell,
 * Broadwell, Haswell E, Broadwell E, Skylake, or Cannonlake;
 * or AMD Excavator CPUs. Tested on Intel Core i5-4200U.
 *
 * Using GCC-4.8.2 and
 *     gcc -O2 -march=core-avx2 -mtune=generic
 * this code produces assembly (AT&T syntax)
 *     vmulps       %ymm0, %ymm0, %ymm2
 *     vmovaps      (%rdi), %ymm1
 *     vmovaps      %ymm0, %ymm3
 *     vfmadd213ps  32(%rdi), %ymm2, %ymm1
 *     vfmadd213ps  64(%rdi), %ymm2, %ymm1
 *     vfmadd213ps  96(%rdi), %ymm2, %ymm1
 *     vfmadd213ps  128(%rdi), %ymm2, %ymm1
 *     vfmadd213ps  160(%rdi), %ymm2, %ymm1
 *     vfmadd213ps  192(%rdi), %ymm2, %ymm1
 *     vfmadd213ps  224(%rdi), %ymm2, %ymm1
 *     vmulps       %ymm2, %ymm1, %ymm0
 *     vfmadd132ps  %ymm3, %ymm3, %ymm0
 *     ret
 * if you omit the 'static inline'.
*/
static inline __v8sf poly8(const __v8sf x, const __v8sf *const c)

    const __v8sf xx = x * x;
    return (((((((c[0]*xx + c[1])*xx + c[2])*xx + c[3])*xx + c[4])*xx + c[5])*xx + c[6])*xx + c[7])*xx*x + x;


/** error8() - Calculate maximum error in ULPs
 * @x  - the arguments
 * @co -  C17, C15, C13, C11, C9, C7, C5, C3 
 * @f  - the correctly rounded results in single precision
 * @a  - the expected results in double precision
 * @m  - 16777216.0 raised to the same power of two as @a normalized
 * @n  - number of vectors to test
 * @max_under - pointer to store the maximum underflow (negative, in ULPs) to
 * @max_over  - pointer to store the maximum overflow (positive, in ULPs) to
 * Returns the number of correctly rounded float results, 0..8*n.
*/
size_t error8(const __v8sf *const x, const float *const co,
              const __v8sf *const f, const __v4df *const a, const __v4df *const m,
              const size_t n,
              float *const max_under, float *const max_over)

    const __v8sf c[8] =   co[0], co[0], co[0], co[0], co[0], co[0], co[0], co[0] ,
                           co[1], co[1], co[1], co[1], co[1], co[1], co[1], co[1] ,
                           co[2], co[2], co[2], co[2], co[2], co[2], co[2], co[2] ,
                           co[3], co[3], co[3], co[3], co[3], co[3], co[3], co[3] ,
                           co[4], co[4], co[4], co[4], co[4], co[4], co[4], co[4] ,
                           co[5], co[5], co[5], co[5], co[5], co[5], co[5], co[5] ,
                           co[6], co[6], co[6], co[6], co[6], co[6], co[6], co[6] ,
                           co[7], co[7], co[7], co[7], co[7], co[7], co[7], co[7]  ;
    __v4df min =  0.0, 0.0, 0.0, 0.0 ;
    __v4df max =  0.0, 0.0, 0.0, 0.0 ;
    __v8si eqs =  0, 0, 0, 0, 0, 0, 0, 0 ;
    size_t i;

    for (i = 0; i < n; i++) 
        const __v8sf v = poly8(x[i], c);
        const __v4df d0 =  v[0], v[1], v[2], v[3] ;
        const __v4df d1 =  v[4], v[5], v[6], v[7] ;
        const __v4df err0 = (d0 - a[2*i+0]) * m[2*i+0];
        const __v4df err1 = (d1 - a[2*i+1]) * m[2*i+1];
        eqs -= (__v8si)_mm256_cmp_ps(v, f[i], _CMP_EQ_OQ);
        min = _mm256_min_pd(min, err0);
        max = _mm256_max_pd(max, err1);
        min = _mm256_min_pd(min, err1);
        max = _mm256_max_pd(max, err0);
    

    if (max_under) 
        if (min[0] > min[1]) min[0] = min[1];
        if (min[0] > min[2]) min[0] = min[2];
        if (min[0] > min[3]) min[0] = min[3];
        *max_under = min[0];
    

    if (max_over) 
        if (max[0] < max[1]) max[0] = max[1];
        if (max[0] < max[2]) max[0] = max[2];
        if (max[0] < max[3]) max[0] = max[3];
        *max_over = max[0];
    

    return (size_t)((unsigned int)eqs[0])
         + (size_t)((unsigned int)eqs[1])
         + (size_t)((unsigned int)eqs[2])
         + (size_t)((unsigned int)eqs[3])
         + (size_t)((unsigned int)eqs[4])
         + (size_t)((unsigned int)eqs[5])
         + (size_t)((unsigned int)eqs[6])
         + (size_t)((unsigned int)eqs[7]);


/** precalculate() - Allocate and precalculate tables for error8().
 * @x0   - First argument to precalculate
 * @x1   - Last argument to precalculate
 * @xptr - Pointer to a __v8sf pointer for the arguments
 * @fptr - Pointer to a __v8sf pointer for the correctly rounded results
 * @aptr - Pointer to a __v4df pointer for the comparison results
 * @mptr - Pointer to a __v4df pointer for the difference multipliers
 * Returns the vector count if successful,
 * 0 with errno set otherwise.
*/
size_t precalculate(const float x0, const float x1,
                 __v8sf **const xptr, __v8sf **const fptr,
                 __v4df **const aptr, __v4df **const mptr)

    const size_t align = 64;
    unsigned int i0, i1;
    size_t       n, i, sbytes, dbytes;
    __v8sf      *x = NULL;
    __v8sf      *f = NULL;
    __v4df      *a = NULL;
    __v4df      *m = NULL;

    if (!xptr || !fptr || !aptr || !mptr) 
        errno = EINVAL;
        return (size_t)0;
    

    memcpy(&i0, &x0, sizeof i0);
    memcpy(&i1, &x1, sizeof i1);

    i0 ^= (i0 & 0x80000000U) ? 0xFFFFFFFFU : 0x80000000U;
    i1 ^= (i1 & 0x80000000U) ? 0xFFFFFFFFU : 0x80000000U;

    if (i1 > i0)
        n = (((size_t)i1 - (size_t)i0) | (size_t)7) + (size_t)1;
    else
    if (i0 > i1)
        n = (((size_t)i0 - (size_t)i1) | (size_t)7) + (size_t)1;
    else 
        errno = EINVAL;
        return (size_t)0;
    

    sbytes = n * sizeof (float);
    if (sbytes % align)
        sbytes += align - (sbytes % align);

    dbytes = n * sizeof (double);
    if (dbytes % align)
        dbytes += align - (dbytes % align);

    if (posix_memalign((void **)&x, align, sbytes)) 
        errno = ENOMEM;
        return (size_t)0;
    
    if (posix_memalign((void **)&f, align, sbytes)) 
        free(x);
        errno = ENOMEM;
        return (size_t)0;
    
    if (posix_memalign((void **)&a, align, dbytes)) 
        free(f);
        free(x);
        errno = ENOMEM;
        return (size_t)0;
    
    if (posix_memalign((void **)&m, align, dbytes)) 
        free(a);
        free(f);
        free(x);
        errno = ENOMEM;
        return (size_t)0;
    

    if (x1 > x0) 
        float *const xp = (float *)x;
        float        curr = x0;

        for (i = 0; i < n; i++) 
            xp[i] = curr;
            curr = nextafterf(curr, HUGE_VALF);
        

        i = n;
        while (i-->0 && xp[i] > x1)
            xp[i] = x1;
     else 
        float *const xp = (float *)x;
        float        curr = x0;

        for (i = 0; i < n; i++) 
            xp[i] = curr;
            curr = nextafterf(curr, -HUGE_VALF);
        

        i = n;
        while (i-->0 && xp[i] < x1)
            xp[i] = x1;
    

    
        const float *const xp = (const float *)x;
        float *const       fp = (float *)f;
        double *const      ap = (double *)a;
        double *const      mp = (double *)m;

        for (i = 0; i < n; i++) 
            const float curr = xp[i];
            int         temp;

            fp[i] = atanf(curr);
            ap[i] = atan((double)curr);

            (void)frexp(ap[i], &temp);
            mp[i] = ldexp(16777216.0, temp);
        
    

    *xptr = x;
    *fptr = f;
    *aptr = a;
    *mptr = m;

    errno = 0;
    return n/8;


static int parse_range(const char *const str, float *const range)

    float fmin, fmax;
    char  dummy;

    if (sscanf(str, " %f %f %c",   &fmin, &fmax, &dummy) == 2 ||
        sscanf(str, " %f:%f %c",   &fmin, &fmax, &dummy) == 2 ||
        sscanf(str, " %f,%f %c",   &fmin, &fmax, &dummy) == 2 ||
        sscanf(str, " %f/%f %c",   &fmin, &fmax, &dummy) == 2 ||
        sscanf(str, " %ff %ff %c", &fmin, &fmax, &dummy) == 2 ||
        sscanf(str, " %ff:%ff %c", &fmin, &fmax, &dummy) == 2 ||
        sscanf(str, " %ff,%ff %c", &fmin, &fmax, &dummy) == 2 ||
        sscanf(str, " %ff/%ff %c", &fmin, &fmax, &dummy) == 2) 
        if (range) 
            range[0] = fmin;
            range[1] = fmax;
        
        return 0;
    

    if (sscanf(str, " %f %c",  &fmin, &dummy) == 1 ||
        sscanf(str, " %ff %c", &fmin, &dummy) == 1) 
        if (range) 
            range[0] = fmin;
            range[1] = fmin;
        
        return 0;
    

    return errno = ENOENT;


static int fix_range(float *const f)

    if (f && f[0] > f[1]) 
        const float tmp = f[0];
        f[0] = f[1];
        f[1] = tmp;
    
    return f && isfinite(f[0]) && isfinite(f[1]) && (f[1] >= f[0]);


static const char *f2s(char *const buffer, const size_t size, const float value, const char *const invalid)

    char  format[32];
    float parsed;
    int   decimals, length;

    for (decimals = 0; decimals <= 16; decimals++) 
        length = snprintf(format, sizeof format, "%%.%df", decimals);
        if (length < 1 || length >= (int)sizeof format)
            break;

        length = snprintf(buffer, size, format, value);
        if (length < 1 || length >= (int)size)
            break;

        if (sscanf(buffer, "%f", &parsed) == 1 && parsed == value)
            return buffer;

        decimals++;
    

    for (decimals = 0; decimals <= 16; decimals++) 
        length = snprintf(format, sizeof format, "%%.%dg", decimals);
        if (length < 1 || length >= (int)sizeof format)
            break;

        length = snprintf(buffer, size, format, value);
        if (length < 1 || length >= (int)size)
            break;

        if (sscanf(buffer, "%f", &parsed) == 1 && parsed == value)
            return buffer;

        decimals++;
    

    length = snprintf(buffer, size, "%a", value);
    if (length < 1 || length >= (int)size)
        return invalid;

    if (sscanf(buffer, "%f", &parsed) == 1 && parsed == value)
        return buffer;

    return invalid;


int main(int argc, char *argv[])

    float xrange[2] =  0.75f, 1.00f ;
    float c17range[2], c15range[2], c13range[2], c11range[2];
    float c9range[2], c7range[2], c5range[2], c3range[2];
    float c[8];

    __v8sf *known_x;
    __v8sf *known_f;
    __v4df *known_a;
    __v4df *known_m;
    size_t  known_n;

    if (argc != 10 || !strcmp(argv[1], "-h") || !strcmp(argv[1], "--help")) 
        fprintf(stderr, "\n");
        fprintf(stderr, "Usage: %s [ -h | --help ]\n", argv[0]);
        fprintf(stderr, "       %s C17 C15 C13 C11 C9 C7 C5 C3 x\n", argv[0]);
        fprintf(stderr, "\n");
        fprintf(stderr, "Each of the coefficients can be a constant or a range,\n");
        fprintf(stderr, "for example 0.25 or 0.75:1. x must be a non-empty range.\n");
        fprintf(stderr, "\n");
        return EXIT_FAILURE;
    

    if (parse_range(argv[1], c17range) || !fix_range(c17range)) 
        fprintf(stderr, "%s: Invalid C17 range or constant.\n", argv[1]);
        return EXIT_FAILURE;
    
    if (parse_range(argv[2], c15range) || !fix_range(c15range)) 
        fprintf(stderr, "%s: Invalid C15 range or constant.\n", argv[2]);
        return EXIT_FAILURE;
    
    if (parse_range(argv[3], c13range) || !fix_range(c13range)) 
        fprintf(stderr, "%s: Invalid C13 range or constant.\n", argv[3]);
        return EXIT_FAILURE;
    
    if (parse_range(argv[4], c11range) || !fix_range(c11range)) 
        fprintf(stderr, "%s: Invalid C11 range or constant.\n", argv[4]);
        return EXIT_FAILURE;
    
    if (parse_range(argv[5], c9range) || !fix_range(c9range)) 
        fprintf(stderr, "%s: Invalid C9 range or constant.\n", argv[5]);
        return EXIT_FAILURE;
    
    if (parse_range(argv[6], c7range) || !fix_range(c7range)) 
        fprintf(stderr, "%s: Invalid C7 range or constant.\n", argv[6]);
        return EXIT_FAILURE;
    
    if (parse_range(argv[7], c5range) || !fix_range(c5range)) 
        fprintf(stderr, "%s: Invalid C5 range or constant.\n", argv[7]);
        return EXIT_FAILURE;
    
    if (parse_range(argv[8], c3range) || !fix_range(c3range)) 
        fprintf(stderr, "%s: Invalid C3 range or constant.\n", argv[8]);
        return EXIT_FAILURE;
    

    if (parse_range(argv[9], xrange) || xrange[0] == xrange[1] ||
        !isfinite(xrange[0]) || !isfinite(xrange[1])) 
        fprintf(stderr, "%s: Invalid x range.\n", argv[9]);
        return EXIT_FAILURE;
    

    known_n = precalculate(xrange[0], xrange[1], &known_x, &known_f, &known_a, &known_m);
    if (!known_n) 
        if (errno == ENOMEM)
            fprintf(stderr, "Not enough memory for precalculated tables.\n");
        else
            fprintf(stderr, "Invalid (empty) x range.\n");
        return EXIT_FAILURE;
    

    fprintf(stderr, "Precalculated %lu arctangents to compare to.\n", 8UL * (unsigned long)known_n);
    fprintf(stderr, "\nC17 C15 C13 C11 C9 C7 C5 C3 max-ulps-under max-ulps-above correctly-rounded percentage cycles\n");
    fflush(stderr);

    
        const double  percent = 12.5 / (double)known_n;
        size_t        rounded;
        char          c17buffer[64], c15buffer[64], c13buffer[64], c11buffer[64];
        char          c9buffer[64], c7buffer[64], c5buffer[64], c3buffer[64];
        char          minbuffer[64], maxbuffer[64];
        float         minulps, maxulps;
        unsigned long tsc_start, tsc_stop;

        for (c[0] = c17range[0]; c[0] <= c17range[1]; c[0] = nextafterf(c[0], HUGE_VALF))
        for (c[1] = c15range[0]; c[1] <= c15range[1]; c[1] = nextafterf(c[1], HUGE_VALF))
        for (c[2] = c13range[0]; c[2] <= c13range[1]; c[2] = nextafterf(c[2], HUGE_VALF))
        for (c[3] = c11range[0]; c[3] <= c11range[1]; c[3] = nextafterf(c[3], HUGE_VALF))
        for (c[4] = c9range[0]; c[4] <= c9range[1]; c[4] = nextafterf(c[4], HUGE_VALF))
        for (c[5] = c7range[0]; c[5] <= c7range[1]; c[5] = nextafterf(c[5], HUGE_VALF))
        for (c[6] = c5range[0]; c[6] <= c5range[1]; c[6] = nextafterf(c[6], HUGE_VALF))
        for (c[7] = c3range[0]; c[7] <= c3range[1]; c[7] = nextafterf(c[7], HUGE_VALF)) 
            tsc_start = __builtin_ia32_rdtsc();
            rounded = error8(known_x, c, known_f, known_a, known_m, known_n, &minulps, &maxulps);
            tsc_stop = __builtin_ia32_rdtsc();
            printf("%-13s %-13s %-13s %-13s %-13s %-13s %-13s %-13s %-13s %-13s %lu %.3f %lu\n",
                   f2s(c17buffer, sizeof c17buffer, c[0], "?"),
                   f2s(c15buffer, sizeof c15buffer, c[1], "?"),
                   f2s(c13buffer, sizeof c13buffer, c[2], "?"),
                   f2s(c11buffer, sizeof c11buffer, c[3], "?"),
                   f2s(c9buffer, sizeof c9buffer, c[4], "?"),
                   f2s(c7buffer, sizeof c7buffer, c[5], "?"),
                   f2s(c5buffer, sizeof c5buffer, c[6], "?"),
                   f2s(c3buffer, sizeof c3buffer, c[7], "?"),
                   f2s(minbuffer, sizeof minbuffer, minulps, "?"),
                   f2s(maxbuffer, sizeof maxbuffer, maxulps, "?"),
                   rounded, (double)rounded * percent,
                   (unsigned long)(tsc_stop - tsc_start));
            fflush(stdout);

        
    

    return EXIT_SUCCESS;

代码在 Linux 上使用 GCC-4.8.2 进行编译,但可能需要针对其他编译器和/或操作系统进行修改。 (不过,我很乐意包含/接受修复这些问题的编辑。我自己没有 Windows 或 ICC,所以我可以检查一下。)

要编译这个,我推荐

gcc -Wall -O3 -fomit-frame-pointer -march=native -mtune=native example.c -lm -o example

不带参数运行以查看用法;或

./example 0x1.7ed24ap-9f -0x1.0c2c12p-6f  0x1.61fdd2p-5f -0x1.3556b0p-4f 0x1.b4e138p-4f -0x1.230ae2p-3f  0x1.9978eep-3f -0x1.5554dap-2f 0.75:1

检查它报告的 njuffa 系数集,与标准 C 库 atan() 函数进行比较,并考虑 [0.75, 1] 中所有可能的 x。

除了固定系数,您还可以使用min:max 定义要扫描的范围(扫描所有唯一的单精度浮点值)。对每个可能的系数组​​合进行测试。

因为我更喜欢十进制表示法,但需要保持数值准确,所以我使用f2s() 函数来显示浮点值。这是一个简单的蛮力辅助函数,它使用最短的格式,在解析回浮点数时产生相同的值。

例如,

./example 0x1.7ed248p-9f:0x1.7ed24cp-9f -0x1.0c2c10p-6f:-0x1.0c2c14p-6f 0x1.61fdd0p-5f:0x1.61fdd4p-5f -0x1.3556aep-4f:-0x1.3556b2p-4f 0x1.b4e136p-4f:0x1.b4e13ap-4f -0x1.230ae0p-3f:-0x1.230ae4p-3f 0x1.9978ecp-3f:0x1.9978f0p-3f -0x1.5554d8p-2f:-0x1.5554dcp-2f 0.75:1

计算所有 6561 (38) 系数组合±1 ULP 周围 njuffa 的集合为 x 在 [0.75, 1]。 (事实上​​,它表明将 C17 减少 1 个 ULP 到 0x1.7ed248p-9f 会产生完全相同的结果。)

(在 2.6 GHz 的 Core i5-4200U 上运行需要 90 秒——与我估计的每核每 GHz 每秒 30 个系数集的估计基本一致。虽然这段代码没有线程化,但关键功能是线程化的——安全,所以线程应该不会太难。这个 Core i5-4200U 是一台笔记本电脑,即使只对一个核心施加压力也会很热,所以我没有打扰。)

(我认为上面的代码是在公共领域,或者是 CC0 许可的,在公共领域是不可能的。事实上,我不确定它是否有足够的创意来获得版权。无论如何,请随意以任何你希望的方式在任何地方使用它,只要你不怪我,如果它坏了。)

问题?增强功能?欢迎进行编辑以修复 Linux/GCC 问题!

【讨论】:

以上是关于[-1,1] 上反正切的最佳机器优化多项式极小极大近似?的主要内容,如果未能解决你的问题,请参考以下文章

线性回归

PICOS 中的极小极大优化

梯度下降

我正在尝试实现一个极小极大算法来创建一个井字游戏机器人,但我遇到了递归错误

计算机博弈 期望搜索算法算法 期望极大极小算法

转:极小极大搜索方法负值最大算法和Alpha-Beta搜索方法