使用 AVX2 查找元素索引 - 代码优化

Posted

技术标签:

【中文标题】使用 AVX2 查找元素索引 - 代码优化【英文标题】:Find element index with AVX2 - code optimization 【发布时间】:2020-05-31 18:32:59 【问题描述】:

我正在摆弄 AVX2 来编写一些代码,该代码能够在具有 14 个条目的数组中搜索 32 位哈希并返回找到的条目的索引。

因为很可能绝大多数命中都在数组的前 8 个条目中,所以添加 __builtin_expect 的使用已经可以改进此代码,这不是我现在的优先事项。

虽然哈希数组(在由变量 hashes 表示的代码中)总是有 14 个条目长,但它包含在这种结构中

typedef struct chain_ring chain_ring_t;
struct chain_ring 
    uint32_t hashes[14];
    chain_ring_t* next;
    ...other stuff...
 __attribute__((aligned(16)))

这里是代码

int8_t hash32_find_14_avx2(uint32_t hash, volatile uint32_t* hashes) 
    uint32_t compacted_result_mask, leading_zeroes;
    __m256i cmp_vector, ring_vector, result_mask_vector;
    int8_t found_index = -1;

    if (hashes[0] == hash) 
        return 0;
    

    for(uint8_t base_index = 0; base_index < 14; base_index += 8) 
        cmp_vector = _mm256_set1_epi32(hash);
        ring_vector = _mm256_stream_load_si256((__m256i*) (hashes + base_index));

        result_mask_vector = _mm256_cmpeq_epi32(ring_vector, cmp_vector);
        compacted_result_mask = _mm256_movemask_epi8(result_mask_vector);

        if (compacted_result_mask != 0) 
            leading_zeroes = 32 - __builtin_clz(compacted_result_mask);
            found_index = base_index + (leading_zeroes >> 2u) - 1;
            break;
        
    

    return found_index > 13 ? -1 : found_index;

简单解释的逻辑是搜索前 8 个条目,然后搜索后 8 个条目。如果找到的索引大于 13,则意味着它找到了一些不属于数组的内容的匹配项,因此必须将其视为不匹配。

注意事项:

为了加速加载(来自对齐的内存)我正在使用 _mm256_stream_load_si256 由于上述原因,我需要检查返回值是否大于 13,并且我不太喜欢这个特定部分,是否应该使用 _mm256_maskload_epi32? 我使用for循环来避免重复代码,gcc当然会展开循环 我正在使用 __builtin_clz 但我正在使用 -mlzcnt 编译代码,因为据我所知,AMD cpus 运行 bsr 指令的速度要慢得多,gcc 使用 lzcnt 而不是带有标志的 bsr 第一个 IF 引入了平均约 0.30 ns 的延迟,但平均而言它减少了第一次匹配的时间 0.6ns 代码仅适用于 64 位机器 有时我需要针对 aarch64 优化此代码

这里有一个很好的链接到godbolt的生产程序集 https://godbolt.org/z/5bxbN6

我也实现了 SSE 版本(它在要点中),但逻辑是相同的,虽然我不确定它的性能值得

作为参考,我构建了一个简单的线性搜索函数,并使用 google-benchmark 库与它进行了性能比较

int8_t hash32_find_14_loop(uint32_t hash, volatile uint32_t* hashes) 
    for(uint8_t index = 0; index <= 14; index++) 
        if (hashes[index] == hash) 
            return index;
        
    

    return -1;

完整代码可在此网址获得https://gist.github.com/danielealbano/9fcbc1ff0a42cc9ad61be205366bdb5f

除了 google-benchmark 库的必要标志之外,我正在使用 -avx2 -avx -msse4 -O3 -mbmi -mlzcnt 编译它

执行每个元素的工作台(我想比较循环与替代方案)

