ThreadLocal 源码
Posted codingbug
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了ThreadLocal 源码相关的知识,希望对你有一定的参考价值。
ThreadLocal 源码分析
* ThreadLocal的内存泄露问题,使用完后,执行remove操作
* 在开放定址算法中,线性探测法是散列解决冲突的一种方法,当hash一个关键字的时候,发现没有冲突,
就保存关键字,如果有冲突,就探测冲突地址的下一个地址,如此循环,知道有空地址为止,从而解决冲突
package java.lang;
public class ThreadLocal<T> {
// 每当创建ThreadLocal是此值增加0x61c88647,
// 是为了能让哈希码均匀的分布在2的n次方的数组里
private final int threadLocalHashCode = nextHashCode();
private static AtomicInteger nextHashCode =
new AtomicInteger();
private static final int HASH_INCREMENT = 0x61c88647;
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
// 初始值为null
protected T initialValue() {
return null;
}
// 用java8的语法包装了下,有参构造器
public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new SuppliedThreadLocal<>(supplier);
}
// 无参构造器
public ThreadLocal() {
}
// 取值
public T get() {
// 获取当前线程
Thread t = Thread.currentThread();
// 获取ThreadLocal的内部map,ThreadLocal的值是存在内部map中的
ThreadLocalMap map = getMap(t);
if (map != null) {
// ThreadLocalMap内部是一个entry[]数组
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
// 如果没有map把当前线程对象加入map中
return setInitialValue();
}
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
// 没有map创建一个map
createMap(t, value);
return value;
}
// 往ThreadLocal里加值
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}
// 此方法在InheritableThreadLocal有被重写
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
return new ThreadLocalMap(parentMap);
}
T childValue(T parentValue) {
throw new UnsupportedOperationException();
}
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
private final Supplier<? extends T> supplier;
SuppliedThreadLocal(Supplier<? extends T> supplier) {
this.supplier = Objects.requireNonNull(supplier);
}
@Override
protected T initialValue() {
return supplier.get();
}
}
static class ThreadLocalMap {
static class Entry extends WeakReference<ThreadLocal<?>> {
Object value;
// k是一个ThreadLocal对象
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
// 初始容量16,必须是2的n次方
private static final int INITIAL_CAPACITY = 16;
// Entry数组,大小必须是2的n次方
private Entry[] table;
// entry在数组中的下标
private int size = 0;
// 扩容时的阈值
private int threshold; // Default to 0
// 扩容时的阈值设为总长度的2/3
private void setThreshold(int len) {
threshold = len * 2 / 3;
}
// 下一个索引
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}
// 上一个索引
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}
// cap是2^n,所以cap-1的二进制就是低位连续的n个1,
// threadLocalHashCode & (INITIAL_CAPACITY - 1)的值就是threadLocalHashCode的低n位
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
// 通过散列哈希确认存放的小标,此散列能保存map的值均匀的分布在2^n的数组里
// 因此数组的cap也必须是2^n
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}
// 从父ThreadMap中copy一份到子map
private ThreadLocalMap(ThreadLocalMap parentMap) {
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
setThreshold(len);
table = new Entry[len];
for (int j = 0; j < len; j++) {
Entry e = parentTable[j];
if (e != null) {
@SuppressWarnings("unchecked")
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
Object value = key.childValue(e.value);
Entry c = new Entry(key, value);
int h = key.threadLocalHashCode & (len - 1);
while (table[h] != null)
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}
// 取值
private Entry getEntry(ThreadLocal<?> key) {
// 首先通过散列确定值在table中的下标
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
// 异常情况,通过entry没获取到值
return getEntryAfterMiss(key, i, e);
}
// Miss entry单独处理
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
// entry不为null
while (e != null) {
ThreadLocal<?> k = e.get();
// 存在
if (k == key)
return e;
// 说明当前threadLocal已被回收,对应的entry应该被清除
if (k == null)
// 具体清除逻辑
expungeStaleEntry(i);
else
// 开放寻址法,也叫线性探测,闭散列,entry[]逻辑上是一个环形
// 找到下一个数组索引取出entry循环遍历
i = nextIndex(i, len);
e = tab[i];
}
// entry都为null了,当然值为null
return null;
}
// 设值
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
// 获取当前entry
for (Entry e = tab[i];
e != null;
// 不为null继续获取下个entry
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
// 如果已有key并和要设置的key相等,覆盖value退出
if (k == key) {
e.value = value;
return;
}
// 如果当前坐标系的k为null,说明该ThreadLocal已被回收,调用替换方法
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
tab[i] = new Entry(key, value);
int sz = ++size;
// 如果没有清除数据但是size大于阈值,通过重新hash、扩容来清除非法数据
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
// 删除值
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
// 清除软引用
e.clear();
// 清除entry
expungeStaleEntry(i);
return;
}
}
}
//
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
// 向前遍历,如果已被gc回收,slotToExpunge标记为当前位置
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;
// 向后遍历查找是否已有要存的这个元素
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// 如果有,替换value,更新下标
if (k == key) {
e.value = value;
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
if (slotToExpunge == staleSlot)
slotToExpunge = i;
// 向前继续清除null元素
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
// k为null,说明当前位置就是需要开始清理的位置
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
// If key not found, put new entry in stale slot
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
// 不等,说明有还有被gc过的元素,继续清除
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
// 通过下标清理
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
// 当前下标的entry value置空,entry指控,size -1
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;
// 重新hash
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
// 为null的置空同上
e.value = null;
tab[i] = null;
size--;
} else {
int h = k.threadLocalHashCode & (len - 1);
// 如果h不是当前位置i,则往后遍历找到空节点,把当前entry索引过去
if (h != i) {
tab[i] = null;
// Unlike Knuth 6.4 Algorithm R, we must scan until
// null because multiple entries could have been stale.
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
// 从i开始往后清理无效entry,n为扫描次数
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
if (e != null && e.get() == null) {
// 发现无效数据就扩大清理范围
n = len;
removed = true;
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}
// 重新hash散列
private void rehash() {
// 先清理一遍无效数据
expungeStaleEntries();
// threshold = 2/3 len, - threshold/4 = len/2
if (size >= threshold - threshold / 4)
// 所以是超过len/2,就扩容
resize();
}
// 扩容,双倍
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
// 双倍
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
// 二次清理
e.value = null; // Help the GC
} else {
int h = k.threadLocalHashCode & (newLen - 1);
// 如果新表不为空,往后寻址为空的插入
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}
// 设置阈值为之前的两倍
setThreshold(newLen);
size = count;
table = newTab;
}
// 全局清除无效数据
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
// 满足条件
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}
}
}
以上是关于ThreadLocal 源码的主要内容,如果未能解决你的问题,请参考以下文章