ThreadLocal/InheritableThreadLocal 设计与源码分析

ThreadLocal 提供了线程本地的变量,每个线程只能通过 get/set 方法访问自己的变量。此类的实例通常声明为类的 private static 属性、用来把状态(比如事务ID)关联到线程上。

InheritableThreadLocal 扩展了 ThreadLocal,为子线程提供从父线程那里继承的值:在创建子线程时,子线程会接收所有可继承的线程局部变量的初始值,以获得父线程所具有的值。

1. 实现思路

如果自行实现一个 ThreadLocal,直接思路可能是:ThreadLocal 内维护一个 Map,以 线程对象为 key,value 为变量。

这个思路的问题有:
1. 当线程终止、JVM 进行垃圾回收时,这个 Map 还持有对线程的引用而没法回收线程的资源;如果 JVM 要能回收,那么必须知道有多少 ThreadLocal 实例持有对线程的引用,这会给 JVM 带来负担。
2. 为了实现 InheritableThreadLocal 时,在创建时还必须找出所有的 InheritableThreadLocal,判断父线程是否有设置变量,有的则进行拷贝变量。

从上述问题来看,实现线程本地变量至少应该考虑:
1. 线程本地变量不应该直接持有对 Thread 对象的引用,避免给 JVM 回收 Thread 带来额外的开销;
2. 为实现 InheritableThreadLocal,一个线程在哪些 InheritableThreadLocal 里设置了变量应该有个集中式的存储,这样才方便把父线程的可继承本地变量拷贝到子线程的。
3. 可能有多个线程同时对 ThreadLocal 进行设置变量,那么对 Map 的访问应当是线程安全的。

再来看下线程本地变量涉及哪些参与者:ThreadLocal 、Thread、变量值。

一个 ThreadLocal 可以持有多个 Thread 的变量,一个 Thread 也可以在多个 ThreadLocal 上设置变量。因此一个 (ThreadLocal, Thread) 的组合才能唯一确定一个线程本地变量值。

Map 只能放在 (ThreadLocal, Thread) 中的一个,前面也说了放在 ThreadLocal 上是不合适的。再来看看放在 Thread 上如何。

每个 Thread 的 Map 属性的 key 是 ThreadLocal 对象,value 是变量值,看来也能实现线程本地变量。

这样反转之后,ThreadLocal 不会持有线程的引用,线程回收不存在问题,线程的 Map 也可以在线程回收时进行回收,Map 里面保存的变量值也可以进行回收。

可继承的线程本地变量可以用另一个 Map 来维护,起到了集中存储的作用。

每个线程都只访问自己的 Map,自然没有并发的竞争。

完美!JDK 从 1.3 开始就是按这个思路去实现的。

2. JDK 里的实现

下面的代码是 JDK 1.8.0_40 的。

ThreadLocalMap 是一个定制的、只适合维护线程本地变量的 ThreadLocal 的内部类。

Thread 类声明了两个 ThreadLocal.ThreadLocalMap 类型的变量 threadLocalsinheritableThreadLocals ,分别用于保存线程本地变量和可继承的线程本地变量。

这样实现 InheritableThreadLocal 就非常简单,在初始化子线程时用父线程的 inheritableThreadLocals 初始子线程的 inheritableThreadLocals 就可以了。

public class Thread implements Runnable {
    // 包级访问级别
    ThreadLocal.ThreadLocalMap threadLocals = null;

    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

    private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc) {
        // 省略了其他初始化代码
        Thread parent = currentThread();

        if (parent.inheritableThreadLocals != null)
            this.inheritableThreadLocals =
                ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);

        // 省略了其他初始化代码
    }
}


public class ThreadLocal<T> {
    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

     public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
     }

    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
}

public class InheritableThreadLocal<T> extends ThreadLocal<T> {
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

InheritableThreadLocal 只需重写了 ThreadLocal 的 getMap(Thread) 方法就可以。

3. ThreadLocalMap

3.1 数据模型

ThreadLocalMap 是定制的、只适用于线程本地变量的实现。

ThreadLocalMap 以数组来存放所有的键值对 Entry,数组的长度必须是 2 的 N 次方,这样计算某个哈希值的存放位置时可以用位操作来取代求模%运算,有助于提升性能。扩容的时候按原来的 2 倍进行扩容。

数组里存放的 Entry 超过一定的阈值时就会进行 resize,以保证数组里总是有空闲的位置可以存放新的 Entry 。

对于哈希冲突,采用开放地址法来处理,从预期位置开始向后线性探测,到了数组末尾则从下标 0 继续,直至找到目标 Entry 或可存放的位置。

采用的哈希函数是斐波那契(Fibonacci)散列法,该散列结合前面的位操作能在减少哈希冲突上有很好的效果。具体可以网上找找资料。

相关定义如下:

public class ThreadLocal<T> {
    // 取下一个 斐波那契 散列值
    private final int threadLocalHashCode = nextHashCode();

    // 分配下一个 hash code 的原子整型变量
    private static AtomicInteger nextHashCode = new AtomicInteger();

    // 斐波那契(Fibonacci)散列法在 32 位的有符号乘数
    private static final int HASH_INCREMENT = 0x61c88647;

    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

    // 省略其他代码
}


static class ThreadLocalMap {
    static class Entry extends WeakReference<ThreadLocal<?>> {
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }

    // 初始化容量 
    private static final int INITIAL_CAPACITY = 16;

