diff --git a/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc b/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc index ba04d030b94c0924311dcff5c6a34270a764f877..e0eb919bd896d73a557001982a436fc93f087a74 100644 --- a/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc @@ -18,12 +18,12 @@ namespace paddle { namespace inference { using namespace framework; // NOLINT +static std::vector result_data; struct DataRecord { std::vector>> link_step_data_all; std::vector lod; std::vector> rnn_link_data; - std::vector result_data; size_t num_samples; // total number of samples size_t batch_iter{0}; size_t batch_size{1}; @@ -57,6 +57,7 @@ struct DataRecord { std::ifstream file(path); std::string line; int num_lines = 0; + result_data.clear(); while (std::getline(file, line)) { num_lines++; std::vector data; @@ -135,13 +136,12 @@ TEST(Analyzer_rnn2, profile) { if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { // the first inference result - DataRecord data(FLAGS_infer_data, FLAGS_batch_size); PADDLE_ENFORCE_GT(outputs.size(), 0); size_t size = GetSize(outputs[0]); PADDLE_ENFORCE_GT(size, 0); float *result = static_cast(outputs[0].data.data()); for (size_t i = 0; i < size; i++) { - EXPECT_NEAR(result[i], data.result_data[i], 1e-3); + EXPECT_NEAR(result[i], result_data[i], 1e-3); } } } diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 7365bfeeb8edf09a8ad5e1cb2c61300e86bdf518..c7bdec354735773a15b4c99baf9f7798f2d92564 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -76,5 +76,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_library(jit_kernel SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc - DEPS cpu_info cblas activation_functions) + DEPS cpu_info cblas) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/jit_kernel_exp.cc b/paddle/fluid/operators/math/jit_kernel_exp.cc index b62e130c43743f542e2074868fc01598047d6b19..c4247580f491a7ca26259528ca74dd92e35785a9 100644 --- a/paddle/fluid/operators/math/jit_kernel_exp.cc +++ b/paddle/fluid/operators/math/jit_kernel_exp.cc @@ -27,13 +27,6 @@ limitations under the License. */ namespace paddle { namespace operators { namespace math { - -#ifdef __AVX__ -namespace detail { -__m256 Exp(__m256 a); -} // namespace detail -#endif - namespace jitkernel { namespace jit = platform::jit; @@ -69,37 +62,186 @@ FOR_EACH_ISA(MKL_FLOAT, kGT16); FOR_EACH_ISA_BLOCK(MKL_DOUBLE); #endif -#define INTRI8_FLOAT(isa) \ +namespace detail { + +#ifdef __AVX__ + +#define ALIGN32 __attribute__((aligned(32))) + +#define _PS256_CONST(Name, Val) \ + static const float _ps256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \ + Val, Val, Val, Val} + +#define _PI256_CONST(Name, Val) \ + static const int _pi256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \ + Val, Val, Val, Val} + +_PI256_CONST(0x7f, 0x7f); +_PS256_CONST(one, 1.f); +_PS256_CONST(0p5, 0.5f); +_PS256_CONST(exp_hi, 88.3762626647949f); +_PS256_CONST(exp_lo, -88.3762626647949f); +_PS256_CONST(cephes_LOG2EF, 1.44269504088896341); +_PS256_CONST(cephes_exp_C1, 0.693359375); +_PS256_CONST(cephes_exp_C2, -2.12194440e-4); +_PS256_CONST(cephes_exp_p0, 1.9875691500E-4); +_PS256_CONST(cephes_exp_p1, 1.3981999507E-3); +_PS256_CONST(cephes_exp_p2, 8.3334519073E-3); +_PS256_CONST(cephes_exp_p3, 4.1665795894E-2); +_PS256_CONST(cephes_exp_p4, 1.6666665459E-1); +_PS256_CONST(cephes_exp_p5, 5.0000001201E-1); + +typedef union imm_xmm_union { + __m256i imm; + __m128i xmm[2]; +} imm_xmm_union; + +#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \ + { \ + imm_xmm_union u ALIGN32; \ + u.imm = imm_; \ + xmm0_ = u.xmm[0]; \ + xmm1_ = u.xmm[1]; \ + } + +#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \ + { \ + imm_xmm_union u ALIGN32; \ + u.xmm[0] = xmm0_; \ + u.xmm[1] = xmm1_; \ + imm_ = u.imm; \ + } + +#define AVX2_BITOP_USING_SSE2(fn) \ + static inline __m256i avx2_mm256_##fn(__m256i x, int y) { \ + /* use SSE2 to perform the bitop AVX2 */ \ + __m128i x1, x2; \ + __m256i ret; \ + COPY_IMM_TO_XMM(x, x1, x2); \ + x1 = _mm_##fn(x1, y); \ + x2 = _mm_##fn(x2, y); \ + COPY_XMM_TO_IMM(x1, x2, ret); \ + return ret; \ + } + +#define AVX2_INTOP_USING_SSE2(fn) \ + static inline __m256i avx2_mm256_add_epi32(__m256i x, __m256i y) { \ + /* use SSE2 to perform the AVX2 integer operation */ \ + __m128i x1, x2; \ + __m128i y1, y2; \ + __m256i ret; \ + COPY_IMM_TO_XMM(x, x1, x2); \ + COPY_IMM_TO_XMM(y, y1, y2); \ + x1 = _mm_##fn(x1, y1); \ + x2 = _mm_##fn(x2, y2); \ + COPY_XMM_TO_IMM(x1, x2, ret); \ + return ret; \ + } + +AVX2_BITOP_USING_SSE2(slli_epi32); +AVX2_INTOP_USING_SSE2(add_epi32); + +#define AVXEXP_BASE \ + __m256 tmp = _mm256_setzero_ps(), fx; \ + __m256 one = *reinterpret_cast(_ps256_one); \ + __m256i imm0; \ + x = _mm256_min_ps(x, *reinterpret_cast(_ps256_exp_hi)); \ + x = _mm256_max_ps(x, *reinterpret_cast(_ps256_exp_lo)); \ + /* express exp(x) as exp(g + n*log(2)) */ \ + fx = _mm256_mul_ps(x, \ + *reinterpret_cast(_ps256_cephes_LOG2EF)); \ + fx = _mm256_add_ps(fx, *reinterpret_cast(_ps256_0p5)); \ + tmp = _mm256_floor_ps(fx); \ + /* if greater, substract 1 */ \ + __m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); \ + mask = _mm256_and_ps(mask, one); \ + fx = _mm256_sub_ps(tmp, mask); \ + tmp = _mm256_mul_ps(fx, \ + *reinterpret_cast(_ps256_cephes_exp_C1)); \ + __m256 z = _mm256_mul_ps( \ + fx, *reinterpret_cast(_ps256_cephes_exp_C2)); \ + x = _mm256_sub_ps(x, tmp); \ + x = _mm256_sub_ps(x, z); \ + z = _mm256_mul_ps(x, x); \ + __m256 y = *reinterpret_cast(_ps256_cephes_exp_p0); \ + y = _mm256_mul_ps(y, x); \ + y = _mm256_add_ps(y, \ + *reinterpret_cast(_ps256_cephes_exp_p1)); \ + y = _mm256_mul_ps(y, x); \ + y = _mm256_add_ps(y, \ + *reinterpret_cast(_ps256_cephes_exp_p2)); \ + y = _mm256_mul_ps(y, x); \ + y = _mm256_add_ps(y, \ + *reinterpret_cast(_ps256_cephes_exp_p3)); \ + y = _mm256_mul_ps(y, x); \ + y = _mm256_add_ps(y, \ + *reinterpret_cast(_ps256_cephes_exp_p4)); \ + y = _mm256_mul_ps(y, x); \ + y = _mm256_add_ps(y, \ + *reinterpret_cast(_ps256_cephes_exp_p5)); \ + y = _mm256_mul_ps(y, z); \ + y = _mm256_add_ps(y, x); \ + y = _mm256_add_ps(y, one); \ + /* build 2^n */ \ + imm0 = _mm256_cvttps_epi32(fx) + +__m256 ExpAVX(__m256 x) { + AVXEXP_BASE; + // two AVX2 instructions using SSE2 + imm0 = avx2_mm256_add_epi32(imm0, + *reinterpret_cast(_pi256_0x7f)); + imm0 = avx2_mm256_slli_epi32(imm0, 23); + __m256 pow2n = _mm256_castsi256_ps(imm0); + y = _mm256_mul_ps(y, pow2n); + return y; +} +#endif + +#ifdef __AVX2__ +__m256 ExpAVX2(__m256 x) { + AVXEXP_BASE; + // two AVX2 instructions + imm0 = _mm256_add_epi32(imm0, *reinterpret_cast(_pi256_0x7f)); + imm0 = _mm256_slli_epi32(imm0, 23); + __m256 pow2n = _mm256_castsi256_ps(imm0); + y = _mm256_mul_ps(y, pow2n); + return y; +} +#endif + +} // namespace detail + +#define INTRI8_FLOAT(isa, expisa) \ template <> \ void VExpKernelImpl::Compute(const float* x, float* y) \ const { \ __m256 tmp = _mm256_loadu_ps(x); \ - _mm256_storeu_ps(y, detail::Exp(tmp)); \ + _mm256_storeu_ps(y, expisa(tmp)); \ } -#define INTRI16_FLOAT(isa) \ +#define INTRI16_FLOAT(isa, expisa) \ template <> \ void VExpKernelImpl::Compute(const float* x, float* y) \ const { \ __m256 tmp0 = _mm256_loadu_ps(x); \ __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - tmp0 = detail::Exp(tmp0); \ - tmp1 = detail::Exp(tmp1); \ + tmp0 = expisa(tmp0); \ + tmp1 = expisa(tmp1); \ _mm256_storeu_ps(y, tmp0); \ _mm256_storeu_ps(y + 8, tmp1); \ } #ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI16_FLOAT(jit::avx); +INTRI8_FLOAT(jit::avx, detail::ExpAVX); +INTRI16_FLOAT(jit::avx, detail::ExpAVX); #endif #ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI16_FLOAT(jit::avx2); +INTRI8_FLOAT(jit::avx2, detail::ExpAVX2); +INTRI16_FLOAT(jit::avx2, detail::ExpAVX2); #endif #ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f); -INTRI16_FLOAT(jit::avx512f); +INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2); +INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2); #endif // TODO(TJ): eq16 test and complete avx512 @@ -135,26 +277,27 @@ class VSigmoidKernelImpl : public VSigmoidKernel { std::shared_ptr> vexp_; }; -#define INTRI_SIGMOID(tmp, min, max) \ +#define INTRI_SIGMOID(tmp, min, max, expisa) \ tmp = _mm256_max_ps(tmp, min); \ tmp = _mm256_min_ps(tmp, max); \ tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \ - tmp = detail::Exp(tmp); \ + tmp = expisa(tmp); \ tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp) -#define INTRI8_FLOAT(isa) \ +#define INTRI8_FLOAT(isa, expisa) \ template <> \ void VSigmoidKernelImpl::Compute(const float* x, float* y) \ const { \ + /* TODO(TJ): try to use static const*/ \ __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_SIGMOID(tmp, min, max); \ + INTRI_SIGMOID(tmp, min, max, expisa); \ _mm256_storeu_ps(y, tmp); \ } -#define INTRI16_FLOAT(isa) \ +#define INTRI16_FLOAT(isa, expisa) \ template <> \ void VSigmoidKernelImpl::Compute(const float* x, \ float* y) const { \ @@ -162,13 +305,13 @@ class VSigmoidKernelImpl : public VSigmoidKernel { __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 tmp0 = _mm256_loadu_ps(x); \ __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - INTRI_SIGMOID(tmp0, min, max); \ - INTRI_SIGMOID(tmp1, min, max); \ + INTRI_SIGMOID(tmp0, min, max, expisa); \ + INTRI_SIGMOID(tmp1, min, max, expisa); \ _mm256_storeu_ps(y, tmp0); \ _mm256_storeu_ps(y + 8, tmp1); \ } -#define INTRI_GT8LT16_FLOAT(isa) \ +#define INTRI_GT8LT16_FLOAT(isa, expisa) \ template <> \ VSigmoidKernelImpl::VSigmoidKernelImpl(int d) \ : VSigmoidKernel() { \ @@ -184,7 +327,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel { __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \ __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_SIGMOID(tmp, min, max); \ + INTRI_SIGMOID(tmp, min, max, expisa); \ _mm256_storeu_ps(y, tmp); \ const float min_ = SIGMOID_THRESHOLD_MIN; \ const float max_ = SIGMOID_THRESHOLD_MAX; \ @@ -198,7 +341,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel { } \ } -#define INTRI_GT16_FLOAT(isa) \ +#define INTRI_GT16_FLOAT(isa, expisa) \ template <> \ VSigmoidKernelImpl::VSigmoidKernelImpl(int d) \ : VSigmoidKernel() { \ @@ -215,7 +358,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel { __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \ for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ __m256 tmp = _mm256_loadu_ps(x + i); \ - INTRI_SIGMOID(tmp, min, max); \ + INTRI_SIGMOID(tmp, min, max, expisa); \ _mm256_storeu_ps(y + i, tmp); \ } \ const float min_ = SIGMOID_THRESHOLD_MIN; \ @@ -231,22 +374,20 @@ class VSigmoidKernelImpl : public VSigmoidKernel { } #ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI16_FLOAT(jit::avx); -INTRI_GT8LT16_FLOAT(jit::avx); -INTRI_GT16_FLOAT(jit::avx); +INTRI8_FLOAT(jit::avx, detail::ExpAVX); +INTRI16_FLOAT(jit::avx, detail::ExpAVX); +INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX); +INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX); #endif #ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI16_FLOAT(jit::avx2); -// INTRI_GT8LT16_FLOAT(jit::avx2); -// INTRI_GT16_FLOAT(jit::avx2); +INTRI8_FLOAT(jit::avx2, detail::ExpAVX2); +INTRI16_FLOAT(jit::avx2, detail::ExpAVX2); +// maybe use avx at gt8lt16 and gt16 #endif #ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f); -INTRI16_FLOAT(jit::avx512f); -// INTRI_GT8LT16_FLOAT(jit::avx512f); -// INTRI_GT16_FLOAT(jit::avx512f); +INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2); +INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2); +// maybe use avx2 at gt8lt16 and gt16 #endif #undef INTRI8_FLOAT @@ -280,36 +421,36 @@ class VTanhKernelImpl : public VTanhKernel { std::shared_ptr> vaddbias_; }; -#define INTRI_VTANH(tmp) \ +#define INTRI_VTANH(tmp, expisa) \ tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \ tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \ - tmp = detail::Exp(tmp); \ + tmp = expisa(tmp); \ tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \ tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \ tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f)) -#define INTRI8_FLOAT(isa) \ +#define INTRI8_FLOAT(isa, expisa) \ template <> \ void VTanhKernelImpl::Compute(const float* x, float* y) \ const { \ __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_VTANH(tmp); \ + INTRI_VTANH(tmp, expisa); \ _mm256_storeu_ps(y, tmp); \ } -#define INTRI16_FLOAT(isa) \ +#define INTRI16_FLOAT(isa, expisa) \ template <> \ void VTanhKernelImpl::Compute(const float* x, float* y) \ const { \ __m256 tmp0 = _mm256_loadu_ps(x); \ __m256 tmp1 = _mm256_loadu_ps(x + 8); \ - INTRI_VTANH(tmp0); \ - INTRI_VTANH(tmp1); \ + INTRI_VTANH(tmp0, expisa); \ + INTRI_VTANH(tmp1, expisa); \ _mm256_storeu_ps(y, tmp0); \ _mm256_storeu_ps(y + 8, tmp1); \ } -#define INTRI_GT8LT16_FLOAT(isa) \ +#define INTRI_GT8LT16_FLOAT(isa, expisa) \ template <> \ VTanhKernelImpl::VTanhKernelImpl(int d) \ : VTanhKernel() { \ @@ -327,7 +468,7 @@ class VTanhKernelImpl : public VTanhKernel { void VTanhKernelImpl::Compute(const float* x, \ float* y) const { \ __m256 tmp = _mm256_loadu_ps(x); \ - INTRI_VTANH(tmp); \ + INTRI_VTANH(tmp, expisa); \ _mm256_storeu_ps(y, tmp); \ x += AVX_FLOAT_BLOCK; \ y += AVX_FLOAT_BLOCK; \ @@ -337,7 +478,7 @@ class VTanhKernelImpl : public VTanhKernel { vaddbias_->Compute(-1.f, y, y); \ } -#define INTRI_GT16_FLOAT(isa) \ +#define INTRI_GT16_FLOAT(isa, expisa) \ template <> \ VTanhKernelImpl::VTanhKernelImpl(int d) \ : VTanhKernel() { \ @@ -356,7 +497,7 @@ class VTanhKernelImpl : public VTanhKernel { const { \ for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ __m256 tmp = _mm256_loadu_ps(x + i); \ - INTRI_VTANH(tmp); \ + INTRI_VTANH(tmp, expisa); \ _mm256_storeu_ps(y + i, tmp); \ } \ x += this->end_; \ @@ -368,19 +509,19 @@ class VTanhKernelImpl : public VTanhKernel { } #ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI16_FLOAT(jit::avx); -INTRI_GT8LT16_FLOAT(jit::avx); -INTRI_GT16_FLOAT(jit::avx); +INTRI8_FLOAT(jit::avx, detail::ExpAVX); +INTRI16_FLOAT(jit::avx, detail::ExpAVX); +INTRI_GT8LT16_FLOAT(jit::avx, detail::ExpAVX); +INTRI_GT16_FLOAT(jit::avx, detail::ExpAVX); #endif #ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI16_FLOAT(jit::avx2); +INTRI8_FLOAT(jit::avx2, detail::ExpAVX2); +INTRI16_FLOAT(jit::avx2, detail::ExpAVX2); // maybe use avx at gt8lt16 and gt16 #endif #ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f); -INTRI16_FLOAT(jit::avx512f); +INTRI8_FLOAT(jit::avx512f, detail::ExpAVX2); +INTRI16_FLOAT(jit::avx512f, detail::ExpAVX2); // maybe use avx at gt8lt16 and gt16 #endif diff --git a/paddle/fluid/operators/math/jit_kernel_lstm.cc b/paddle/fluid/operators/math/jit_kernel_lstm.cc index 42a2b96fd945c516f8c26ca51ecb452345a9a86f..26bd26e2e171feea569fbd646a9caf03bebbaa46 100644 --- a/paddle/fluid/operators/math/jit_kernel_lstm.cc +++ b/paddle/fluid/operators/math/jit_kernel_lstm.cc @@ -25,13 +25,18 @@ limitations under the License. */ namespace paddle { namespace operators { namespace math { -#ifdef __AVX__ +namespace jitkernel { namespace detail { -__m256 Exp(__m256 a); -} // namespace detail +#ifdef __AVX__ +__m256 ExpAVX(__m256 x); #endif -namespace jitkernel { +#ifdef __AVX2__ +__m256 ExpAVX2(__m256 x); +#endif + +} // namespace detail + namespace jit = platform::jit; #ifdef __AVX__ @@ -43,43 +48,72 @@ class AVXAct { virtual __m256 Compute(__m256 x) const = 0; }; -template +template class AVXActImpl : public AVXAct { public: __m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); } }; -template <> -__m256 AVXActImpl::Compute(__m256 x) const { - __m256 ones = _mm256_set1_ps(1.0f); - x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); - x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); - x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); - x = detail::Exp(x); - x = _mm256_add_ps(ones, x); - return _mm256_div_ps(ones, x); -} +#define AVX_SIGMOID(isa, expisa) \ + template <> \ + __m256 AVXActImpl::Compute(__m256 x) const { \ + __m256 ones = _mm256_set1_ps(1.0f); \ + x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); \ + x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); \ + x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); \ + x = expisa(x); \ + x = _mm256_add_ps(ones, x); \ + return _mm256_div_ps(ones, x); \ + } -template <> -__m256 AVXActImpl::Compute(__m256 x) const { - __m256 ones = _mm256_set1_ps(1.0f); - x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); - x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); - x = detail::Exp(x); - x = _mm256_add_ps(ones, x); - x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); - return _mm256_sub_ps(x, ones); -} +#define AVX_TANH(isa, expisa) \ + template <> \ + __m256 AVXActImpl::Compute(__m256 x) const { \ + __m256 ones = _mm256_set1_ps(1.0f); \ + x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); \ + x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); \ + x = expisa(x); \ + x = _mm256_add_ps(ones, x); \ + x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); \ + return _mm256_sub_ps(x, ones); \ + } -template <> -__m256 AVXActImpl::Compute(__m256 x) const { - return _mm256_max_ps(x, _mm256_setzero_ps()); -} +#define AVX_RELU(isa) \ + template <> \ + __m256 AVXActImpl::Compute(__m256 x) const { \ + return _mm256_max_ps(x, _mm256_setzero_ps()); \ + } + +#define AVX_IDENTITY(isa) \ + template <> \ + __m256 AVXActImpl::Compute(__m256 x) const { \ + return x; \ + } + +#define FOR_EACH_AVX_ISA(macro_) \ + macro_(jit::avx); \ + macro_(jit::avx2); \ + macro_(jit::avx512f) + +FOR_EACH_AVX_ISA(AVX_RELU); +FOR_EACH_AVX_ISA(AVX_IDENTITY); + +AVX_SIGMOID(jit::avx, detail::ExpAVX); +AVX_TANH(jit::avx, detail::ExpAVX); + +#ifdef __AVX2__ +AVX_SIGMOID(jit::avx2, detail::ExpAVX2); +AVX_SIGMOID(jit::avx512f, detail::ExpAVX2); +AVX_TANH(jit::avx2, detail::ExpAVX2); +AVX_TANH(jit::avx512f, detail::ExpAVX2); +#endif + +#undef FOR_EACH_AVX_ISA +#undef AVX_IDENTITY +#undef AVX_RELU +#undef AVX_TANH +#undef AVX_SIGMOID -template <> -__m256 AVXActImpl::Compute(__m256 x) const { - return x; -} #endif template @@ -119,23 +153,6 @@ class LSTMKernelImpl : public LSTMKernel { act_cell_d_ = GetActKernel(act_cell, d); vmul_d_ = KernelPool::Instance().template Get>(d); vadd_d_ = KernelPool::Instance().template Get>(d); -#ifdef __AVX__ - auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr { - if (type == "sigmoid") { - return std::unique_ptr(new AVXActImpl()); - } else if (type == "relu") { - return std::unique_ptr(new AVXActImpl()); - } else if (type == "tanh") { - return std::unique_ptr(new AVXActImpl()); - } else if (type == "identity" || type == "") { - return std::unique_ptr(new AVXActImpl()); - } - PADDLE_THROW("Not support type: %s", type); - }; - avx_act_gate_ = GetAVXAct(act_gate); - avx_act_cand_ = GetAVXAct(act_cand); - avx_act_cell_ = GetAVXAct(act_cell); -#endif } void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data, @@ -175,26 +192,61 @@ class LSTMKernelImpl : public LSTMKernel { #endif }; -#define INTRI8_FLOAT(isa) \ - template <> \ - void LSTMKernelImpl::ComputeCtHt( \ - float* gates, const float* ct_1, float* ct, float* ht, \ - const float* wp_data, float* checked) const { \ - /* gates: W_ch, W_ih, W_fh, W_oh */ \ - __m256 c, i, f, o; \ - c = _mm256_loadu_ps(gates); \ - i = _mm256_loadu_ps(gates + 8); \ - f = _mm256_loadu_ps(gates + 16); \ - o = _mm256_loadu_ps(gates + 24); \ - /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ - c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ - i = _mm256_loadu_ps(ct_1); \ - f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ - f = _mm256_add_ps(c, f); \ - _mm256_storeu_ps(ct, f); \ - /* H_t = act_cell(C_t) * ogated */ \ - o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ - _mm256_storeu_ps(ht, o); \ +#define INTRI8_FLOAT(isa) \ + template <> \ + LSTMKernelImpl::LSTMKernelImpl( \ + const std::string& act_gate, const std::string& act_cand, \ + const std::string& act_cell, int d) \ + : LSTMKernel() { \ + auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr { \ + if (type == "sigmoid") { \ + return std::unique_ptr(new AVXActImpl()); \ + } else if (type == "relu") { \ + return std::unique_ptr(new AVXActImpl()); \ + } else if (type == "tanh") { \ + return std::unique_ptr(new AVXActImpl()); \ + } else if (type == "identity" || type == "") { \ + return std::unique_ptr(new AVXActImpl()); \ + } \ + PADDLE_THROW("Not support type: %s", type); \ + }; \ + avx_act_gate_ = GetAVXAct(act_gate); \ + avx_act_cand_ = GetAVXAct(act_cand); \ + avx_act_cell_ = GetAVXAct(act_cell); \ + } \ + template <> \ + void LSTMKernelImpl::ComputeCtHt( \ + float* gates, const float* ct_1, float* ct, float* ht, \ + const float* wp_data, float* checked) const { \ + /* gates: W_ch, W_ih, W_fh, W_oh */ \ + __m256 c, i, f, o; \ + c = _mm256_loadu_ps(gates); \ + i = _mm256_loadu_ps(gates + 8); \ + f = _mm256_loadu_ps(gates + 16); \ + o = _mm256_loadu_ps(gates + 24); \ + /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ + c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ + i = _mm256_loadu_ps(ct_1); \ + f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ + f = _mm256_add_ps(c, f); \ + _mm256_storeu_ps(ct, f); \ + /* H_t = act_cell(C_t) * ogated */ \ + o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ + _mm256_storeu_ps(ht, o); \ + } \ + template <> \ + void LSTMKernelImpl::ComputeC1H1( \ + float* gates, float* ct, float* ht, const float* wp_data) const { \ + __m256 c, i, o; \ + c = _mm256_loadu_ps(gates); \ + i = _mm256_loadu_ps(gates + 8); \ + o = _mm256_loadu_ps(gates + 24); \ + /* C_t = igated * cgated*/ \ + c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \ + _mm256_storeu_ps(ct, c); \ + /* H_t = act_cell(C_t) * ogated */ \ + o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \ + _mm256_storeu_ps(ht, o); \ } // TODO(TJ): optimize keq16