TreeMap 源码分析

Posted zhuxudong

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了TreeMap 源码分析相关的知识,希望对你有一定的参考价值。

TreeMap

1)TreeMap 是基于红黑树的 NavigableMap 接口实现类。
2)TreeMap 根据键的自然顺序或构造时提供的比较器进行键的排序。
3){@code containsKey}, {@code get}, {@code put} and {@code remove} 方法的时间复杂度为 log(n)。
4)TreeMap 不是线程同步的,多线程并发访问 TreeMap 并且至少有一个线程修改了其结构,则它必须在外部实现同步。
SortedMap m = Collections.synchronizedSortedMap(new TreeMap(...));
5)TreeMap 返回的所有集合视图都是快速失败的,在迭代器创建之后,如果不是通过迭代器自身的 remove 方法修改其结构,则将抛出 ConcurrentModificationException 异常。

创建实例

    /**
     * 进行键排序的比较器,如果为 null,则使用自然顺序进行排序
     */
    private final Comparator<? super K> comparator;

    /**
     * 红黑树的根节点
     */
    private transient Entry<K,V> root;

    /**
     * TreeMap 中的元素总个数
     */
    private transient int size = 0;

    /**
     * TreeMap 被结构化修改的次数
     */
    private transient int modCount = 0;

    /**
     * 创建一个使用自然顺序进行键排序的空 TreeMap 实例
     */
    public TreeMap() {
        comparator = null;
    }

    /**
     * 创建一个使用指定比较器进行键排序的空 TreeMap 实例
     */
    public TreeMap(Comparator<? super K> comparator) {
        this.comparator = comparator;
    }