    // 存放条目的数组,其长度必须是 2 的N次方。在需要时进行 resize
    private Entry[] table;

    // 表里条目的数量
    private int size = 0;

    // The next size value at which to resize.
    private int threshold; // Default to 0

    // 根据长度设置阈值
    private void setThreshold(int len) {
        threshold = len * 2 / 3;
    }
}

为啥 Entry 要继承自 WeakReference ?

因为当 ThreadLocal 不再被引用、可以被垃圾回收时,如果 Thread 的 ThreadLocalMap 属性仍然持有对 ThreadLocal 的强引用,那么就会导致 ThreadLocal 对象的内存泄漏。如果用 WeakReference 来引用 ThreadLocal,那么 ThreadLocal 被回收时,WeakReference 里的引用被设置为 null 。这时也表示这个 Entry 可垃圾回收。ThreadLocalMap 可以在特定的时候回收这个 Entry,防止其产生内存泄漏。

3.2 读取

读取操作的一个优化是首先尝试快速路径访问,也就是根据 ThreadLocal 的哈希值 threadLocalHashCode 计算预期的存放位置,如果命中则直接返回,否则需要进行线性遍历。

如果能直接命中,说明 put 操作时没有冲突。

private Entry getEntry(ThreadLocal<?> key) {
    // 首先尝试快速访问,如果冲突不多,那么效率会很高。
    int i = key.threadLocalHashCode & (table.length - 1); // 相当于 mod 
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        return getEntryAfterMiss(key, i, e);
        // 在预期位置没有命中,说明有冲突,需要遍历来查找
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            return e;
        if (k == null)
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

3.3 put

private void set(ThreadLocal<?> key, Object value) {
    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.
    // 此方法不像 get 方法那样使用快速路径是因为:
    // 创建一个新的 Entry 与替换一个已存在的是一样普遍的。
    // 在这样的情况下,快速路径更容易失败。

    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);

    // 总是存在 tab[n] == null
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            // 命中,覆盖旧值
            e.value = value;
            return;
        }

        if (k == null) {
            // 碰到一个可回收的 Entry,尝试回收可回收的Entry并重用这个位置
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    // i 位置没有存放 Entry
    tab[i] = new Entry(key, value);
    int sz = ++size;

    // 判等需要需要 rehash,保证总是存在 table[n] == null
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    // Back up to check for prior stale entry in current run.
    // We clean out whole runs at a time to avoid continual
    // incremental rehashing due to garbage collector freeing
    // up refs in bunches (i.e., whenever the collector runs).
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    // Find either the key or trailing null slot of run, whichever
    // occurs first
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        // If we find key, then we need to swap it
        // with the stale entry to maintain hash table order.
        // The newly stale slot, or any other stale slot
        // encountered above it, can then be sent to expungeStaleEntry
        // to remove or rehash all of the other entries in run.
        if (k == key) {
            e.value = value;

            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // Start expunge at preceding stale entry if it exists
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // If we didn't find stale entry on backward scan, the
        // first stale entry seen while scanning for key is the
        // first still present in the run.
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // If key not found, put new entry in stale slot
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // If there are any other stale entries in run, expunge them
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

3.4 rehash/resize

首先尝试回收所有可回收的 Entry,回收后 entry 数量仍然高于阈值则进行 resize 。

private void rehash() {
    expungeStaleEntries();

    // Use lower threshold for doubling to avoid hysteresis
    // 
    if (size >= threshold - threshold / 4)
        resize();
}

private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }

    setThreshold(newLen);
    size = count;
    table = newTab;
}

3.5 回收 entry

// 回收所有的可回收 entry
private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null)
            expungeStaleEntry(j);
    }
}

// 从 staleSlot 位置开始回收 entry,
// 直到遇到 tab[i] == null 时结束
private int expungeStaleEntry(int staleSlot ) {
    Entry[] tab = table;
    int len = tab.length;

    // expunge entry at staleSlot
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // Rehash until we encounter null
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            // 找到了一个可回收的entry
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                // k 不在它的最佳位置,把它往最佳位置靠近,
                // 减少后续查找它时需要遍历的位置,提升效率
                tab[i] = null;

                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

4. 总结

4.1 ThreadLocal 小结

  1. 如果在 ThreadLocal 里维护线程到变量的映射 Map,那么:这个 Map 需要提供线程安全性访问能力;为了防止线程终止时没法被回收,这个 Map 应该把线程封装在 WeakReference 里。
  2. 为了方便实现 InheritableThreadLocal 语义,一个线程在哪些 InheritableThreadLocal 里设置了变量应该有个集中存放的地方。
  3. 综上2点,JDK 采用在线程里维护 ThreadLocal 到变量映射的两个 Map,分别用于存放 ThreadLocalInheritableThreadLocal 到变量的映射。这两个 Map 是线程私有的,天然就具有线程安全性。

4.2 ThreadLocalMap 小结

ThreadLocalMap 是定制的、只适用于 ThreadLocal 的一个 Map 实现。
采用 数组、斐波那契散列法生成 ThreadLocal 的哈希值再对数组求模作为哈希函数、开放地址法解决哈希冲突。

Entry 继承 WeakReference,把 key ThreadLocal 封装在 WeakReference 的里,弱引用被设置为空时表示该 entry 可回收,称为 stale entry。

对该 map 实现进行 put 操作时会触发回收 stale entry 。


欢迎关注我的微信公众号: coderbee笔记,可以更及时回复你的讨论。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据