diff --git a/src/main/java/com/alibaba/ttl/TransmittableOnlyThreadLocal.java b/src/main/java/com/alibaba/ttl/TransmittableOnlyThreadLocal.java new file mode 100644 index 0000000000000000000000000000000000000000..bd3230ff81f9eb8c3fb13d758c97635ce1086ec9 --- /dev/null +++ b/src/main/java/com/alibaba/ttl/TransmittableOnlyThreadLocal.java @@ -0,0 +1,109 @@ +package com.alibaba.ttl; + +import java.util.Map; +import java.util.logging.Logger; + +/** + * {@link TransmittableOnlyThreadLocal} can transmit value from the thread of submitting task to the thread of executing task. + *

+ * Note: {@link TransmittableOnlyThreadLocal} extends {@link InheritableThreadLocal}, + * so {@link TransmittableOnlyThreadLocal} first is a {@link InheritableThreadLocal}. + * + * @author Jerry Lee (oldratlee at gmail dot com) + * @see TtlRunnable + * @see TtlCallable + * @since 0.10.0 + */ +public class TransmittableOnlyThreadLocal extends InheritableThreadLocal { + private static final Logger logger = Logger.getLogger(TransmittableOnlyThreadLocal.class.getName()); + + /** + * Computes the value for this transmittable thread-local variable + * as a function of the source thread's value at the time the task + * Object is created. This method is called from {@link TtlRunnable} or + * {@link TtlCallable} when it create, before the task is started. + *

+ * This method merely returns reference of its source thread value, and should be overridden + * if a different behavior is desired. + * + * @since 1.0.0 + */ + protected T copy(T parentValue) { + return parentValue; + } + + /** + * Callback method before task object({@link TtlRunnable}/{@link TtlCallable}) execute. + *

+ * Default behavior is do nothing, and should be overridden + * if a different behavior is desired. + *

+ * Do not throw any exception, just ignored. + * + * @since 1.2.0 + */ + protected void beforeExecute() { + } + + /** + * Callback method after task object({@link TtlRunnable}/{@link TtlCallable}) execute. + *

+ * Default behavior is do nothing, and should be overridden + * if a different behavior is desired. + *

+ * Do not throw any exception, just ignored. + * + * @since 1.2.0 + */ + protected void afterExecute() { + } + + /** + * see {@link InheritableThreadLocal#get()} + */ + @Override + public final T get() { + T value = super.get(); + if (null != value) addValue(); + return value; + } + + /** + * see {@link InheritableThreadLocal#set} + */ + @Override + public final void set(T value) { + super.set(value); + // may set null to remove value + if (null == value) removeValue(); + else addValue(); + } + + /** + * see {@link InheritableThreadLocal#remove()} + */ + @Override + public final void remove() { + removeValue(); + super.remove(); + } + + void superRemove() { + super.remove(); + } + + T copyValue() { + return copy(get()); + } + + private void addValue() { + final InheritableThreadLocal> holder = TransmittableThreadLocal.holder; + if (!holder.get().containsKey(this)) { + holder.get().put(this, null); // WeakHashMap supports null value. + } + } + + private void removeValue() { + TransmittableThreadLocal.holder.get().remove(this); + } +} diff --git a/src/main/java/com/alibaba/ttl/TransmittableThreadLocal.java b/src/main/java/com/alibaba/ttl/TransmittableThreadLocal.java index 755a4c6da08fa06a8b30bbe170dc827e821efe1f..ae4b0cabf9f9bcf824f6515481dc96850c45eb14 100644 --- a/src/main/java/com/alibaba/ttl/TransmittableThreadLocal.java +++ b/src/main/java/com/alibaba/ttl/TransmittableThreadLocal.java @@ -105,16 +105,16 @@ public class TransmittableThreadLocal extends InheritableThreadLocal { // 1. The value of holder is type Map, ?> (WeakHashMap implementation), // but it is used as *set*. // 2. WeakHashMap support null value. - private static InheritableThreadLocal, ?>> holder = - new InheritableThreadLocal, ?>>() { + static InheritableThreadLocal> holder = + new InheritableThreadLocal>() { @Override - protected Map, ?> initialValue() { - return new WeakHashMap, Object>(); + protected Map initialValue() { + return new WeakHashMap(); } @Override - protected Map, ?> childValue(Map, ?> parentValue) { - return new WeakHashMap, Object>(parentValue); + protected Map childValue(Map parentValue) { + return new WeakHashMap(parentValue); } }; @@ -129,12 +129,14 @@ public class TransmittableThreadLocal extends InheritableThreadLocal { } private static void doExecuteCallback(boolean isBefore) { - for (Map.Entry, ?> entry : holder.get().entrySet()) { - TransmittableThreadLocal threadLocal = entry.getKey(); + for (Map.Entry entry : holder.get().entrySet()) { + Object threadLocal = entry.getKey(); try { - if (isBefore) threadLocal.beforeExecute(); - else threadLocal.afterExecute(); + if (isBefore) + if (threadLocal instanceof TransmittableThreadLocal) + ((TransmittableThreadLocal) threadLocal).beforeExecute(); + else ((TransmittableOnlyThreadLocal) threadLocal).afterExecute(); } catch (Throwable t) { if (logger.isLoggable(Level.WARNING)) { logger.log(Level.WARNING, "TTL exception when " + (isBefore ? "beforeExecute" : "afterExecute") + ", cause: " + t.toString(), t); @@ -153,9 +155,11 @@ public class TransmittableThreadLocal extends InheritableThreadLocal { System.out.println("Start TransmittableThreadLocal Dump..."); } - for (Map.Entry, ?> entry : holder.get().entrySet()) { - final TransmittableThreadLocal key = entry.getKey(); - System.out.println(key.get()); + for (Map.Entry entry : holder.get().entrySet()) { + final Object threadLocal = entry.getKey(); + if (threadLocal instanceof TransmittableThreadLocal) + System.out.println(((TransmittableThreadLocal) threadLocal).get()); + else System.out.println(((TransmittableOnlyThreadLocal) threadLocal).get()); } System.out.println("TransmittableThreadLocal Dump end!"); } @@ -251,9 +255,11 @@ public class TransmittableThreadLocal extends InheritableThreadLocal { */ @Nonnull public static Object capture() { - Map, Object> captured = new HashMap, Object>(); - for (TransmittableThreadLocal threadLocal : holder.get().keySet()) { - captured.put(threadLocal, threadLocal.copyValue()); + Map captured = new HashMap(); + for (Object threadLocal : holder.get().keySet()) { + if (threadLocal instanceof TransmittableThreadLocal) + captured.put(threadLocal, ((TransmittableThreadLocal) threadLocal).copyValue()); + else captured.put(threadLocal, ((TransmittableOnlyThreadLocal) threadLocal).copyValue()); } return captured; } @@ -271,21 +277,29 @@ public class TransmittableThreadLocal extends InheritableThreadLocal { public static Object replay(@Nonnull Object captured) { @SuppressWarnings("unchecked") Map, Object> capturedMap = (Map, Object>) captured; - Map, Object> backup = new HashMap, Object>(); + Map backup = new HashMap(); - for (Iterator, ?>> iterator = holder.get().entrySet().iterator(); + for (Iterator> iterator = holder.get().entrySet().iterator(); iterator.hasNext(); ) { - Map.Entry, ?> next = iterator.next(); - TransmittableThreadLocal threadLocal = next.getKey(); + Map.Entry next = iterator.next(); + Object threadLocal = next.getKey(); // backup - backup.put(threadLocal, threadLocal.get()); + Object value; + if (threadLocal instanceof TransmittableThreadLocal) { + value = ((TransmittableThreadLocal) threadLocal).get(); + } else value = ((TransmittableOnlyThreadLocal) threadLocal).get(); + + backup.put(threadLocal, value); // clear the TTL values that is not in captured // avoid the extra TTL values after replay when run task if (!capturedMap.containsKey(threadLocal)) { iterator.remove(); - threadLocal.superRemove(); + + if (threadLocal instanceof TransmittableThreadLocal) + ((TransmittableThreadLocal) threadLocal).superRemove(); + else ((TransmittableOnlyThreadLocal) threadLocal).superRemove(); } } @@ -324,16 +338,19 @@ public class TransmittableThreadLocal extends InheritableThreadLocal { // call afterExecute callback doExecuteCallback(false); - for (Iterator, ?>> iterator = holder.get().entrySet().iterator(); + for (Iterator> iterator = holder.get().entrySet().iterator(); iterator.hasNext(); ) { - Map.Entry, ?> next = iterator.next(); - TransmittableThreadLocal threadLocal = next.getKey(); + Map.Entry next = iterator.next(); + Object threadLocal = next.getKey(); // clear the TTL values that is not in backup // avoid the extra TTL values after restore if (!backupMap.containsKey(threadLocal)) { iterator.remove(); - threadLocal.superRemove(); + + if (threadLocal instanceof TransmittableThreadLocal) + ((TransmittableThreadLocal) threadLocal).superRemove(); + else ((TransmittableOnlyThreadLocal) threadLocal).superRemove(); } }