diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index e3b600d4427672faa477341e207a5eab2bcf383d..e484e9a3c705c5638fa94010a4513ae1566a8be3 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/jit_code.h" -#include "paddle/fluid/operators/math/jit_kernel.h" -#include "paddle/fluid/platform/cpu_info.h" +#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me namespace paddle { namespace operators { @@ -60,257 +59,83 @@ void VXXJitCode::generate() { offset += sizeof(float) * YMM_FLOAT_BLOCK; } int rest = num_ % YMM_FLOAT_BLOCK; - if (rest >= 4) { - if (scalar_index_ != 1) { - vmovups(xmm_src1, ptr[param1 + offset]); - } - if (scalar_index_ != 2) { - vmovups(xmm_src2, ptr[param2 + offset]); - } - if (type_ == operand_type::mul) { - vmulps(xmm_dst, xmm_src1, xmm_src2); - } else if (type_ == operand_type::add) { - vaddps(xmm_dst, xmm_src1, xmm_src2); - } - if (with_relu_) { - vmaxps(xmm_dst, xmm_zero, xmm_dst); - } - vmovups(ptr[param3 + offset], xmm_dst); - offset += sizeof(float) * 4; - rest -= 4; - } - if (rest >= 2) { - if (scalar_index_ != 1) { - vmovups(xmm_src1, ptr[param1 + offset]); + while (rest > 0) { + int block = XMM_FLOAT_BLOCK; + if (rest >= 4) { + block = 4; + if (scalar_index_ != 1) { + vmovups(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(xmm_src2, ptr[param2 + offset]); + } + } else if (rest >= 2) { + block = 2; + if (scalar_index_ != 1) { + vmovq(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovq(xmm_src2, ptr[param2 + offset]); + } + } else { + block = 1; + if (scalar_index_ != 1) { + vmovss(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovss(xmm_src2, ptr[param2 + offset]); + } } - if (scalar_index_ != 2) { - vmovups(xmm_src2, ptr[param2 + offset]); - } - if (type_ == operand_type::mul) { - vmulps(xmm_dst, xmm_src1, xmm_src2); - } else if (type_ == operand_type::add) { - vaddps(xmm_dst, xmm_src1, xmm_src2); + switch (type_) { + case operand_type::mul: + vmulps(xmm_dst, xmm_src1, xmm_src2); + break; + case operand_type::add: + vaddps(xmm_dst, xmm_src1, xmm_src2); + break; + default: + break; } if (with_relu_) { vmaxps(xmm_dst, xmm_zero, xmm_dst); } - vmovq(ptr[param3 + offset], xmm_dst); - offset += sizeof(float) * 2; - rest -= 2; - } - if (rest > 0) { - if (scalar_index_ != 1) { - vmovups(xmm_src1, ptr[param1 + offset]); - } - if (scalar_index_ != 2) { - vmovups(xmm_src2, ptr[param2 + offset]); - } - if (type_ == operand_type::mul) { - vmulss(xmm_dst, xmm_src1, xmm_src2); - } else if (type_ == operand_type::add) { - vaddss(xmm_dst, xmm_src1, xmm_src2); + if (rest >= 4) { + vmovups(ptr[param3 + offset], xmm_dst); + } else if (rest >= 2) { + vmovq(ptr[param3 + offset], xmm_dst); + } else { + vmovss(ptr[param3 + offset], xmm_dst); } - if (with_relu_) { - vmaxps(xmm_dst, xmm_zero, xmm_dst); - } - vmovss(ptr[param3 + offset], xmm_dst); + offset += sizeof(float) * block; + rest -= block; } ret(); } -#define ALIGN32 __attribute__((aligned(32))) -#define EXP_HIG 88.3762626647949f -#define EXP_LOW -88.3762626647949f -#define CEPHES_LOG2EF 1.44269504088896341 -#define CEPHES_EXP_C1 0.693359375 -#define CEPHES_EXP_C2 -2.12194440e-4 -#define CEPHES_EXP_P0 1.9875691500E-4 -#define CEPHES_EXP_P1 1.3981999507E-3 -#define CEPHES_EXP_P2 8.3334519073E-3 -#define CEPHES_EXP_P3 4.1665795894E-2 -#define CEPHES_EXP_P4 1.6666665459E-1 -#define CEPHES_EXP_P5 5.0000001201E-1 - -#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val - -#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float) -#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float) - -static const float exp_float_consts[] ALIGN32 = { - REPEAT_8TIMES(1.f), - REPEAT_8TIMES(2.f), - REPEAT_8TIMES(0.5f), - REPEAT_8TIMES(EXP_HIG), - REPEAT_8TIMES(EXP_LOW), - REPEAT_8TIMES(CEPHES_LOG2EF), - REPEAT_8TIMES(CEPHES_EXP_C1), - REPEAT_8TIMES(CEPHES_EXP_C2), - REPEAT_8TIMES(CEPHES_EXP_P0), - REPEAT_8TIMES(CEPHES_EXP_P1), - REPEAT_8TIMES(CEPHES_EXP_P2), - REPEAT_8TIMES(CEPHES_EXP_P3), - REPEAT_8TIMES(CEPHES_EXP_P4), - REPEAT_8TIMES(CEPHES_EXP_P5), - REPEAT_8TIMES(EXP_MAX_INPUT), - REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX), - REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)}; - -static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)}; -static int g_tmp_mem[16] ALIGN32 = {0}; +const float exp_float_consts[] ALIGN32 = {REPEAT_8TIMES(1.f), + REPEAT_8TIMES(2.f), + REPEAT_8TIMES(0.5f), + REPEAT_8TIMES(EXP_HIG), + REPEAT_8TIMES(EXP_LOW), + REPEAT_8TIMES(CEPHES_LOG2EF), + REPEAT_8TIMES(CEPHES_EXP_C1), + REPEAT_8TIMES(CEPHES_EXP_C2), + REPEAT_8TIMES(CEPHES_EXP_P0), + REPEAT_8TIMES(CEPHES_EXP_P1), + REPEAT_8TIMES(CEPHES_EXP_P2), + REPEAT_8TIMES(CEPHES_EXP_P3), + REPEAT_8TIMES(CEPHES_EXP_P4), + REPEAT_8TIMES(CEPHES_EXP_P5), + REPEAT_8TIMES(EXP_MAX_INPUT), + REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX), + REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)}; + +const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)}; +int g_tmp_mem[16] ALIGN32 = {0}; bool VActJitCode::init(int d, operand_type type) { - bool ok = MayIUse(avx); - if (type == operand_type::relu) { - return ok; - } else if (type == operand_type::exp) { - // exp is slower than mkl when d >= 256 - return ok && d % 8 == 0 && d < 256; - } else { - // TODO(TJ): support more - return ok && d % 8 == 0; - } -} - -void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) { - vmaxps(ymm_dst, ymm_zero, ymm_src); -} - -void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, - int fy_idx, int mask_idx, int tmp_idx) { - assert(ymm_src.getIdx() != ymm_dst.getIdx()); // TODO(TJ): use enfore - // check all idx can not equal - ymm_t ymm_fx = ymm_t(fx_idx); - ymm_t ymm_fy = ymm_t(fy_idx); - ymm_t ymm_mask = ymm_t(mask_idx); - ymm_t ymm_tmp = ymm_t(tmp_idx); - reg64_t reg_ptr_global = rax; - push(reg_ptr_global); - mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]); - vminps(ymm_src, ymm_src, ymm_tmp); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]); - vmaxps(ymm_src, ymm_src, ymm_tmp); - // express exp(x) as exp(g + n*log(2)) - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]); - vmulps(ymm_fx, ymm_src, ymm_tmp); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]); - vaddps(ymm_fx, ymm_fx, ymm_tmp); - vroundps(ymm_fy, ymm_fx, 0x01); - // if greater, substract 1 - vcmpgtps(ymm_mask, ymm_fy, ymm_fx); - vmovaps(ymm_tmp, ptr[reg_ptr_global]); - vandps(ymm_mask, ymm_mask, ymm_tmp); - vsubps(ymm_fx, ymm_fy, ymm_mask); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]); - vmulps(ymm_fy, ymm_fx, ymm_tmp); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]); - ymm_t ymm_z = ymm_t(ymm_mask.getIdx()); - vmulps(ymm_z, ymm_fx, ymm_tmp); - vsubps(ymm_src, ymm_src, ymm_fy); - vsubps(ymm_src, ymm_src, ymm_z); - vmulps(ymm_z, ymm_src, ymm_src); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]); - vmulps(ymm_dst, ymm_src, ymm_tmp); - for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5; - i += (YMM_FLOAT_BLOCK * sizeof(float))) { - vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4 - vaddps(ymm_dst, ymm_dst, ymm_tmp); - vmulps(ymm_dst, ymm_dst, ymm_src); - } - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]); - vaddps(ymm_dst, ymm_dst, ymm_tmp); - vmulps(ymm_dst, ymm_dst, ymm_z); - vaddps(ymm_dst, ymm_dst, ymm_src); - vmovaps(ymm_tmp, ptr[reg_ptr_global]); - vaddps(ymm_dst, ymm_dst, ymm_tmp); - // build 2^n - ymm_t ymm_int = ymm_fx; - vcvttps2dq(ymm_int, ymm_fx); - mov(reg_ptr_global, reinterpret_cast(exp_int_0x7f)); - vmovdqa(ymm_tmp, ptr[reg_ptr_global]); - if (MayIUse(avx2)) { - vpaddd(ymm_int, ymm_int, ymm_tmp); - vpslld(ymm_int, ymm_int, 23); - } else if (MayIUse(avx)) { - xmm_t xtmp1 = xmm_t(ymm_int.getIdx()); - xmm_t xtmp2 = xmm_t(ymm_tmp.getIdx()); - reg64_t reg_ptr_tmp = reg_ptr_global; - mov(reg_ptr_tmp, reinterpret_cast(g_tmp_mem)); - vmovdqa(ptr[reg_ptr_tmp], ymm_int); - vmovdqa(ptr[reg_ptr_tmp + YMM_FLOAT_BLOCK * sizeof(float)], ymm_tmp); - vpaddd(xtmp1, xtmp1, xtmp2); - vpslld(xtmp1, xtmp1, 23); - vmovdqa(ptr[reg_ptr_tmp], xtmp1); - // next 128bits - vmovdqa(xtmp1, ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)]); - vmovdqa(xtmp2, - ptr[reg_ptr_tmp + - (YMM_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]); - vpaddd(xtmp1, xtmp1, xtmp2); - vpslld(xtmp1, xtmp1, 23); - vmovdqa(ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)], xtmp1); - // load out - vmovdqa(ymm_int, ptr[reg_ptr_tmp]); - } - vmulps(ymm_dst, ymm_dst, ymm_int); - pop(reg_ptr_global); -} - -void VActJitCode::sigmoid_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, - int fy_idx, int mask_idx, int tmp_idx) { - // y = 1 / (1 + e^-x) - ymm_t ymm_tmp = ymm_t(tmp_idx); - reg64_t reg_ptr_global = rax; - push(reg_ptr_global); - mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]); - vminps(ymm_src, ymm_src, ymm_tmp); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]); - vmaxps(ymm_src, ymm_src, ymm_tmp); - vxorps(ymm_tmp, ymm_tmp, ymm_tmp); - vsubps(ymm_src, ymm_tmp, ymm_src); - exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); - vaddps(ymm_dst, ymm_dst, ymm_tmp); - vdivps(ymm_dst, ymm_tmp, ymm_dst); - pop(reg_ptr_global); -} - -void VActJitCode::tanh_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx, - int fy_idx, int mask_idx, int tmp_idx) { - // y = 2 / (1 + e^(-2x)) - 1 - ymm_t ymm_tmp = ymm_t(tmp_idx); - ymm_t ymm_zero = ymm_t(mask_idx); - reg64_t reg_ptr_global = rax; - push(reg_ptr_global); - mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); - vxorps(ymm_zero, ymm_zero, ymm_zero); - vsubps(ymm_tmp, ymm_zero, ymm_tmp); - vmulps(ymm_src, ymm_src, ymm_tmp); - exp_ymm(ymm_dst, ymm_src, fx_idx, fy_idx, mask_idx, tmp_idx); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); - vaddps(ymm_dst, ymm_dst, ymm_tmp); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); - vdivps(ymm_dst, ymm_tmp, ymm_dst); - vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); - vsubps(ymm_dst, ymm_dst, ymm_tmp); - pop(reg_ptr_global); + // TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256 + return MayIUse(avx); } void VActJitCode::generate() { @@ -324,16 +149,16 @@ void VActJitCode::generate() { vmovups(ymm_src, ptr[param1 + offset]); switch (type_) { case operand_type::relu: - relu_ymm(ymm_dst, ymm_src, ymm_zero); + relu_jmm(ymm_dst, ymm_src, ymm_zero); break; case operand_type::exp: - exp_ymm(ymm_dst, ymm_src, 2, 3, 4, 5); + exp_jmm(ymm_dst, ymm_src, 2, 3, 4, 5); break; case operand_type::sigmoid: - sigmoid_ymm(ymm_dst, ymm_src, 2, 3, 4, 5); + sigmoid_jmm(ymm_dst, ymm_src, 2, 3, 4, 5); break; case operand_type::tanh: - tanh_ymm(ymm_dst, ymm_src, 2, 3, 4, 5); + tanh_jmm(ymm_dst, ymm_src, 2, 3, 4, 5); break; case operand_type::identity: break; @@ -343,30 +168,44 @@ void VActJitCode::generate() { vmovups(ptr[param2 + offset], ymm_dst); offset += sizeof(float) * YMM_FLOAT_BLOCK; } - if (type_ != operand_type::relu) { - // TODO(TJ): remove me - ret(); - return; - } int rest = num_ % YMM_FLOAT_BLOCK; - if (rest >= 4) { - vmovups(xmm_src, ptr[param1 + offset]); - vmaxps(xmm_dst, xmm_zero, xmm_src); - vmovups(ptr[param2 + offset], xmm_dst); - offset += sizeof(float) * 4; - rest -= 4; - } - if (rest >= 2) { - vmovups(xmm_src, ptr[param1 + offset]); - vmaxps(xmm_dst, xmm_zero, xmm_src); - vmovq(ptr[param2 + offset], xmm_dst); - offset += sizeof(float) * 2; - rest -= 2; - } - if (rest > 0) { - vmovups(xmm_src, ptr[param1 + offset]); - vmaxps(xmm_dst, xmm_zero, xmm_src); - vmovss(ptr[param2 + offset], xmm_dst); + while (rest > 0) { + int block = XMM_FLOAT_BLOCK; + if (rest >= 4) { + block = 4; + vmovups(xmm_src, ptr[param1 + offset]); + } else if (rest >= 2) { + block = 2; + vmovq(xmm_src, ptr[param1 + offset]); + } else { + block = 1; + vmovss(xmm_src, ptr[param1 + offset]); + } + switch (type_) { + case operand_type::relu: + relu_jmm(xmm_dst, xmm_src, xmm_zero); + break; + case operand_type::exp: + exp_jmm(xmm_dst, xmm_src, 2, 3, 4, 5); + break; + case operand_type::sigmoid: + sigmoid_jmm(xmm_dst, xmm_src, 2, 3, 4, 5); + break; + case operand_type::tanh: + tanh_jmm(xmm_dst, xmm_src, 2, 3, 4, 5); + break; + default: + break; + } + if (rest >= 4) { + vmovups(ptr[param2 + offset], xmm_dst); + } else if (rest >= 2) { + vmovq(ptr[param2 + offset], xmm_dst); + } else { + vmovss(ptr[param2 + offset], xmm_dst); + } + offset += sizeof(float) * block; + rest -= block; } ret(); } diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 71205b211b7f571f8081640ef60222de051ff49d..65f83ff4846601d1575daa994772cd869d526f56 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -16,6 +16,8 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/jit_gen.h" +#include "paddle/fluid/platform/cpu_info.h" + namespace paddle { namespace operators { namespace math { @@ -40,6 +42,51 @@ typedef enum { identity } operand_type; +extern const float exp_float_consts[]; +extern const int exp_int_0x7f[]; +extern int g_tmp_mem[]; + +// TODO(TJ): move these to some proper place +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 +#define XMM_FLOAT_BLOCK 4 +#define YMM_FLOAT_BLOCK 8 +#define ZMM_FLOAT_BLOCK 16 + +#define ALIGN32 __attribute__((aligned(32))) +#define EXP_HIG 88.3762626647949f +#define EXP_LOW -88.3762626647949f +#define CEPHES_LOG2EF 1.44269504088896341 +#define CEPHES_EXP_C1 0.693359375 +#define CEPHES_EXP_C2 -2.12194440e-4 +#define CEPHES_EXP_P0 1.9875691500E-4 +#define CEPHES_EXP_P1 1.3981999507E-3 +#define CEPHES_EXP_P2 8.3334519073E-3 +#define CEPHES_EXP_P3 4.1665795894E-2 +#define CEPHES_EXP_P4 1.6666665459E-1 +#define CEPHES_EXP_P5 5.0000001201E-1 + +#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val + +#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float) +#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float) + // function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) class VXXJitCode : public JitCode { public: @@ -127,21 +174,140 @@ class VActJitCode : public JitCode { void generate() override; protected: - // compute relu with ymm - void relu_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, - const Xbyak::Ymm& zero); + // compute relu with ymm, xmm + template + void relu_jmm(JMM& dst, JMM& src, JMM& zero) { // NOLINT + vmaxps(dst, src, zero); + } - // compute exp with ymm - void exp_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, - int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); + // compute exp with ymm, xmm + template + void exp_jmm(JMM& dst, JMM& src, int fx_idx = 2, int fy_idx = 3, // NOLINT + int mask_idx = 4, int tmp_idx = 5) { + using namespace platform::jit; // NOLINT + assert(src.getIdx() != dst.getIdx()); // TODO(TJ): use enfore + // check all idx can not equal + JMM jmm_fx = JMM(fx_idx); + JMM jmm_fy = JMM(fy_idx); + JMM jmm_mask = JMM(mask_idx); + JMM jmm_tmp = JMM(tmp_idx); + reg64_t reg_ptr_global = rax; + push(reg_ptr_global); + mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_HIG]); + vminps(src, src, jmm_tmp); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOW]); + vmaxps(src, src, jmm_tmp); + // express exp(x) as exp(g + n*log(2)) + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_LOG2EF]); + vmulps(jmm_fx, src, jmm_tmp); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_0P5]); + vaddps(jmm_fx, jmm_fx, jmm_tmp); + vroundps(jmm_fy, jmm_fx, 0x01); + // if greater, substract 1 + vcmpgtps(jmm_mask, jmm_fy, jmm_fx); + vmovaps(jmm_tmp, ptr[reg_ptr_global]); + vandps(jmm_mask, jmm_mask, jmm_tmp); + vsubps(jmm_fx, jmm_fy, jmm_mask); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C1]); + vmulps(jmm_fy, jmm_fx, jmm_tmp); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_C2]); + JMM ymm_z = JMM(jmm_mask.getIdx()); + vmulps(ymm_z, jmm_fx, jmm_tmp); + vsubps(src, src, jmm_fy); + vsubps(src, src, ymm_z); + vmulps(ymm_z, src, src); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]); + vmulps(dst, src, jmm_tmp); + for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5; + i += (YMM_FLOAT_BLOCK * sizeof(float))) { + vmovaps(jmm_tmp, ptr[reg_ptr_global + i]); // P1~P4 + vaddps(dst, dst, jmm_tmp); + vmulps(dst, dst, src); + } + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P5]); + vaddps(dst, dst, jmm_tmp); + vmulps(dst, dst, ymm_z); + vaddps(dst, dst, src); + vmovaps(jmm_tmp, ptr[reg_ptr_global]); + vaddps(dst, dst, jmm_tmp); + // build 2^n + JMM ymm_int = jmm_fx; + vcvttps2dq(ymm_int, jmm_fx); + mov(reg_ptr_global, reinterpret_cast(exp_int_0x7f)); + vmovdqa(jmm_tmp, ptr[reg_ptr_global]); + if (MayIUse(avx2) || std::is_same::value) { + vpaddd(ymm_int, ymm_int, jmm_tmp); + vpslld(ymm_int, ymm_int, 23); + } else if (MayIUse(avx)) { + xmm_t xtmp1 = xmm_t(ymm_int.getIdx()); + xmm_t xtmp2 = xmm_t(jmm_tmp.getIdx()); + reg64_t reg_ptr_tmp = reg_ptr_global; + mov(reg_ptr_tmp, reinterpret_cast(g_tmp_mem)); + vmovdqa(ptr[reg_ptr_tmp], ymm_int); + vmovdqa(ptr[reg_ptr_tmp + YMM_FLOAT_BLOCK * sizeof(float)], jmm_tmp); + vpaddd(xtmp1, xtmp1, xtmp2); + vpslld(xtmp1, xtmp1, 23); + vmovdqa(ptr[reg_ptr_tmp], xtmp1); + // next 128bits + vmovdqa(xtmp1, ptr[reg_ptr_tmp + XMM_FLOAT_BLOCK * sizeof(float)]); + vmovdqa(xtmp2, ptr[reg_ptr_tmp + + (YMM_FLOAT_BLOCK + XMM_FLOAT_BLOCK) * sizeof(float)]); + vpaddd(xtmp1, xtmp1, xtmp2); + vpslld(xtmp1, xtmp1, 23); + vmovdqa(ptr[reg_ptr_tmp + XMM_FLOAT_BLOCK * sizeof(float)], xtmp1); + // load out + vmovdqa(ymm_int, ptr[reg_ptr_tmp]); + } + vmulps(dst, dst, ymm_int); + pop(reg_ptr_global); + } - // compute sigmoid with ymm - void sigmoid_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, - int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); + // compute sigmoid with ymm, xmm + template + void sigmoid_jmm(JMM& dst, JMM& src, int fx_idx = 2, // NOLINT + int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5) { + // y = 1 / (1 + e^-x) + JMM jmm_tmp = JMM(tmp_idx); + reg64_t reg_ptr_global = rax; + push(reg_ptr_global); + mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MAX]); + vminps(src, src, jmm_tmp); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_SIGMOID_MIN]); + vmaxps(src, src, jmm_tmp); + vxorps(jmm_tmp, jmm_tmp, jmm_tmp); + vsubps(src, jmm_tmp, src); + exp_jmm(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); + vaddps(dst, dst, jmm_tmp); + vdivps(dst, jmm_tmp, dst); + pop(reg_ptr_global); + } - // compute tanh with ymm - void tanh_ymm(const Xbyak::Ymm& dst, const Xbyak::Ymm& src, int fx_idx = 2, - int fy_idx = 3, int mask_idx = 4, int tmp_idx = 5); + // compute tanh with ymm, xmm + template + void tanh_jmm(JMM& dst, JMM& src, int fx_idx = 2, int fy_idx = 3, // NOLINT + int mask_idx = 4, int tmp_idx = 5) { + // y = 2 / (1 + e^(-2x)) - 1 + JMM jmm_tmp = JMM(tmp_idx); + JMM jmm_zero = JMM(mask_idx); + reg64_t reg_ptr_global = rax; + push(reg_ptr_global); + mov(reg_ptr_global, reinterpret_cast(exp_float_consts)); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); + vxorps(jmm_zero, jmm_zero, jmm_zero); + vsubps(jmm_tmp, jmm_zero, jmm_tmp); + vmulps(src, src, jmm_tmp); + exp_jmm(dst, src, fx_idx, fy_idx, mask_idx, tmp_idx); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); + vaddps(dst, dst, jmm_tmp); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_TWO]); + vdivps(dst, jmm_tmp, dst); + vmovaps(jmm_tmp, ptr[reg_ptr_global + OFFSET_EXP_ONE]); + vsubps(dst, dst, jmm_tmp); + pop(reg_ptr_global); + } protected: int num_; diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 665ba24872a09897c4c1cb9bb5fc163b0c564dda..7e163c1349e73d8fe5e436b98c9a8f67e6439506 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -26,6 +26,7 @@ namespace operators { namespace math { namespace jitkernel { +// TODO(TJ): move these to some proper place #define SIGMOID_THRESHOLD_MIN -40.0 #define SIGMOID_THRESHOLD_MAX 13.0 #define EXP_MAX_INPUT 40.0 diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 5a6f87fe1f7d10d65d03d78c168d61719cec772e..b6c62a26348cdc20582cf7465f93026402051587 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -33,6 +33,9 @@ limitations under the License. */ constexpr int repeat = 20000; +// TODO(TJ): benchmark and test should be seperated, +// benchmark should verify more sizes + inline double GetCurrentUS() { struct timeval time; gettimeofday(&time, NULL); @@ -66,7 +69,7 @@ void vrelu_intri8(const int n, const float* x, float* y) { TEST(JitKernel, vrelu) { namespace jit = paddle::operators::math::jitkernel; - for (int d : {7, 8, 15, 16, 30, 256, 512}) { + for (int d : {3, 7, 8, 15, 16, 30, 256, 512}) { std::vector x(d); std::vector zref(d), ztgt(d); RandomVec(d, x.data(), -10.f, 1.f); @@ -156,7 +159,7 @@ void vexp_mkl(const int n, const float* x, float* y) { TEST(JitKernel, vexp) { namespace jit = paddle::operators::math::jitkernel; - for (int d : {7, 8, 15, 16, 30, 128, 256}) { + for (int d : {1, 3, 4, 6, 7, 8, 12, 15, 16, 20, 30, 128, 256}) { std::vector x(d); std::vector zref(d), ztgt(d); RandomVec(d, x.data(), -2.f, 2.f); @@ -231,7 +234,7 @@ void vsigmoid_better( TEST(JitKernel, vsigmoid) { namespace jit = paddle::operators::math::jitkernel; - for (int d : {7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { + for (int d : {1, 3, 4, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { std::vector x(d); std::vector zref(d), ztgt(d); RandomVec(d, x.data(), -2.f, 2.f); @@ -295,7 +298,7 @@ void vtanh_better( TEST(JitKernel, vtanh) { namespace jit = paddle::operators::math::jitkernel; - for (int d : {7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { + for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) { std::vector x(d); std::vector zref(d), ztgt(d); RandomVec(d, x.data(), -2.f, 2.f); @@ -386,7 +389,7 @@ void lstm_ctht_better( TEST(JitKernel, lstm) { namespace jit = paddle::operators::math::jitkernel; - for (int d : {7, 8, 15, 16, 30, 32, 64, 100}) { + for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100}) { int d4 = d * 4; int d3 = d * 3; std::vector x(d4), xref(d4); @@ -759,7 +762,7 @@ TEST(JitKernel, vaddrelu) { float* zref_data = zref.data(); auto trefs = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - vadd_ref(d, x_data, y_data, zref_data); + vaddrelu_ref(d, x_data, y_data, zref_data); } auto trefe = GetCurrentUS(); auto tmkls = GetCurrentUS();