diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 7710525717e33497db77b9b7da8e846602dc8501..15f8bf7145cf9db24be2173b6eb8d134bd13ad89 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -75,25 +75,24 @@ namespace jit = platform::jit; DEFINE_WITH_DTYPE(ker_key, ker_class, float, f); \ DEFINE_WITH_DTYPE(ker_key, ker_class, double, d) -// do not include lt8, eq8, eq16 -#define FOR_EACH_COMMON_BLOCK(macro_, isa) \ - macro_(isa, kGT8LT16) macro_(isa, kGT16) - -#define FOR_EACH_ISA_COMMON_BLOCK(macro_) \ - FOR_EACH_COMMON_BLOCK(macro_, jit::avx512f) \ - FOR_EACH_COMMON_BLOCK(macro_, jit::avx2) \ - FOR_EACH_COMMON_BLOCK(macro_, jit::avx) \ - FOR_EACH_COMMON_BLOCK(macro_, jit::isa_any) - -#define FOR_EACH_ALL_BLOCK(macro_, isa) \ - macro_(isa, kLT8) macro_(isa, kEQ8) macro_(isa, kGT8LT16) macro_(isa, kEQ16) \ - macro_(isa, kGT16) - -#define FOR_EACH_ISA_ALL_BLOCK(macro_) \ - FOR_EACH_ALL_BLOCK(macro_, jit::avx512f) \ - FOR_EACH_ALL_BLOCK(macro_, jit::avx2) \ - FOR_EACH_ALL_BLOCK(macro_, jit::avx) \ - FOR_EACH_ALL_BLOCK(macro_, jit::isa_any) +#define FOR_EACH_ISA(macro_, block) \ + macro_(jit::avx512f, block); \ + macro_(jit::avx2, block); \ + macro_(jit::avx, block); \ + macro_(jit::isa_any, block) + +#define FOR_EACH_BLOCK(macro_, isa) \ + macro_(isa, kLT8); \ + macro_(isa, kEQ8); \ + macro_(isa, kGT8LT16); \ + macro_(isa, kEQ16); \ + macro_(isa, kGT16) + +#define FOR_EACH_ISA_BLOCK(macro_) \ + FOR_EACH_BLOCK(macro_, jit::avx512f); \ + FOR_EACH_BLOCK(macro_, jit::avx2); \ + FOR_EACH_BLOCK(macro_, jit::avx); \ + FOR_EACH_BLOCK(macro_, jit::isa_any) /* VMUL JitKernel */ template @@ -121,8 +120,8 @@ class VMulKernelImpl : public VMulKernel { platform::dynload::vdMul(n, x, y, z); \ } -FOR_EACH_ISA_COMMON_BLOCK(VMUL_MKL_FLOAT); -FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE); +FOR_EACH_ISA(VMUL_MKL_FLOAT, kGT16); +FOR_EACH_ISA_BLOCK(VMUL_MKL_DOUBLE); #endif #define VMUL_INTRI8_FLOAT(isa) \ @@ -178,8 +177,8 @@ class VAddKernelImpl : public VAddKernel { platform::dynload::vdAdd(n, x, y, z); \ } -FOR_EACH_ISA_COMMON_BLOCK(VADD_MKL_FLOAT); -FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE); +FOR_EACH_ISA(VADD_MKL_FLOAT, kGT16); +FOR_EACH_ISA_BLOCK(VADD_MKL_DOUBLE); #endif #define VADD_INTRI8_FLOAT(isa) \ @@ -210,10 +209,9 @@ VADD_INTRI8_FLOAT(jit::avx512f); REGISTER_BLAS_JITKERNEL(vmul, VMulKernel); REGISTER_BLAS_JITKERNEL(vadd, VAddKernel); -#undef FOR_EACH_ISA_ALL_BLOCK -#undef FOR_EACH_ALL_BLOCK -#undef FOR_EACH_ISA_COMMON_BLOCK -#undef FOR_EACH_COMMON_BLOCK +#undef FOR_EACH_ISA +#undef FOR_EACH_BLOCK +#undef FOR_EACH_ISA_BLOCK #undef REGISTER_BLAS_JITKERNEL #undef DEFINE_WITH_DTYPE #undef SEARCH_ISA_BLOCK diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index f57fd665a6d0e59a6d1fcf03fb7fe01cf1d0a94e..88437a050bf68583aef7900d176d0c1889d1aca1 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -79,12 +79,10 @@ TEST(JitKernel, vmul) { RandomVec(d, y.data()); const auto& ker = jit::KernelPool::Instance().template Get>(d); - const float* x_data = x.data(); const float* y_data = y.data(); float* ztgt_data = ztgt.data(); float* zref_data = zref.data(); - auto trefs = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { vmul_ref(d, x_data, y_data, zref_data); @@ -129,6 +127,85 @@ TEST(JitKernel, vmul) { } } +void vadd_ref(const int n, const float* x, const float* y, float* z) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] + y[i]; + } +} + +#if defined __AVX__ || defined __AVX2__ +void vadd_intri8(const int n, const float* x, const float* y, float* z) { + __m256 tmpx, tmpy; + tmpx = _mm256_loadu_ps(x); + tmpy = _mm256_loadu_ps(y); + tmpx = _mm256_add_ps(tmpx, tmpy); + _mm256_storeu_ps(z, tmpx); +} +#endif + +#ifdef PADDLE_WITH_MKLML +void vadd_mkl(const int n, const float* x, const float* y, float* z) { + paddle::platform::dynload::vsAdd(n, x, y, z); +} +#endif + +TEST(JitKernel, vadd) { + namespace jit = paddle::operators::math::jitkernel; + for (int d : {7, 8, 15, 16, 30, 256, 512}) { + std::vector x(d), y(d); + std::vector zref(d), ztgt(d); + RandomVec(d, x.data()); + RandomVec(d, y.data()); + const auto& ker = + jit::KernelPool::Instance().template Get>(d); + const float* x_data = x.data(); + const float* y_data = y.data(); + float* ztgt_data = ztgt.data(); + float* zref_data = zref.data(); + auto trefs = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vadd_ref(d, x_data, y_data, zref_data); + } + auto trefe = GetCurrentUS(); + +#ifdef PADDLE_WITH_MKLML + auto tmkls = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vadd_mkl(d, x_data, y_data, zref_data); + } + auto tmkle = GetCurrentUS(); +#endif + +#if defined __AVX__ || defined __AVX2__ + if (d == 8) { + auto si0 = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vadd_intri8(d, x_data, y_data, zref_data); + } + auto si1 = GetCurrentUS(); + VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat; + } +#endif + + auto ttgts = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + ker->Compute(d, x_data, y_data, ztgt_data); + } + auto ttgte = GetCurrentUS(); + + VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat +#ifdef PADDLE_WITH_MKLML + << " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, " +#else + << " us, " +#endif + << "tgt takes: " << (ttgte - ttgts) / repeat; + for (int i = 0; i < d; ++i) { + EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); + } + } +} + TEST(JitKernel, pool) { namespace jit = paddle::operators::math::jitkernel; const int frame_size = 4;