----------------------------------------------------------------------------------------------------
Benchmark                                                          Time             CPU   Iterations
----------------------------------------------------------------------------------------------------
bench_template_hash32_find_14_loop/0/iterations:100000000       0.610 ns        0.610 ns    100000000
bench_template_hash32_find_14_loop/1/iterations:100000000        1.16 ns         1.16 ns    100000000
bench_template_hash32_find_14_loop/2/iterations:100000000        1.18 ns         1.18 ns    100000000
bench_template_hash32_find_14_loop/3/iterations:100000000        1.19 ns         1.19 ns    100000000
bench_template_hash32_find_14_loop/4/iterations:100000000        1.28 ns         1.28 ns    100000000
bench_template_hash32_find_14_loop/5/iterations:100000000        1.26 ns         1.26 ns    100000000
bench_template_hash32_find_14_loop/6/iterations:100000000        1.52 ns         1.52 ns    100000000
bench_template_hash32_find_14_loop/7/iterations:100000000        2.15 ns         2.15 ns    100000000
bench_template_hash32_find_14_loop/8/iterations:100000000        1.66 ns         1.66 ns    100000000
bench_template_hash32_find_14_loop/9/iterations:100000000        1.67 ns         1.67 ns    100000000
bench_template_hash32_find_14_loop/10/iterations:100000000       1.90 ns         1.90 ns    100000000
bench_template_hash32_find_14_loop/11/iterations:100000000       1.89 ns         1.89 ns    100000000
bench_template_hash32_find_14_loop/12/iterations:100000000       2.13 ns         2.13 ns    100000000
bench_template_hash32_find_14_loop/13/iterations:100000000       2.20 ns         2.20 ns    100000000
bench_template_hash32_find_14_loop/14/iterations:100000000       2.32 ns         2.32 ns    100000000
bench_template_hash32_find_14_loop/15/iterations:100000000       2.53 ns         2.53 ns    100000000
bench_template_hash32_find_14_sse/0/iterations:100000000        0.531 ns        0.531 ns    100000000
bench_template_hash32_find_14_sse/1/iterations:100000000         1.42 ns         1.42 ns    100000000
bench_template_hash32_find_14_sse/2/iterations:100000000         2.53 ns         2.53 ns    100000000
bench_template_hash32_find_14_sse/3/iterations:100000000         1.45 ns         1.45 ns    100000000
bench_template_hash32_find_14_sse/4/iterations:100000000         2.26 ns         2.26 ns    100000000
bench_template_hash32_find_14_sse/5/iterations:100000000         1.90 ns         1.90 ns    100000000
bench_template_hash32_find_14_sse/6/iterations:100000000         1.90 ns         1.90 ns    100000000
bench_template_hash32_find_14_sse/7/iterations:100000000         1.93 ns         1.93 ns    100000000
bench_template_hash32_find_14_sse/8/iterations:100000000         2.07 ns         2.07 ns    100000000
bench_template_hash32_find_14_sse/9/iterations:100000000         2.05 ns         2.05 ns    100000000
bench_template_hash32_find_14_sse/10/iterations:100000000        2.08 ns         2.08 ns    100000000
bench_template_hash32_find_14_sse/11/iterations:100000000        2.08 ns         2.08 ns    100000000
bench_template_hash32_find_14_sse/12/iterations:100000000        2.55 ns         2.55 ns    100000000
bench_template_hash32_find_14_sse/13/iterations:100000000        2.53 ns         2.53 ns    100000000
bench_template_hash32_find_14_sse/14/iterations:100000000        2.37 ns         2.37 ns    100000000
bench_template_hash32_find_14_sse/15/iterations:100000000        2.59 ns         2.59 ns    100000000
bench_template_hash32_find_14_avx2/0/iterations:100000000       0.537 ns        0.537 ns    100000000
bench_template_hash32_find_14_avx2/1/iterations:100000000        1.37 ns         1.37 ns    100000000
bench_template_hash32_find_14_avx2/2/iterations:100000000        1.38 ns         1.38 ns    100000000
bench_template_hash32_find_14_avx2/3/iterations:100000000        1.36 ns         1.36 ns    100000000
bench_template_hash32_find_14_avx2/4/iterations:100000000        1.37 ns         1.37 ns    100000000
bench_template_hash32_find_14_avx2/5/iterations:100000000        1.38 ns         1.38 ns    100000000
bench_template_hash32_find_14_avx2/6/iterations:100000000        1.40 ns         1.40 ns    100000000
bench_template_hash32_find_14_avx2/7/iterations:100000000        1.39 ns         1.39 ns    100000000
bench_template_hash32_find_14_avx2/8/iterations:100000000        1.99 ns         1.99 ns    100000000
bench_template_hash32_find_14_avx2/9/iterations:100000000        2.02 ns         2.02 ns    100000000
bench_template_hash32_find_14_avx2/10/iterations:100000000       1.98 ns         1.98 ns    100000000
bench_template_hash32_find_14_avx2/11/iterations:100000000       1.98 ns         1.98 ns    100000000
bench_template_hash32_find_14_avx2/12/iterations:100000000       2.03 ns         2.03 ns    100000000
bench_template_hash32_find_14_avx2/13/iterations:100000000       1.98 ns         1.98 ns    100000000
bench_template_hash32_find_14_avx2/14/iterations:100000000       1.96 ns         1.96 ns    100000000
bench_template_hash32_find_14_avx2/15/iterations:100000000       1.97 ns         1.97 ns    100000000