添加元素

    private static final boolean RED   = false;
    private static final boolean BLACK = true;

    static final class Entry<K,V> implements Map.Entry<K,V> {
        /**
         * 节点键
         */
        K key;
        /**
         * 节点值
         */
        V value;
        /**
         * 左子节点
         */
        Entry<K,V> left;
        /**
         * 右子节点
         */
        Entry<K,V> right;
        /**
         * 父节点
         */
        Entry<K,V> parent;
        /**
         * 节点默认为黑色
         */
        boolean color = BLACK;
    }

    /**
     * 往 TreeMap 中添加新的键值对
     */
    @Override
    public V put(K key, V value) {
        // 读取根节点
        Entry<K,V> t = root;
        // 1)TreeMap 为空
        if (t == null) {
            /**
             * 1)使用自然顺序进行键排序时,键是否实现了 Comparable 接口
             * 2)键值是否为 null
             */
            compare(key, key); // type (and possibly null) check
            // 创建根节点
            root = new Entry<>(key, value, null);
            size = 1;
            modCount++;
            return null;
        }
        int cmp;
        Entry<K,V> parent;
        // 读取键比较器
        final Comparator<? super K> cpr = comparator;
        // 1)键比较器不为 null,则使用比较器进行排序
        if (cpr != null) {
            do {
                parent = t;
                cmp = cpr.compare(key, t.key);
                if (cmp < 0) {
                    t = t.left;
                } else if (cmp > 0) {
                    t = t.right;
                } else {
                    return t.setValue(value);
                }
            } while (t != null);
        }
        // 2)键必须实现 Comparable 接口,并使用 Comparable#compareTo 方法进行比较
        else {
            if (key == null) {
                throw new NullPointerException();
            }
            @SuppressWarnings("unchecked")
            // 将键转换为 Comparable 实例
            final
            Comparable<? super K> k = (Comparable<? super K>) key;
            do {
                parent = t;
                // 将键和当前节点的键进行比较
                cmp = k.compareTo(t.key);
                // 1)目标键小于节点键
                if (cmp < 0) {
                    // 1-1)尝试比较其左子节点
                    t = t.left;
                    // 2)目标键大于节点键
                } else if (cmp > 0) {
                    // 2-1)尝试比较其右子节点
                    t = t.right;
                    // 3)目标键已经存在
                } else {
                    // 设置新值,返回旧值
                    return t.setValue(value);
                }
                // 已经无节点可比较
            } while (t != null);
        }
        // 创建新的节点,并设置其父节点
        final Entry<K,V> e = new Entry<>(key, value, parent);
        if (cmp < 0) {
            // parent 的左子节点更新为新增节点
            parent.left = e;
        } else {
            // parent 的右子节点更新为新增节点
            parent.right = e;
        }
        // 插入新元素后,平衡红黑树结构
        fixAfterInsertion(e);
        // 递增计数值
        size++;
        // 递增结构化修改计数值
        modCount++;
        return null;
    }

    // 平衡红黑树结构
    private void fixAfterInsertion(Entry<K,V> x) {
        // 新增的节点总是红色的
        x.color = TreeMap.RED;
        // 如果新增节点和其父节点都是红色的,则需要执行旋转
        while (x != null && x != root && x.parent.color == TreeMap.RED) {
            /**
             *    xpp
             *    /
             * [红]xp
             *
             * 父节点在祖父节点的左侧
             */
            if (TreeMap.parentOf(x) == TreeMap.leftOf(TreeMap.parentOf(TreeMap.parentOf(x)))) {
                // 读取父节点的兄弟节点
                final Entry<K,V> y = TreeMap.rightOf(TreeMap.parentOf(TreeMap.parentOf(x)));
                /**
                 *      xpp
                 *    /                      * [红]xp    xpr[红色]
                 *
                 * 祖父节点的右子节点为 红色,则需要执行变色
                 */
                if (TreeMap.colorOf(y) == TreeMap.RED) {
                    // 父节点设置为黑色
                    TreeMap.setColor(TreeMap.parentOf(x), TreeMap.BLACK);
                    // 父节点的兄弟节点设置为黑色
                    TreeMap.setColor(y, TreeMap.BLACK);
                    // 祖父节点设置为红色
                    TreeMap.setColor(TreeMap.parentOf(TreeMap.parentOf(x)), TreeMap.RED);
                    // 处理祖父节点,因为【父节点、叔叔节点、子节点已经是红黑树结构】
                    x = TreeMap.parentOf(TreeMap.parentOf(x));
                    // 2)兄弟节点不存在或为黑色
                } else {
                    /**
                     *    xpp
                     *    /
                     * [红]xp
                     *                         *     x[红]
                     *
                     * x 在父节点的右侧
                     */
                    if (x == TreeMap.rightOf(TreeMap.parentOf(x))) {
                        // 读取父节点
                        x = TreeMap.parentOf(x);
                        // 基于父节点执行左旋
                        rotateLeft(x);
                    }
                    // 设置父节点为黑色
                    TreeMap.setColor(TreeMap.parentOf(x), TreeMap.BLACK);
                    // 设置祖父节点为红色
                    TreeMap.setColor(TreeMap.parentOf(TreeMap.parentOf(x)), TreeMap.RED);
                    // 基于祖父节点执行右旋
                    rotateRight(TreeMap.parentOf(TreeMap.parentOf(x)));
                }
                /**
                 *   xpp
                 *                      *      xp[红]
                 * 
                 * 父节点在祖父节点的右侧
                 */
            } else {
                // 读取祖父节点的左孩子
                final Entry<K,V> y = TreeMap.leftOf(TreeMap.parentOf(TreeMap.parentOf(x)));
                /**
                 *      xpp
                 *    /                      * [红]xp    xpr[红色]
                 *
                 * 祖父节点的左子节点为红色,则需要执行变色
                 */
                if (TreeMap.colorOf(y) == TreeMap.RED) {
                    TreeMap.setColor(TreeMap.parentOf(x), TreeMap.BLACK);
                    TreeMap.setColor(y, TreeMap.BLACK);
                    TreeMap.setColor(TreeMap.parentOf(TreeMap.parentOf(x)), TreeMap.RED);
                    x = TreeMap.parentOf(TreeMap.parentOf(x));
                } else {
                    /**
                     *   xpp
                     *                          *      xp[红]
                     *      /
                     *     x[红]
                     * x 在父节点的左侧    
                     */
                    if (x == TreeMap.leftOf(TreeMap.parentOf(x))) {
                        x = TreeMap.parentOf(x);
                        // 基于父节点执行右旋
                        rotateRight(x);
                    }
                    // 设置父节点为黑色
                    TreeMap.setColor(TreeMap.parentOf(x), TreeMap.BLACK);
                    // 设置祖父节点为红色
                    TreeMap.setColor(TreeMap.parentOf(TreeMap.parentOf(x)), TreeMap.RED);
                    // 基于祖父节点执行左旋
                    rotateLeft(TreeMap.parentOf(TreeMap.parentOf(x)));
                }
            }
        }
        // 根节点永远是黑色
        root.color = TreeMap.BLACK;
    }

    // https://www.cs.usfca.edu/~galles/visualization/RedBlack.html
    private void rotateLeft(Entry<K,V> p) {
        if (p != null) {
            /**
             * 读取中心节点的右侧子节点
             */
            final Entry<K,V> r = p.right;
            /**
             *    20
             *   /               * 10    40[红]
             *      /                *     30    50
             *          /               *       [红]45  60[红]
             *                       *           47[红]
             * 先执行 45、60、50 的变色,之后由于 40、50 都为红色,
             * 执行以 20 为中心的左旋
             */
            p.right = r.left;
            if (r.left != null) {
                r.left.parent = p;
            }
            r.parent = p.parent;
            /**
             * 1[p]
             *               *   2[红]
             *                 *     3[红]
             * 1)以 1 为中心的单节点左旋
             */
            if (p.parent == null) {
                root = r;
                /**
                 *   3
                 *  /
                 * 1[红][p]
                 *                   *   2[红]
                 * 2)以 1 为中心的单节点左旋
                 */
            } else if (p.parent.left == p) {
                // p的右侧子节点上升为p的父节点的左节点
                p.parent.left = r;
                /**
                 *   20
                 *  /                  * 10  30[p]
                 *                       *       40[红]
                 *                         *         50[红]
                 * 以 30 为中心的左旋
                 */
            } else {
                // p的右侧子节点上升为p的父节点的右节点
                p.parent.right = r;
            }
            // p下沉为p右侧子节点的左节点
            r.left = p;
            p.parent = r;
        }
    }

    // https://www.cs.usfca.edu/~galles/visualization/RedBlack.html
    private void rotateRight(Entry<K,V> p) {
        if (p != null) {
            final Entry<K,V> l = p.left;
            p.left = l.right;
            /**
             *
             *        50
             *       /                *  [红]30     60
             *     /               *    10   40
             *    /              *[红]5   20[红]
             *  /
             * 3[红]
             * 插入 3 时,先执行 5、20、10 的变色,变色后 10、30 都为红色,
             * 执行以 50 为中心的右旋
             */
            if (l.right != null) {
                l.right.parent = p;
            }
            l.parent = p.parent;
            /**
             *     3
             *    /
             *   2[红]
             *  /
             * 1[红]
             * 以根节点为中心的单节点右旋
             */
            if (p.parent == null) {
                root = l;
                /**
                 *  1
                 *                    *   3[红]
                 *  /
                 * 2[红]
                 * 以3为中心的单节点右旋
                 */
            } else if (p.parent.right == p) {
                // p的左侧子节点上升为p的父节点的右节点
                p.parent.right = l;
            } else {
                // p的左侧子节点上升为p的父节点的左节点
                p.parent.left = l;
            }
            // p节点下沉为p的左侧子节点的右节点
            l.right = p;
            p.parent = l;
        }
    }

