diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 6a059968b79189458349e466079cc7a663a8e5ff..0aed253c80fc28560716cbcfa70f74ef9c84f9b6 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -125,10 +125,8 @@ inline void vec_scal(const int n, const float a, } template <> -inline void vec_scal(const int n, - const float a, - const float* x, - float* y) { +inline void vec_scal(const int n, const float a, + const float* x, float* y) { // TODO(TJ): enable me vec_scal(n, a, x, y); } @@ -181,10 +179,10 @@ inline void vec_bias_sub(const int n, const float a, } template <> -inline void vec_bias_sub(const int n, - const float a, - const float* x, - float* y) { +inline void vec_bias_sub(const int n, + const float a, + const float* x, + float* y) { // TODO(TJ): enable me vec_bias_sub(n, a, x, y); } @@ -242,7 +240,7 @@ inline void vec_cross(const int n, const float* x, } template <> -inline void vec_cross( +inline void vec_cross( const int n, const float* x, const float* y, const float* z, float* out) { // TODO(TJ): enable me vec_cross(n, x, y, z, out); @@ -296,10 +294,10 @@ inline void vec_add_bias(const int n, const float a, } template <> -inline void vec_add_bias(const int n, - const float a, - const float* x, - float* y) { +inline void vec_add_bias(const int n, + const float a, + const float* x, + float* y) { // TODO(TJ): enable me vec_add_bias(n, a, x, y); } @@ -390,9 +388,9 @@ inline void vec_sigmoid(const int n, const float* x, } template <> -inline void vec_sigmoid(const int n, - const float* x, - float* y) { +inline void vec_sigmoid(const int n, + const float* x, + float* y) { // TODO(TJ): enable me vec_sigmoid(n, x, y); } @@ -454,9 +452,8 @@ inline void vec_relu(const int n, const float* x, } template <> -inline void vec_relu(const int n, - const float* x, - float* y) { +inline void vec_relu(const int n, const float* x, + float* y) { // TODO(TJ): enable me vec_relu(n, x, y); } diff --git a/paddle/fluid/operators/math/cpu_vec_test.cc b/paddle/fluid/operators/math/cpu_vec_test.cc index 3ce66f49ed8354c49e8af26ca6eb48fef654a40b..cd40f1b2f984126663a5711efac24fdf6d680b32 100644 --- a/paddle/fluid/operators/math/cpu_vec_test.cc +++ b/paddle/fluid/operators/math/cpu_vec_test.cc @@ -110,7 +110,7 @@ TEST(CpuVecTest, sigmoid) { TestAndBench(sz, vec_sigmoid, ref_sigmoid); TestAndBench(sz, vec_sigmoid, ref_sigmoid); TestAndBench(sz, vec_sigmoid, ref_sigmoid); - TestAndBench(sz, vec_sigmoid, + TestAndBench(sz, vec_sigmoid, ref_sigmoid); } TestAndBench(30, vec_sigmoid, ref_sigmoid); @@ -123,8 +123,7 @@ TEST(CpuVecTest, tanh) { TestAndBench(sz, vec_tanh, ref_tanh); TestAndBench(sz, vec_tanh, ref_tanh); TestAndBench(sz, vec_tanh, ref_tanh); - TestAndBench(sz, vec_tanh, - ref_tanh); + TestAndBench(sz, vec_tanh, ref_tanh); } TestAndBench(30, vec_tanh, ref_tanh); } @@ -136,8 +135,7 @@ TEST(CpuVecTest, relu) { TestAndBench(sz, vec_relu, ref_relu); TestAndBench(sz, vec_relu, ref_relu); TestAndBench(sz, vec_relu, ref_relu); - TestAndBench(sz, vec_relu, - ref_relu); + TestAndBench(sz, vec_relu, ref_relu); } TestAndBench(30, vec_relu, ref_relu); } @@ -170,7 +168,7 @@ TEST(CpuVecTest, inplace_sigmoid) { TestInplace(sz, vec_sigmoid, ref_sigmoid); TestInplace(sz, vec_sigmoid, ref_sigmoid); TestInplace(sz, vec_sigmoid, ref_sigmoid); - TestInplace(sz, vec_sigmoid, + TestInplace(sz, vec_sigmoid, ref_sigmoid); } TestInplace(30, vec_sigmoid, ref_sigmoid); @@ -183,8 +181,7 @@ TEST(CpuVecTest, inplace_tanh) { TestInplace(sz, vec_tanh, ref_tanh); TestInplace(sz, vec_tanh, ref_tanh); TestInplace(sz, vec_tanh, ref_tanh); - TestInplace(sz, vec_tanh, - ref_tanh); + TestInplace(sz, vec_tanh, ref_tanh); } TestInplace(30, vec_tanh, ref_tanh); } @@ -196,8 +193,7 @@ TEST(CpuVecTest, inplace_relu) { TestInplace(sz, vec_relu, ref_relu); TestInplace(sz, vec_relu, ref_relu); TestInplace(sz, vec_relu, ref_relu); - TestInplace(sz, vec_relu, - ref_relu); + TestInplace(sz, vec_relu, ref_relu); } TestInplace(30, vec_relu, ref_relu); } diff --git a/paddle/fluid/operators/math/jit_kernel.cc b/paddle/fluid/operators/math/jit_kernel.cc index 81b56ef2e8aa083da270171be462cc5a7ba73507..71b1ffc6670c15859b81e590fdc46f2e41c567f6 100644 --- a/paddle/fluid/operators/math/jit_kernel.cc +++ b/paddle/fluid/operators/math/jit_kernel.cc @@ -36,35 +36,38 @@ KernelPool& KernelPool::Instance() { static KernelPool g_jit_kernels; return g_jit_kernels; } -#define SEARCH_BLOCK(src, t, isa) \ - if (d < AVX_FLOAT_BLOCK) { \ - Compute = src; \ - } else if (d == AVX_FLOAT_BLOCK) { \ - Compute = src; \ - } else if (d == AVX512_FLOAT_BLOCK) { \ - Compute = src; \ - } else { \ - Compute = src; \ +#define SEARCH_BLOCK(src, t, isa) \ + if (d < AVX_FLOAT_BLOCK) { \ + Compute = src; \ + } else if (d == AVX_FLOAT_BLOCK) { \ + Compute = src; \ + } else if (d > AVX_FLOAT_BLOCK && d < AVX512_FLOAT_BLOCK) { \ + Compute = src; \ + } else if (d == AVX512_FLOAT_BLOCK) { \ + Compute = src; \ + } else { \ + Compute = src; \ } -#define SEARCH_ISA_BLOCK(src, t) \ - if (jit::MayIUse(jit::avx512_common)) { \ - SEARCH_BLOCK(src, t, jit::avx512_common); \ - } else if (jit::MayIUse(jit::avx2)) { \ - SEARCH_BLOCK(src, t, jit::avx2); \ - } else if (jit::MayIUse(jit::avx)) { \ - SEARCH_BLOCK(src, t, jit::avx); \ - } else { \ - SEARCH_BLOCK(src, t, jit::isa_any); \ +#define SEARCH_ISA_BLOCK(src, t) \ + if (jit::MayIUse(jit::avx512f)) { \ + SEARCH_BLOCK(src, t, jit::avx512f); \ + } else if (jit::MayIUse(jit::avx2)) { \ + SEARCH_BLOCK(src, t, jit::avx2); \ + } else if (jit::MayIUse(jit::avx)) { \ + SEARCH_BLOCK(src, t, jit::avx); \ + } else { \ + SEARCH_BLOCK(src, t, jit::isa_any); \ } -#define FOR_EACH_BLOCK(macro_, isa) \ - macro_(isa, kLT8) macro_(isa, kEQ8) macro_(isa, kEQ16) macro_(isa, kGT16) +// do not include lt8, eq8, eq16 +#define FOR_EACH_COMMON_BLOCK(macro_, isa) \ + macro_(isa, kGT8LT16) macro_(isa, kGT16) -#define FOR_EACH_ISA_BLOCK(macro_) \ - FOR_EACH_BLOCK(macro_, jit::avx512_common) \ - FOR_EACH_BLOCK(macro_, jit::avx2) \ - FOR_EACH_BLOCK(macro_, jit::avx) \ +#define FOR_EACH_ISA_COMMON_BLOCK(macro_) \ + FOR_EACH_BLOCK(macro_, jit::avx512f) \ + FOR_EACH_BLOCK(macro_, jit::avx2) \ + FOR_EACH_BLOCK(macro_, jit::avx) \ FOR_EACH_BLOCK(macro_, jit::any) #define VMUL_ANY \ @@ -78,24 +81,56 @@ static void VMulCompute(const int n, const T* x, const T* y, T* z) { } #ifdef PADDLE_USE_MKLML -#define DEFINE_VMUL_COMPUTE_FLOAT(isa, block) \ - template <> \ - static void VMulCompute(const int n, const float* x, \ - const float* y, float* z) { \ - platform::dynload::vsMul(n, x, y, z); \ +#define DEFINE_VMUL_COMPUTE_FLOAT(isa, block) \ + template <> \ + void VMulCompute(const int n, const float* x, \ + const float* y, float* z) { \ + platform::dynload::vsMul(n, x, y, z); \ } -#define DEFINE_VMUL_COMPUTE_DOUBLE(isa, block) \ - template <> \ - static void VMulCompute(const int n, const double* x, \ - const double* y, float* z) { \ - platform::dynload::vdMul(n, x, y, z); \ +#define DEFINE_VMUL_COMPUTE_DOUBLE(isa, block) \ + template <> \ + void VMulCompute(const int n, const double* x, \ + const double* y, float* z) { \ + platform::dynload::vdMul(n, x, y, z); \ } -FOR_EACH_ISA_BLOCK(DEFINE_VMUL_COMPUTE_FLOAT) -FOR_EACH_ISA_BLOCK(DEFINE_VMUL_COMPUTE_DOUBLE) -// TODO(TJ): add EQ8 +FOR_EACH_ISA_COMMON_BLOCK(DEFINE_VMUL_COMPUTE_FLOAT) +FOR_EACH_ISA_COMMON_BLOCK(DEFINE_VMUL_COMPUTE_DOUBLE) +DEFINE_VMUL_COMPUTE_FLOAT(jit::avx, kLT8) +DEFINE_VMUL_COMPUTE_FLOAT(jit::avx, kEQ16) +#endif + +// mkl > avx > for, ">" means better +#ifdef PADDLE_USE_MKLML +DEFINE_VMUL_COMPUTE_FLOAT(jit::avx, kEQ8) +#elif defined __AVX__ +template <> +void VMulCompute(const int n, const float* x, + const float* y, float* z) { + __m256 tmpx, tmpy; + tmpx = _mm256_loadu_ps(x); + tmpy = _mm256_loadu_ps(y); + tmpx = _mm256_mul_ps(tmpx, tmpy); + _mm256_storeu_ps(z, tmpx); +} +#endif + +// avx2 > mkl > for +#ifdef __AVX2__ +template <> +void VMulCompute(const int n, const float* x, + const float* y, float* z) { + __m256 tmpx, tmpy; + tmpx = _mm256_loadu_ps(x); + tmpy = _mm256_loadu_ps(y); + tmpx = _mm256_mul_ps(tmpx, tmpy); + _mm256_storeu_ps(z, tmpx); +} +#elif defined PADDLE_USE_MKLML +DEFINE_VMUL_COMPUTE_FLOAT(jit::avx2, kEQ8) #endif +// TODO(TJ): test and complete avx512 #undef DEFINE_VMUL_COMPUTE_FLOAT #undef DEFINE_VMUL_COMPUTE_DOUBLE @@ -142,8 +177,8 @@ LSTMKernel::LSTMKernel(int d, const std::string& act_gate_str, : Kernel(), d_(d) { d2_ = d * 2; d3_ = d * 3; - if (platform::jit::MayIUse(platform::jit::avx512_common)) { - math::VecActivations act_functor; + if (platform::jit::MayIUse(platform::jit::avx512f)) { + math::VecActivations act_functor; act_gate_ = act_functor(act_gate_str); act_cell_ = act_functor(act_cell_str); act_cand_ = act_functor(act_cand_str); diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index b65653498318625d8abb17e9f846aaa2d56ac4eb..6005ea76f415a16b125ba76d5cbfebc787e67fe3 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -36,7 +36,7 @@ namespace jitkernel { #define AVX512_FLOAT_BLOCK 16 #define AVX512_DOUBLE_BLOCK 8 -typedef enum { kLT8, kEQ8, kEQ16, kGT16 } jit_block; +typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block; class Kernel { public: diff --git a/paddle/fluid/platform/cpu_info.cc b/paddle/fluid/platform/cpu_info.cc index 2880c09263f10e9c624e11b77188171f48d9db28..b5f472d20f40fa182a4aa55ff384b0954e4ba9e3 100644 --- a/paddle/fluid/platform/cpu_info.cc +++ b/paddle/fluid/platform/cpu_info.cc @@ -128,7 +128,7 @@ bool MayIUse(const cpu_isa_t cpu_isa) { return cpu.has(Cpu::tAVX); case avx2: return cpu.has(Cpu::tAVX2); - case avx512_common: + case avx512f: return cpu.has(Cpu::tAVX512F); case avx512_core: return true && cpu.has(Cpu::tAVX512F) && cpu.has(Cpu::tAVX512BW) && diff --git a/paddle/fluid/platform/cpu_info.h b/paddle/fluid/platform/cpu_info.h index 30c8fbcfce92a8b06a175ddf198cde572f72b2a4..6810a1651a14cdb2080af846b21cad242b70bf35 100644 --- a/paddle/fluid/platform/cpu_info.h +++ b/paddle/fluid/platform/cpu_info.h @@ -43,7 +43,7 @@ typedef enum { sse42, avx, avx2, - avx512_common, + avx512f, avx512_core, avx512_core_vnni, avx512_mic, diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 4c99f4be321160caf0ee2f89a655bdfb933408e3..ab91ca5345047f3053eb8771e6a265d2a3011f85 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -116,7 +116,7 @@ void InitDevices(bool init_p2p, const std::vector devices) { platform::SetNumThreads(FLAGS_paddle_num_threads); #endif - if (platform::jit::MayIUse(platform::jit::avx512_common)) { + if (platform::jit::MayIUse(platform::jit::avx512f)) { #ifndef __AVX512F__ LOG(WARNING) << "AVX512F is available, Please re-compile on local machine"; #endif