jdk8源码TimSort算法——从头看到脚

Posted 小兀哥

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了jdk8源码TimSort算法——从头看到脚相关的知识,希望对你有一定的参考价值。

      首先,在Java 6中Arrays.sort()和Collections.sort()使用的是MergeSort,而在Java7以后,内部实现换成了TimSort。我们通过看jdk8的Collections.sort()源码,来了解一下TimSort算法

简介

      Timsort是一个自适应的、混合的、稳定的排序算法,融合了归并算法和二分插入排序算法的精髓,在现实世界的数据中有着特别优秀的表现。它是由Tim Peter于2002年发明的,用在Python这个编程语言里面。这个算法之所以快,是因为它充分利用了现实世界的待排序数据里面,有很多子串是已经排好序的不需要再重新排序,利用这个特性并且加上合适的合并规则可以更加高效的排序剩下的待排序序列。

      当Timsort运行在部分排序好的数组里面的时候,需要的比较次数要远小于nlogn,也是远小于相同情况下的归并排序算法需要的比较次数。但是和其他的归并排序算法一样,最坏情况下的时间复杂度是O(nlogn)的水平。但是在最坏的情况下,Timsort需要的临时存储空间只有n/2,在最好的情况下,需要的额外空间是常数级别的。从各个方面都能够击败需要O(n)空间和稳定O(nlogn)时间的归并算法。

jdk源码

Collections

public static <T extends Comparable<? super T>> void sort(List<T> list) 
    list.sort(null);


List

default void sort(Comparator<? super E> c) 
        Object[] a = this.toArray();
         //在这里真正确定了使用的是TimSort算法
         //默认的Array. sort(int[] a)这里用的是双轴排序,以后再说
        Arrays.sort(a, (Comparator) c);
        ListIterator<E> i = this.listIterator();
        for (Object e : a) 
            i.next();
            i.set((E) e);
        
    

Arrays

	public static <T> void sort(T[] a, Comparator<? super T> c) 
        if (c == null) 
        	//这里数组里的元素如果是引用类型必须要实现Comparator<T>接口
        	//并对其排序内部的比较函数compare()进行重写,以便于我们按照我们的排序要求对引用对象数组极性排序,默认是升序排序,但可以自己自定义成降序排序
            sort(a); 
         else 
            if (LegacyMergeSort.userRequested)
            	//这是兼容1.6之前旧版本,采用的是冒泡排序和归并排序
                legacyMergeSort(a, c);
            else
                TimSort.sort(a, 0, a.length, c, null, 0, 0);
        
    
    public static void sort(Object[] a) 
        if (LegacyMergeSort.userRequested)
            legacyMergeSort(a);
        else
        	//这里跟 TimSort.sort的思想是一样的
            ComparableTimSort.sort(a, 0, a.length, null, 0, 0);
    

TimSort

	static <T> void sort(T[] a, int lo, int hi, Comparator<? super T> c,
                         T[] work, int workBase, int workLen) 
        assert c != null && a != null && lo >= 0 && lo <= hi && hi <= a.length;

        int nRemaining  = hi - lo;
        if (nRemaining < 2)
            return;   // 长度是0或者1 就不需要排序了。

        // 1 如果小于32,就用二分插入排序算法
        if (nRemaining < MIN_MERGE) 
        	// 1.1 先找自然升序序列(如果是倒序,会颠倒为正序排列),返回自然序列大小
            // 这里的自然序列就是数组中从lo以后,已经排好的序列
            int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
            // 1.2 二分插入排序
            binarySort(a, lo, hi, lo + initRunLen, c);
            return;
        

        // 2 归并排序
        // 新建TimSort对象,保存栈的状态
        TimSort<T> ts = new TimSort<>(a, c, work, workBase, workLen);
        // 2.1 切分数组,返回大小在[16,32)之间的数,作为最小的分割槽大小
        int minRun = minRunLength(nRemaining);
        // 2.2 循环拆分数组,形成大小相差不多的run分割曹
        do 
            // 2.2.1 先找自然升序序列(如果是倒序,会颠倒为正序排列),返回自然序列大小
            int runLen = countRunAndMakeAscending(a, lo, hi, c);

            // 2.2.2 如果自然升序序列小于minRun ,需要按照minRun大小进行拆分并排序
            if (runLen < minRun) 
                int force = nRemaining <= minRun ? nRemaining : minRun;
                //把短的自然升序序列通过二分插入排序
                binarySort(a, lo, lo + force, lo + runLen, c);
                runLen = force;
            

            // 2.2.3 把已经排好序的数列压入栈中,检查是不是需要合并
            ts.pushRun(lo, runLen);
            // 2.2.3 检查是不是需要合并
            ts.mergeCollapse();

            //把指针后移runLen距离,准备开始下一轮片段的排序
            lo += runLen;
            //剩下待排序的数量相应的减少 runLen
            nRemaining -= runLen;
         while (nRemaining != 0);

        // 3 合并栈中所有待合并的序列
        assert lo == hi;
        ts.mergeForceCollapse();
        assert ts.stackSize == 1;
    

