diff --git a/paddle/fluid/operators/jit/README.md b/paddle/fluid/operators/jit/README.md index 6b2f2b2848e47fa5fadf0e5874cbbc80ffb2c1a7..28d21f40af37e5298f4645b02049c284796ae556 100644 --- a/paddle/fluid/operators/jit/README.md +++ b/paddle/fluid/operators/jit/README.md @@ -45,6 +45,8 @@ PaddlePaddle/Paddle/paddle/fluid/ - 在`KernelType` 中添加 `your_key` . - 实现Reference 的逻辑,每个jitkernel的Reference 实现是必须的。不要依赖任何第三方库。并在`refer/CmakeLists.txt`中`USE_JITKERNEL_REFER(your_key)`. +- (optional) 实现更多的算法在`more`目录下,可以依赖mkl,openblas,或者mkldnn等第三方库。 +- (optional) 实现基于Xbyak的生成code,在`gen`目下。 - 必要时可以添加新的`KernelTuples`,可以参考`XYZNTuples`,新加的Attr类型需要特例化`JitCodeKey`方法。 - 添加unit test,需要测试float和double - 添加benchmark确保get得到的速度是最快。 diff --git a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt index 0c15c7060d2d497c4881b27b72bb9aa205abd55f..ffecb732975a652456b154a61feb8a20a727d306 100644 --- a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt @@ -4,3 +4,5 @@ set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE # use mkl kernels by name and type USE_JITKERNEL_MORE(vmul, mkl) +USE_JITKERNEL_MORE(vadd, mkl) +USE_JITKERNEL_MORE(vscal, mkl) diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index 0ffe1d565f1b2a1639c82b85e9c0825d64461de7..3d963cbf1dd5468afc717178f1a53234d8e14a99 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -13,7 +13,9 @@ * limitations under the License. */ #include "paddle/fluid/operators/jit/more/mkl/mkl.h" +#include "paddle/fluid/operators/jit/refer/refer.h" #include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/dynload/mklml.h" namespace paddle { @@ -32,6 +34,61 @@ void VMul(const double* x, const double* y, double* z, int n) { platform::dynload::vdMul(n, x, y, z); } +template <> +void VAdd(const float* x, const float* y, float* z, int n) { + platform::dynload::vsAdd(n, x, y, z); +} + +template <> +void VAdd(const double* x, const double* y, double* z, int n) { + platform::dynload::vdAdd(n, x, y, z); +} + +template <> +void VScal(const float* a, const float* x, float* y, int n) { + if (x == y) { + platform::dynload::cblas_sscal(n, *a, y, 1); + } else { + refer::VScal(a, x, y, n); + } +} + +template <> +void VScal(const double* a, const double* x, double* y, int n) { + if (x == y) { + platform::dynload::cblas_dscal(n, *a, y, 1); + } else { + refer::VScal(a, x, y, n); + } +} + +// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 +template <> +bool VMulKernel::UseMe(int d) const { + return platform::MayIUse(platform::avx512f) && d > 512; +} + +template <> +bool VAddKernel::UseMe(int d) const { + return platform::MayIUse(platform::avx512f) && d > 512; +} + +template <> +bool VScalKernel::UseMe(int d) const { + return platform::MayIUse(platform::avx512f) && d > 512; +} + +#define AWALYS_USE_ME_WITH_DOUBLE(func) \ + template <> \ + bool func##Kernel::UseMe(int d) const { \ + return true; \ + } + +AWALYS_USE_ME_WITH_DOUBLE(VMul); +AWALYS_USE_ME_WITH_DOUBLE(VAdd); +AWALYS_USE_ME_WITH_DOUBLE(VScal); + +#undef AWALYS_USE_ME_WITH_DOUBLE } // namespace mkl } // namespace more } // namespace jit @@ -40,5 +97,12 @@ void VMul(const double* x, const double* y, double* z, int n) { namespace mkl = paddle::operators::jit::more::mkl; -REGISTER_JITKERNEL_MORE(vmul, mkl, mkl::VMulKernel, - mkl::VMulKernel); +#define REGISTER_MKL_KERNEL(key, func) \ + REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel, \ + mkl::func##Kernel) + +REGISTER_MKL_KERNEL(vmul, VMul); +REGISTER_MKL_KERNEL(vadd, VAdd); +REGISTER_MKL_KERNEL(vscal, VScal); + +#undef REGISTER_MKL_KERNEL diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index 4173d1f3de0ce4a7ee727d0261d2fede86bb72b7..84a93f408f51e444c62ea3b70fba8daab280fed0 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -16,7 +16,6 @@ #include #include "paddle/fluid/operators/jit/kernel_base.h" -#include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace operators { @@ -28,17 +27,27 @@ template void VMul(const T* x, const T* y, T* z, int n); template -class VMulKernel : public KernelImpl> { - public: - VMulKernel() { this->func = VMul; } - bool UseMe(int d) const override { - if (std::is_same::value) { - return platform::MayIUse(platform::avx512f) && d > 512; - } else { - return true; - } +void VAdd(const T* x, const T* y, T* z, int n); + +template +void VScal(const T* a, const T* x, T* y, int n); + +#define DECLARE_MKL_KERNEL(name, tuples) \ + template \ + class name##Kernel : public KernelImpl> { \ + public: \ + name##Kernel() { this->func = name; } \ + bool UseMe(typename tuples::attr_type) const override; \ } -}; + +// XYZN +DECLARE_MKL_KERNEL(VMul, XYZNTuples); +DECLARE_MKL_KERNEL(VAdd, XYZNTuples); + +// AXYN +DECLARE_MKL_KERNEL(VScal, AXYNTuples); + +#undef DECLARE_MKL_KERNEL } // namespace mkl } // namespace more diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 8cf588efba52314650bfd376b95b10e6d4336b2e..682e51e89d67f36f4da14ba4b80416f9ad8a1fa1 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -31,56 +31,6 @@ namespace operators { namespace math { namespace jitkernel { -#ifdef PADDLE_WITH_MKLML -template -void VMulMKL(const T* x, const T* y, T* z, int n); - -template <> -void VMulMKL(const float* x, const float* y, float* z, int n) { - platform::dynload::vsMul(n, x, y, z); -} - -template <> -void VMulMKL(const double* x, const double* y, double* z, int n) { - platform::dynload::vdMul(n, x, y, z); -} - -template -void VAddMKL(const T* x, const T* y, T* z, int n); - -template <> -void VAddMKL(const float* x, const float* y, float* z, int n) { - platform::dynload::vsAdd(n, x, y, z); -} - -template <> -void VAddMKL(const double* x, const double* y, double* z, int n) { - platform::dynload::vdAdd(n, x, y, z); -} - -template -void VScalMKL(const T* a, const T* x, T* y, int n); - -template <> -void VScalMKL(const float* a, const float* x, float* y, int n) { - if (x == y) { - platform::dynload::cblas_sscal(n, *a, y, 1); - } else { - refer::VScal(a, x, y, n); - } -} - -template <> -void VScalMKL(const double* a, const double* x, double* y, int n) { - if (x == y) { - platform::dynload::cblas_dscal(n, *a, y, 1); - } else { - refer::VScal(a, x, y, n); - } -} - -#endif - /* VMUL JitKernel */ template class VMulKernelImpl : public VMulKernel {