看看不一样的ConcurrentHashMap

Posted 姓chen的大键哥

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了看看不一样的ConcurrentHashMap相关的知识,希望对你有一定的参考价值。

HashMap是Java中常见的数据结构,它结合了数组和链表的特点,查找和增删改操作均十分高效,但HashMap不适合在多线程环境下使用(非线程安全的集合),在多线程下对HashMap进行操作可能出现各种问题:

  • 多线程put的时候可能导致元素丢失
  • JDK 1.7 扩容采用的是“头插法”,在高并发下会出现死循环
  • HashMap 在并发下存在数据覆盖、遍历的同时进行修改会抛出 ConcurrentModificationException 异常等问题

HashMap虽好,但不适合在多线程环境下使用,那有线程安全的HashMap吗,答案是有的,JDK为我们提供了三种线程安全的HashMap:

  1. HashTable
  2. synchronizedMap
  3. ConcurrentHashMap

synchronizedMap是通过Collections.synchronizedMap()方法得到的,它将所有Map操作通过synchronized块进行修饰,与HashTable类似(HashTable是在所有方法加上synchronized关键字),这两个Map在多线程环境下是线程安全的,但他们的并发性能很差,同一时刻只能有一个线程进行操作,效率太低。
为了解决HashMap线程不安全以及synchronizedMap和HashTable并发效率低下的问题,Doug Lea大师为我们准备了兼具高效和安全的HashMap --> ConcurrentHashMap
本文就来讲讲ConcurrentHashMap是如何做到高效和安全的,由于ConcurrentHashMap在JDK1.7 和JDK1.8的实现不同,本文就分别介绍这两个版本的实现原理
我要开始装逼了

JDK1.7的ConcurrentHashMap

内部结构

ConcurrentHashMap内部是由Segments数组结构和HashEntry数组结构组成,Segment是一种可重入锁(继承自ReentrantLock),在ConcurrentHashMap扮演锁的角色;HashEntry则是真正用于存储数据的数据结构。一个ConcurrentHashMap中包含一个Segments数组,一个Segment中包含一个HashEntry数组,HashEntry是一个链表结构,所以Segment是一个散列表的结构(与HashMap类似)。ConcurrentHashMap的结构如图所示:
ConcurrentHashMap结构
Segments数组实现了分段锁,对ConcurrentHashMap进行访问时需要获取对应Segment的锁,这样多线程在访问容器不同数据段中的数据时就不会互相影响,线程之间锁竞争就会大大减小,从而提高了并发效率,同时也能保证安全访问数据

Segment相关代码如下

	/**
	 * 段是哈希表的专用版本。该子类是ReentrantLock的子类,为了简化一些锁并避免单独构造。
	 */
	static final class Segment<K,V> extends ReentrantLock implements Serializable {

        private static final long serialVersionUID = 2249069246763182397L;

        /**
         * 在可能阻塞获取以准备锁定的段操作之前,尝试在预扫描中尝试锁定的最大次数。
         * 在多处理器上,使用有限数量的重试可以维护在定位节点时获取的缓存。
         */
        static final int MAX_SCAN_RETRIES =
            Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;

        /**
         * 每个段。元素是通过entryAtsetEntryAt访问的,提供可变的语义。
         */
        transient volatile HashEntry<K,V>[] table;

        /**
         * 元素数。
         * 仅在锁内或其他保持可见性的易失性读取中访问。
         */
        transient int count;

        /**
         * HashEntry的操作总数。即使这可能溢出32位,它也为CHM isEmpty()和size()方法中的稳定性检查提供了足够的准确性。
         * 仅在锁内或其他保持可见性的易失性读取中访问。
         */
        transient int modCount;

        /**
         * 当表的大小超过此阈值时,将对其进行扩容并重新哈希处理。 
         * 此字段的值始终为(int)(capacity * loadFactor)
         */
        transient int threshold;

        /**
         * 哈希表的负载因子。即使所有段的该值都相同,也将复制该值以避免需要链接到外部对象。
         */
        final float loadFactor;

        Segment(float lf, int threshold, HashEntry<K,V>[] tab) {
            this.loadFactor = lf;
            this.threshold = threshold;
            this.table = tab;
        }
	}