感谢您的任何建议!

--- 更新

我已经用@chtz 制作的无分支实现更新了要点,并将 __lzcnt32 替换为 _tzcnt_u32,当返回 32 而不是 -1 时,我不得不稍微改变行为以考虑未找到,但这并不重要。

他们运行的 CPU 是 Intel Core i7 8700(6c/12t,3.20GHZ)。

bench 使用 cpu-pinning,使用比物理或逻辑 cpu 内核更多的线程,并执行一些额外的操作,特别是 for 循环,所以会有开销,但在两个测试之间是相同的,所以它应该对它们产生相同的影响方式。

如果您想运行测试,您需要调整 CPU_CORE_LOGICAL_COUNT 以手动匹配您的 cpu 的逻辑 cpu 核心数。

有趣的是,当存在更多争用(从单线程到 64 线程)时,性能提升如何从 +17% 跃升至 +41%。 我已经用 128 和 256 线程进行了一些测试,发现使用 AVX2 时速度提高了 +60%,但我没有包括以下数字。

(bench_template_hash32_find_14_avx2 是对无分支版本进行基准测试,我已缩短名称以使帖子更具可读性)

------------------------------------------------------------------------------------------
Benchmark                                                                 CPU   Iterations
------------------------------------------------------------------------------------------
bench_template_hash32_find_14_loop/iterations:10000000/threads:1      45.2 ns     10000000
bench_template_hash32_find_14_loop/iterations:10000000/threads:2      50.4 ns     20000000
bench_template_hash32_find_14_loop/iterations:10000000/threads:4      52.1 ns     40000000
bench_template_hash32_find_14_loop/iterations:10000000/threads:8      70.9 ns     80000000
bench_template_hash32_find_14_loop/iterations:10000000/threads:16     86.8 ns    160000000
bench_template_hash32_find_14_loop/iterations:10000000/threads:32     87.3 ns    320000000
bench_template_hash32_find_14_loop/iterations:10000000/threads:64     92.9 ns    640000000
bench_template_hash32_find_14_avx2/iterations:10000000/threads:1      38.4 ns     10000000
bench_template_hash32_find_14_avx2/iterations:10000000/threads:2      42.1 ns     20000000
bench_template_hash32_find_14_avx2/iterations:10000000/threads:4      46.5 ns     40000000
bench_template_hash32_find_14_avx2/iterations:10000000/threads:8      52.6 ns     80000000
bench_template_hash32_find_14_avx2/iterations:10000000/threads:16     60.0 ns    160000000
bench_template_hash32_find_14_avx2/iterations:10000000/threads:32     62.1 ns    320000000
bench_template_hash32_find_14_avx2/iterations:10000000/threads:64     65.8 ns    640000000

【问题讨论】:

_mm256_stream_load_si256?您的数据是否在视频 RAM 中,或者您是否以某种方式将内存页面映射为 WC,而不是普通的 WB 可缓存?如果不是,那么vmovntdqa 加载只是正常加载的慢版本。此外,在 movemask_epi8 之前将_mm256_movemask_pspackssdw / packsswb 一起使用您的双字向量,这样您就可以在每个分支中获得更多数据。 __builtin_clz 对于 0 是未定义的,实际上 gcc 很乐意将 31 - __builtin_clz(x) 优化为 bsr(对于零输入也是未定义的)。 由于您需要前导零计数,您可能需要 _lzcnt_u32 而不是 GNU C 内置函数。我认为所有 AVX2 机器也有 lzcnt(以及 BMI1 的其余部分),所以你也不会因为需要 BMI1 而错过任何东西。除非你真的想要32 - clz 而不是31-clz 注:如果我正确理解了您的基准代码,您的loop 结果很可能存在缺陷,因为每次测试相同的索引都会为您提供近乎完美的分支预测(实际上也在您的分支中) avx2 代码)。当然,除非您实际上期望在实践中出现这种行为。 您应该在本地计算机上使用-march=native 进行编译以适当地设置调整选项,并让编译器使用您的所有 CPU 功能(如 cmpxchg16b、FMA 和 BMI2)。 【参考方案1】:

