TransmittableThreadLocal.java 13.8 KB
Newer Older
1 2
package com.alibaba.ttl;

3
import java.util.*;
4 5
import java.util.concurrent.Callable;
import java.util.function.Supplier;
J
Jerry Lee 已提交
6 7
import java.util.logging.Level;
import java.util.logging.Logger;
8 9 10 11

/**
 * {@link TransmittableThreadLocal} can transmit value from the thread of submitting task to the thread of executing task.
 * <p>
oldratlee's avatar
oldratlee 已提交
12
 * Note: {@link TransmittableThreadLocal} extends {@link java.lang.InheritableThreadLocal},
13 14 15 16 17 18 19 20
 * so {@link TransmittableThreadLocal} first is a {@link java.lang.InheritableThreadLocal}.
 *
 * @author Jerry Lee (oldratlee at gmail dot com)
 * @see TtlRunnable
 * @see TtlCallable
 * @since 0.10.0
 */
public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> {
J
Jerry Lee 已提交
21 22
    private static final Logger logger = Logger.getLogger(TransmittableThreadLocal.class.getName());

23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
    /**
     * Computes the value for this transmittable thread-local variable
     * as a function of the source thread's value at the time the task
     * Object is created.  This method is called from {@link TtlRunnable} or
     * {@link TtlCallable} when it create, before the task is started.
     * <p>
     * This method merely returns reference of its source thread value, and should be overridden
     * if a different behavior is desired.
     *
     * @since 1.0.0
     */
    protected T copy(T parentValue) {
        return parentValue;
    }

    /**
     * Callback method before task object({@link TtlRunnable}/{@link TtlCallable}) execute.
     * <p>
     * Default behavior is do nothing, and should be overridden
     * if a different behavior is desired.
     * <p>
     * Do not throw any exception, just ignored.
     *
     * @since 1.2.0
     */
    protected void beforeExecute() {
    }

    /**
     * Callback method after task object({@link TtlRunnable}/{@link TtlCallable}) execute.
     * <p>
     * Default behavior is do nothing, and should be overridden
     * if a different behavior is desired.
     * <p>
     * Do not throw any exception, just ignored.
     *
     * @since 1.2.0
     */
    protected void afterExecute() {
    }

    @Override
    public final T get() {
        T value = super.get();
        if (null != value) {
            addValue();
        }
        return value;
    }

    @Override
    public final void set(T value) {
        super.set(value);
        if (null == value) { // may set null to remove value
            removeValue();
        } else {
            addValue();
        }
    }

    @Override
    public final void remove() {
        removeValue();
        super.remove();
    }

89
    void superRemove() {
90 91 92 93 94 95 96
        super.remove();
    }

    T copyValue() {
        return copy(get());
    }

97 98
    private static InheritableThreadLocal<Set<TransmittableThreadLocal<?>>> holder =
            new InheritableThreadLocal<Set<TransmittableThreadLocal<?>>>() {
99
                @Override
100 101
                protected Set<TransmittableThreadLocal<?>> initialValue() {
                    return Collections.newSetFromMap(new WeakHashMap<TransmittableThreadLocal<?>, Boolean>());
102
                }
103 104

                @Override
105 106 107 108 109
                protected Set<TransmittableThreadLocal<?>> childValue(Set<TransmittableThreadLocal<?>> parentValue) {
                    Set<TransmittableThreadLocal<?>> result = Collections.newSetFromMap(
                            new WeakHashMap<TransmittableThreadLocal<?>, Boolean>());
                    result.addAll(parentValue);
                    return result;
110
                }
111 112
            };

oldratlee's avatar
oldratlee 已提交
113
    private void addValue() {
114 115
        if (!holder.get().contains(this)) {
            holder.get().add(this);
116 117 118
        }
    }

oldratlee's avatar
oldratlee 已提交
119
    private void removeValue() {
120 121 122 123
        holder.get().remove(this);
    }

    private static void doExecuteCallback(boolean isBefore) {
124
        for (TransmittableThreadLocal<?> threadLocal : holder.get()) {
125 126 127 128 129 130 131
            try {
                if (isBefore) {
                    threadLocal.beforeExecute();
                } else {
                    threadLocal.afterExecute();
                }
            } catch (Throwable t) {
J
Jerry Lee 已提交
132 133 134
                if (logger.isLoggable(Level.WARNING)) {
                    logger.log(Level.WARNING, "TTL exception when " + (isBefore ? "beforeExecute" : "afterExecute") + ", cause: " + t.toString(), t);
                }
135 136 137
            }
        }
    }
oldratlee's avatar
oldratlee 已提交
138 139 140 141 142 143 144 145 146 147 148

    /**
     * Debug only method!
     */
    static void dump(String title) {
        if (title != null && title.length() > 0) {
            System.out.printf("Start TransmittableThreadLocal[%s] Dump...\n", title);
        } else {
            System.out.println("Start TransmittableThreadLocal Dump...");
        }

149
        for (final TransmittableThreadLocal<?> key : holder.get()) {
oldratlee's avatar
oldratlee 已提交
150 151 152 153 154 155 156 157 158 159 160
            System.out.println(key.get());
        }
        System.out.println("TransmittableThreadLocal Dump end!");
    }

    /**
     * Debug only method!
     */
    static void dump() {
        dump(null);
    }
161 162

    /**
oldratlee's avatar
oldratlee 已提交
163 164
     * {@link Transmitter} transmit all {@link TransmittableThreadLocal} values of current thread to
     * other thread by static method {@link #capture()} =&gt; {@link #replay(Object)} =&gt; {@link #restore(Object)} (aka {@code CRR} operation).
165
     * <p>
oldratlee's avatar
oldratlee 已提交
166 167 168
     * {@link Transmitter} is <b><i>internal</i></b> manipulation api for <b><i>framework/middleware integration</i></b>;
     * In general, you will <b><i>never</i></b> use it in the <i>biz/application code</i>!
     * <p>
169
     * Below is the example code:
oldratlee's avatar
oldratlee 已提交
170 171 172
     *
     * <pre><code>
     * ///////////////////////////////////////////////////////////////////////////
oldratlee's avatar
oldratlee 已提交
173
     * // in thread A, capture all TransmittableThreadLocal values of thread A
oldratlee's avatar
oldratlee 已提交
174 175
     * ///////////////////////////////////////////////////////////////////////////
     *
176
     * Object captured = Transmitter.capture(); // (1)
oldratlee's avatar
oldratlee 已提交
177 178
     *
     * ///////////////////////////////////////////////////////////////////////////
179
     * // in thread B
oldratlee's avatar
oldratlee 已提交
180 181
     * ///////////////////////////////////////////////////////////////////////////
     *
182 183
     * // replay all TransmittableThreadLocal values from thread A
     * Object backup = Transmitter.replay(captured); // (2)
oldratlee's avatar
oldratlee 已提交
184
     * try {
185 186 187 188
     *     // your biz logic, run with the TransmittableThreadLocal values of thread B
     *     System.out.println("Hello");
     *     // ...
     *     return "World";
oldratlee's avatar
oldratlee 已提交
189
     * } finally {
190 191
     *     // restore the TransmittableThreadLocal of thread B when replay
     *     Transmitter.restore(backup); (3)
oldratlee's avatar
oldratlee 已提交
192 193 194 195
     * }
     * </code></pre>
     * <p>
     * see the implementation code of {@link TtlRunnable} and {@link TtlCallable} for more actual code sample.
196 197
     * <hr>
     * Of course, {@link #replay(Object)} and {@link #restore(Object)} operation can be simplified
oldratlee's avatar
oldratlee 已提交
198
     * by util methods {@link #runCallableWithCaptured(Object, Callable)} or {@link #runSupplierWithCaptured(Object, Supplier)}
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
     * and the adorable {@code Java 8 lambda syntax}.
     * <p>
     * Below is the example code:
     *
     * <pre><code>
     * ///////////////////////////////////////////////////////////////////////////
     * // in thread A, capture all TransmittableThreadLocal values of thread A
     * ///////////////////////////////////////////////////////////////////////////
     *
     * Object captured = Transmitter.capture(); // (1)
     *
     * ///////////////////////////////////////////////////////////////////////////
     * // in thread B
     * ///////////////////////////////////////////////////////////////////////////
     *
     * String result = runSupplierWithCaptured(captured, () -&gt; {
     *      // your biz logic, run with the TransmittableThreadLocal values of thread A
     *      System.out.println("Hello");
     *      ...
     *      return "World";
     * }); // (2) + (3)
     * </code></pre>
     * <p>
oldratlee's avatar
oldratlee 已提交
222
     * The reason of providing 2 util methods is the different {@code throws Exception} type from biz logic({@code lambda}):
223 224
     * <ol>
     * <li>{@link #runCallableWithCaptured(Object, Callable)}: No {@code throws}</li>
oldratlee's avatar
oldratlee 已提交
225
     * <li>{@link #runSupplierWithCaptured(Object, Supplier)}: {@code throws Exception}</li>
226 227 228
     * </ol>
     * <p>
     * If you has the different {@code throws Exception},
oldratlee's avatar
oldratlee 已提交
229
     * you can define your own util method with your own {@code throws Exception} type function interface({@code lambda}).
230 231
     *
     * @author Yang Fang (snoop dot fy at gmail dot com)
232
     * @author Jerry Lee (oldratlee at gmail dot com)
233 234 235 236 237 238 239 240
     * @see TtlRunnable
     * @see TtlCallable
     * @since 2.3.0
     */
    public static class Transmitter {
        /**
         * Capture all {@link TransmittableThreadLocal} values in current thread.
         *
oldratlee's avatar
oldratlee 已提交
241
         * @return the captured {@link TransmittableThreadLocal} values
242
         * @since 2.3.0
243 244
         */
        public static Object capture() {
245
            Map<TransmittableThreadLocal<?>, Object> captured = new HashMap<TransmittableThreadLocal<?>, Object>();
246
            for (TransmittableThreadLocal<?> threadLocal : holder.get()) {
247 248 249 250 251 252
                captured.put(threadLocal, threadLocal.copyValue());
            }
            return captured;
        }

        /**
oldratlee's avatar
oldratlee 已提交
253 254
         * Replay the captured {@link TransmittableThreadLocal} values from {@link #capture()},
         * and return the backup {@link TransmittableThreadLocal} values in current thread before replay.
255
         *
oldratlee's avatar
oldratlee 已提交
256 257
         * @param captured captured {@link TransmittableThreadLocal} values from other thread from {@link #capture()}
         * @return the backup {@link TransmittableThreadLocal} values before replay
258
         * @see #capture()
259
         * @since 2.3.0
260 261 262 263
         */
        public static Object replay(Object captured) {
            @SuppressWarnings("unchecked")
            Map<TransmittableThreadLocal<?>, Object> capturedMap = (Map<TransmittableThreadLocal<?>, Object>) captured;
264
            Map<TransmittableThreadLocal<?>, Object> backup = new HashMap<TransmittableThreadLocal<?>, Object>();
265

266 267
            for (Iterator<TransmittableThreadLocal<?>> iterator = holder.get().iterator(); iterator.hasNext(); ) {
                TransmittableThreadLocal<?> threadLocal = iterator.next();
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293

                // backup
                backup.put(threadLocal, threadLocal.get());

                // clear the TTL value only in captured
                // avoid extra TTL value in captured, when run task.
                if (!capturedMap.containsKey(threadLocal)) {
                    iterator.remove();
                    threadLocal.superRemove();
                }
            }

            // set value to captured TTL
            for (Map.Entry<TransmittableThreadLocal<?>, Object> entry : capturedMap.entrySet()) {
                @SuppressWarnings("unchecked")
                TransmittableThreadLocal<Object> threadLocal = (TransmittableThreadLocal<Object>) entry.getKey();
                threadLocal.set(entry.getValue());
            }

            // call beforeExecute callback
            doExecuteCallback(true);

            return backup;
        }

        /**
oldratlee's avatar
oldratlee 已提交
294
         * Restore the backup {@link TransmittableThreadLocal} values from {@link Transmitter#replay(Object)}.
295
         *
oldratlee's avatar
oldratlee 已提交
296
         * @param backup the backup {@link TransmittableThreadLocal} values from {@link Transmitter#replay(Object)}
297
         * @since 2.3.0
298 299 300 301 302 303 304
         */
        public static void restore(Object backup) {
            @SuppressWarnings("unchecked")
            Map<TransmittableThreadLocal<?>, Object> backupMap = (Map<TransmittableThreadLocal<?>, Object>) backup;
            // call afterExecute callback
            doExecuteCallback(false);

305 306
            for (Iterator<TransmittableThreadLocal<?>> iterator = holder.get().iterator(); iterator.hasNext(); ) {
                TransmittableThreadLocal<?> threadLocal = iterator.next();
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322

                // clear the TTL value only in backup
                // avoid the extra value of backup after restore
                if (!backupMap.containsKey(threadLocal)) {
                    iterator.remove();
                    threadLocal.superRemove();
                }
            }

            // restore TTL value
            for (Map.Entry<TransmittableThreadLocal<?>, Object> entry : backupMap.entrySet()) {
                @SuppressWarnings("unchecked")
                TransmittableThreadLocal<Object> threadLocal = (TransmittableThreadLocal<Object>) entry.getKey();
                threadLocal.set(entry.getValue());
            }
        }
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365

        /**
         * Util method for simplifying {@link #replay(Object)} and {@link #restore(Object)} operation.
         *
         * @param captured captured {@link TransmittableThreadLocal} values from other thread from {@link #capture()}
         * @param bizLogic biz logic
         * @param <R>      the return type of biz logic
         * @return the return value of biz logic
         * @see #capture()
         * @see #replay(Object)
         * @see #restore(Object)
         * @since 2.3.1
         */
        public static <R> R runSupplierWithCaptured(Object captured, Supplier<R> bizLogic) {
            Object backup = replay(captured);
            try {
                return bizLogic.get();
            } finally {
                restore(backup);
            }
        }

        /**
         * Util method for simplifying {@link #replay(Object)} and {@link #restore(Object)} operation.
         *
         * @param captured captured {@link TransmittableThreadLocal} values from other thread from {@link #capture()}
         * @param bizLogic biz logic
         * @param <R>      the return type of biz logic
         * @return the return value of biz logic
         * @throws Exception exception threw by biz logic
         * @see #capture()
         * @see #replay(Object)
         * @see #restore(Object)
         * @since 2.3.1
         */
        public static <R> R runCallableWithCaptured(Object captured, Callable<R> bizLogic) throws Exception {
            Object backup = replay(captured);
            try {
                return bizLogic.call();
            } finally {
                restore(backup);
            }
        }
366 367 368 369

        private Transmitter() {
            throw new InstantiationError("Must not instantiate this class");
        }
370
    }
371
}