读取元素

    /**
     * 读取指定键关联的值,键不存在或其关联值为 null 时都返回 null。
     */
    @Override
    public V get(Object key) {
        final Entry<K,V> p = getEntry(key);
        return p==null ? null : p.value;
    }

    final Entry<K,V> getEntry(Object key) {
        if (key == null) {
            throw new NullPointerException();
        }
        // 1)使用指定的键比较器进行读取
        if (comparator != null) {
            return getEntryUsingComparator(key);
        }
        @SuppressWarnings("unchecked")
        // 键必须实现 Comparable 接口
        final
        Comparable<? super K> k = (Comparable<? super K>) key;
        // 读取根节点
        Entry<K,V> p = root;
        while (p != null) {
            // 将目标键和当前树节点进行比较
            final int cmp = k.compareTo(p.key);
            if (cmp < 0) {
                // 遍历左子树
                p = p.left;
            } else if (cmp > 0) {
                // 遍历右子树
                p = p.right;
            } else {
                // 找到目标键则直接返回
                return p;
            }
        }
        return null;
    }

    /**
     * 使用键比较器读取目标键
     */
    final Entry<K,V> getEntryUsingComparator(Object key) {
        @SuppressWarnings("unchecked")
        final
        K k = (K) key;
        final Comparator<? super K> cpr = comparator;
        // 读取根节点
        Entry<K,V> p = root;
        while (p != null) {
            // 使用键比较器比较目标键和节点键
            final int cmp = cpr.compare(k, p.key);
            if (cmp < 0) {
                p = p.left;
            } else if (cmp > 0) {
                p = p.right;
            } else {
                return p;
            }
        }
        return null;
    }

