diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index 9efd4e81748c203b4790916f3fd78006cbefb104..a5eef019c891e07ded96ef002309c63dc70a6bfc 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -60,60 +60,53 @@ void VXXJitCode::generate() { offset += sizeof(float) * YMM_FLOAT_BLOCK; } int rest = num_ % YMM_FLOAT_BLOCK; - if (rest >= 4) { - if (scalar_index_ != 1) { - vmovups(xmm_src1, ptr[param1 + offset]); - } - if (scalar_index_ != 2) { - vmovups(xmm_src2, ptr[param2 + offset]); - } - if (type_ == operand_type::mul) { - vmulps(xmm_dst, xmm_src1, xmm_src2); - } else if (type_ == operand_type::add) { - vaddps(xmm_dst, xmm_src1, xmm_src2); - } - if (with_relu_) { - vmaxps(xmm_dst, xmm_zero, xmm_dst); - } - vmovups(ptr[param3 + offset], xmm_dst); - offset += sizeof(float) * 4; - rest -= 4; - } - if (rest >= 2) { - if (scalar_index_ != 1) { - vmovq(xmm_src1, ptr[param1 + offset]); - } - if (scalar_index_ != 2) { - vmovq(xmm_src2, ptr[param2 + offset]); + int block = XMM_FLOAT_BLOCK; + while (rest > 0) { + if (rest >= 4) { + if (scalar_index_ != 1) { + vmovups(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovups(xmm_src2, ptr[param2 + offset]); + } + } else if (rest >= 2) { + if (scalar_index_ != 1) { + vmovq(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovq(xmm_src2, ptr[param2 + offset]); + } + } else { + if (scalar_index_ != 1) { + vmovss(xmm_src1, ptr[param1 + offset]); + } + if (scalar_index_ != 2) { + vmovss(xmm_src2, ptr[param2 + offset]); + } } - if (type_ == operand_type::mul) { - vmulps(xmm_dst, xmm_src1, xmm_src2); - } else if (type_ == operand_type::add) { - vaddps(xmm_dst, xmm_src1, xmm_src2); + switch (type_) { + case operand_type::mul: + vmulps(xmm_dst, xmm_src1, xmm_src2); + break; + case operand_type::add: + vaddps(xmm_dst, xmm_src1, xmm_src2); + break; + default: + break; } if (with_relu_) { vmaxps(xmm_dst, xmm_zero, xmm_dst); } - vmovq(ptr[param3 + offset], xmm_dst); - offset += sizeof(float) * 2; - rest -= 2; - } - if (rest > 0) { - if (scalar_index_ != 1) { - vmovss(xmm_src1, ptr[param1 + offset]); - } - if (scalar_index_ != 2) { - vmovss(xmm_src2, ptr[param2 + offset]); - } - if (type_ == operand_type::mul) { - vmulss(xmm_dst, xmm_src1, xmm_src2); - } else if (type_ == operand_type::add) { - vaddss(xmm_dst, xmm_src1, xmm_src2); + if (rest >= 4) { + vmovups(ptr[param3 + offset], xmm_dst); + } else if (rest >= 2) { + vmovq(ptr[param3 + offset], xmm_dst); + } else { + vmovss(ptr[param3 + offset], xmm_dst); } - if (with_relu_) { - vmaxps(xmm_dst, xmm_zero, xmm_dst); - } - vmovss(ptr[param3 + offset], xmm_dst); + offset += sizeof(float) * block; + rest -= block; + block /= 2; } ret(); } @@ -175,11 +168,9 @@ static int g_tmp_mem[16] ALIGN32 = {0}; bool VActJitCode::init(int d, operand_type type) { bool ok = MayIUse(avx); - if (type == operand_type::relu) { + if (type == operand_type::relu || type == operand_type::exp) { + // TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256 return ok; - } else if (type == operand_type::exp) { - // exp is slower than mkl when d >= 256 - return ok; //&& d % 4 == 0 && d < 256; } else { // TODO(TJ): support more return ok && d % 8 == 0; @@ -412,24 +403,15 @@ void VActJitCode::generate() { return; } int rest = num_ % YMM_FLOAT_BLOCK; - if (rest >= 4) { - vmovups(xmm_src, ptr[param1 + offset]); - switch (type_) { - case operand_type::relu: - relu_xmm(xmm_dst, xmm_src, xmm_zero); - break; - case operand_type::exp: - exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5); - break; - default: - break; + int block = XMM_FLOAT_BLOCK; + while (rest > 0) { + if (rest >= 4) { + vmovups(xmm_src, ptr[param1 + offset]); + } else if (rest >= 2) { + vmovq(xmm_src, ptr[param1 + offset]); + } else { + vmovss(xmm_src, ptr[param1 + offset]); } - vmovups(ptr[param2 + offset], xmm_dst); - offset += sizeof(float) * 4; - rest -= 4; - } - if (rest >= 2) { - vmovq(xmm_src, ptr[param1 + offset]); switch (type_) { case operand_type::relu: relu_xmm(xmm_dst, xmm_src, xmm_zero); @@ -440,25 +422,16 @@ void VActJitCode::generate() { default: break; } - vmovq(ptr[param2 + offset], xmm_dst); - offset += sizeof(float) * 2; - rest -= 2; - } - if (rest > 0) { - // vmovups(); - vmovss(xmm_src, ptr[param1 + offset]); - - switch (type_) { - case operand_type::relu: - relu_xmm(xmm_dst, xmm_src, xmm_zero); - break; - case operand_type::exp: - exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5); - break; - default: - break; + if (rest >= 4) { + vmovups(ptr[param2 + offset], xmm_dst); + } else if (rest >= 2) { + vmovq(ptr[param2 + offset], xmm_dst); + } else { + vmovss(ptr[param2 + offset], xmm_dst); } - vmovss(ptr[param2 + offset], xmm_dst); + offset += sizeof(float) * block; + rest -= block; + block /= 2; } ret(); }