AVX2 列填充计数算法分别在每个位列上

Posted

技术标签:

【中文标题】AVX2 列填充计数算法分别在每个位列上【英文标题】:AVX2 column population count algorithm over each bit-column separately 【发布时间】:2019-10-21 12:18:01 【问题描述】:

对于我正在处理的项目,我需要计算撕裂的PDF 图像数据中每列设置的位数。

我正在尝试获取整个 PDF 作业(所有页面)中每一列的总设置位数。

数据一旦被翻录,就会存储在MemoryMappedFile 中,没有后备文件(在内存中)。

PDF 页面尺寸为 13952 像素 x 15125 像素。生成的翻录数据的总大小可以通过将PDF 的长度(高度)(以像素为单位)乘以宽度(以字节为单位)来计算。翻录的数据是1 bit == 1 pixel。所以一个翻录页面的大小(以字节为单位)是(13952 / 8) * 15125

请注意,宽度始终是64 bits 的倍数。

我必须在被翻录后计算PDF(可能是数万页)的每一页中每一列的设置位。

我首先通过一个基本的解决方案来解决这个问题,即循环每个字节并计算设置位的数量并将结果放在vector 中。从那以后,我将算法缩减到如下所示。我的执行时间从约 350 毫秒变为约 120 毫秒。

static void count_dots( )

    using namespace diag;
    using namespace std::chrono;

    std::vector<std::size_t> dot_counts( 13952, 0 );
    uint64_t* ptr_dot_counts dot_counts.data( ) ;

    std::vector<uint64_t> ripped_pdf_data( 3297250, 0xFFFFFFFFFFFFFFFFUL );
    const uint64_t* ptr_data ripped_pdf_data.data( ) ;

    std::size_t line_count 0 ;
    std::size_t counter ripped_pdf_data.size( ) ;

    stopwatch sw;
    sw.start( );

    while( counter > 0 )
    
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000000000001UL ) >> 0;

        ++ptr_data;
        --counter;
        if( ++line_count >= 218 )
        
            ptr_dot_counts = dot_counts.data( );
            line_count = 0;
        
       

    sw.stop( );
    std::cout << sw.elapsed<milliseconds>( ) << "ms\n";

不幸的是,这仍然会增加很多额外的处理时间,这是不可接受的。

上面的代码很丑,不会赢得任何选美比赛,但它有助于减少执行时间。 从我写的原始版本开始,我做了以下事情:

使用pointers 而不是indexersuint64 的块而不是uint8 处理数据 手动展开for循环以遍历byte的每个bit中的每个uint64 使用最终的bit shift 而不是__popcnt64 来计算屏蔽后的集合bit

对于这个测试,我正在生成伪造的翻录数据,其中每个 bit 都设置为 1。测试完成后,dot_counts vector 应该包含每个 element15125

我希望这里的一些人可以帮助我将算法的平均执行时间控制在 100 毫秒以下。 我不关心这里的可移植性。

目标机器的CPU:Xeon E5-2680 v4 - Intel 编译器:MSVC++ 14.23 操作系统:Windows 10 C++ 版本:C++17 编译器标志:/O2/arch:AVX2

大约 8 年前有人问过一个非常相似的问题: How to quickly count bits into separate bins in a series of ints on Sandy Bridge?

(编者注:也许您错过了Count each bit-position separately over many 64-bit bitmasks, with AVX but not AVX2,它有一些更新更快的答案,至少对于在连续内存中沿着一列而不是沿着一行向下移动。也许您可以向下移动 1 或 2 个缓存线宽列,这样您就可以使您的计数器在 SIMD 寄存器中保持热状态。)

当我将迄今为止的答案与已接受的答案进行比较时,我已经相当接近了。我已经在处理uint64 的块而不是uint8。我只是想知道是否还有更多我可以做的事情,无论是使用内在函数、程序集还是简单的事情,比如更改我正在使用的数据结构。

【问题讨论】:

稍后我可能会更深入地了解一下,但从快速浏览来看,您应该能够通过简单地缓存按位结果来提高速度。您为每个字节重复右移 8 次,这是完全没有必要的。 @PickleRick 当你说“缓存按位结果”时,你的意思是缓存*ptr_data &gt;&gt; 7 的结果并重用它吗? 是的,完全正确。保存所有 8 个字节的移位 ptr_data 结果,这样就不必在每次迭代时重新计算 (8 * 8 = 64) 次。 只是为了澄清:您的图像以行为主存储?尺寸总是一样的,还是只能保证宽度是64的倍数?您需要将结果作为uint64 的向量还是uint16 的向量也可以?你能对对齐做出任何假设吗? @PickleRick 好主意,我试试看 【参考方案1】:

它可以用 AVX2 完成,如标记。

为了使这项工作正常进行,我推荐vector&lt;uint16_t&gt; 进行计数。增加计数是最大的问题,我们需要扩大的越多,问题就越大。 uint16_t 足以计算一页,因此您可以一次计算一页并将计数器添加到一组更宽的计数器中以获得总计。这是一些开销,但比必须在主循环中扩大更多的开销要少得多。

计数的大端顺序非常烦人,引入了更多的洗牌以使其正确。所以我建议弄错错误,然后再重新排序(也许在将它们加到总数中时?)。 “先右移 7,再右移 6,再右移 5”的顺序可以免费维护,因为我们可以随意选择 64 位块的移位计数。所以在下面的代码中,实际的计数顺序是:

最低有效字节的第 7 位, 第二个字节的第7位 ... 最高有效字节的第 7 位, 最低有效字节的第 6 位, ...

所以每组 8 人都颠倒过来。 (至少这是我打算做的,AVX2 unpacks 令人困惑)

代码(未测试):

