diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index 0433cfc23eb11b7589c7041addc1d43aea2686c1..56269f051861a87c45dfce7e556edb81be0ea684 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -118,40 +118,6 @@ void VXXJitCode::generate() { ret(); } -bool ReluJitCode::init(int d) { return MayIUse(avx); } - -void ReluJitCode::generate() { - int offset = 0; - vxorps(ymm_zero, ymm_zero, ymm_zero); - for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { - vmovups(ymm_src, ptr[param1 + offset]); - vmaxps(ymm_dst, ymm_zero, ymm_src); - vmovups(ptr[param2 + offset], ymm_dst); - offset += sizeof(float) * AVX_FLOAT_BLOCK; - } - int rest = num_ % AVX_FLOAT_BLOCK; - if (rest >= 4) { - vmovups(xmm_src, ptr[param1 + offset]); - vmaxps(xmm_dst, xmm_zero, xmm_src); - vmovups(ptr[param2 + offset], xmm_dst); - offset += sizeof(float) * 4; - rest -= 4; - } - if (rest >= 2) { - vmovups(xmm_src, ptr[param1 + offset]); - vmaxps(xmm_dst, xmm_zero, xmm_src); - vmovq(ptr[param2 + offset], xmm_dst); - offset += sizeof(float) * 2; - rest -= 2; - } - if (rest > 0) { - vmovups(xmm_src, ptr[param1 + offset]); - vmaxps(xmm_dst, xmm_zero, xmm_src); - vmovss(ptr[param2 + offset], xmm_dst); - } - ret(); -} - #define ALIGN32 __attribute__((aligned(32))) #define EXP_HIG 88.3762626647949f #define EXP_LOW -88.3762626647949f @@ -207,18 +173,28 @@ static const float exp_float_consts[] ALIGN32 = { static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)}; static int g_tmp_mem[16] ALIGN32 = {0}; -bool VExpJitCode::init(int d) { - return MayIUse(avx) && d == 8; // only 8 yet +bool VActJitCode::init(int d, operand_type type) { + bool ok = MayIUse(avx); + if (type == operand_type::relu) { + return ok; + } else { + return ok && d == 8; // only 8 yet + } } -void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) { - // use reg rax and ymm 2~5 - reg64_t reg_ptr_global = rax; - ymm_t ymm_fx = ymm_t(2); - ymm_t ymm_fy = ymm_t(3); - ymm_t ymm_mask = ymm_t(4); - ymm_t ymm_tmp = ymm_t(5); +void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) { + vmaxps(ymm_dst, ymm_zero, ymm_src); +} + +void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, + int fy_idx, int mask_idx, int tmp_idx) { assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore + // check all idx can not equal + ymm_t ymm_fx = ymm_t(fx_idx); + ymm_t ymm_fy = ymm_t(fy_idx); + ymm_t ymm_mask = ymm_t(mask_idx); + ymm_t ymm_tmp = ymm_t(tmp_idx); + reg64_t reg_ptr_global = rax; push(reg_ptr_global); mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]); @@ -291,22 +267,11 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) { pop(reg_ptr_global); } -void VExpJitCode::generate() { - int offset = 0; - vmovups(ymm_src, ptr[param1 + offset]); - exp_ymm(ymm_src, ymm_dst); - vmovups(ptr[param2 + offset], ymm_dst); - ret(); -} - -bool VSigmoidJitCode::init(int d) { - return MayIUse(avx) && d == 8; // only 8 yet -} - -void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) { - // use ymm2 +void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, + int fy_idx, int mask_idx, int tmp_idx) { + // y = 1 / (1 + e^-x) + ymm_t ymm_tmp = ymm_t(tmp_idx); reg64_t reg_ptr_global = rax; - ymm_t ymm_tmp = ymm_t(2); push(reg_ptr_global); mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]); @@ -315,38 +280,26 @@ void VSigmoidJitCode::sigmoid_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) { vmaxps(ymm_src, ymm_src, ymm_tmp); vxorps(ymm_tmp, ymm_tmp, ymm_tmp); vsubps(ymm_src, ymm_tmp, ymm_src); - exp_ymm(ymm_src, ymm_dst); + exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); vaddps(ymm_dst, ymm_dst, ymm_tmp); vdivps(ymm_dst, ymm_tmp, ymm_dst); pop(reg_ptr_global); } -void VSigmoidJitCode::generate() { - int offset = 0; - vmovups(ymm_src, ptr[param1 + offset]); - sigmoid_ymm(ymm_src, ymm_dst); - vmovups(ptr[param2 + offset], ymm_dst); - ret(); -} - -bool VTanhJitCode::init(int d) { - return MayIUse(avx) && d == 8; // only 8 yet -} - -void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) { +void VActJitCode::tanh_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, + int fy_idx, int mask_idx, int tmp_idx) { // y = 2 / (1 + e^(-2x)) - 1 - // use ymm2, ymm3 + ymm_t ymm_tmp = ymm_t(tmp_idx); + ymm_t ymm_zero = ymm_t(mask_idx); reg64_t reg_ptr_global = rax; - ymm_t ymm_tmp = ymm_t(2); - ymm_t ymm_zero = ymm_t(3); push(reg_ptr_global); mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); vxorps(ymm_zero, ymm_zero, ymm_zero); vsubps(ymm_tmp, ymm_zero, ymm_tmp); vmulps(ymm_src, ymm_src, ymm_tmp); - exp_ymm(ymm_src, ymm_dst); + exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); vaddps(ymm_dst, ymm_dst, ymm_tmp); vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); @@ -356,11 +309,61 @@ void VTanhJitCode::vtanh_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) { pop(reg_ptr_global); } -void VTanhJitCode::generate() { +void VActJitCode::generate() { + xmm_t xmm_zero = xmm_t(2); + ymm_t ymm_zero = ymm_t(2); + if (type_ == operand_type::relu) { + vxorps(ymm_zero, ymm_zero, ymm_zero); + } int offset = 0; - vmovups(ymm_src, ptr[param1 + offset]); - vtanh_ymm(ymm_src, ymm_dst); - vmovups(ptr[param2 + offset], ymm_dst); + for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { + vmovups(ymm_src, ptr[param1 + offset]); + switch (type_) { + case operand_type::relu: + relu_ymm(ymm_dst, ymm_src, ymm_zero); + break; + case operand_type::exp: + exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5); + break; + case operand_type::sigmoid: + sigmoid_ymm(ymm_dst, ymm_src, 2, 3, 4, 5); + break; + case operand_type::tanh: + tanh_ymm(ymm_dst, ymm_src, 2, 3, 4, 5); + break; + case operand_type::identity: + break; + default: + break; + } + vmovups(ptr[param2 + offset], ymm_dst); + offset += sizeof(float) * AVX_FLOAT_BLOCK; + } + if (type_ != operand_type::relu) { + // TODO(TJ): remove me + ret(); + return; + } + int rest = num_ % AVX_FLOAT_BLOCK; + if (rest >= 4) { + vmovups(xmm_src, ptr[param1 + offset]); + vmaxps(xmm_dst, xmm_zero, xmm_src); + vmovups(ptr[param2 + offset], xmm_dst); + offset += sizeof(float) * 4; + rest -= 4; + } + if (rest >= 2) { + vmovups(xmm_src, ptr[param1 + offset]); + vmaxps(xmm_dst, xmm_zero, xmm_src); + vmovq(ptr[param2 + offset], xmm_dst); + offset += sizeof(float) * 2; + rest -= 2; + } + if (rest > 0) { + vmovups(xmm_src, ptr[param1 + offset]); + vmaxps(xmm_dst, xmm_zero, xmm_src); + vmovss(ptr[param2 + offset], xmm_dst); + } ret(); } diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 685ab8750ed84d610685e9b4490881e3c864c795..71205b211b7f571f8081640ef60222de051ff49d 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm; using zmm_t = const Xbyak::Zmm; using Label = Xbyak::Label; -typedef enum { mul = 0, add } operand_type; +typedef enum { + mul = 0, + add, + sub, + relu, + exp, + sigmoid, + tanh, + identity +} operand_type; // function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) class VXXJitCode : public JitCode { @@ -85,87 +94,65 @@ class VXXJitCode : public JitCode { ymm_t ymm_zero = ymm_t(3); }; -class ReluJitCode : public JitCode { +class VActJitCode : public JitCode { public: - DECLARE_JIT_CODE(ReluJitCode); - explicit ReluJitCode(int d, size_t code_size = 256 * 1024, - void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), num_(d) {} - static bool init(int d); - void generate() override; - - private: - int num_; - reg64_t param1{abi_param1}; - reg64_t param2{abi_param2}; - - xmm_t xmm_zero = xmm_t(0); - xmm_t xmm_src = xmm_t(1); - xmm_t xmm_dst = xmm_t(1); - - ymm_t ymm_zero = ymm_t(0); - ymm_t ymm_src = ymm_t(1); - ymm_t ymm_dst = ymm_t(1); -}; + const char* name() const override { + std::string base = "VActJitCode"; + switch (type_) { + case operand_type::relu: + base += "_Relu"; + break; + case operand_type::exp: + base += "_Exp"; + break; + case operand_type::sigmoid: + base += "_Sigmoid"; + break; + case operand_type::tanh: + base += "_Tanh"; + break; + case operand_type::identity: + base += "_Identity"; + break; + default: + break; + } + return base.c_str(); + } -class VExpJitCode : public JitCode { - public: - DECLARE_JIT_CODE(VExpJitCode); - explicit VExpJitCode(int d, size_t code_size = 256 * 1024, + explicit VActJitCode(int d, operand_type type, size_t code_size = 256 * 1024, void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), num_(d) {} - static bool init(int d); + : JitCode(code_size, code_ptr), num_(d), type_(type) {} + static bool init(int d, operand_type type); void generate() override; protected: - // compute exp with ymm - void exp_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst); + // compute relu with ymm + void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, + const Xbyak::Ymm& zero); - private: - int num_; - reg64_t param1{abi_param1}; - reg64_t param2{abi_param2}; - ymm_t ymm_src = ymm_t(0); - ymm_t ymm_dst = ymm_t(1); -}; - -class VSigmoidJitCode : public VExpJitCode { - public: - DECLARE_JIT_CODE(VSigmoidJitCode); - explicit VSigmoidJitCode(int d, size_t code_size = 256 * 1024, - void* code_ptr = nullptr) - : VExpJitCode(d, code_size, code_ptr), num_(d) {} - static bool init(int d); - void generate() override; + // compute exp with ymm + void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, + int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); // compute sigmoid with ymm - void sigmoid_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst); + void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, + int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); - private: - int num_; - reg64_t param1{abi_param1}; - reg64_t param2{abi_param2}; - ymm_t ymm_src = ymm_t(0); - ymm_t ymm_dst = ymm_t(1); -}; + // compute tanh with ymm + void tanh_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, + int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); -class VTanhJitCode : public VExpJitCode { - public: - DECLARE_JIT_CODE(VTanhJitCode); - explicit VTanhJitCode(int d, size_t code_size = 256 * 1024, - void* code_ptr = nullptr) - : VExpJitCode(d, code_size, code_ptr), num_(d) {} - static bool init(int d); - void generate() override; - - // compute sigmoid with ymm - void vtanh_ymm(const Xbyak::Ymm& src, const Xbyak::Ymm& dst); - - private: + protected: int num_; + operand_type type_; reg64_t param1{abi_param1}; reg64_t param2{abi_param2}; + + xmm_t xmm_src = xmm_t(0); ymm_t ymm_src = ymm_t(0); + + xmm_t xmm_dst = xmm_t(1); ymm_t ymm_dst = ymm_t(1); }; diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index d96d5f15ea7cd984d9f84f0943094b6c9abfc045..05af7432c5787db28f919858a4319f9e989f5038 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -352,7 +352,8 @@ class VReluKernelImpl : public VReluKernel { size_t sz = 96 /* init size */ + d / AVX_FLOAT_BLOCK * 4 /* instructions */ * 8 /* average bytes for each instruction */; - jitcode_.reset(new gen::ReluJitCode(d, sz > 4096 ? sz : 4096)); + jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::relu, + sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); return; } @@ -366,14 +367,14 @@ class VReluKernelImpl : public VReluKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VReluKernelImpl::useJIT(int d) { - return gen::ReluJitCode::init(d); + return gen::VActJitCode::init(d, gen::operand_type::relu); } #endif diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index f0431be5816bfc156c0b200d1f3461065927b2cc..28059ad270fd13ff6008464dd5914d2f53cbb223 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -116,7 +116,8 @@ class VExpKernelImpl : public VExpKernel { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change - jitcode_.reset(new gen::VExpJitCode(d, sz > 4096 ? sz : 4096)); + jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::exp, + sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); return; } @@ -135,14 +136,14 @@ class VExpKernelImpl : public VExpKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VExpKernelImpl::useJIT(int d) { - return gen::VExpJitCode::init(d); + return gen::VActJitCode::init(d, gen::operand_type::exp); } #endif @@ -169,7 +170,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change - jitcode_.reset(new gen::VSigmoidJitCode(d, sz > 4096 ? sz : 4096)); + jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::sigmoid, + sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); return; } @@ -190,14 +192,14 @@ class VSigmoidKernelImpl : public VSigmoidKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VSigmoidKernelImpl::useJIT(int d) { - return gen::VSigmoidJitCode::init(d); + return gen::VActJitCode::init(d, gen::operand_type::sigmoid); } #endif @@ -223,7 +225,8 @@ class VTanhKernelImpl : public VTanhKernel { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change - jitcode_.reset(new gen::VTanhJitCode(d, sz > 4096 ? sz : 4096)); + jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::tanh, + sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); return; } @@ -244,14 +247,14 @@ class VTanhKernelImpl : public VTanhKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VTanhKernelImpl::useJIT(int d) { - return gen::VTanhJitCode::init(d); + return gen::VActJitCode::init(d, gen::operand_type::tanh); } #endif