为啥我的基数排序 python 实现比快速排序慢?

Posted

技术标签:

【中文标题】为啥我的基数排序 python 实现比快速排序慢?【英文标题】:Why is my radix sort python implementation slower than quick sort?为什么我的基数排序 python 实现比快速排序慢? 【发布时间】:2011-11-26 15:53:36 【问题描述】:

我使用 SciPy 中的数组重写了 Wikipedia 中用于 Python 的原始基数排序算法,以提高性能并减少代码长度,我设法做到了。然后我采用了 Literate Programming 中的 classic(内存中,基于枢轴)快速排序算法并比较了它们的性能。

我曾期望基数排序会在超过某个阈值时击败快速排序,但事实并非如此。此外,我发现Erik Gorset's Blog's 提出了一个问题“基数排序比整数数组的快速排序更快吗?”。答案是这样的

.. 基准测试显示 MSB 就地基数排序始终比大型数组的快速排序快 3 倍以上。

很遗憾,我无法重现结果;不同之处在于 (a) Erik 选择了 Java 而不是 Python,并且 (b) 他使用 MSB 就地基数排序,而我只是在 Python 字典中填充 buckets .

根据理论基数排序应该比快速排序更快(线性);但显然这在很大程度上取决于实施。那么我的错误在哪里?

这是比较两种算法的代码:

from sys   import argv
from time  import clock

from pylab import array, vectorize
from pylab import absolute, log10, randint
from pylab import semilogy, grid, legend, title, show

###############################################################################
# radix sort
###############################################################################

