From ba3eaed7a7426a10f4a394071852c6f5d6ab8e1e Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 16 Nov 2018 09:13:34 +0000 Subject: [PATCH] exp support all size --- paddle/fluid/operators/math/jit_code.cc | 114 ++++++++++++++++-- paddle/fluid/operators/math/jit_code.h | 8 +- .../fluid/operators/math/jit_kernel_test.cc | 5 +- 3 files changed, 113 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index e3b600d4427..9efd4e81748 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -81,10 +81,10 @@ void VXXJitCode::generate() { } if (rest >= 2) { if (scalar_index_ != 1) { - vmovups(xmm_src1, ptr[param1 + offset]); + vmovq(xmm_src1, ptr[param1 + offset]); } if (scalar_index_ != 2) { - vmovups(xmm_src2, ptr[param2 + offset]); + vmovq(xmm_src2, ptr[param2 + offset]); } if (type_ == operand_type::mul) { vmulps(xmm_dst, xmm_src1, xmm_src2); @@ -100,10 +100,10 @@ void VXXJitCode::generate() { } if (rest > 0) { if (scalar_index_ != 1) { - vmovups(xmm_src1, ptr[param1 + offset]); + vmovss(xmm_src1, ptr[param1 + offset]); } if (scalar_index_ != 2) { - vmovups(xmm_src2, ptr[param2 + offset]); + vmovss(xmm_src2, ptr[param2 + offset]); } if (type_ == operand_type::mul) { vmulss(xmm_dst, xmm_src1, xmm_src2); @@ -179,7 +179,7 @@ bool VActJitCode::init(int d, operand_type type) { return ok; } else if (type == operand_type::exp) { // exp is slower than mkl when d >= 256 - return ok && d % 8 == 0 && d < 256; + return ok; //&& d % 4 == 0 && d < 256; } else { // TODO(TJ): support more return ok && d % 8 == 0; @@ -190,6 +190,10 @@ void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) { vmaxps(ymm_dst, ymm_zero, ymm_src); } +void VActJitCode::relu_xmm(xmm_t& xmm_dst, xmm_t& xmm_src, xmm_t& xmm_zero) { + vmaxps(xmm_dst, xmm_zero, xmm_src); +} + void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, int fy_idx, int mask_idx, int tmp_idx) { assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore @@ -271,6 +275,65 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, pop(reg_ptr_global); } +void VActJitCode::exp_xmm(xmm_t& ymm_dst, xmm_t& ymm_src, int fx_idx, + int fy_idx, int mask_idx, int tmp_idx) { + assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore + // check all idx can not equal + xmm_t ymm_fx = xmm_t(fx_idx); + xmm_t ymm_fy = xmm_t(fy_idx); + xmm_t ymm_mask = xmm_t(mask_idx); + xmm_t ymm_tmp = xmm_t(tmp_idx); + reg64_t reg_ptr_global = rax; + push(reg_ptr_global); + mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]); + vminps(ymm_src, ymm_src, ymm_tmp); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]); + vmaxps(ymm_src, ymm_src, ymm_tmp); + // express exp(x) as exp(g + n*log(2)) + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]); + vmulps(ymm_fx, ymm_src, ymm_tmp); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]); + vaddps(ymm_fx, ymm_fx, ymm_tmp); + vroundps(ymm_fy, ymm_fx, 0x01); + // if greater, substract 1 + vcmpgtps(ymm_mask, ymm_fy, ymm_fx); + vmovaps(ymm_tmp, ptr[reg_ptr_global]); + vandps(ymm_mask, ymm_mask, ymm_tmp); + vsubps(ymm_fx, ymm_fy, ymm_mask); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]); + vmulps(ymm_fy, ymm_fx, ymm_tmp); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]); + xmm_t ymm_z = xmm_t(ymm_mask.getIdx()); + vmulps(ymm_z, ymm_fx, ymm_tmp); + vsubps(ymm_src, ymm_src, ymm_fy); + vsubps(ymm_src, ymm_src, ymm_z); + vmulps(ymm_z, ymm_src, ymm_src); + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]); + vmulps(ymm_dst, ymm_src, ymm_tmp); + for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5; + i += (YMM_FLOAT_BLOCK * sizeof(float))) { + vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4 + vaddps(ymm_dst, ymm_dst, ymm_tmp); + vmulps(ymm_dst, ymm_dst, ymm_src); + } + vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]); + vaddps(ymm_dst, ymm_dst, ymm_tmp); + vmulps(ymm_dst, ymm_dst, ymm_z); + vaddps(ymm_dst, ymm_dst, ymm_src); + vmovaps(ymm_tmp, ptr[reg_ptr_global]); + vaddps(ymm_dst, ymm_dst, ymm_tmp); + // build 2^n + xmm_t ymm_int = ymm_fx; + vcvttps2dq(ymm_int, ymm_fx); + mov(reg_ptr_global, reinterpret_cast(exp_int_0x7f)); + vmovdqa(ymm_tmp, ptr[reg_ptr_global]); + vpaddd(ymm_int, ymm_int, ymm_tmp); + vpslld(ymm_int, ymm_int, 23); + vmulps(ymm_dst, ymm_dst, ymm_int); + pop(reg_ptr_global); +} + void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, int fy_idx, int mask_idx, int tmp_idx) { // y = 1 / (1 + e^-x) @@ -343,7 +406,7 @@ void VActJitCode::generate() { vmovups(ptr[param2 + offset], ymm_dst); offset += sizeof(float) * YMM_FLOAT_BLOCK; } - if (type_ != operand_type::relu) { + if (type_ != operand_type::relu && type_ != operand_type::exp) { // TODO(TJ): remove me ret(); return; @@ -351,21 +414,50 @@ void VActJitCode::generate() { int rest = num_ % YMM_FLOAT_BLOCK; if (rest >= 4) { vmovups(xmm_src, ptr[param1 + offset]); - vmaxps(xmm_dst, xmm_zero, xmm_src); + switch (type_) { + case operand_type::relu: + relu_xmm(xmm_dst, xmm_src, xmm_zero); + break; + case operand_type::exp: + exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5); + break; + default: + break; + } vmovups(ptr[param2 + offset], xmm_dst); offset += sizeof(float) * 4; rest -= 4; } if (rest >= 2) { - vmovups(xmm_src, ptr[param1 + offset]); - vmaxps(xmm_dst, xmm_zero, xmm_src); + vmovq(xmm_src, ptr[param1 + offset]); + switch (type_) { + case operand_type::relu: + relu_xmm(xmm_dst, xmm_src, xmm_zero); + break; + case operand_type::exp: + exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5); + break; + default: + break; + } vmovq(ptr[param2 + offset], xmm_dst); offset += sizeof(float) * 2; rest -= 2; } if (rest > 0) { - vmovups(xmm_src, ptr[param1 + offset]); - vmaxps(xmm_dst, xmm_zero, xmm_src); + // vmovups(); + vmovss(xmm_src, ptr[param1 + offset]); + + switch (type_) { + case operand_type::relu: + relu_xmm(xmm_dst, xmm_src, xmm_zero); + break; + case operand_type::exp: + exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5); + break; + default: + break; + } vmovss(ptr[param2 + offset], xmm_dst); } ret(); diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 71205b211b7..1467978f26c 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -127,13 +127,17 @@ class VActJitCode : public JitCode { void generate() override; protected: - // compute relu with ymm + // compute relu with ymm, xmm void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, const Xbyak::Ymm& zero); + void relu_xmm(const Xbyak::Xmm& dst, const Xbyak::Xmm& src, + const Xbyak::Xmm& zero); - // compute exp with ymm + // compute exp with ymm, xmm void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); + void exp_xmm(const Xbyak::Xmm& dst, const Xbyak::Xmm& src, int fx_idx = 2, + int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); // compute sigmoid with ymm void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 5a6f87fe1f7..178298bf567 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -33,6 +33,9 @@ limitations under the License. */ constexpr int repeat = 20000; +// TODO(TJ): benchmark and test should be seperated, +// benchmark should verify more sizes + inline double GetCurrentUS() { struct timeval time; gettimeofday(&time, NULL); @@ -156,7 +159,7 @@ void vexp_mkl(const int n, const float* x, float* y) { TEST(JitKernel, vexp) { namespace jit = paddle::operators::math::jitkernel; - for (int d : {7, 8, 15, 16, 30, 128, 256}) { + for (int d : {7, 8, 12, 15, 16, 20, 30, 128, 256}) { std::vector x(d); std::vector zref(d), ztgt(d); RandomVec(d, x.data(), -2.f, 2.f); -- GitLab