diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index f853497804c3be07515f7f596911e5dd417f1081..6b3eecfbd11471b5d95dcb10c91acc536f78cb85 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -24,21 +24,30 @@ namespace gen { using namespace platform::jit; // NOLINT -bool VVVJitCode::init(int d) { +bool VXXJitCode::init(int d, int scalar_index) { // It's not necessary to use avx512 since it would slow down the frequency // and this kernel is not compute bound. - return MayIUse(avx); + return MayIUse(avx) && scalar_index >= 0 && scalar_index <= 2; } -void VVVJitCode::generate() { +void VXXJitCode::generate() { // do not need push stack, and do not need save avx512reg if do not use avx512 int offset = 0; if (with_relu_) { vxorps(ymm_zero, ymm_zero, ymm_zero); } + if (scalar_index_ == 1) { + vbroadcastss(ymm_src1, ptr[param1]); + } else if (scalar_index_ == 2) { + vbroadcastss(ymm_src2, ptr[param2]); + } for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { - vmovups(ymm_src1, ptr[param1 + offset]); - vmovups(ymm_src2, ptr[param2 + offset]); + if (scalar_index_ != 1) { + vmovups(ymm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(ymm_src2, ptr[param2 + offset]); + } if (type_ == operand_type::mul) { vmulps(ymm_dst, ymm_src1, ymm_src2); } else if (type_ == operand_type::add) { @@ -52,8 +61,12 @@ void VVVJitCode::generate() { } int rest = num_ % AVX_FLOAT_BLOCK; if (rest >= 4) { - vmovups(xmm_src1, ptr[param1 + offset]); - vmovups(xmm_src2, ptr[param2 + offset]); + if (scalar_index_ != 1) { + vmovups(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(xmm_src2, ptr[param2 + offset]); + } if (type_ == operand_type::mul) { vmulps(xmm_dst, xmm_src1, xmm_src2); } else if (type_ == operand_type::add) { @@ -67,8 +80,12 @@ void VVVJitCode::generate() { rest -= 4; } if (rest >= 2) { - vmovq(xmm_src1, ptr[param1 + offset]); - vmovq(xmm_src2, ptr[param2 + offset]); + if (scalar_index_ != 1) { + vmovups(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(xmm_src2, ptr[param2 + offset]); + } if (type_ == operand_type::mul) { vmulps(xmm_dst, xmm_src1, xmm_src2); } else if (type_ == operand_type::add) { @@ -82,8 +99,12 @@ void VVVJitCode::generate() { rest -= 2; } if (rest > 0) { - vmovss(xmm_src1, ptr[param1 + offset]); - vmovss(xmm_src2, ptr[param2 + offset]); + if (scalar_index_ != 1) { + vmovups(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(xmm_src2, ptr[param2 + offset]); + } if (type_ == operand_type::mul) { vmulss(xmm_dst, xmm_src1, xmm_src2); } else if (type_ == operand_type::add) { @@ -97,40 +118,6 @@ void VVVJitCode::generate() { ret(); } -bool VScalJitCode::init(int d) { return MayIUse(avx); } - -void VScalJitCode::generate() { - int offset = 0; - vbroadcastss(ymm_src1, ptr[param1]); - for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { - vmovups(ymm_src2, ptr[param2 + offset]); - vmulps(ymm_dst, ymm_src1, ymm_src2); - vmovups(ptr[param3 + offset], ymm_dst); - offset += sizeof(float) * AVX_FLOAT_BLOCK; - } - int rest = num_ % AVX_FLOAT_BLOCK; - if (rest >= 4) { - vmovups(xmm_src2, ptr[param2 + offset]); - vmulps(xmm_dst, xmm_src1, xmm_src2); - vmovups(ptr[param3 + offset], xmm_dst); - offset += sizeof(float) * 4; - rest -= 4; - } - if (rest >= 2) { - vmovq(xmm_src2, ptr[param2 + offset]); - vmulps(xmm_dst, xmm_src1, xmm_src2); - vmovq(ptr[param3 + offset], xmm_dst); - offset += sizeof(float) * 2; - rest -= 2; - } - if (rest > 0) { - vmovss(xmm_src2, ptr[param2 + offset]); - vmulss(xmm_dst, xmm_src1, xmm_src2); - vmovss(ptr[param3 + offset], xmm_dst); - } - ret(); -} - } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index d87831c5798f92b71fd6c57f7e9c89e8e9d2fbfc..939d9897e6cc6c8958caf54e0404ce524f01f8d5 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -31,11 +31,11 @@ using Label = Xbyak::Label; typedef enum { mul = 0, add } operand_type; -// function: vec = Operand(vec, vec) (maybe with relu) -class VVVJitCode : public JitCode { +// function: vec = Operand(vec(scalar), vec(scalar)) (maybe with relu) +class VXXJitCode : public JitCode { public: const char* name() const override { - std::string base = "VVVJitCode"; + std::string base = "VXXJitCode"; if (type_ == operand_type::mul) { base += "_Mul"; } else if (type_ == operand_type::add) { @@ -44,18 +44,21 @@ class VVVJitCode : public JitCode { base += (with_relu_ ? "_Relu" : ""); return base.c_str(); } - explicit VVVJitCode(int d, operand_type type, bool with_relu, - size_t code_size = 256 * 1024, void* code_ptr = nullptr) + explicit VXXJitCode(int d, operand_type type, int scalar_index, + bool with_relu, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) : JitCode(code_size, code_ptr), num_(d), type_(type), + scalar_index_(scalar_index), with_relu_(with_relu) {} - static bool init(int d); + static bool init(int d, int scalar_index = 0); void generate() override; private: int num_; operand_type type_; + int scalar_index_; bool with_relu_; reg64_t param1{abi_param1}; reg64_t param2{abi_param2}; @@ -63,39 +66,13 @@ class VVVJitCode : public JitCode { xmm_t xmm_src1 = xmm_t(0); xmm_t xmm_src2 = xmm_t(1); - xmm_t xmm_dst = xmm_t(1); - xmm_t xmm_zero = xmm_t(2); + xmm_t xmm_dst = xmm_t(2); + xmm_t xmm_zero = xmm_t(3); ymm_t ymm_src1 = ymm_t(0); ymm_t ymm_src2 = ymm_t(1); - ymm_t ymm_dst = ymm_t(1); - ymm_t ymm_zero = ymm_t(2); -}; - -class VScalJitCode : public JitCode { - public: - DECLARE_JIT_CODE(VScalJitCode); - explicit VScalJitCode(int d, size_t code_size = 256 * 1024, - void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), num_(d) {} - static bool init(int d); - void generate() override; - - private: - int num_; - reg64_t param1{abi_param1}; - reg64_t param2{abi_param2}; - reg64_t param3{abi_param3}; - - xmm_t xmm_src1 = xmm_t(0); - xmm_t xmm_src2 = xmm_t(1); - xmm_t xmm_dst = xmm_t(1); - xmm_t xmm_zero = xmm_t(2); - - ymm_t ymm_src1 = ymm_t(0); - ymm_t ymm_src2 = ymm_t(1); - ymm_t ymm_dst = ymm_t(1); - ymm_t ymm_zero = ymm_t(2); + ymm_t ymm_dst = ymm_t(2); + ymm_t ymm_zero = ymm_t(3); }; } // namespace gen diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index a9537ab0969febcf2057f19af920004cfeff467f..ead4385cdb5def030c8f5f3528c97b1146879667 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -131,7 +131,7 @@ class VMulKernelImpl : public VMulKernel { if (useJIT(d)) { // roughly estimate the size of code size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::mul, false, + jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); @@ -150,14 +150,14 @@ class VMulKernelImpl : public VMulKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VMulKernelImpl::useJIT(int d) { - return gen::VVVJitCode::init(d); + return gen::VXXJitCode::init(d); } #endif @@ -182,7 +182,7 @@ class VAddKernelImpl : public VAddKernel { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, false, + jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); @@ -200,14 +200,14 @@ class VAddKernelImpl : public VAddKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VAddKernelImpl::useJIT(int d) { - return gen::VVVJitCode::init(d); + return gen::VXXJitCode::init(d); } #endif @@ -232,7 +232,7 @@ class VAddReluKernelImpl : public VAddReluKernel { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, true, + jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true, sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); @@ -244,14 +244,14 @@ class VAddReluKernelImpl : public VAddReluKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VAddReluKernelImpl::useJIT(int d) { - return gen::VVVJitCode::init(d); + return gen::VXXJitCode::init(d); } #endif @@ -264,7 +264,8 @@ class VScalKernelImpl : public VScalKernel { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VScalJitCode(d, sz > 4096 ? sz : 4096)); + jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false, + sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); return; @@ -281,14 +282,14 @@ class VScalKernelImpl : public VScalKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VScalKernelImpl::useJIT(int d) { - return gen::VScalJitCode::init(d); + return gen::VXXJitCode::init(d, 1); } #endif