diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index 06cf82513dfd0a0b3dc139602c59e953983c101a..c3bb60f2a8cb825c9c155b2d13adea4e21a0ecdd 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -25,10 +25,10 @@ namespace gen { using namespace platform::jit; // NOLINT bool VMulJitCode::init(int d) { - // TODO(TJ): maybe one AVX is enough, AVX above would slow down freq - // try more with avx2 or avx512 - if (MayIUse(avx) || MayIUse(avx2)) { - return d % AVX_FLOAT_BLOCK == 0; + // It's not necessary to use avx512 since it would slow down the frequency + // and this kernel is not compute bound. + if (MayIUse(avx)) { + return d % 2 == 0; } else { return false; } @@ -36,12 +36,33 @@ bool VMulJitCode::init(int d) { void VMulJitCode::generate() { // do not need push stack, and do not need save avx512reg if do not use avx512 - int stride = sizeof(float) * AVX_FLOAT_BLOCK; + int offset = 0; for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { - vmovups(ymm_src1, ptr[param1 + i * stride]); - vmovups(ymm_src2, ptr[param2 + i * stride]); + vmovups(ymm_src1, ptr[param1 + offset]); + vmovups(ymm_src2, ptr[param2 + offset]); vmulps(ymm_dst, ymm_src1, ymm_src2); - vmovups(ptr[param3 + stride * i], ymm_dst); + vmovups(ptr[param3 + offset], ymm_dst); + offset += sizeof(float) * AVX_FLOAT_BLOCK; + } + int rest = num_ % AVX_FLOAT_BLOCK; + if (rest >= 4) { + vmovups(xmm_src1, ptr[param1 + offset]); + vmovups(xmm_src2, ptr[param2 + offset]); + vmulps(xmm_dst, xmm_src1, xmm_src2); + vmovups(ptr[param3 + offset], xmm_dst); + offset += sizeof(float) * 4; + rest -= 4; + } + if (rest >= 2) { + mov(tmp, qword[param1 + offset]); + vmovq(xmm_src1, tmp); + mov(tmp, qword[param2 + offset]); + vmovq(xmm_src2, tmp); + vmulps(xmm_dst, xmm_src1, xmm_src2); + vmovq(tmp, xmm_dst); + mov(ptr[param3 + offset], tmp); + offset += sizeof(float) * 2; + rest -= 2; } ret(); } diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index db1a0cd0958d7047b83d1dd234eaf9f0fb748fa6..c77252a326c61a96997b2f5474dee573952540db 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -43,17 +43,15 @@ class VMulJitCode : public JitCode { reg64_t param1{abi_param1}; reg64_t param2{abi_param2}; reg64_t param3{abi_param3}; + reg64_t tmp = rax; xmm_t xmm_src1 = xmm_t(0); - ymm_t ymm_src1 = ymm_t(0); - zmm_t zmm_src1 = zmm_t(0); xmm_t xmm_src2 = xmm_t(1); - ymm_t ymm_src2 = ymm_t(1); - zmm_t zmm_src2 = zmm_t(1); - xmm_t xmm_dst = xmm_t(2); + + ymm_t ymm_src1 = ymm_t(0); + ymm_t ymm_src2 = ymm_t(1); ymm_t ymm_dst = ymm_t(2); - zmm_t zmm_dst = zmm_t(2); }; } // namespace gen diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index cf0d6c60d19933690de62e04926f14e6a3c12e19..593209d42b5ba102738e5419f2cf77c5ebbcd0da 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -578,7 +578,7 @@ void vmul_mkl(const int n, const float* x, const float* y, float* z) { TEST(JitKernel, vmul) { namespace jit = paddle::operators::math::jitkernel; - for (int d : {7, 8, 15, 16, 30, 256, 512, 1000, 1024}) { + for (int d : {7, 8, 15, 16, 20, 30, 256, 512, 1000, 1024}) { std::vector x(d), y(d); std::vector zref(d), ztgt(d); RandomVec(d, x.data());