1.1 先找自然升序序列

	/**
     * 这一段代码是TimSort算法中的一个小优化,它利用了数组中前面一段已有的顺序。
     * 如果是升序,直接返回统计结果;如果是降序,在返回之前,将这段数列倒置,
     * 以确保这断序列从首个位置到此位置的序列都是升序的。
     * 返回的结果是这种两种形式的,lo是这段序列的开始位置。
     * 为了保证排序的稳定性,这里要使用严格的降序,这样才能保证相等的元素不参与倒置子序列的过程,
     * 保证它们原本的顺序不被打乱。
     *
     * @param a  参与排序的数组
     * @param lo run中首个元素的位置
     * @param hi run中最后一个元素的后面一个位置,需要确保lo<hi
     * @param c  本次排序的比较器
     * @return 从首个元素开始的最长升序子序列的结尾位置+1 or 严格的降序子序列的结尾位置+1。
     */
	private static <T> int countRunAndMakeAscending(T[] a, int lo, int hi,
                                                    Comparator<? super T> c) 
        assert lo < hi;
        int runHi = lo + 1;
        if (runHi == hi)
            return 1;

        // 找出最长升序序的子序列,如果降序,倒置之
        if (c.compare(a[runHi++], a[lo]) < 0)  // 前两个元素是降序,就按照降序统计
            while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) < 0)
                runHi++;
            reverseRange(a, lo, runHi);
         else                               // 前两个元素是升序,按照升序统计
            while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) >= 0)
                runHi++;
        

        return runHi - lo;
    

1.2 二分插入排序

 	/**
     * 被优化的二分插入排序
     * 使用二分插入排序算法给指定一部分数组排序。这是给小数组排序的最佳方案。最差情况下
     * 它需要 O(n log n) 次比较和 O(n^2)次数据移动。
     * 如果开始的部分数据是有序的那么我们可以利用它们。这个方法默认数组中的位置lo(包括在内)到
     * start(不包括在内)的范围内是已经排好序的。
     *
     * @param a     被排序的数组
     * @param lo    待排序范围内的首个元素的位置
     * @param hi    待排序范围内最后一个元素的后一个位置
     * @param start 待排序范围内的第一个没有排好序的位置,确保 (lo <= start <= hi)
     * @param c     本次排序的比较器
     */
	private static <T> void binarySort(T[] a, int lo, int hi, int start,
                                       Comparator<? super T> c) 
        assert lo <= start && start <= hi;
        if (start == lo)
            start++;
        for ( ; start < hi; start++) 
        	//pivot 代表正在参与排序的值
            T pivot = a[start];

            //如果start 从起点开始,做下预处理;也就是原本就是无序的。
            int left = lo;
            int right = start;
            assert left <= right;
            /*
             * 利用二分查找,找到需要插入的位置,保证的逻辑:
             *   pivot >= all in [lo, left).
             *   pivot <  all in [right, start).
             */
            while (left < right) 
                int mid = (left + right) >>> 1;
                if (c.compare(pivot, a[mid]) < 0)
                    right = mid;
                else
                    left = mid + 1;
            
            assert left == right;

             /**
             * 此时,仍然能保证:pivot >= [lo, left) && pivot < [left,start)
             * 所以,pivot的值应当在left所在的位置,然后需要把[left,start)范围内的内容整体右移一位腾出空间。
             * 如果pivot与区间中的某个值相等,left指正会指向重复的值的后一位(从left = mid + 1;这里可以看出),
             * 所以这里的排序是稳定的。
             */
            int n = start - left;  //需要移动的范围的长度
            // switch语句是一条小优化,1-2个元素的移动就不需要System.arraycopy了。
            // (这代码写的真是简洁,switch原来可以这样用)
            switch (n) 
                case 2:  a[left + 2] = a[left + 1];
                case 1:  a[left + 1] = a[left];
                         break;
                default: System.arraycopy(a, left, a, left + 1, n);
            
            //移动过之后,把pivot的值放到应该插入的位置,就是left的位置了
            a[left] = pivot;
        
    

