ThreadLocal原理和用法

ThreadLocal.set()方法源码如下:

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        map.set(this, value);
    } else {
        createMap(t, value);
    }
}

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

可以看到ThreadLocal的set是通过Thread类内的ThreadLocalMap来实现的,接着看下ThreadLocalMap的源码:

static class ThreadLocalMap {

    /**
     * The entries in this hash map extend WeakReference, using
     * its main ref field as the key (which is always a
     * ThreadLocal object).  Note that null keys (i.e. entry.get()
     * == null) mean that the key is no longer referenced, so the
     * entry can be expunged from table.  Such entries are referred to
     * as "stale entries" in the code that follows.
     */
    static class Entry extends WeakReference<ThreadLocal<?>> {
        /** The value associated with this ThreadLocal. */
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }

    /**
     * The initial capacity -- MUST be a power of two.
     */
    private static final int INITIAL_CAPACITY = 16;

    /**
     * The table, resized as necessary.
     * table.length MUST always be a power of two.
     */
    private Entry[] table;

    /**
     * The number of entries in the table.
     */
    private int size = 0;
    
    
    ...
    
    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
        table = new Entry[INITIAL_CAPACITY];
        int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
        table[i] = new Entry(firstKey, firstValue);
        size = 1;
        setThreshold(INITIAL_CAPACITY);
    }
    
    private void set(ThreadLocal<?> key, Object value) {

        // We don't use a fast path as with get() because it is at
        // least as common to use set() to create new entries as
        // it is to replace existing ones, in which case, a fast
        // path would fail more often than not.

        Entry[] tab = table;
        int len = tab.length;
        int i = key.threadLocalHashCode & (len-1);

        for (Entry e = tab[i];
             e != null;
             e = tab[i = nextIndex(i, len)]) {
            ThreadLocal<?> k = e.get();

            if (k == key) {
                e.value = value;
                return;
            }

            if (k == null) {
                replaceStaleEntry(key, value, i);
                return;
            }
        }

        tab[i] = new Entry(key, value);
        int sz = ++size;
        if (!cleanSomeSlots(i, sz) && sz >= threshold)
            rehash();
    }
    
}

可以发现,ThreadLocalMap中维护了一个名为table的Entry数组,通过构造函数和set()方法可知,table初始大小为16,通过ThreadLocal的HashCode和table的大小-1相与,得出此ThreadLocal在table中存放的位置,然后将ThreadLocal实例对象和值封装进Entry,放入table中,在get()时,通过ThreadLocal实例可计算出其在table中的位置,从而找到对应的Entry,就得到了值。因为每个Thread中有各自的ThreadLocalMap,而ThreadLocal是用一个实例,ThreadLocal在每一个线程内的ThreadLocalMap中的table的存储位置计算方法一致,所以不同线程在不同的ThreadLocalMap中的存储位置是一致的,那么在每一个线程中通过get()获取到的值便是该线程私有的,就不会有线程安全问题。

通过以下代码做下测试:

public void test() {
    ThreadLocal<Integer> count = new ThreadLocal<>();
    for (int i = 0; i < 5; i++) {
        int a = i;
        new Thread(() -> {
            System.out.println(Thread.currentThread().getName() + "-----" + count.get());
            count.set(a);
            System.out.println(Thread.currentThread().getName() + "-----" + count.get());
        }).start();
    }
}

结果如下:

Thread-1-----null
Thread-3-----null
Thread-3-----3
Thread-0-----null
Thread-0-----0
Thread-2-----null
Thread-4-----null
Thread-1-----1
Thread-4-----4
Thread-2-----2

可以看出,每个线程在set前get到的都是null,set之后get到的都是本线程set进去的值,各线程之间互不影响,性能要比使用同步方式解决一定情况下的线程安全问题要好。

ThreadLocal