未验证 提交 e6cfdf6c 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #14274 from tensor-tang/fix/jit

fix jit on mac
...@@ -75,7 +75,12 @@ if(WITH_GPU) ...@@ -75,7 +75,12 @@ if(WITH_GPU)
endif() endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel
SRCS jit_kernel.cc jit_gen.cc jit_code.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc)
DEPS cpu_info cblas gflags enforce) set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce)
if(WITH_XBYAK)
list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc)
list(APPEND JIT_KERNEL_DEPS xbyak)
endif()
cc_library(jit_kernel SRCS ${JIT_KERNEL_SRCS} DEPS ${JIT_KERNEL_DEPS})
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
...@@ -14,10 +14,13 @@ limitations under the License. */ ...@@ -14,10 +14,13 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/jit_kernel.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_code.h"
#include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h" #include "paddle/fluid/platform/dynload/mklml.h"
#endif #endif
...@@ -64,6 +67,7 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -64,6 +67,7 @@ class VMulKernelImpl : public VMulKernel<T> {
static inline bool useMKL(int d) { return false; } static inline bool useMKL(int d) { return false; }
explicit VMulKernelImpl(int d) : VMulKernel<T>() { explicit VMulKernelImpl(int d) : VMulKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) { if (useJIT(d)) {
// roughly estimate the size of code // roughly estimate the size of code
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
...@@ -72,6 +76,7 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -72,6 +76,7 @@ class VMulKernelImpl : public VMulKernel<T> {
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>(); jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return; return;
} }
#endif
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
if (useMKL(d)) { if (useMKL(d)) {
this->Compute = VMulMKL<T>; this->Compute = VMulMKL<T>;
...@@ -81,15 +86,21 @@ class VMulKernelImpl : public VMulKernel<T> { ...@@ -81,15 +86,21 @@ class VMulKernelImpl : public VMulKernel<T> {
this->Compute = VMulRefer<T>; this->Compute = VMulRefer<T>;
} }
#ifdef PADDLE_WITH_XBYAK
private: private:
std::unique_ptr<gen::VMulJitCode> jitcode_{nullptr}; std::unique_ptr<gen::VMulJitCode> jitcode_{nullptr};
#endif
}; };
#ifdef PADDLE_WITH_XBYAK
template <> template <>
bool VMulKernelImpl<float>::useJIT(int d) { bool VMulKernelImpl<float>::useJIT(int d) {
return gen::VMulJitCode::init(d); return gen::VMulJitCode::init(d);
} }
#endif
#ifdef PADDLE_WITH_MKLML
template <> template <>
bool VMulKernelImpl<float>::useMKL(int d) { bool VMulKernelImpl<float>::useMKL(int d) {
return jit::MayIUse(jit::avx512f) && d > 512; return jit::MayIUse(jit::avx512f) && d > 512;
...@@ -99,6 +110,7 @@ template <> ...@@ -99,6 +110,7 @@ template <>
bool VMulKernelImpl<double>::useMKL(int d) { bool VMulKernelImpl<double>::useMKL(int d) {
return true; return true;
} }
#endif
REGISTER_JITKERNEL(vmul, VMulKernel); REGISTER_JITKERNEL(vmul, VMulKernel);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册