HashEntry的相关代码如下:

	/**
     * ConcurrentHashMap列表条目。它永远不会导出为用户可见的Map.Entry
     */
    static final class HashEntry<K,V> {
        final int hash;
        final K key;
        volatile V value;
        volatile HashEntry<K,V> next;

        HashEntry(int hash, K key, V value, HashEntry<K,V> next) {
            this.hash = hash;
            this.key = key;
            this.value = value;
            this.next = next;
        }

        /**
         * 使用易失性写语义设置下一个字段。
         */
        final void setNext(HashEntry<K,V> n) {
            UNSAFE.putOrderedObject(this, nextOffset, n);
        }

        // Unsafe mechanics
        static final sun.misc.Unsafe UNSAFE;
        static final long nextOffset;
        static {
            try {
                UNSAFE = sun.misc.Unsafe.getUnsafe();
                Class k = HashEntry.class;
                nextOffset = UNSAFE.objectFieldOffset
                    (k.getDeclaredField("next"));
            } catch (Exception e) {
                throw new Error(e);
            }
        }
    }

构造函数

先来看看几个重要的参数:

    /**
     * 默认初始容量,在没有在构造函数中另外指定时使用
     */
    static final int DEFAULT_INITIAL_CAPACITY = 16;

    /**
     * 默认加载因子,在没有在构造函数中另外指定时使用
     */
    static final float DEFAULT_LOAD_FACTOR = 0.75f;

    /**
     * 默认并发级别,在没有在构造函数中另外指定时使用。
     */
    static final int DEFAULT_CONCURRENCY_LEVEL = 16;

    /**
     * 最大容量,如果两个构造函数都使用参数隐式指定了更高的值,则使用该容量。
     * 必须是2的幂且小于等于 1 << 30,以确保条目可以使用int进行索引
     */
    static final int MAXIMUM_CAPACITY = 1 << 30;

    /**
     * 每段表的最小容量。必须为2的幂,至少为2的幂,以免在延迟构造后立即调整下次使用时的大小。
     */
    static final int MIN_SEGMENT_TABLE_CAPACITY = 2;

    /**
     * 允许的最大段数;用于绑定构造函数参数。必须是小于1 << 24的2的幂。
     */
    static final int MAX_SEGMENTS = 1 << 16; // slightly conservative

    /**
     * 在锁定整个表之前,size和containsValue()方法的不同步重试次数。
     * 如果表进行连续修改,这将用于避免无限制的重试,这将导致无法获得准确的结果。
     */
    static final int RETRIES_BEFORE_LOCK = 2;
    
    /**
     * 用于编入段的掩码值。密钥的哈希码的高位用于选择段。
     */
    final int segmentMask;

    /**
     * 段内索引的移位值。
     */
    final int segmentShift;

    /**
     * 段,每个段都是一个专用的哈希表
     */
    final Segment<K,V>[] segments;