通过比较数组的两个重叠部分,对它们进行位或运算,并使用单个 lzcnt 获取最后一位位置,您可以完全不使用分支来实现这一点。此外,使用 vmovmskps 而不是 vpmovmskb 可以节省除以 4 的结果(不过我不确定这是否会导致任何跨域延迟)。

int8_t hash32_find_14_avx2(uint32_t hash, volatile uint32_t* hashes) 
    uint32_t compacted_result_mask = 0;
    __m256i cmp_vector = _mm256_set1_epi32(hash);
    for(uint8_t base_index = 0; base_index < 12; base_index += 6) 
        __m256i ring_vector = _mm256_loadu_si256((__m256i*) (hashes + base_index));

        __m256i result_mask_vector = _mm256_cmpeq_epi32(ring_vector, cmp_vector);
        compacted_result_mask |= _mm256_movemask_ps(_mm256_castsi256_ps(result_mask_vector)) << (base_index);
    
    int32_t leading_zeros = __lzcnt32(compacted_result_mask);
    return (31 - leading_zeros);

正如 Peter 在 cmets 中已经指出的那样,在大多数情况下,_mm256_stream_load_si256 比正常负载差。另外,请注意,在 gcc 中使用未对齐加载时,您必须使用 -mno-avx256-split-unaligned-load(或者实际上只是使用 -march=native)进行编译 -- see this post for details。

Godbolt-Link 与简单的测试代码(请注意,如果数组中有多个匹配值,则循环版本和 avx2 版本的行为会有所不同): https://godbolt.org/z/2jNWqK

【讨论】:

感谢@chtz,如果没有这些跳转,代码会快得多! 我已经用您提供的代码更新了要点以供参考,非常感谢! gist.github.com/danielealbano/9fcbc1ff0a42cc9ad61be205366bdb5f 啊,是的,2x packssdw -> packsswb 适用于 16 字节向量。在这种情况下,最后 2 个元素可以通过 movq 加载完成。或者可能是 SSE3 movddup 加载以在前 2 个元素中重复相同的比较,以防哈希码可以是 0。等一下,我刚刚意识到这是使用lzcnt/bsr 而不是tzcnt/bsf,所以它找到了与标量代码相反的 latest 匹配项,而不是最早的匹配项。如果我们从最低位开始扫描,那么让位 14,15 重复 12,13 意味着 tzcnt 要么在它们之前停止,要么因为它们是 0 而扫描过去。 我猜 OP 不希望有重复的哈希值 - 或者在这种情况下返回任何有效索引都可以(否则原始版本已经无法正常工作)。 lzcnttzcnt 的优势在于,如果没有设置位,则更容易产生 -1 结果(是否返回 -1,而不是检查 Z 标志实际上更好是另一个问题。那是可能比稍后在索引号上进行另一个比较+分支便宜)。 @DanieleSalvatoreAlbano:那么您需要根据其中设置位的位置来索引其他内容吗?听起来你想遍历设置的位并找到它们的位置,你需要分支。从最低设置位到最高设置位迭代更容易,因为您可以使用单个 BMI1 指令mask &amp;= mask-1 i.e. BLSR 清除最低设置位。 (然后是tzcnt)。哦,刚刚看到你最后的评论。如果您必须重做比较,那么也许您没有这样做。但无论如何,您可能不需要实际的-1 并且可以 tzcnt

以上是关于使用 AVX2 查找元素索引 - 代码优化的主要内容,如果未能解决你的问题,请参考以下文章

算法递归算法 ② ( 使用递归实现二分法 | if else 编码优化 )

英特尔 AVX2 组装开发

MySQL优化-索引

appium-代码优化--H5页面点击后元素变更,查找元素时,找不到元素

MySQL索引基础补充以及优化笔记-上

索引优化之:创建填充和查找