读取最小键

    /**
     * 读取键最小的键值对
     */
    public Map.Entry<K,V> firstEntry() {
        return exportEntry(getFirstEntry());
    }

    /**
     * 返回简单的不可变键值对,以防止影响源 TreeMap 中的键值对
     */
    static <K,V> Map.Entry<K,V> exportEntry(TreeMap.Entry<K,V> e) {
        return e == null ? null :
            new AbstractMap.SimpleImmutableEntry<>(e);
    }

    /**
     * 最小键位于左子树的最左侧叶子节点
     */
    final Entry<K,V> getFirstEntry() {
        Entry<K,V> p = root;
        if (p != null) {
            while (p.left != null) {
                p = p.left;
            }
        }
        return p;
    }

读取最大键

    /**
     * 读取最大键关联的键值对
     */
    public Map.Entry<K,V> lastEntry() {
        return exportEntry(getLastEntry());
    }

    /**
     * 最大键位于右子树的最右叶子节点
     */
    final Entry<K,V> getLastEntry() {
        Entry<K,V> p = root;
        if (p != null) {
            while (p.right != null) {
                p = p.right;
            }
        }
        return p;
    }

小于或小于等于指定键的最大键:lowerEntry、floorEntry

    /**
     *  获取小于目标 key 的最大键
     */
    public Map.Entry<K,V> lowerEntry(K key) {
        return exportEntry(getLowerEntry(key));
    }

    /**
     *  返回小于目标 key 的最大键,如果不存在,则返回 null
     */
    final Entry<K,V> getLowerEntry(K key) {
        // 读取根节点
        Entry<K,V> p = root;
        // 树不为空
        while (p != null) {
            // 比较目标键与当前键
            final int cmp = compare(key, p.key);
            // 目标键比较大
            if (cmp > 0) {
                // 读取当前节点的右子节点
                if (p.right != null) {
                    // 往右侧查找
                    p = p.right;
                } else {
                    // 已经没有更大的值,则直接返回
                    return p;
                }
                // 目标键比较小
            } else {
                // 读取当前节点的左子节点
                if (p.left != null) {
                    p = p.left;
                } else {
                    // 目标键小于当前节点并且已经没有左子节点,则需要回溯到父节点
                    Entry<K,V> parent = p.parent;
                    Entry<K,V> ch = p;
                    // 当前节点在父节点的左侧,则一直回溯到其位于父节点的右侧为止
                    while (parent != null && ch == parent.left) {
                        ch = parent;
                        parent = parent.parent;
                    }
                    // 当前节点在父节点的右侧,则直接返回父节点
                    return parent;
                }
            }
        }
        return null;
    }

    /**
     *  获取小于等于目标 key 的最大键
     */
    public Map.Entry<K,V> floorEntry(K key) {
        return exportEntry(getFloorEntry(key));
    }

    /**
     *  返回小于等于目标 key 的最大键,如果不存在,则返回 null
     */
    final Entry<K,V> getFloorEntry(K key) {
        Entry<K,V> p = root;
        while (p != null) {
            final int cmp = compare(key, p.key);
            if (cmp > 0) {
                if (p.right != null) {
                    p = p.right;
                } else {
                    return p;
                }
            } else if (cmp < 0) {
                if (p.left != null) {
                    p = p.left;
                } else {
                    Entry<K,V> parent = p.parent;
                    Entry<K,V> ch = p;
                    while (parent != null && ch == parent.left) {
                        ch = parent;
                        parent = parent.parent;
                    }
                    return parent;
                }
            // 如果目标键和当前节点相等,则直接返回   
            } else {
                return p;
            }
        }
        return null;
    }

