diff --git a/paddle/fluid/operators/jitkernels/README.md b/paddle/fluid/operators/jitkernels/README.md index 3401e9be531928091daa0300719ce6c23d35d863..fd6428b43ece8714b2f95e80c07a2a3190e6cb02 100644 --- a/paddle/fluid/operators/jitkernels/README.md +++ b/paddle/fluid/operators/jitkernels/README.md @@ -1,4 +1,46 @@ -TBD +# JIT Kernel + +结合函数模板和JIT生成需要的kernel函数。 +这里的kernel是比Operator中kernel更小级别的算子单元,更侧重的是在不同硬件上的性能。 +目前仅支持CPU上的高性能计算。 + +## 目录结构 + +```txt +PaddlePaddle/Paddle/paddle/fluid/ +├── ... +├── operator/ +│ ├── .../ +└── jit/ + ├── ... + ├── jitcode/ + │ └── ... + |── more/ + │ ├── ... + │ ├── mkl/ + │ │ └── ... + │ └── openblas/ + │ └── ... + └── refer/ + └── ... +``` + +基础class都的根目录下,根目录下包括jitcode,more和refer。每个目录下都是一种实现,每种kernel算子都需要有reference的实现,其他的都是可选的。 +- jitcode: 代表使用jit生成的code,需要依赖xbyak。他关心的是性能。 +- refer:代表reference的实现,每种kernel算子都需要有在CPU上的reference的实现,他主要关心的算法逻辑。 +- more: 下面可以放入跟多实现,包括mkl,mkldnn,openblas等,也可以是自身已有的kernel组合。 -# Use me +## 动态获取 + +提供一个get方法,根据kernel类别获取,每种实现都有自己的使用范围,根据范围动态和当前条件选择需要的kernel函数。 + +## 测试 + +- 逻辑测试 + 所有实现都要与refer的code对比,需要满足精度要求 +- 性能测试 + +# 如何添加新的算子 +TBD +## Use me Add USE_JIT_KERNEL(yourname) to CMakefile. diff --git a/paddle/fluid/operators/jitkernels/jitcode_base.h b/paddle/fluid/operators/jitkernels/jitcode_base.h index ffec62163a70ebb9a1e43d5ab736f4f509719fc9..de8aaf229fe8f2844afd76b7536f5fa13e569bce 100644 --- a/paddle/fluid/operators/jitkernels/jitcode_base.h +++ b/paddle/fluid/operators/jitkernels/jitcode_base.h @@ -62,11 +62,7 @@ class JitBase : public Kernel { }; template -std::unique_ptr CreateJitCode(Attr attr); //{ -// if (UseJitCode) { -// return make_unique(attr, CodeSize()); -// } -// } +std::unique_ptr CreateJitCode(Attr attr); } // namespace jitkernels } // namespace operators diff --git a/paddle/fluid/operators/jitkernels/kernel_base.h b/paddle/fluid/operators/jitkernels/kernel_base.h index eeaa0617cb8c941bb669939b6824309f3730d308..6fbb0f9f7ea71d622ee0e7e2cf29db287ca53cfa 100644 --- a/paddle/fluid/operators/jitkernels/kernel_base.h +++ b/paddle/fluid/operators/jitkernels/kernel_base.h @@ -21,6 +21,13 @@ namespace jitkernels { typedef enum { vmul = 0, vadd = 1, vsub, vexp } KernelType; +template +struct VMulTypes { + typedef T data_type; + typedef int attr_type; + typedef void (*func_type)(const T*, const T*, T*, int); +}; + // Just for adding to kernel pool without template class Kernel { public: @@ -29,10 +36,10 @@ class Kernel { DISABLE_COPY_AND_ASSIGN(Kernel); }; -template // TODO(TJ): use tuple +template class KernelImpl : public Kernel { public: - using ELEMENT_TYPE = T; // TODO(TJ): remove me? + using ELEMENT_TYPE = T; virtual Func GetFunc() const { return func; } virtual bool UseMe(Attr attr) const = 0; @@ -40,7 +47,7 @@ class KernelImpl : public Kernel { Func func{nullptr}; }; -template // TODO(TJ): use tuple +template class ReferKernel : public KernelImpl { public: // Refer code can always be used diff --git a/paddle/fluid/operators/jitkernels/kernel_pool.h b/paddle/fluid/operators/jitkernels/kernel_pool.h index f398093dfe2b1bf66b5e1d547ba824de96ef5ea5..901a891cb38a260f304d0d883229b2ab8ae3bd38 100644 --- a/paddle/fluid/operators/jitkernels/kernel_pool.h +++ b/paddle/fluid/operators/jitkernels/kernel_pool.h @@ -27,8 +27,6 @@ namespace paddle { namespace operators { namespace jitkernels { -// TODO(TJ): rename file to kernel_pool - template class JitCodePool { typedef std::unique_ptr JitBasePtr; @@ -54,14 +52,6 @@ class JitCodePool { DISABLE_COPY_AND_ASSIGN(JitCodePool); }; -// TODO(TJ): std::tuple -// template -// struct KernelAttr { -// typedef T data_type; -// typedef Func return_type; -// typedef Attr attr_type; -// }; - typedef std::unique_ptr KernelPtr; typedef std::unordered_map, KernelKey::Hash> KernelMap; @@ -120,7 +110,6 @@ inline Func GetRefer() { return nullptr; } -// TODO(TJ): make tuple? named KernelAttr template const Func Get(Attr attr) { @@ -130,8 +119,7 @@ const Func Get(Attr attr) { return codes.AllKernels().at(key)->template getCode(); } - if (std::is_same::value) { // TODO(TJ): float - // move to create + if (std::is_same::value) { auto p = CreateJitCode(attr); if (p) { auto f = p->template getCode(); diff --git a/paddle/fluid/operators/jitkernels/more/mkl/mkl.h b/paddle/fluid/operators/jitkernels/more/mkl/mkl.h index 75ed34ef48eec5be41d2a098454f06cba6fa60b1..9cf032db43f4aac700a575bc0aac0ae961c28f9f 100644 --- a/paddle/fluid/operators/jitkernels/more/mkl/mkl.h +++ b/paddle/fluid/operators/jitkernels/more/mkl/mkl.h @@ -27,16 +27,9 @@ namespace mkl { template void VMul(const T* x, const T* y, T* z, int n); -// template -// struct VMulTypes{ -// typedef T date_type; -// typedef void (*func)(const T*, const T*, T*, int) func_type; -// typedef int attr_type; -// }; - template -class VMulKernel - : public KernelImpl { +class VMulKernel : public KernelImpl::func_type, + typename VMulTypes::attr_type> { public: VMulKernel() { this->func = VMul; } bool UseMe(int d) const override { diff --git a/paddle/fluid/operators/jitkernels/refer/refer.h b/paddle/fluid/operators/jitkernels/refer/refer.h index 163c6d73dce4d9257dbaf2ff93dda842e7bb16fb..796f58d40177be87e33bbe540ed1b9cf75e16c92 100644 --- a/paddle/fluid/operators/jitkernels/refer/refer.h +++ b/paddle/fluid/operators/jitkernels/refer/refer.h @@ -29,8 +29,8 @@ void VMul(const T* x, const T* y, T* z, int n) { } template -class VMulKernel - : public ReferKernel { +class VMulKernel : public ReferKernel::func_type, + typename VMulTypes::attr_type> { public: VMulKernel() { this->func = VMul; } }; diff --git a/paddle/fluid/operators/jitkernels/registry.h b/paddle/fluid/operators/jitkernels/registry.h index cd414bb096c6956bd0908ccf405a01cb7879391f..6d817461bec8dcdbcf6e659d98a720afd7b61a95 100644 --- a/paddle/fluid/operators/jitkernels/registry.h +++ b/paddle/fluid/operators/jitkernels/registry.h @@ -26,7 +26,7 @@ namespace paddle { namespace operators { namespace jitkernels { -// make_unique is supported from c++14 +// make_unique is supported since c++14 template inline std::unique_ptr make_unique(Args&&... args) { static_assert(!std::is_array::value, "T must not be array"); diff --git a/paddle/fluid/operators/jitkernels/test.cc b/paddle/fluid/operators/jitkernels/test.cc index eb0d30eecdbfc2cedab8dadc764139a95c1e6846..d27b5d1cbae9b2644382106daaf12a50df7f942c 100644 --- a/paddle/fluid/operators/jitkernels/test.cc +++ b/paddle/fluid/operators/jitkernels/test.cc @@ -69,10 +69,10 @@ TEST(JitKernel, vmul) { namespace jit = paddle::operators::jitkernels; // TODO(TJ): test more vector size for (int d = 1; d < 30; ++d) { - auto ref = jit::GetRefer(); - auto tgt = jit::Get(d); + auto ref = jit::GetRefer::func_type, + jit::VMulTypes::attr_type>(); + auto tgt = jit::Get::func_type, + jit::VMulTypes::attr_type, PlaceType>(d); EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(tgt != nullptr);