diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 6949cf55c09e25ce2695af54ccdfd0a5f386815c..a23704621cd34f67e699d04de2dd93bae1de397e 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -396,15 +396,15 @@ class FuisonLSTMKernel : public framework::OpKernel { } } else { // TODO(TJ): unly workaround, clean me - std::function compute_ctht; + std::function compute_ctht; if (platform::jit::MayIUse(platform::jit::avx) && act_gate_str == "sigmoid" && act_cand_str == "tanh" && act_cell_str == "tanh" && D == 8) { compute_ctht = math::lstm_compute_ctht; } else { - compute_ctht = [&](const T* gates, const T* ct_1, T* ct, T* ht) { + compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) { COMPUTE_CtHt(gates, ct_1, ct, ht); - } + }; } for (int i = 0; i < N; ++i) { PROCESS_H0C0 diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.cc b/paddle/fluid/operators/math/cpu_lstm_compute.cc index 7e487079db5a7c1ce9d4d0a2d380a19f509eed4d..83094d01cf62d1f7493ac0a0662ce35fad42c549 100644 --- a/paddle/fluid/operators/math/cpu_lstm_compute.cc +++ b/paddle/fluid/operators/math/cpu_lstm_compute.cc @@ -25,12 +25,15 @@ namespace math { namespace detail { namespace forward { -namespace avx {} // namespace avx +namespace avx { +__m256 Sigmoid(const __m256 a); +__m256 Tanh(const __m256 a); +} // namespace avx } // namespace forward } // namespace detail template <> -void lstm_compute_ctht(const float* gates, const float* ct_1, float* ct, +void lstm_compute_ctht(float* gates, const float* ct_1, float* ct, float* ht) { namespace act = detail::forward::avx; // gates: W_ch, W_ih, W_fh, W_oh @@ -52,6 +55,7 @@ void lstm_compute_ctht(const float* gates, const float* ct_1, float* ct, _mm256_storeu_ps(ht, o); } #endif + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.h b/paddle/fluid/operators/math/cpu_lstm_compute.h index 7b803b6c8adcbe55bfdcbf1c17ae12b33c6501f1..fe6c01b7d930eb64409668a30b99aa5039a58689 100644 --- a/paddle/fluid/operators/math/cpu_lstm_compute.h +++ b/paddle/fluid/operators/math/cpu_lstm_compute.h @@ -23,22 +23,19 @@ namespace math { // TODO(TJ): ugly workaround, clean me template -void lstm_compute_ctht(const T* gates, const T* ct_1, T* ct, T* ht) { +void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) { // gates: W_ch, W_ih, W_fh, W_oh vec_sigmoid(24, gates + 8, gates + 8); vec_tanh(8, gates, gates); const T *i = gates + 8, *f = gates + 16, *o = gates + 24; + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; for (int d = 0; d < 8; ++d) { // C_t = C_t-1 * fgated + cand_gated * igated ct[d] = ct_1[d] * f[d] + gates[d] * i[d]; - // H_t = act_cell(C_t) * ogated T tmp = ct[d] * 2; - tmp = static_cast(0) - (tmp < static_cast(SIGMOID_THRESHOLD_MIN)) - ? min - : ((tmp > static_cast(SIGMOID_THRESHOLD_MAX)) - ? static_cast(SIGMOID_THRESHOLD_MAX) - : tmp); + tmp = static_cast(0) - (tmp < min) ? min : ((tmp > max) ? max : tmp); vec_exp(1, &tmp, &tmp); tmp = static_cast(2) / (static_cast(1) + tmp) - static_cast(1); ht[d] = tmp * o[d];