diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index c1d4cc1b889700079091f859dbbdb46f626dbb0f..868a7a706471717ce0c8f268d5eaa6dc4789588c 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -75,7 +75,12 @@ if(WITH_GPU) endif() cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) -cc_library(jit_kernel - SRCS jit_kernel.cc jit_gen.cc jit_code.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc - DEPS cpu_info cblas gflags enforce) + +set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc) +set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce) +if(WITH_XBYAK) + list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc) + list(APPEND JIT_KERNEL_DEPS xbyak) +endif() +cc_library(jit_kernel SRCS ${JIT_KERNEL_SRCS} DEPS ${JIT_KERNEL_DEPS}) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 7d38d511723ab3e6edfd4aa853bd7f2521ec98e2..8a988f8f482e4a4963f70c39bccd89387c1e0059 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -14,10 +14,13 @@ limitations under the License. */ #include "paddle/fluid/operators/math/jit_kernel.h" #include -#include "paddle/fluid/operators/math/jit_code.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/platform/enforce.h" +#ifdef PADDLE_WITH_XBYAK +#include "paddle/fluid/operators/math/jit_code.h" +#endif + #ifdef PADDLE_WITH_MKLML #include "paddle/fluid/platform/dynload/mklml.h" #endif @@ -64,6 +67,7 @@ class VMulKernelImpl : public VMulKernel { static inline bool useMKL(int d) { return false; } explicit VMulKernelImpl(int d) : VMulKernel() { +#ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { // roughly estimate the size of code size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; @@ -72,6 +76,7 @@ class VMulKernelImpl : public VMulKernel { jitcode_->getCode(); return; } +#endif #ifdef PADDLE_WITH_MKLML if (useMKL(d)) { this->Compute = VMulMKL; @@ -81,15 +86,21 @@ class VMulKernelImpl : public VMulKernel { this->Compute = VMulRefer; } +#ifdef PADDLE_WITH_XBYAK + private: std::unique_ptr jitcode_{nullptr}; +#endif }; +#ifdef PADDLE_WITH_XBYAK template <> bool VMulKernelImpl::useJIT(int d) { return gen::VMulJitCode::init(d); } +#endif +#ifdef PADDLE_WITH_MKLML template <> bool VMulKernelImpl::useMKL(int d) { return jit::MayIUse(jit::avx512f) && d > 512; @@ -99,6 +110,7 @@ template <> bool VMulKernelImpl::useMKL(int d) { return true; } +#endif REGISTER_JITKERNEL(vmul, VMulKernel);