From 4dbdfa60ef6d13568880fb2de5ee31a469080ab7 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 16 Nov 2018 17:29:36 +0000 Subject: [PATCH] sigmoid and tanh support all size test=develop --- paddle/fluid/operators/math/jit_code.cc | 67 ++++--------------------- paddle/fluid/operators/math/jit_code.h | 50 +++++++++++++++--- 2 files changed, 54 insertions(+), 63 deletions(-) diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index fd18256b0c9..a080079a2de 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -132,56 +132,8 @@ const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)}; int g_tmp_mem[16] ALIGN32 = {0}; bool VActJitCode::init(int d, operand_type type) { - bool ok = MayIUse(avx); - if (type == operand_type::relu || type == operand_type::exp) { - // TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256 - return ok; - } else { - // TODO(TJ): support more - return ok && d % 8 == 0; - } -} - -void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, - int fy_idx, int mask_idx, int tmp_idx) { - // y = 1 / (1 + e^-x) - ymm_t ymm_tmp = ymm_t(tmp_idx); - reg64_t reg_ptr_global = rax; - push(reg_ptr_global); - mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]); - vminps(ymm_src, ymm_src, ymm_tmp); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]); - vmaxps(ymm_src, ymm_src, ymm_tmp); - vxorps(ymm_tmp, ymm_tmp, ymm_tmp); - vsubps(ymm_src, ymm_tmp, ymm_src); - exp_jmm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); - vaddps(ymm_dst, ymm_dst, ymm_tmp); - vdivps(ymm_dst, ymm_tmp, ymm_dst); - pop(reg_ptr_global); -} - -void VActJitCode::tanh_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, - int fy_idx, int mask_idx, int tmp_idx) { - // y = 2 / (1 + e^(-2x)) - 1 - ymm_t ymm_tmp = ymm_t(tmp_idx); - ymm_t ymm_zero = ymm_t(mask_idx); - reg64_t reg_ptr_global = rax; - push(reg_ptr_global); - mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); - vxorps(ymm_zero, ymm_zero, ymm_zero); - vsubps(ymm_tmp, ymm_zero, ymm_tmp); - vmulps(ymm_src, ymm_src, ymm_tmp); - exp_jmm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); - vaddps(ymm_dst, ymm_dst, ymm_tmp); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); - vdivps(ymm_dst, ymm_tmp, ymm_dst); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); - vsubps(ymm_dst, ymm_dst, ymm_tmp); - pop(reg_ptr_global); + // TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256 + return MayIUse(avx); } void VActJitCode::generate() { @@ -201,10 +153,10 @@ void VActJitCode::generate() { exp_jmm(ymm_dst, ymm_src, 2, 3, 4, 5); break; case operand_type::sigmoid: - sigmoid_ymm(ymm_dst, ymm_src, 2, 3, 4, 5); + sigmoid_jmm(ymm_dst, ymm_src, 2, 3, 4, 5); break; case operand_type::tanh: - tanh_ymm(ymm_dst, ymm_src, 2, 3, 4, 5); + tanh_jmm(ymm_dst, ymm_src, 2, 3, 4, 5); break; case operand_type::identity: break; @@ -214,11 +166,6 @@ void VActJitCode::generate() { vmovups(ptr[param2 + offset], ymm_dst); offset += sizeof(float) * YMM_FLOAT_BLOCK; } - if (type_ != operand_type::relu && type_ != operand_type::exp) { - // TODO(TJ): remove me - ret(); - return; - } int rest = num_ % YMM_FLOAT_BLOCK; int block = XMM_FLOAT_BLOCK; while (rest > 0) { @@ -236,6 +183,12 @@ void VActJitCode::generate() { case operand_type::exp: exp_jmm(xmm_dst, xmm_src, 2, 3, 4, 5); break; + case operand_type::sigmoid: + sigmoid_jmm(xmm_dst, xmm_src, 2, 3, 4, 5); + break; + case operand_type::tanh: + tanh_jmm(xmm_dst, xmm_src, 2, 3, 4, 5); + break; default: break; } diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 534398f4a42..65f83ff4846 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -263,13 +263,51 @@ class VActJitCode : public JitCode { pop(reg_ptr_global); } - // compute sigmoid with ymm - void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, - int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); + // compute sigmoid with ymm, xmm + template + void sigmoid_jmm(JMM& dst, JMM& src, int fx_idx = 2, // NOLINT + int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5) { + // y = 1 / (1 + e^-x) + JMM jmm_tmp = JMM(tmp_idx); + reg64_t reg_ptr_global = rax; + push(reg_ptr_global); + mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]); + vminps(src, src, jmm_tmp); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]); + vmaxps(src, src, jmm_tmp); + vxorps(jmm_tmp, jmm_tmp, jmm_tmp); + vsubps(src, jmm_tmp, src); + exp_jmm(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); + vaddps(dst, dst, jmm_tmp); + vdivps(dst, jmm_tmp, dst); + pop(reg_ptr_global); + } - // compute tanh with ymm - void tanh_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, - int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); + // compute tanh with ymm, xmm + template + void tanh_jmm(JMM& dst, JMM& src, int fx_idx = 2, int fy_idx = 3, // NOLINT + int mask_idx = 4, int tmp_idx = 5) { + // y = 2 / (1 + e^(-2x)) - 1 + JMM jmm_tmp = JMM(tmp_idx); + JMM jmm_zero = JMM(mask_idx); + reg64_t reg_ptr_global = rax; + push(reg_ptr_global); + mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); + vxorps(jmm_zero, jmm_zero, jmm_zero); + vsubps(jmm_tmp, jmm_zero, jmm_tmp); + vmulps(src, src, jmm_tmp); + exp_jmm(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); + vaddps(dst, dst, jmm_tmp); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); + vdivps(dst, jmm_tmp, dst); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); + vsubps(dst, dst, jmm_tmp); + pop(reg_ptr_global); + } protected: int num_; -- GitLab