TtlForkJoinTransformlet.java 4.6 KB
Newer Older
1
package com.alibaba.ttl.threadpool.agent.internal.transformlet.impl;
2

oldratlee's avatar
oldratlee 已提交
3
import com.alibaba.ttl.spi.TtlEnhanced;
4
import com.alibaba.ttl.threadpool.agent.internal.logging.Logger;
oldratlee's avatar
oldratlee 已提交
5
import com.alibaba.ttl.threadpool.agent.internal.transformlet.ClassInfo;
6
import com.alibaba.ttl.threadpool.agent.internal.transformlet.JavassistTransformlet;
7
import edu.umd.cs.findbugs.annotations.NonNull;
8 9 10 11
import javassist.*;

import java.io.IOException;

12
import static com.alibaba.ttl.threadpool.agent.internal.transformlet.impl.Utils.*;
oldratlee's avatar
oldratlee 已提交
13

14 15 16 17 18 19 20 21 22 23
/**
 * TTL {@link JavassistTransformlet} for {@link java.util.concurrent.ForkJoinTask}.
 *
 * @author Jerry Lee (oldratlee at gmail dot com)
 * @author wuwen5 (wuwen.55 at aliyun dot com)
 * @see java.util.concurrent.ForkJoinPool
 * @see java.util.concurrent.ForkJoinTask
 * @since 2.5.1
 */
public class TtlForkJoinTransformlet implements JavassistTransformlet {
24
    private static final Logger logger = Logger.getLogger(TtlForkJoinTransformlet.class);
oldratlee's avatar
oldratlee 已提交
25

26
    private static final String FORK_JOIN_TASK_CLASS_NAME = "java.util.concurrent.ForkJoinTask";
27 28 29
    private static final String FORK_JOIN_POOL_CLASS_NAME = "java.util.concurrent.ForkJoinPool";
    private static final String FORK_JOIN_WORKER_THREAD_FACTORY_CLASS_NAME = "java.util.concurrent.ForkJoinPool$ForkJoinWorkerThreadFactory";

30
    private final boolean disableInheritableForThreadPool;
31

32 33
    public TtlForkJoinTransformlet(boolean disableInheritableForThreadPool) {
        this.disableInheritableForThreadPool = disableInheritableForThreadPool;
34
    }
35 36

    @Override
37
    public void doTransform(@NonNull final ClassInfo classInfo) throws IOException, NotFoundException, CannotCompileException {
oldratlee's avatar
oldratlee 已提交
38 39 40
        if (FORK_JOIN_TASK_CLASS_NAME.equals(classInfo.getClassName())) {
            updateForkJoinTaskClass(classInfo.getCtClass());
            classInfo.setModified();
41
        } else if (disableInheritableForThreadPool && FORK_JOIN_POOL_CLASS_NAME.equals(classInfo.getClassName())) {
oldratlee's avatar
oldratlee 已提交
42 43
            updateConstructorDisableInheritable(classInfo.getCtClass());
            classInfo.setModified();
oldratlee's avatar
oldratlee 已提交
44
        }
45 46
    }

oldratlee's avatar
oldratlee 已提交
47 48 49
    /**
     * @see Utils#doCaptureWhenNotTtlEnhanced(java.lang.Object)
     */
50
    private void updateForkJoinTaskClass(@NonNull final CtClass clazz) throws CannotCompileException, NotFoundException {
51 52
        final String className = clazz.getName();

oldratlee's avatar
oldratlee 已提交
53
        // add new field
oldratlee's avatar
oldratlee 已提交
54
        final String capturedFieldName = "captured$field$added$by$ttl";
oldratlee's avatar
oldratlee 已提交
55
        final CtField capturedField = CtField.make("private final Object " + capturedFieldName + ";", clazz);
56
        clazz.addField(capturedField, "com.alibaba.ttl.threadpool.agent.internal.transformlet.impl.Utils.doCaptureWhenNotTtlEnhanced(this);");
57 58
        logger.info("add new field " + capturedFieldName + " to class " + className);

oldratlee's avatar
oldratlee 已提交
59
        final CtMethod doExecMethod = clazz.getDeclaredMethod("doExec", new CtClass[0]);
60
        final String doExec_renamed_method_name = renamedMethodNameByTtl(doExecMethod);
61

oldratlee's avatar
oldratlee 已提交
62
        final String beforeCode = "if (this instanceof " + TtlEnhanced.class.getName() + ") {\n" + // if the class is already TTL enhanced(eg: com.alibaba.ttl.TtlRecursiveTask)
63
                "    return " + doExec_renamed_method_name + "($$);\n" +                           // return directly/do nothing
64
                "}\n" +
65 66 67 68
                "Object backup = com.alibaba.ttl.TransmittableThreadLocal.Transmitter.replay(" + capturedFieldName + ");";

        final String finallyCode = "com.alibaba.ttl.TransmittableThreadLocal.Transmitter.restore(backup);";

69
        doTryFinallyForMethod(doExecMethod, doExec_renamed_method_name, beforeCode, finallyCode);
70
    }
71

72
    private void updateConstructorDisableInheritable(@NonNull final CtClass clazz) throws NotFoundException, CannotCompileException {
73 74 75 76 77 78 79 80 81 82 83 84 85 86
        for (CtConstructor constructor : clazz.getDeclaredConstructors()) {
            final CtClass[] parameterTypes = constructor.getParameterTypes();
            final StringBuilder insertCode = new StringBuilder();
            for (int i = 0; i < parameterTypes.length; i++) {
                final String paramTypeName = parameterTypes[i].getName();
                if (FORK_JOIN_WORKER_THREAD_FACTORY_CLASS_NAME.equals(paramTypeName)) {
                    String code = String.format("$%d = com.alibaba.ttl.threadpool.TtlForkJoinPoolHelper.getDisableInheritableForkJoinWorkerThreadFactory($%<d);", i + 1);
                    logger.info("insert code before method " + signatureOfMethod(constructor) + " of class " + constructor.getDeclaringClass().getName() + ": " + code);
                    insertCode.append(code);
                }
            }
            if (insertCode.length() > 0) constructor.insertBefore(insertCode.toString());
        }
    }
87
}