diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index 35f0bdb9b31d472e8f2fbdc7b76a5aa4ac8175e3..a92e5d351e71a55bca2845ce275780950d096031 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -24,51 +24,14 @@ namespace gen { using namespace platform::jit; // NOLINT -bool VMulJitCode::init(int d) { +bool VVVJitCode::init(int d) { // It's not necessary to use avx512 since it would slow down the frequency // and this kernel is not compute bound. return MayIUse(avx); } -void VMulJitCode::generate() { +void VVVJitCode::generate() { // do not need push stack, and do not need save avx512reg if do not use avx512 - int offset = 0; - for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { - vmovups(ymm_src1, ptr[param1 + offset]); - vmovups(ymm_src2, ptr[param2 + offset]); - vmulps(ymm_dst, ymm_src1, ymm_src2); - vmovups(ptr[param3 + offset], ymm_dst); - offset += sizeof(float) * AVX_FLOAT_BLOCK; - } - int rest = num_ % AVX_FLOAT_BLOCK; - if (rest >= 4) { - vmovups(xmm_src1, ptr[param1 + offset]); - vmovups(xmm_src2, ptr[param2 + offset]); - vmulps(xmm_dst, xmm_src1, xmm_src2); - vmovups(ptr[param3 + offset], xmm_dst); - offset += sizeof(float) * 4; - rest -= 4; - } - if (rest >= 2) { - vmovq(xmm_src1, ptr[param1 + offset]); - vmovq(xmm_src2, ptr[param2 + offset]); - vmulps(xmm_dst, xmm_src1, xmm_src2); - vmovq(ptr[param3 + offset], xmm_dst); - offset += sizeof(float) * 2; - rest -= 2; - } - if (rest > 0) { - vmovss(xmm_src1, ptr[param1 + offset]); - vmovss(xmm_src2, ptr[param2 + offset]); - vmulss(xmm_dst, xmm_src1, xmm_src2); - vmovss(ptr[param3 + offset], xmm_dst); - } - ret(); -} - -bool VAddJitCode::init(int d) { return MayIUse(avx); } - -void VAddJitCode::generate() { int offset = 0; if (with_relu_) { vxorps(ymm_zero, ymm_zero, ymm_zero); @@ -76,7 +39,11 @@ void VAddJitCode::generate() { for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { vmovups(ymm_src1, ptr[param1 + offset]); vmovups(ymm_src2, ptr[param2 + offset]); - vaddps(ymm_dst, ymm_src1, ymm_src2); + if (type_ == operand_type::mul) { + vmulps(ymm_dst, ymm_src1, ymm_src2); + } else if (type_ == operand_type::add) { + vaddps(ymm_dst, ymm_src1, ymm_src2); + } if (with_relu_) { vmaxps(ymm_dst, ymm_zero, ymm_dst); } @@ -87,7 +54,11 @@ void VAddJitCode::generate() { if (rest >= 4) { vmovups(xmm_src1, ptr[param1 + offset]); vmovups(xmm_src2, ptr[param2 + offset]); - vaddps(xmm_dst, xmm_src1, xmm_src2); + if (type_ == operand_type::mul) { + vmulps(xmm_dst, xmm_src1, xmm_src2); + } else if (type_ == operand_type::add) { + vaddps(xmm_dst, xmm_src1, xmm_src2); + } if (with_relu_) { vmaxps(xmm_dst, xmm_zero, xmm_dst); } @@ -98,7 +69,11 @@ void VAddJitCode::generate() { if (rest >= 2) { vmovq(xmm_src1, ptr[param1 + offset]); vmovq(xmm_src2, ptr[param2 + offset]); - vaddps(xmm_dst, xmm_src1, xmm_src2); + if (type_ == operand_type::mul) { + vmulps(xmm_dst, xmm_src1, xmm_src2); + } else if (type_ == operand_type::add) { + vaddps(xmm_dst, xmm_src1, xmm_src2); + } if (with_relu_) { vmaxps(xmm_dst, xmm_zero, xmm_dst); } @@ -109,7 +84,11 @@ void VAddJitCode::generate() { if (rest > 0) { vmovss(xmm_src1, ptr[param1 + offset]); vmovss(xmm_src2, ptr[param2 + offset]); - vaddss(xmm_dst, xmm_src1, xmm_src2); + if (type_ == operand_type::mul) { + vmulss(xmm_dst, xmm_src1, xmm_src2); + } else if (type_ == operand_type::add) { + vaddss(xmm_dst, xmm_src1, xmm_src2); + } if (with_relu_) { vmaxps(xmm_dst, xmm_zero, xmm_dst); } diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 6bfed4b22d251b656145ba9d69142b862a82656a..73692ebc67c71f6190f2d18bd50071a28a35d4c9 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/operators/math/jit_gen.h" - namespace paddle { namespace operators { namespace math { @@ -29,41 +29,33 @@ using ymm_t = const Xbyak::Ymm; using zmm_t = const Xbyak::Zmm; using Label = Xbyak::Label; -class VMulJitCode : public JitCode { - public: - DECLARE_JIT_CODE(VMulJitCode); - explicit VMulJitCode(int d, size_t code_size = 256 * 1024, - void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), num_(d) {} - static bool init(int d); - void generate() override; - - private: - int num_; - reg64_t param1{abi_param1}; - reg64_t param2{abi_param2}; - reg64_t param3{abi_param3}; - - xmm_t xmm_src1 = xmm_t(0); - xmm_t xmm_src2 = xmm_t(1); - xmm_t xmm_dst = xmm_t(1); - - ymm_t ymm_src1 = ymm_t(0); - ymm_t ymm_src2 = ymm_t(1); - ymm_t ymm_dst = ymm_t(1); -}; +// function: vec = Operand(vec, vec) (maybe with relu) +typedef enum { mul = 0, add } operand_type; -class VAddJitCode : public JitCode { +class VVVJitCode : public JitCode { public: - DECLARE_JIT_CODE(VAddJitCode); - explicit VAddJitCode(int d, bool with_relu, size_t code_size = 256 * 1024, - void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), num_(d), with_relu_(with_relu) {} + const char* name() const override { + std::string base = "VVVJitCode"; + if (type_ == operand_type::mul) { + base += "_Mul"; + } else if (type_ == operand_type::add) { + base += "_Add"; + } + base += (with_relu_ ? "_relu" : ""); + return base.c_str(); + } + explicit VVVJitCode(int d, operand_type type, bool with_relu, + size_t code_size = 256 * 1024, void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), + num_(d), + type_(type), + with_relu_(with_relu) {} static bool init(int d); void generate() override; private: int num_; + operand_type type_; bool with_relu_; reg64_t param1{abi_param1}; reg64_t param2{abi_param2}; diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 27801f4c63ac5809a502912844460c45b55a6b62..9acb349f663cca1d38fa7840c3308dfa17a43ab1 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -102,7 +102,8 @@ class VMulKernelImpl : public VMulKernel { if (useJIT(d)) { // roughly estimate the size of code size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VMulJitCode(d, sz > 4096 ? sz : 4096)); + jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::mul, false, + sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); return; @@ -120,14 +121,14 @@ class VMulKernelImpl : public VMulKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VMulKernelImpl::useJIT(int d) { - return gen::VMulJitCode::init(d); + return gen::VVVJitCode::init(d); } #endif @@ -149,13 +150,16 @@ class VAddKernelImpl : public VAddKernel { public: DECLARE_STATIC_FUNC; explicit VAddKernelImpl(int d) : VAddKernel() { +#ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VAddJitCode(d, false, sz > 4096 ? sz : 4096)); + jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, false, + sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); return; } +#endif #ifdef PADDLE_WITH_MKLML if (useMKL(d)) { this->Compute = VAddMKL; @@ -166,14 +170,17 @@ class VAddKernelImpl : public VAddKernel { } private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; }; +#ifdef PADDLE_WITH_XBYAK template <> bool VAddKernelImpl::useJIT(int d) { - return gen::VAddJitCode::init(d); + return gen::VVVJitCode::init(d); } +#endif +#ifdef PADDLE_WITH_MKLML template <> bool VAddKernelImpl::useMKL(int d) { return d > 512; @@ -183,6 +190,7 @@ template <> bool VAddKernelImpl::useMKL(int d) { return true; } +#endif /* VAddRelu JitKernel */ template @@ -190,24 +198,29 @@ class VAddReluKernelImpl : public VAddReluKernel { public: DECLARE_STATIC_FUNC; explicit VAddReluKernelImpl(int d) : VAddReluKernel() { +#ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VAddJitCode(d, true, sz > 4096 ? sz : 4096)); + jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, true, + sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); return; } +#endif this->Compute = VAddReluRefer; } private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; }; +#ifdef PADDLE_WITH_XBYAK template <> bool VAddReluKernelImpl::useJIT(int d) { - return gen::VAddJitCode::init(d); + return gen::VVVJitCode::init(d); } +#endif #undef DECLARE_STATIC_FUNC