2.1 TimSort.minRunLength()切分数组,返回大小在[16,32)之间。

private static int minRunLength(int n) 
        assert n >= 0;
        int r = 0;      // Becomes 1 if any 1 bits are shifted off
        while (n >= MIN_MERGE) 
            r |= (n & 1); //计算最近一次n的末位数是1还是0
            n >>= 1; //缩小二倍,除以2
        
        return n + r;
    

2.2.3 检查是不是需要合并

 	/**
     * 检查栈中待归并的升序序列,如果他们不满足下列条件就把相邻的两个序列合并,
     * 直到他们满足下面的条件
     *
     * 1. runLen[i - 3] > runLen[i - 2] + runLen[i - 1]
     * 2. runLen[i - 2] > runLen[i - 1]
     *
     * 每次添加新序列到栈中的时候都会执行一次这个操作。所以栈中的需要满足的条件
     * 需要靠调用这个方法来维护。
     *
     * 最差情况下,有点像玩2048。
     */
	private void mergeCollapse() 
        while (stackSize > 1) 
            int n = stackSize - 2;
            if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) 
                if (runLen[n - 1] < runLen[n + 1])
                    n--;
                mergeAt(n);
             else if (runLen[n] <= runLen[n + 1]) 
                mergeAt(n);
             else 
                break; // Invariant is established
            
        
    

3 合并栈中所有待合并的序列

 	/**
     * 合并栈中所有待合并的序列,最后剩下一个序列。这个方法在整次排序中只执行一次
     */
    private void mergeForceCollapse() 
        while (stackSize > 1) 
            int n = stackSize - 2;
            if (n > 0 && runLen[n - 1] < runLen[n + 1])
                n--;
            mergeAt(n);
        
    

归并排序mergeAt

   /**
     * 合并在栈中位于i和i+1的两个相邻的升序序列。 i必须为从栈顶数,第二和第三个元素。
     * 换句话说i == stackSize - 2 || i == stackSize - 3
     *
     * @param i 待合并的第一个序列所在的位置
     */
	private void mergeAt(int i) 
		//校验
        assert stackSize >= 2;
        assert i >= 0;
        assert i == stackSize - 2 || i == stackSize - 3;
		//内部初始化
        int base1 = runBase[i];
        int len1 = runLen[i];
        int base2 = runBase[i + 1];
        int len2 = runLen[i + 1];
        assert len1 > 0 && len2 > 0;
        assert base1 + len1 == base2;

        /*
         * 记录合并后的序列的长度;如果i == stackSize - 3 就把最后一个序列的信息
         * 往前移一位,因为本次合并不关它的事。i+1对应的序列被合并到i序列中了,所以
         * i+1 数列可以消失了
         */
        runLen[i] = len1 + len2;
        if (i == stackSize - 3) 
            runBase[i + 1] = runBase[i + 2];
            runLen[i + 1] = runLen[i + 2];
        
        //i+1消失了,所以长度也减下来了
        stackSize--;

        /*
         * 找出第二个序列的首个元素可以插入到第一个序列的什么位置,因为在此位置之前的序列已经就位了。
         * 它们可以被忽略,不参加归并。
         */
        int k = gallopRight(a[base2], a, base1, len1, 0, c);
        assert k >= 0;
        // 因为要忽略前半部分元素,所以起点和长度相应的变化
        base1 += k;
        len1 -= k;
        // 如果序列2 的首个元素要插入到序列1的后面,那就直接结束了,
        // !!! 因为序列2在数组中的位置本来就在序列1后面,也就是整个范围本来就是有序的!!!
        if (len1 == 0)
            return;

        /*
         * 跟上面相似,看序列1的最后一个元素(a[base1+len1-1])可以插入到序列2的什么位置(相对第二个序列起点的位置,非在数组中的位置),
         * 这个位置后面的元素也是不需要参与归并的。所以len2直接设置到这里,后面的元素直接忽略。
         */
        len2 = gallopLeft(a[base1 + len1 - 1], a, base2, len2, len2 - 1, c);
        assert len2 >= 0;
        if (len2 == 0)
            return;

        // 合并剩下的两个有序序列,并且这里为了节省空间,临时数组选用 min(len1,len2)的长度
        // 优化的很细呢
        if (len1 <= len2)
            mergeLo(base1, len1, base2, len2);
        else
            mergeHi(base1, len1, base2, len2);
    

