diff --git a/src/main/java/com/alibaba/ttl/TransmittableThreadLocal.java b/src/main/java/com/alibaba/ttl/TransmittableThreadLocal.java index 7cf2f9a846037dad02ef8d83b1f3c18c74668c25..9b420636a91993447ee21c17fc32d312cb4cc35a 100644 --- a/src/main/java/com/alibaba/ttl/TransmittableThreadLocal.java +++ b/src/main/java/com/alibaba/ttl/TransmittableThreadLocal.java @@ -1,9 +1,6 @@ package com.alibaba.ttl; -import java.util.HashMap; -import java.util.Iterator; -import java.util.Map; -import java.util.WeakHashMap; +import java.util.*; import java.util.concurrent.Callable; import java.util.function.Supplier; import java.util.logging.Level; @@ -97,22 +94,25 @@ public class TransmittableThreadLocal extends InheritableThreadLocal { return copy(get()); } - private static InheritableThreadLocal, ?>> holder = - new InheritableThreadLocal, ?>>() { + private static InheritableThreadLocal>> holder = + new InheritableThreadLocal>>() { @Override - protected Map, ?> initialValue() { - return new WeakHashMap, Object>(); + protected Set> initialValue() { + return Collections.newSetFromMap(new WeakHashMap, Boolean>()); } @Override - protected Map, ?> childValue(Map, ?> parentValue) { - return new WeakHashMap, Object>(parentValue); + protected Set> childValue(Set> parentValue) { + Set> result = Collections.newSetFromMap( + new WeakHashMap, Boolean>()); + result.addAll(parentValue); + return result; } }; private void addValue() { - if (!holder.get().containsKey(this)) { - holder.get().put(this, null); // WeakHashMap supports null value. + if (!holder.get().contains(this)) { + holder.get().add(this); } } @@ -121,9 +121,7 @@ public class TransmittableThreadLocal extends InheritableThreadLocal { } private static void doExecuteCallback(boolean isBefore) { - for (Map.Entry, ?> entry : holder.get().entrySet()) { - TransmittableThreadLocal threadLocal = entry.getKey(); - + for (TransmittableThreadLocal threadLocal : holder.get()) { try { if (isBefore) { threadLocal.beforeExecute(); @@ -148,8 +146,7 @@ public class TransmittableThreadLocal extends InheritableThreadLocal { System.out.println("Start TransmittableThreadLocal Dump..."); } - for (Map.Entry, ?> entry : holder.get().entrySet()) { - final TransmittableThreadLocal key = entry.getKey(); + for (final TransmittableThreadLocal key : holder.get()) { System.out.println(key.get()); } System.out.println("TransmittableThreadLocal Dump end!"); @@ -246,7 +243,7 @@ public class TransmittableThreadLocal extends InheritableThreadLocal { */ public static Object capture() { Map, Object> captured = new HashMap, Object>(); - for (TransmittableThreadLocal threadLocal : holder.get().keySet()) { + for (TransmittableThreadLocal threadLocal : holder.get()) { captured.put(threadLocal, threadLocal.copyValue()); } return captured; @@ -266,10 +263,8 @@ public class TransmittableThreadLocal extends InheritableThreadLocal { Map, Object> capturedMap = (Map, Object>) captured; Map, Object> backup = new HashMap, Object>(); - for (Iterator, ?>> iterator = holder.get().entrySet().iterator(); - iterator.hasNext(); ) { - Map.Entry, ?> next = iterator.next(); - TransmittableThreadLocal threadLocal = next.getKey(); + for (Iterator> iterator = holder.get().iterator(); iterator.hasNext(); ) { + TransmittableThreadLocal threadLocal = iterator.next(); // backup backup.put(threadLocal, threadLocal.get()); @@ -307,10 +302,8 @@ public class TransmittableThreadLocal extends InheritableThreadLocal { // call afterExecute callback doExecuteCallback(false); - for (Iterator, ?>> iterator = holder.get().entrySet().iterator(); - iterator.hasNext(); ) { - Map.Entry, ?> next = iterator.next(); - TransmittableThreadLocal threadLocal = next.getKey(); + for (Iterator> iterator = holder.get().iterator(); iterator.hasNext(); ) { + TransmittableThreadLocal threadLocal = iterator.next(); // clear the TTL value only in backup // avoid the extra value of backup after restore