上面的参数知道大概就行,在接下来的代码中就能理解这些参数的作用的,接下来看看构造方法:

	/**
     * 使用指定的初始容量,负载因子和并发级别创建一个新的空映射。
     *
     * @param 初始容量。该实现执行内部大小调整以容纳许多元素。
     * 
     * @param loadFactor  负载系数阈值,用于控制调整大小。
     * 当每个仓的平均元素数超过此阈值时,可以执行大小调整。
     * 
     * @param concurrencyLevel 估计的并发更新线程数。该实现执行内部大小调整以尝试容纳这么多线程。
     * 
     * @throws IllegalArgumentException 如果初始容量为负,或者负载因子或concurrencyLevel为非正数。
     */
    @SuppressWarnings("unchecked")
    public ConcurrentHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) {
    	//对非法输入进行处理
        if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
            throw new IllegalArgumentException();
        // 若并发线程数大于最大段数,则等于最大段数    
        if (concurrencyLevel > MAX_SEGMENTS)
            concurrencyLevel = MAX_SEGMENTS;
        // 为保证能通过位与运算的散列算法来定位segments数组索引,要保证数组长度为2的幂,查找最适合参数的二乘幂
        int sshift = 0;
        int ssize = 1;
        while (ssize < concurrencyLevel) {
            ++sshift;
            ssize <<= 1;
        }
        this.segmentShift = 32 - sshift;
        this.segmentMask = ssize - 1;
        
        if (initialCapacity > MAXIMUM_CAPACITY)
            initialCapacity = MAXIMUM_CAPACITY;
        //初始化每个segment中的HashEntry长度
        int c = initialCapacity / ssize;
        //如果c大于1,cap会取大于等于c的2次方,所以cap要么等于1要么等于2的幂次方
        if (c * ssize < initialCapacity)
            ++c;
        int cap = MIN_SEGMENT_TABLE_CAPACITY;
        while (cap < c)
            cap <<= 1;
        // 创建segments数组,并初始化segments[0]
        Segment<K,V> s0 = new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                             (HashEntry<K,V>[])new HashEntry[cap]);
        Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
        UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
        this.segments = ss;
    }
    
    public ConcurrentHashMap(int initialCapacity, float loadFactor) {
        this(initialCapacity, loadFactor, DEFAULT_CONCURRENCY_LEVEL);
    }

    public ConcurrentHashMap(int initialCapacity) {
        this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
    }

    public ConcurrentHashMap() {
        this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
    }

    public ConcurrentHashMap(Map<? extends K, ? extends V> m) {
        this(Math.max((int) (m.size() / DEFAULT_LOAD_FACTOR) + 1, DEFAULT_INITIAL_CAPACITY), DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
        putAll(m);
    }

由上述代码可以看到ConcurrentHashMap的构造方法最终会调用到同一个方法,只要集中看这个方法即可。
ConcurrentHashMap初始化的时候做了以下几个事情:

  1. 处理非法输入
  2. 计算segmentShiftsegmentMask
  3. 计算每个Segments中HashEntry的容量cap
  4. 初始化segments数组同时初始化Segment,并加入到segments[0]的位置上

segments数组的大小ssize是通过concurrencyLevel来计算的,为保证能通过位与运算的散列算法来定位segments数组索引,要保证数组长度为2的幂,需要计算一个大于或等于concurrencyLevel的最小2的N次方值来作为数组的长度。
segmentShift用于定位参与散列运算的位数,segmentShift等于32 - sshift,使用32是因为ConcurrentHashMap的hash() 方法输出值最大是32位的。
segmentMask是散列运算的掩码,segmentMask记录的是ssize - 1的值,ssize是2的幂次方,所以segmentMask每个二进制位都是1.
这两个参数有些难理解,后面讲到get() 方法时就明白这两个参数的含义了。

get()方法

话不多说,先上代码

public V get(Object key) {
        Segment<K,V> s; // manually integrate access methods to reduce overhead
        HashEntry<K,V>[] tab;
        //将key的hashcode进行再散列,减少hash冲突
        int h = hash(key);
        //散列算法,定位元素在segments数组的位置
        long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
        //Segment不为空且Segment内部的HashEntry不为空,则继续查找
        if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null && (tab = s.table) != null) {
        	//散列算法定位HashEntry,遍历HashEntry,直到找到对应key,没有则退出循环
            for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
                 e != null; e = e.next) {
                K k;
                if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                    return e.value;
            }
        }
        return null;
    }

get() 方法主要做了几件事:

  1. 通过散列计算获取元素在segments数组的位置
  2. 获取对应的Segment,若不为空则继续则通过散列算法获取Segment内部的HashEntry数组对应的HashEntry
  3. 若HashEntry不为空则遍历HashEntry找到key对应的值

get() 操作之所以高效是因为ConcurrentHashMap中的共享变量都被定义成volatile类型,从Segment和HashEntry的成员变量中可以看出,这样做能保证所有线程都能看到最新的值,根据Java内存模型的happen before 原则,对volatile类型的写入操作先于读取操作,即两个线程同时修改和获取volatile变量,get操作也能得到最新值,这样get() 无需加锁,进而提高了并发效率。

get() 方法并发高效的原因知道了,那ConcurrentHashMap是如何定位元素位置的呢?get() 方法中出现了hash() 方法,先看一下hash() 方法

	private int hash(Object k) {
        int h = hashSeed;

        if ((0 != h) && (k instanceof String)) {
            return sun.misc.Hashing.stringHash32((String) k);
        }

        h ^= k.hashCode();

        // 使用单字Wang/Jenkins哈希的变体来扩展位,以规范化段和索引位置。
        h += (h <<  15) ^ 0xffffcd7d;
        h ^= (h >>> 10);
        h += (h <<   3);
        h ^= (h >>>  6);
        h += (h <<   2) + (h << 14);
        return h ^ (h >>> 16);
    }