def splitmerge0 (ls, digit): ## python (pure!)

    seq = map (lambda n: ((n // 10 ** digit) % 10, n), ls)
    buf = 0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]

    return reduce (lambda acc, key: acc.extend(buf[key]) or acc,
        reduce (lambda _, (d,n): buf[d].append (n) or buf, seq, buf), [])

def splitmergeX (ls, digit): ## python & numpy

    seq = array (vectorize (lambda n: ((n // 10 ** digit) % 10, n)) (ls)).T
    buf = 0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]

    return array (reduce (lambda acc, key: acc.extend(buf[key]) or acc,
        reduce (lambda _, (d,n): buf[d].append (n) or buf, seq, buf), []))

def radixsort (ls, fn = splitmergeX):

    return reduce (fn, xrange (int (log10 (absolute (ls).max ()) + 1)), ls)

###############################################################################
# quick sort
###############################################################################

def partition (ls, start, end, pivot_index):

    lower = start
    upper = end - 1

    pivot = ls[pivot_index]
    ls[pivot_index] = ls[end]

    while True:

        while lower <= upper and ls[lower] <  pivot: lower += 1
        while lower <= upper and ls[upper] >= pivot: upper -= 1
        if lower > upper: break

        ls[lower], ls[upper] = ls[upper], ls[lower]

    ls[end] = ls[lower]
    ls[lower] = pivot

    return lower

def qsort_range (ls, start, end):

    if end - start + 1 < 32:
        insertion_sort(ls, start, end)
    else:
        pivot_index = partition (ls, start, end, randint (start, end))
        qsort_range (ls, start, pivot_index - 1)
        qsort_range (ls, pivot_index + 1, end)

    return ls

def insertion_sort (ls, start, end):

    for idx in xrange (start, end + 1):
        el = ls[idx]
        for jdx in reversed (xrange(0, idx)):
            if ls[jdx] <= el:
                ls[jdx + 1] = el
                break
            ls[jdx + 1] = ls[jdx]
        else:
            ls[0] = el

    return ls

def quicksort (ls):

    return qsort_range (ls, 0, len (ls) - 1)

###############################################################################
if __name__ == "__main__":
###############################################################################

    lower = int (argv [1]) ## requires: >= 2
    upper = int (argv [2]) ## requires: >= 2
    color = dict (enumerate (3*['r','g','b','c','m','k']))

    rslbl = "radix sort"
    qslbl = "quick sort"

    for value in xrange (lower, upper):

        #######################################################################

        ls = randint (1, value, size=value)

        t0 = clock ()
        rs = radixsort (ls)
        dt = clock () - t0

        print "%06d -- t0:%0.6e, dt:%0.6e" % (value, t0, dt)
        semilogy (value, dt, '%s.' % color[int (log10 (value))], label=rslbl)

        #######################################################################

        ls = randint (1, value, size=value)

        t0 = clock ()
        rs = quicksort (ls)
        dt = clock () - t0

        print "%06d -- t0:%0.6e, dt:%0.6e" % (value, t0, dt)
        semilogy (value, dt, '%sx' % color[int (log10 (value))], label=qslbl)

    grid ()
    legend ((rslbl,qslbl), numpoints=3, shadow=True, prop='size':'small')
    title ('radix & quick sort: #(integer) vs duration [s]')
    show ()

###############################################################################
###############################################################################

这是比较大小在 2 到 1250(横轴)范围内的整数数组的排序持续时间(以秒为单位)(对数纵轴)的结果;下面的曲线属于快速排序:

Radix vs Quick Sort Comparison

快速排序在功率变化时很平滑(例如,在 10、100 或 1000 时),但基数排序只是跳跃一点,但在质量上遵循与快速排序相同的路径,只是慢得多!

【问题讨论】:

包含 1250 个元素的数组并不是真正的大数组。对 1000000 个元素进行排序会得到什么结果? 当你抛出一个包含 1,000,000 或 10,000,000 个值的列表时会发生什么?看起来你有一个非常低效的基数排序实现(比如在内部循环中计算 10**digit 和许多不必要的函数调用),所以当你只有 1250 个元素时,理论上的效率可能不可见进行排序。 顺便说一句,您的代码很难阅读,但在我看来,您似乎依赖于迭代 dict 以数字顺序为您提供键。这意味着您依赖于未定义的行为,因此您的代码可能随时中断。 @Duncan:嗯,我已经预先计算数字的能力,并且正在使用查找来提高性能,但这并没有帮助;我没有发现任何显着的改进。 不必要的函数调用是什么意思? lambda 表达式? @sth:好吧,问题是实现显然太慢了,以至于我可以设法对多达 10000 个列表进行排序,之后非常慢。尽管算法之间的相对距离似乎缩小了,但即使是 10000 快速排序也更快。 【参考方案1】:

这里有几个问题。

首先,正如 cmets 中所指出的,您的数据集太小,理论上的复杂性无法克服代码中的开销。

接下来,您的所有那些不必要的函数调用和复制列表的实现效率非常低。以简单的程序方式编写代码几乎总是比函数式解决方案更快(对于 Python,其他语言在此处会有所不同)。您有一个快速排序的程序实现,因此如果您以相同的样式编写基数排序,即使对于小列表,它也可能会更快。

最后,当您尝试大型列表时,内存管理的开销可能开始占主导地位。这意味着您在实现效率是主要因素的小列表和内存管理是主要因素的大列表之间的窗口有限。

这里有一些使用快速排序的代码,但它是一个简单的基数排序,它是按程序编写的,但试图避免大量复制数据。您会看到,即使对于短列表,它也优于快速排序,但更有趣的是,随着数据大小的增加,快速排序和基数排序之间的比率也在增加,然后随着内存管理开始占主导地位,它又开始下降(简单的事情,比如释放1,000,000 个项目的列表需要很长时间):

from random import randint
from math import log10
from time import clock
from itertools import chain

def splitmerge0 (ls, digit): ## python (pure!)

    seq = map (lambda n: ((n // 10 ** digit) % 10, n), ls)
    buf = 0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]

    return reduce (lambda acc, key: acc.extend(buf[key]) or acc,
        reduce (lambda _, (d,n): buf[d].append (n) or buf, seq, buf), [])

def splitmerge1 (ls, digit): ## python (readable!)
    buf = [[] for i in range(10)]
    divisor = 10 ** digit
    for n in ls:
        buf[(n//divisor)%10].append(n)
    return chain(*buf)

def radixsort (ls, fn = splitmerge1):
    return list(reduce (fn, xrange (int (log10 (max(abs(val) for val in ls)) + 1)), ls))

###############################################################################
# quick sort
###############################################################################

def partition (ls, start, end, pivot_index):

    lower = start
    upper = end - 1

    pivot = ls[pivot_index]
    ls[pivot_index] = ls[end]

    while True:

        while lower <= upper and ls[lower] <  pivot: lower += 1
        while lower <= upper and ls[upper] >= pivot: upper -= 1
        if lower > upper: break

        ls[lower], ls[upper] = ls[upper], ls[lower]

    ls[end] = ls[lower]
    ls[lower] = pivot

    return lower

def qsort_range (ls, start, end):

    if end - start + 1 < 32:
        insertion_sort(ls, start, end)
    else:
        pivot_index = partition (ls, start, end, randint (start, end))
        qsort_range (ls, start, pivot_index - 1)
        qsort_range (ls, pivot_index + 1, end)

    return ls

def insertion_sort (ls, start, end):

    for idx in xrange (start, end + 1):
        el = ls[idx]
        for jdx in reversed (xrange(0, idx)):
            if ls[jdx] <= el:
                ls[jdx + 1] = el
                break
            ls[jdx + 1] = ls[jdx]
        else:
            ls[0] = el

    return ls

def quicksort (ls):

    return qsort_range (ls, 0, len (ls) - 1)

if __name__=='__main__':
    for value in 1000, 10000, 100000, 1000000, 10000000:
        ls = [randint (1, value) for _ in range(value)]
        ls2 = list(ls)
        last = -1
        start = clock()
        ls = radixsort(ls)
        end = clock()
        for i in ls:
            assert last <= i
            last = i
        print("rs %d: %0.2fs" % (value, end-start))
        tdiff = end-start
        start = clock()
        ls2 = quicksort(ls2)
        end = clock()
        last = -1
        for i in ls2:
            assert last <= i
            last = i
        print("qs %d: %0.2fs %0.2f%%" % (value, end-start, ((end-start)/tdiff*100)))

我运行时的输出是:

C:\temp>c:\python27\python radixsort.py
rs 1000: 0.00s
qs 1000: 0.00s 212.98%
rs 10000: 0.02s
qs 10000: 0.05s 291.28%
rs 100000: 0.19s
qs 100000: 0.58s 311.98%
rs 1000000: 2.47s
qs 1000000: 7.07s 286.33%
rs 10000000: 31.74s
qs 10000000: 86.04s 271.08%

编辑: 只是为了澄清。这里的快速排序实现对内存非常友好,它就地排序,所以无论列表有多大,它只是在不复制数据的情况下打乱数据。原始的 radixsort 有效地将每个数字的列表复制两次:一次复制到较小的列表中,然后在连接列表时再次复制。使用itertools.chain 可以避免第二次复制,但仍有大量内存分配/释放正在进行。 (此外,“两次”是近似的,因为列表附加确实涉及额外的复制,即使它是摊销 O(1),所以我应该说“与两次成比例”。)

【讨论】:

我不知道 itertools 的链功能,但会检查一下。我的基本假设是,在 Python 中 for 循环 实际上很慢,因此我故意以函数式的方式表达所有内容。为了进一步避免复制列表,我使用了 acc.extend(buf[key]) 或 acc 之类的技巧,我将累加器返回以减少。我想我需要阅读更多关于函数式 Python 内部结构的信息才能看到相应的陷阱。 一般假设:快速:mapreduce 使用 C 编写的函数。较慢:Python for 循环。更慢:mapreduce 调用 Python 编码的函数(编组参数是一项昂贵的操作)。另外,请记住lambdadef 之间没有区别,def foo(acc, key): acc.extend(buf[key]); return acc 比在表达式之间使用lambdaor 混淆更简单。【参考方案2】:

您的数据表示非常昂贵。为什么要为您的存储桶使用 hashmap?为什么要使用需要计算对数(= 计算成本高)的 base10 表示?

避免使用 lambda 表达式等,我认为 python 还不能很好地优化它们。

也许从为基准测试排序 10 字节字符串开始。并且:没有 Hashmap 和类似的昂贵数据结构。

【讨论】:

我不确定基准的 10 字节字符串是否合适,当声称 radixsort 优于 integer 数组的快速排序时。 感谢有关 lambda 表达式的提示,我会尝试一下.. 虽然我有点失望 Python 在使用 lambda 时表现如此糟糕。 :( @Anonymous:当我想对整数进行排序并因此想避免将它们转换为 10 字节字符串 时,你会建议什么而不是 hashmap?进一步:据我所知,对数只计算一次,所以避免这样做不会带来太多收益,对吗? 使用例如一个数组而不是 hashmap。 试试 Wikipedia 中的示例代码,en.wikipedia.org/wiki/Radix_sort#Example_in_Python,看看它与你的相比表现如何。

以上是关于为啥我的基数排序 python 实现比快速排序慢?的主要内容,如果未能解决你的问题,请参考以下文章

为啥我的分拣程序这么慢? (java中的基数/桶排序)

为啥在我的情况下快速排序总是比冒泡排序慢?

为啥 R 使用基数排序?

经典排序算法和python详解:归并排序快速排序堆排序计数排序桶排序和基数排序

Python八大算法的实现,插入排序希尔排序冒泡排序快速排序直接选择排序堆排序归并排序基数排序。

快速排序及优化