大于或大于等于指定键的最小键:higherEntry、ceilingEntry

    /**
     *  读取大于目标键的最小键
     */
    public Map.Entry<K,V> higherEntry(K key) {
        return exportEntry(getHigherEntry(key));
    }

    /**
     *  读取大于目标键的最小键
     */
    final Entry<K,V> getHigherEntry(K key) {
        Entry<K,V> p = root;
        while (p != null) {
            final int cmp = compare(key, p.key);
            // 目标键小于当前节点
            if (cmp < 0) {
                // 往左读取其子节点
                if (p.left != null) {
                    p = p.left;
                } else {
                    // 已经没有左子节点,则直接返回
                    return p;
                }
                // 目标键大于当前节点
            } else {
                // 往右读取其子节点
                if (p.right != null) {
                    p = p.right;
                } else {
                    // 目标键大于当前节点并且已经没有右子节点,则需要回溯到父节点
                    Entry<K,V> parent = p.parent;
                    Entry<K,V> ch = p;
                    // 当前节点在父节点的右侧,则一直回溯到其位于父节点的左侧为止
                    while (parent != null && ch == parent.right) {
                        ch = parent;
                        parent = parent.parent;
                    }
                    // 当前节点在父节点的左侧,则直接返回其父节点
                    return parent;
                }
            }
        }
        return null;
    }

    /**
     *  读取大于等于目标键的最小键
     */
    public Map.Entry<K,V> ceilingEntry(K key) {
        return exportEntry(getCeilingEntry(key));
    }

    /**
     *  读取大于等于目标键的最小键,如果不存在,则返回 null
     */
    final Entry<K,V> getCeilingEntry(K key) {
        Entry<K,V> p = root;
        while (p != null) {
            final int cmp = compare(key, p.key);
            if (cmp < 0) {
                if (p.left != null) {
                    p = p.left;
                } else {
                    return p;
                }
            } else if (cmp > 0) {
                if (p.right != null) {
                    p = p.right;
                } else {
                    Entry<K,V> parent = p.parent;
                    Entry<K,V> ch = p;
                    while (parent != null && ch == parent.right) {
                        ch = parent;
                        parent = parent.parent;
                    }
                    return parent;
                }
            // 目标键和当前节点相等,则直接返回 
            } else {
                return p;
            }
        }
        return null;
    }

替换值

    /**
     *  如果键存在则替换值,并返回旧值,键不存在则返回 null
     * created by ZXD at 29 Nov 2018 T 22:36:47
     * @param key
     * @param value
     * @return
     */
    @Override
    public V replace(K key, V value) {
        final Entry<K,V> p = getEntry(key);
        if (p!=null) {
            final V oldValue = p.value;
            p.value = value;
            return oldValue;
        }
        return null;
    }

    /**
     *  如果目标键存在,并且关联值为 oldValue,则将其替换为 newValue,
     *  替换成功返回 true,否则返回 false。
     * created by ZXD at 29 Nov 2018 T 22:38:11
     * @param key
     * @param oldValue
     * @param newValue
     * @return
     */
    @Override
    public boolean replace(K key, V oldValue, V newValue) {
        final Entry<K,V> p = getEntry(key);
        if (p!=null && Objects.equals(oldValue, p.value)) {
            p.value = newValue;
            return true;
        }
        return false;
    }

    /**
     *  使用函数式接口基于旧键和旧值计算新值,并替换旧值
     * created by ZXD at 29 Nov 2018 T 22:41:51
     * @param function
     */
    @Override
    public void replaceAll(BiFunction<? super K, ? super V, ? extends V> function) {
        Objects.requireNonNull(function);
        final int expectedModCount = modCount;
        // 从小到大遍历每一个节点
        for (Entry<K, V> e = getFirstEntry(); e != null; e = TreeMap.successor(e)) {
            // 使用函数式接口基于旧键和旧值计算新值,并替换旧值
            e.value = function.apply(e.key, e.value);
            // fast-fail 机制保证并发修改时快速失败
            if (expectedModCount != modCount) {
                throw new ConcurrentModificationException();
            }
        }
    }

    /**
     *  读取当前节点的后继节点
     */
    static <K,V> TreeMap.Entry<K,V> successor(Entry<K,V> t) {
        // 1)当前节点为 null
        if (t == null) {
            return null;
        // 2)当前节点存在右侧子节点,则找到右子节点的最左侧子节点,如果没有左子节点,则直接返回  
        } else if (t.right != null) {
            Entry<K,V> p = t.right;
            while (p.left != null) {
                p = p.left;
            }
            return p;
        // 3)当前节点无右子节点,则需要往左上回溯到其父节点    
        } else {
            Entry<K,V> p = t.parent;
            Entry<K,V> ch = t;
            while (p != null && ch == p.right) {
                ch = p;
                p = p.parent;
            }
            // 当前节点是父节点的左子节点,则返回其父节点
            return p;
        }
    }

以上是关于TreeMap 源码分析的主要内容,如果未能解决你的问题,请参考以下文章

JDK1.8源码分析之TreeMap

TreeMap 源码分析

源码分析TreeMap和TreeSet

TreeMap源码分析

TreeMap源码分析,看了都说好

死磕 java集合之TreeMap源码分析