这个方法官方的解释是应用于hashcode上,使hashcode的高位和低位充分参与散列算法,减少散列冲突。这么做可以使元素均匀分布在不同的Segment中从而提高存取效率。如果散列的质量差到极点,那所有的元素都会位于同一个Segment中,分段锁的意义就没了,而且并发效率会大大降低。
让hashcode充分散列之后就可以使用散列算法定位元素的位置了,定位HashEntry和定位Segment的算法是一样,但有些细微的差别

//定位Segment
(h >>> segmentShift) & segmentMask

//定位HashEntry
(tab.length - 1) & h

上面说过segmentMask记录的是ssize-1的值,即segments长度-1的值,segmentShift是散列值向右偏移的位数,即实际是向右偏移,让散列值的高位参与运算;定位Segment是用再散列的值的高位进行运算,而定位HashEntry则是用再散列的值直接进行运算,这么做的目的是避免两次散列后的值一样,使元素在Segment散列开了,而没有在HashEntry内部散列开,进而增加了冲突的可能。

看完了get() 方法,接下来看看put() 方法,看看大师的手法o( ̄▽ ̄)o

put()方法

废话不多说,还是先看代码

	public V put(K key, V value) {
        Segment<K,V> s;
        if (value == null)
            throw new NullPointerException();
        int hash = hash(key);
        int j = (hash >>> segmentShift) & segmentMask;
        //易失性,在ensureSegment重新检查
        if ((s = (Segment<K,V>)UNSAFE.getObject(segments, (j << SSHIFT) + SBASE)) == null) 
            s = ensureSegment(j);
        return s.put(key, hash, value, false);
    }

    private Segment<K,V> ensureSegment(int k) {
        final Segment<K,V>[] ss = this.segments;
        long u = (k << SSHIFT) + SBASE; // raw offset
        Segment<K,V> seg;
        //Segment不存在,新建一个
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
        	//使用segments[0]的参数创建新的Segment
            Segment<K,V> proto = ss[0]; // use segment 0 as prototype
            int cap = proto.table.length;
            float lf = proto.loadFactor;
            int threshold = (int)(cap * lf);
            HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
            //二次检查,防止其他线程先创建了Segment,而覆盖其创建的Segment
            if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { 
            	// recheck
                Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
                while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
                	//通过CAS将新建的Segment加到segments数组中
                    if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                        break;
                }
            }
        }
        return seg;
    }

put() 方法做了几件事:

  1. 通过散列算法获取key对应的Segment的下标
  2. 获取该Segment,若Segment不存在则调用ensureSegment() 方法重新确认 (创建Segment在这个方法中,Segment的参数从Segment[0]获取)
  3. 调用Segment的put() 将元素加入到Segment中

Segment.put()

put() 方法较为简单(其他的可以看看注释),详细的接下来看Segment中put() 方法,这个方法才是核心

	/**
	 * ConcurrentHashMap.Segment
	 */
    final V put(K key, int hash, V value, boolean onlyIfAbsent) {
    	// 尝试锁定,若无法获取锁,则先扫描Segment并不断尝试获取锁直到成功
        HashEntry<K,V> node = tryLock() ? null : scanAndLockForPut(key, hash, value);
        V oldValue;
        try {
            HashEntry<K,V>[] tab = table;
            int index = (tab.length - 1) & hash;
            // 查找key对应的HashEntry
            HashEntry<K,V> first = entryAt(tab, index);
            // 遍历查找key对应的元素所在位置
            for (HashEntry<K,V> e = first;;) {
            	// 如果查找的HashEntry不为空,即存在hash冲突,则先比较key是否相同,
            	// 相同则根据onlyIfAbsent决定是否要更新,不相同则将指针指向next,继续查找,
            	// 直到找到相同的key或是遍历到链表最后一个节点
                if (e != null) {
                    K k;
                    if ((k = e.key) == key ||
               

以上是关于看看不一样的ConcurrentHashMap的主要内容,如果未能解决你的问题,请参考以下文章

看看不一样的ConcurrentHashMap

看看不一样的ConcurrentHashMap

看看不一样的ConcurrentHashMap

java并发编程之ConcurrentHashMap

Hashtable与ConcurrentHashMap入门——简单使用

JDK源码那些事儿之并发ConcurrentHashMap上篇