ThreadLocal.set()方法源码如下:
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
map.set(this, value);
} else {
createMap(t, value);
}
}
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
可以看到ThreadLocal的set是通过Thread类内的ThreadLocalMap来实现的,接着看下ThreadLocalMap的源码:
static class ThreadLocalMap {
/**
* The entries in this hash map extend WeakReference, using
* its main ref field as the key (which is always a
* ThreadLocal object). Note that null keys (i.e. entry.get()
* == null) mean that the key is no longer referenced, so the
* entry can be expunged from table. Such entries are referred to
* as "stale entries" in the code that follows.
*/
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
/**
* The initial capacity -- MUST be a power of two.
*/
private static final int INITIAL_CAPACITY = 16;
/**
* The table, resized as necessary.
* table.length MUST always be a power of two.
*/
private Entry[] table;
/**
* The number of entries in the table.
*/
private int size = 0;
...
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}
private void set(ThreadLocal<?> key, Object value) {
// We don't use a fast path as with get() because it is at
// least as common to use set() to create new entries as
// it is to replace existing ones, in which case, a fast
// path would fail more often than not.
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
if (k == key) {
e.value = value;
return;
}
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
}
可以发现,ThreadLocalMap中维护了一个名为table的Entry数组,通过构造函数和set()方法可知,table初始大小为16,通过ThreadLocal的HashCode和table的大小-1相与,得出此ThreadLocal在table中存放的位置,然后将ThreadLocal实例对象和值封装进Entry,放入table中,在get()时,通过ThreadLocal实例可计算出其在table中的位置,从而找到对应的Entry,就得到了值。因为每个Thread中有各自的ThreadLocalMap,而ThreadLocal是用一个实例,ThreadLocal在每一个线程内的ThreadLocalMap中的table的存储位置计算方法一致,所以不同线程在不同的ThreadLocalMap中的存储位置是一致的,那么在每一个线程中通过get()获取到的值便是该线程私有的,就不会有线程安全问题。
通过以下代码做下测试:
public void test() {
ThreadLocal<Integer> count = new ThreadLocal<>();
for (int i = 0; i < 5; i++) {
int a = i;
new Thread(() -> {
System.out.println(Thread.currentThread().getName() + "-----" + count.get());
count.set(a);
System.out.println(Thread.currentThread().getName() + "-----" + count.get());
}).start();
}
}
结果如下:
Thread-1-----null
Thread-3-----null
Thread-3-----3
Thread-0-----null
Thread-0-----0
Thread-2-----null
Thread-4-----null
Thread-1-----1
Thread-4-----4
Thread-2-----2
可以看出,每个线程在set前get到的都是null,set之后get到的都是本线程set进去的值,各线程之间互不影响,性能要比使用同步方式解决一定情况下的线程安全问题要好。