归并排序gallopLeft

/**
     * 在一个序列中,将一个指定的key,从左往右查找它应当插入的位置;如果序列中存在
     * 与key相同的值(一个或者多个),那返回这些值中最左边的位置。
     *
     * 推断: 统计概率的原因,随机数字来说,两个待合并的序列的尾假设是差不多大的,从尾开始
     * 做查找找到的概率高一些。仔细算一下,最差情况下,这种查找也是 log(n),所以这里没有
     * 用简单的二分查找。
     * 这里先简单的做了一个大概的范围锁定lastOfs到ofs,然后再从这个区间中用二分查找法去查
     *
     * @param key  准备插入的key
     * @param a    参与排序的数组
     * @param base 序列范围的第一个元素的位置
     * @param len  整个范围的长度,一定有len > 0
     * @param hint 开始查找的位置,有0 <= hint <= len;越接近结果查找越快
     * @param c    排序,查找使用的比较器
     * @return 返回一个整数 k, 有 0 <= k <=n, 它满足 a[b + k - 1] < a[b + k]
     * 就是说key应当被放在 a[base + k],
     * 有 a[base,base+k) < key && key <=a [base + k, base + len)
     */
    private static <T> int gallopLeft(T key, T[] a, int base, int len, int hint,
                                      Comparator<? super T> c) 
        assert len > 0 && hint >= 0 && hint < len;
        int lastOfs = 0;
        int ofs = 1;
        if (c.compare(key, a[base + hint]) > 0)  // key > a[base+hint]
            // 遍历右边,直到 a[base+hint+lastOfs] < key <= a[base+hint+ofs]
            int maxOfs = len - hint;
            while (ofs < maxOfs && c.compare(key, a[base + hint + ofs]) > 0) 
                lastOfs = ofs;
                ofs = (ofs << 1) + 1;
                if (ofs <= 0)   // int overflow
                    ofs = maxOfs;
            
            if (ofs > maxOfs)
                ofs = maxOfs;

            // 最终的ofs是这样确定的,满足条件 a[base+hint+lastOfs] < key <= a[base+hint+ofs]
            // 的一组
            // ofs:     1   3   7  15  31  63 2^n-1 ... maxOfs
            // lastOfs: 0   1   3   7  15  31 2^(n-1)-1  < ofs


            // 因为目前的offset是相对hint的,所以做相对变换
            lastOfs += hint;
            ofs += hint;
         else  // key <= a[base + hint]
            // 遍历左边,直到[base+hint-ofs] < key <= a[base+hint-lastOfs]
            final int maxOfs = hint + 1;
            while (ofs < maxOfs && c.compare(key, a[base + hint - ofs]) <= 0) 
                lastOfs = ofs;
                ofs = (ofs << 1) + 1;
                if (ofs <= 0)   // int overflow
                    ofs = maxOfs;
            
            if (ofs > maxOfs)
                ofs = maxOfs;
            // 确定ofs的过程与上面相同
            // ofs:     1   3   7  15  31  63 2^n-1 ... maxOfs
            // lastOfs: 0   1   3   7  15  31 2^(n-1)-1  < ofs

            // Make offsets relative to base
            int tmp = lastOfs;
            lastOfs = hint - ofs;
            ofs = hint - tmp;
        
        assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;

        /*
         * 现在的情况是 a[base+lastOfs] < key <= a[base+ofs], 所以,key应当在lastOfs的
         * 右边,又不超过ofs。在base+lastOfs-1到 base+ofs范围内做一次二叉查找。
         */
        lastOfs++;
        while (lastOfs < ofs) 
            int m = lastOfs + ((ofs - lastOfs) >>> 1);

            if (c.compare(key, a[base + m]) > 0)以上是关于jdk8源码TimSort算法——从头看到脚的主要内容,如果未能解决你的问题,请参考以下文章

TimSort算法分析

Java - 源码之 Arrays 内部排序 TimSort 实现

从头用脚分析FFmpeg源码 - avformat_write_header

一文了解 Python 中的 Timsort 排序算法

简易版的TimSort排序算法

TimSort排序算法及一个问题分析