From 2a00969165ae420e33c315ca725cd3e96a4c86ed Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 9 Oct 2018 00:21:30 +0800 Subject: [PATCH] optimize lstm jitkernel keq8 test=develop --- paddle/fluid/operators/math/CMakeLists.txt | 3 +- .../fluid/operators/math/jit_kernel_lstm.cc | 110 +++++++++++++++++- 2 files changed, 111 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 2a389ea1c8..16e1dc40f1 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -77,5 +77,6 @@ endif() cc_test(concat_test SRCS concat_test.cc DEPS concat) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_library(jit_kernel_exp SRCS jit_kernel_exp.cc DEPS cpu_info cblas activation_functions) -cc_library(jit_kernel SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_lstm.cc DEPS cpu_info cblas jit_kernel_exp) +cc_library(jit_kernel_lstm SRCS jit_kernel_lstm.cc DEPS cpu_info cblas activation_functions) +cc_library(jit_kernel SRCS jit_kernel.cc jit_kernel_blas.cc DEPS cpu_info cblas jit_kernel_exp jit_kernel_lstm) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/jit_kernel_lstm.cc b/paddle/fluid/operators/math/jit_kernel_lstm.cc index 210b229b28..71531d833d 100644 --- a/paddle/fluid/operators/math/jit_kernel_lstm.cc +++ b/paddle/fluid/operators/math/jit_kernel_lstm.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/jit_kernel_macro.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/macros.h" #ifdef __AVX__ #include @@ -24,10 +25,63 @@ limitations under the License. */ namespace paddle { namespace operators { namespace math { -namespace jitkernel { +#ifdef __AVX__ +namespace detail { +__m256 Exp(__m256 a); +} // namespace detail +#endif +namespace jitkernel { namespace jit = platform::jit; +#ifdef __AVX__ +typedef enum { kSigmoid, kRelu, kTanh, kIdentity } act_type; + +class AVXAct { + public: + virtual ~AVXAct() = default; + virtual __m256 Compute(__m256 x) const = 0; +}; + +template +class AVXActImpl : public AVXAct { + public: + __m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); } +}; + +template <> +__m256 AVXActImpl::Compute(__m256 x) const { + __m256 ones = _mm256_set1_ps(1.0f); + x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); + x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); + x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); + x = detail::Exp(x); + x = _mm256_add_ps(ones, x); + return _mm256_div_ps(ones, x); +} + +template <> +__m256 AVXActImpl::Compute(__m256 x) const { + __m256 ones = _mm256_set1_ps(1.0f); + x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); + x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); + x = detail::Exp(x); + x = _mm256_add_ps(ones, x); + x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); + return _mm256_sub_ps(x, ones); +} + +template <> +__m256 AVXActImpl::Compute(__m256 x) const { + return _mm256_max_ps(x, _mm256_setzero_ps()); +} + +template <> +__m256 AVXActImpl::Compute(__m256 x) const { + return x; +} +#endif + /* LSTM JitKernel */ template class LSTMKernelImpl : public LSTMKernel { @@ -61,6 +115,23 @@ class LSTMKernelImpl : public LSTMKernel { act_cell_d_ = GetActKernel(act_cell, d); vmul_d_ = KernelPool::Instance().template Get>(d); vadd_d_ = KernelPool::Instance().template Get>(d); +#ifdef __AVX__ + auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr { + if (type == "sigmoid") { + return std::unique_ptr(new AVXActImpl()); + } else if (type == "relu") { + return std::unique_ptr(new AVXActImpl()); + } else if (type == "tanh") { + return std::unique_ptr(new AVXActImpl()); + } else if (type == "identity" || type == "") { + return std::unique_ptr(new AVXActImpl()); + } + PADDLE_THROW("Not support type: %s", type); + }; + avx_act_gate_ = GetAVXAct(act_gate); + avx_act_cand_ = GetAVXAct(act_cand); + avx_act_cell_ = GetAVXAct(act_cell); +#endif } void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht) const override { @@ -83,8 +154,44 @@ class LSTMKernelImpl : public LSTMKernel { std::shared_ptr> act_gate_3d_, act_cand_d_, act_cell_d_; std::shared_ptr> vmul_d_; std::shared_ptr> vadd_d_; +#ifdef __AVX__ + std::unique_ptr avx_act_gate_, avx_act_cand_, avx_act_cell_; +#endif }; +#define INTRI8_FLOAT(isa) \ + template <> \ + void LSTMKernelImpl::ComputeCtHt( \ + float* gates, const float* ct_1, float* ct, float* ht) const { \ + /* gates: W_ch, W_ih, W_fh, W_oh */ \ + __m256 c, i, f, o; \ + c = _mm256_loadu_ps(gates); \ + i = _mm256_loadu_ps(gates + 8); \ + f = _mm256_loadu_ps(gates + 16); \ + o = _mm256_loadu_ps(gates + 24); \ + /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ + c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ + i = _mm256_loadu_ps(ct_1); \ + f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ + f = _mm256_add_ps(c, f); \ + _mm256_storeu_ps(ct, f); \ + /* H_t = act_cell(C_t) * ogated */ \ + o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ + _mm256_storeu_ps(ht, o); \ + } + +// TODO(TJ): optimize keq16 + +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +INTRI8_FLOAT(jit::avx512f); +#endif + #define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \ template <> \ std::shared_ptr> \ @@ -104,6 +211,7 @@ class LSTMKernelImpl : public LSTMKernel { REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, JITKERNEL_KEY_LSTM, JITKERNEL_NEW_LSTM_IMPL); +#undef INTRI8_FLOAT #undef JITKERNEL_DECLARE_LSTM #undef JITKERNEL_KEY_LSTM #undef JITKERNEL_NEW_LSTM_IMPL -- GitLab