最大化 Collat​​z 猜想程序 Python 的效率

Posted

技术标签:

【中文标题】最大化 Collat​​z 猜想程序 Python 的效率【英文标题】:Maximizing Efficiency of Collatz Conjecture Program Python 【发布时间】:2022-01-14 10:36:42 【问题描述】:

我的问题很简单。

我写这个程序纯粹是为了娱乐。它接受一个数字输入并找到每个 Collat​​z 序列的长度,直到并包括该数字。

我想在算法或数学上让它更快(即我知道我可以通过并行运行多个版本或用 C++ 编写它来让它更快,但那有什么乐趣呢?)。

欢迎任何帮助,谢谢!

编辑: 在 dankal444 的帮助下进一步优化代码

from matplotlib import pyplot as plt
import numpy as np
import numba as nb

# Get Range to Check
top_range = int(input('Top Range: '))

@nb.njit('int64[:](int_)')
def collatz(top_range):
    # Initialize mem
    mem = np.zeros(top_range + 1, dtype = np.int64)
    for start in range(2, top_range + 1):
        # If mod4 == 1: (3x + 1)/4
        if start % 4 == 1:
            mem[start] = mem[(start + (start >> 1) + 1) // 2] + 3
        
        # If 4mod == 3: 3(3x + 1) + 1 and continue
        elif start % 4 == 3:
            num = start + (start >> 1) + 1
            num += (num >> 1) + 1
            count = 4

            while num >= start:
                if num % 2:
                    num += (num >> 1) + 1
                    count += 2
                else:
                    num //= 2
                    count += 1
            mem[start] = mem[num] + count

        # If 4mod == 2 or 0: x/2
        else:
            mem[start] = mem[(start // 2)] + 1

    return mem

mem = collatz(top_range)

# Plot each starting number with the length of it's sequence
plt.scatter([*range(1, len(mem) + 1)], mem, color = 'black', s = 1)
plt.show()

【问题讨论】:

也许这更适合Code Review(这是另一个 Stack Exchange 社区)。 【参考方案1】:

在您的代码上应用 numba 确实有很大帮助。

我删除了 tqdm,因为它对性能没有帮助。

import time
from matplotlib import pyplot as plt
from tqdm import tqdm

import numpy as np
import numba as nb
@nb.njit('int64[:](int_)')
def collatz2(top_range):
    mem = np.zeros(top_range + 1, dtype=np.int64)
    for start in range(2, top_range + 1):
        # If mod(4) == 1: Value 2 or 3 Cached
        if start % 4 == 1:
            mem[start] = mem[(start + (start >> 1) + 1) // 2] + 3
        # If mod(4) == 3: Use Algorithm
        elif start % 4 == 3:
            num = start
            count = 0
            while num >= start:
                if num % 2:
                    num += (num >> 1) + 1
                    count += 2
                else:
                    num //= 2
                    count += 1
            mem[start] = mem[num] + count
        # If mod(4) == 2 or 4: Value 1 Cached
        else:
            mem[start] = mem[(start // 2)] + 1
    return mem


def collatz(top_range):
    mem = [0] * (top_range + 1)
    for start in range(2, top_range + 1):
        # If mod(4) == 1: Value 2 or 3 Cached
        if start % 4 == 1:
            mem[start] = mem[(start + (start >> 1) + 1) // 2] + 3
        # If mod(4) == 3: Use Algorithm
        elif start % 4 == 3:
            num = start
            count = 0
            while num >= start:
                if num % 2:
                    num += (num >> 1) + 1
                    count += 2
                else:
                    num //= 2
                    count += 1
            mem[start] = mem[num] + count
        # If mod(4) == 2 or 4: Value 1 Cached
        else:
            mem[start] = mem[(start // 2)] + 1
    return mem

# profiling here
def main():

    top_range = 1_000_000
    mem = collatz(top_range)
    mem2 = collatz2(top_range)
    assert np.allclose(np.array(mem), mem2)


对于 top_range = 1_000,优化后的函数要快约 100 倍。对于 top_range = 1_000_000,优化后的函数大约快 600 倍:

    79                                           def main():
    81         1          3.0      3.0      0.0      top_range = 1_000_000
    83         1   24633045.0 24633045.0     98.7      mem = collatz(top_range)
    85         1      39311.0  39311.0      0.2      mem2 = collatz2(top_range)

【讨论】:

这真是太好了。我对 Numba 不熟悉,这是在做什么?从 int32 转换为 int64? Numba 是一个 jit(即时)编译器,简而言之 - 将给定的函数编译为优化的机器代码。如果答案适合你,请采纳。

以上是关于最大化 Collat​​z 猜想程序 Python 的效率的主要内容,如果未能解决你的问题,请参考以下文章

python Collat​​z序列

python 最长的Collat​​z序列(Euler#14)

Collat​​z C++ 代码的问题

Erlang中的Collat​​z序列

c_cpp 使用libgmp进行Collat​​z。

简单的 nodejs 应用程序中的内存泄漏