while( counter > 0 )

    __m256i data = _mm256_set1_epi64x(*ptr_data);        
    __m256i data1 = _mm256_srlv_epi64(data, _mm256_set_epi64x(4, 6, 5, 7));
    __m256i data2 = _mm256_srlv_epi64(data, _mm256_set_epi64x(0, 2, 1, 3));
    data1 = _mm256_and_si256(data1, _mm256_set1_epi8(1));
    data2 = _mm256_and_si256(data2, _mm256_set1_epi8(1));

    __m256i zero = _mm256_setzero_si256();

    __m256i c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
    c = _mm256_add_epi16(_mm256_unpacklo_epi8(data1, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c);

    c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
    c = _mm256_add_epi16(_mm256_unpackhi_epi8(data1, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c);

    c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
    c = _mm256_add_epi16(_mm256_unpacklo_epi8(data2, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c);

    c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
    c = _mm256_add_epi16(_mm256_unpackhi_epi8(data2, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c);

    ptr_dot_counts += 64;
    ++ptr_data;
    --counter;
    if( ++line_count >= 218 )
    
        ptr_dot_counts = dot_counts.data( );
        line_count = 0;
    

这可以进一步展开,一次处理多行。这很好,因为如前所述,对计数器求和是最大的问题,而按行展开会减少这种情况,而在寄存器中进行更简单的求和。

使用了一些内在函数:

_mm256_set1_epi64x,将一个 int64_t 复制到向量的所有 4 个 64 位元素。 uint64_t 也可以。 _mm256_set_epi64x,将 4 个 64 位值转换为向量。 _mm256_srlv_epi64,逻辑右移,计数可变(每个元素可以是不同的计数)。 _mm256_and_si256,按位与。 _mm256_add_epi16,另外,适用于 16 位元素。 _mm256_unpacklo_epi8_mm256_unpackhi_epi8,可能最好通过该页面上的图表来解释

可以“垂直”求和,使用一个 uint64_t 保存 64 个单独和的所有第 0 位,另一个 uint64_t 保存总和的所有第 1 位等。加法可以通过用按位算术模拟全加器(电路组件)。然后不是只向计数器添加 0 或 1,而是一次添加更大的数字。

垂直和也可以向量化,但这会大大增加将垂直和添加到列和的代码,所以我在这里没有这样做。它应该有帮助,但它只是很多代码。

示例(未测试):

size_t y;
// sum 7 rows at once
for (y = 0; (y + 6) < 15125; y += 7) 
    ptr_dot_counts = dot_counts.data( );
    ptr_data = ripped_pdf_data.data( ) + y * 218;
    for (size_t x = 0; x < 218; x++) 
        uint64_t dataA = ptr_data[0];
        uint64_t dataB = ptr_data[218];
        uint64_t dataC = ptr_data[218 * 2];
        uint64_t dataD = ptr_data[218 * 3];
        uint64_t dataE = ptr_data[218 * 4];
        uint64_t dataF = ptr_data[218 * 5];
        uint64_t dataG = ptr_data[218 * 6];
        // vertical sums, 7 bits to 3
        uint64_t abc0 = (dataA ^ dataB) ^ dataC;
        uint64_t abc1 = (dataA ^ dataB) & dataC | (dataA & dataB);
        uint64_t def0 = (dataD ^ dataE) ^ dataF;
        uint64_t def1 = (dataD ^ dataE) & dataF | (dataD & dataE);
        uint64_t bit0 = (abc0 ^ def0) ^ dataG;
        uint64_t c1   = (abc0 ^ def0) & dataG | (abc0 & def0);
        uint64_t bit1 = (abc1 ^ def1) ^ c1;
        uint64_t bit2 = (abc1 ^ def1) & c1 | (abc1 & def1);
        // add vertical sums to column counts
        __m256i bit0v = _mm256_set1_epi64x(bit0);
        __m256i data01 = _mm256_srlv_epi64(bit0v, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data02 = _mm256_srlv_epi64(bit0v, _mm256_set_epi64x(0, 2, 1, 3));
        data01 = _mm256_and_si256(data01, _mm256_set1_epi8(1));
        data02 = _mm256_and_si256(data02, _mm256_set1_epi8(1));
        __m256i bit1v = _mm256_set1_epi64x(bit1);
        __m256i data11 = _mm256_srlv_epi64(bit1v, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data12 = _mm256_srlv_epi64(bit1v, _mm256_set_epi64x(0, 2, 1, 3));
        data11 = _mm256_and_si256(data11, _mm256_set1_epi8(1));
        data12 = _mm256_and_si256(data12, _mm256_set1_epi8(1));
        data11 = _mm256_add_epi8(data11, data11);
        data12 = _mm256_add_epi8(data12, data12);
        __m256i bit2v = _mm256_set1_epi64x(bit2);
        __m256i data21 = _mm256_srlv_epi64(bit2v, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data22 = _mm256_srlv_epi64(bit2v, _mm256_set_epi64x(0, 2, 1, 3));
        data21 = _mm256_and_si256(data21, _mm256_set1_epi8(1));
        data22 = _mm256_and_si256(data22, _mm256_set1_epi8(1));
        data21 = _mm256_slli_epi16(data21, 2);
        data22 = _mm256_slli_epi16(data22, 2);
        __m256i data1 = _mm256_add_epi8(_mm256_add_epi8(data01, data11), data21);
        __m256i data2 = _mm256_add_epi8(_mm256_add_epi8(data02, data12), data22);

        __m256i zero = _mm256_setzero_si256();

        __m256i c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c);


        ptr_dot_counts += 64;
        ++ptr_data;
    

// leftover rows
for (; y < 15125; y++) 
    ptr_dot_counts = dot_counts.data( );
    ptr_data = ripped_pdf_data.data( ) + y * 218;
    for (size_t x = 0; x < 218; x++) 
        __m256i data = _mm256_set1_epi64x(*ptr_data);
        __m256i data1 = _mm256_srlv_epi64(data, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data2 = _mm256_srlv_epi64(data, _mm256_set_epi64x(0, 2, 1, 3));
        data1 = _mm256_and_si256(data1, _mm256_set1_epi8(1));
        data2 = _mm256_and_si256(data2, _mm256_set1_epi8(1));

        __m256i zero = _mm256_setzero_si256();

        __m256i c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c);


        ptr_dot_counts += 64;
        ++ptr_data;
    


到目前为止,第二好的方法是一种更简单的方法,更像第一个版本,除了一次运行 yloopLen 行以利用快速 8 位求和:

size_t yloopLen = 32;
size_t yblock = yloopLen * 1;
size_t yy;
for (yy = 0; yy < 15125; yy += yblock) 
    for (size_t x = 0; x < 218; x++) 
        ptr_data = ripped_pdf_data.data() + x;
        ptr_dot_counts = dot_counts.data() + x * 64;
        __m256i zero = _mm256_setzero_si256();

        __m256i c1 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
        __m256i c2 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
        __m256i c3 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
        __m256i c4 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);

        size_t end = std::min(yy + yblock, size_t(15125));
        size_t y;
        for (y = yy; y < end; y += yloopLen) 
            size_t len = std::min(size_t(yloopLen), end - y);
            __m256i count1 = zero;
            __m256i count2 = zero;

            for (size_t t = 0; t < len; t++) 
                __m256i data = _mm256_set1_epi64x(ptr_data[(y + t) * 218]);
                __m256i data1 = _mm256_srlv_epi64(data, _mm256_set_epi64x(4, 6, 5, 7));
                __m256i data2 = _mm256_srlv_epi64(data, _mm256_set_epi64x(0, 2, 1, 3));
                data1 = _mm256_and_si256(data1, _mm256_set1_epi8(1));
                data2 = _mm256_and_si256(data2, _mm256_set1_epi8(1));
                count1 = _mm256_add_epi8(count1, data1);
                count2 = _mm256_add_epi8(count2, data2);
            

            c1 = _mm256_add_epi16(_mm256_unpacklo_epi8(count1, zero), c1);
            c2 = _mm256_add_epi16(_mm256_unpackhi_epi8(count1, zero), c2);
            c3 = _mm256_add_epi16(_mm256_unpacklo_epi8(count2, zero), c3);
            c4 = _mm256_add_epi16(_mm256_unpackhi_epi8(count2, zero), c4);
        

        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c1);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c2);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c3);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c4);
    

之前有一些测量问题,最后这实际上并没有更好,但也没有比上面更花哨的“垂直和”版本差多少。

【讨论】:

我还会考虑在字节级别(甚至分层的 2 位、4 位、8 位、16 位)上进行最内层的添加。我认为要考虑的主要问题之一是如何优化使用寄存器和缓存。 @harold 哇,这令人印象深刻。执行时间从平均约 110 毫秒降至约 10 毫秒。推进指针后,dot_counts 计数正确。现在我需要做我的工作,并在实现之前逐行了解代码在做什么。感谢您的时间和精力。 版本 2 将执行时间从 ~10ms(版本 1)降低到 ~5ms。你给了我很多功课要做。 @WBuck 和哈罗德:我对Count each bit-position separately over many 64-bit bitmasks, with AVX but not AVX2 的回答有一个使用标量uint64_t 的逐渐扩大版本。我没有花时间手动对其进行矢量化;您可能想同时向下列 2 个__m256i 向量来执行整个缓存行。如果您想在没有 AVX512 的情况下更宽,套准压力可能是个问题。缓存阻塞有助于避免浪费缓存流量(尤其是在冲突未命中之前从未加载的 L2 空间预取)。 @WBuck 和哈罗德:哦,另请参阅github.com/mklarqvist/positional-popcount。它有一个 SSE 混合版本和 AVX512 harley-seal,列出了显着的加速。 (但它在 k 位字中是定位的,而不是在许多字的整行中,除非它适用于非常大的 k)。 Harley-Seal 是我认为这就是 Harold 在不同寄存器中实现具有不同位置值位的全加器所做的事情。它将 3 个输入减少到 2 个,并且可以重复应用。所以,是的,这样做下去列应该会很好。

以上是关于AVX2 列填充计数算法分别在每个位列上的主要内容,如果未能解决你的问题,请参考以下文章

计数排序,桶排序,基数排序的python实现

十大排序算法之计数排序

算法-排序-计数排序

根据一维计数器数组填充二维数组列

排序算法八:计数排序(Counting Sort)

累积总和数据帧的条件计数 - 遍历列