diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 00213841c3334f099f75d2db32487045ac390182..15889850c6ea6baec49e55e59af7e3e0662a7ab6 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/fluid/operators/math/jit_kernel.h" #include - #ifdef PADDLE_WITH_MKLML #include "paddle/fluid/platform/dynload/mklml.h" #endif @@ -62,7 +61,7 @@ namespace jit = platform::jit; FOR_EACH_COMMON_BLOCK(macro_, jit::avx512f) \ FOR_EACH_COMMON_BLOCK(macro_, jit::avx2) \ FOR_EACH_COMMON_BLOCK(macro_, jit::avx) \ - FOR_EACH_COMMON_BLOCK(macro_, jit::any) + FOR_EACH_COMMON_BLOCK(macro_, jit::isa_any) #define FOR_EACH_ALL_BLOCK(macro_, isa) \ macro_(isa, kLT8) macro_(isa, kEQ8) macro_(isa, kGT8LT16) macro_(isa, kEQ16) \ @@ -72,7 +71,7 @@ namespace jit = platform::jit; FOR_EACH_ALL_BLOCK(macro_, jit::avx512f) \ FOR_EACH_ALL_BLOCK(macro_, jit::avx2) \ FOR_EACH_ALL_BLOCK(macro_, jit::avx) \ - FOR_EACH_ALL_BLOCK(macro_, jit::any) + FOR_EACH_ALL_BLOCK(macro_, jit::isa_any) #define BIND_KERNEL_WITH_DTYPE(ker_class, ker_func, ker_dtype) \ template <> \ @@ -92,7 +91,7 @@ static void VMulCompute(const int n, const T* x, const T* y, T* z) { } } -#ifdef PADDLE_USE_MKLML +#ifdef PADDLE_WITH_MKLML #define VMUL_MKL_FLOAT(isa, block) \ template <> \ void VMulCompute(const int n, const float* x, \ @@ -103,7 +102,7 @@ static void VMulCompute(const int n, const T* x, const T* y, T* z) { #define VMUL_MKL_DOUBLE(isa, block) \ template <> \ void VMulCompute(const int n, const double* x, \ - const double* y, float* z) { \ + const double* y, double* z) { \ platform::dynload::vdMul(n, x, y, z); \ } @@ -112,7 +111,7 @@ FOR_EACH_ISA_ALL_BLOCK(VMUL_MKL_DOUBLE) #endif /// lt8 -#ifdef PADDLE_USE_MKLML +#ifdef PADDLE_WITH_MKLML VMUL_MKL_FLOAT(jit::avx2, kLT8) VMUL_MKL_FLOAT(jit::avx512f, kLT8) #endif @@ -130,21 +129,21 @@ VMUL_MKL_FLOAT(jit::avx512f, kLT8) } // mkl > avx > for, ">" means better -#ifdef PADDLE_USE_MKLML -VMUL_MKL_FLOAT(jit::avx, kEQ8) +#ifdef PADDLE_WITH_MKLML +VMUL_MKL_FLOAT(jit::avx, kEQ8); #elif defined __AVX__ -VMUL_INTRI8_FLOAT(jit::avx) +VMUL_INTRI8_FLOAT(jit::avx); #endif // avx2 > mkl > for #ifdef __AVX2__ VMUL_INTRI8_FLOAT(jit::avx2) -#elif defined PADDLE_USE_MKLML +#elif defined PADDLE_WITH_MKLML VMUL_MKL_FLOAT(jit::avx2, kEQ8) #endif // TODO(TJ): test and complete avx512 /// eq16 -#ifdef PADDLE_USE_MKLML +#ifdef PADDLE_WITH_MKLML // TODO(TJ): test and complete me VMUL_MKL_FLOAT(jit::avx, kEQ16) VMUL_MKL_FLOAT(jit::avx2, kEQ16) @@ -163,7 +162,7 @@ static void VAddCompute(const int n, const T* x, const T* y, T* z) { } } -#ifdef PADDLE_USE_MKLML +#ifdef PADDLE_WITH_MKLML #define VADD_MKL_FLOAT(isa, block) \ template <> \ void VAddCompute(const int n, const float* x, \ @@ -174,7 +173,7 @@ static void VAddCompute(const int n, const T* x, const T* y, T* z) { #define VADD_MKL_DOUBLE(isa, block) \ template <> \ void VAddCompute(const int n, const double* x, \ - const double* y, float* z) { \ + const double* y, double* z) { \ platform::dynload::vdAdd(n, x, y, z); \ } @@ -183,7 +182,7 @@ FOR_EACH_ISA_ALL_BLOCK(VADD_MKL_DOUBLE) #endif /// lt8 -#ifdef PADDLE_USE_MKLML +#ifdef PADDLE_WITH_MKLML VADD_MKL_FLOAT(jit::avx, kLT8) VADD_MKL_FLOAT(jit::avx2, kLT8) VADD_MKL_FLOAT(jit::avx512f, kLT8) @@ -210,13 +209,13 @@ VADD_INTRI8_FLOAT(jit::avx) // avx2 > mkl > for #ifdef __AVX2__ VADD_INTRI8_FLOAT(jit::avx2) -#elif defined PADDLE_USE_MKLML +#elif defined PADDLE_WITH_MKLML VADD_MKL_FLOAT(jit::avx2, kEQ8) #endif // TODO(TJ): test and complete avx512 /// eq16 -#ifdef PADDLE_USE_MKLML +#ifdef PADDLE_WITH_MKLML // TODO(TJ): test and complete me VADD_MKL_FLOAT(jit::avx, kEQ16) VADD_MKL_FLOAT(jit::avx2, kEQ16) diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index d9c8bb6d430c874133cc75efa71bb779d9e11f95..0e2ea06f764fd09dd9808b955246655ff4613fb4 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -20,6 +20,14 @@ limitations under the License. */ #include "glog/logging.h" #include "gtest/gtest.h" +#ifdef PADDLE_WITH_MKLML +#include "paddle/fluid/platform/dynload/mklml.h" +#endif + +#ifdef __AVX__ +#include +#endif + inline double GetCurrentUS() { struct timeval time; gettimeofday(&time, NULL); @@ -38,17 +46,26 @@ void RandomVec(const int n, T* a) { } } -constexpr int repeat = 10000; +constexpr int repeat = 20000; -TEST(JitKernel, vmul) { - namespace jit = paddle::operators::math::jitkernel; +#if defined __AVX__ || defined __AVX2__ +void vmul_intri(const int n, const float* x, const float* y, float* z) { + __m256 tmpx, tmpy; + tmpx = _mm256_loadu_ps(x); + tmpy = _mm256_loadu_ps(y); + tmpx = _mm256_mul_ps(tmpx, tmpy); + _mm256_storeu_ps(z, tmpx); +} +#endif - auto ref = [](const int n, const float* x, const float* y, float* z) { - for (int i = 0; i < n; ++i) { - z[i] = x[i] * y[i]; - } - }; +void vmul_ref(const int n, const float* x, const float* y, float* z) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] * y[i]; + } +} +TEST(JitKernel, vmul) { + namespace jit = paddle::operators::math::jitkernel; for (int d : {7, 8, 15, 16, 30, 256}) { std::vector x(d), y(d); std::vector zref(d), ztgt(d); @@ -61,18 +78,42 @@ TEST(JitKernel, vmul) { const float* y_data = y.data(); float* ztgt_data = ztgt.data(); float* zref_data = zref.data(); + +#ifdef PADDLE_WITH_MKLML + auto s0 = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + paddle::platform::dynload::vsMul(d, x_data, y_data, zref_data); + } +#endif + auto st = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { ker->Compute(d, x_data, y_data, ztgt_data); } auto mt = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ref(d, x_data, y_data, zref_data); + vmul_ref(d, x_data, y_data, zref_data); } auto et = GetCurrentUS(); +#if defined __AVX__ || defined __AVX2__ + if (d == 8) { + auto si0 = GetCurrentUS(); + for (int i = 0; i < repeat; ++i) { + vmul_intri(d, x_data, y_data, zref_data); + } + auto si1 = GetCurrentUS(); + VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat; + } +#endif + VLOG(3) << "Vec size " << d << ": refer takes: " << (et - mt) / repeat - << " us, tgt takes: " << (mt - st) / repeat; + << " us, tgt takes: " << (mt - st) / repeat +#ifdef PADDLE_WITH_MKLML + << " us, mkl takes: " << (st - s0) / repeat << " us"; +#else + << " us"; +#endif for (int i = 0; i < d; ++i) { EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3); }