diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index a23704621cd34f67e699d04de2dd93bae1de397e..8ca79d20ec4f6412b00dbf3990068f81b65e2efd 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -543,12 +543,23 @@ class FuisonLSTMKernel : public framework::OpKernel { MOVE_ONE_STEP; } } else { + // TODO(TJ): unly workaround, clean me + std::function compute_ctht; + if (platform::jit::MayIUse(platform::jit::avx) && + act_gate_str == "sigmoid" && act_cand_str == "tanh" && + act_cell_str == "tanh" && D == 8) { + compute_ctht = math::lstm_compute_ctht; + } else { + compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) { + COMPUTE_CtHt(gates, ct_1, ct, ht); + }; + } for (int step = tstart; step < max_seq_len; ++step) { const int cur_bs = batch_starts[step + 1] - batch_starts[step]; GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); DEFINE_CUR; for (int i = 0; i < cur_bs; ++i) { - COMPUTE_CtHt(cur_in_data, cur_prev_c_data, cur_c_out_data, + compute_ctht(cur_in_data, cur_prev_c_data, cur_c_out_data, cur_h_out_data); MOVE_ONE_BATCH; } diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.cc b/paddle/fluid/operators/math/cpu_lstm_compute.cc index 83094d01cf62d1f7493ac0a0662ce35fad42c549..84e387c974b17f5b4b8cb9c8dae8616f2b3609aa 100644 --- a/paddle/fluid/operators/math/cpu_lstm_compute.cc +++ b/paddle/fluid/operators/math/cpu_lstm_compute.cc @@ -13,49 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/cpu_lstm_compute.h" -#ifdef __AVX__ -#include -#endif + namespace paddle { namespace operators { -namespace math { - -#ifdef __AVX__ -// TODO(TJ): ugly workaround, clean me - -namespace detail { -namespace forward { -namespace avx { -__m256 Sigmoid(const __m256 a); -__m256 Tanh(const __m256 a); -} // namespace avx -} // namespace forward -} // namespace detail - -template <> -void lstm_compute_ctht(float* gates, const float* ct_1, float* ct, - float* ht) { - namespace act = detail::forward::avx; - // gates: W_ch, W_ih, W_fh, W_oh - __m256 c, i, f, o; - c = _mm256_loadu_ps(gates); - i = _mm256_loadu_ps(gates + 8); - f = _mm256_loadu_ps(gates + 16); - o = _mm256_loadu_ps(gates + 24); - - /* C_t = C_t-1 * fgated + cand_gated * igated*/ - c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i)); - i = _mm256_loadu_ps(ct_1); - f = _mm256_mul_ps(i, act::Sigmoid(f)); - f = _mm256_add_ps(c, f); - _mm256_storeu_ps(ct, f); - - /* H_t = act_cell(C_t) * ogated */ - o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o)); - _mm256_storeu_ps(ht, o); -} -#endif - -} // namespace math +namespace math {} // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.h b/paddle/fluid/operators/math/cpu_lstm_compute.h index fe6c01b7d930eb64409668a30b99aa5039a58689..00e9e4f32ae790dcd4417d47db2b0db510a68549 100644 --- a/paddle/fluid/operators/math/cpu_lstm_compute.h +++ b/paddle/fluid/operators/math/cpu_lstm_compute.h @@ -16,6 +16,9 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/platform/cpu_info.h" +#ifdef __AVX__ +#include +#endif namespace paddle { namespace operators { @@ -35,13 +38,47 @@ void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) { ct[d] = ct_1[d] * f[d] + gates[d] * i[d]; // H_t = act_cell(C_t) * ogated T tmp = ct[d] * 2; - tmp = static_cast(0) - (tmp < min) ? min : ((tmp > max) ? max : tmp); + tmp = static_cast(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); vec_exp(1, &tmp, &tmp); tmp = static_cast(2) / (static_cast(1) + tmp) - static_cast(1); ht[d] = tmp * o[d]; } } +#ifdef __AVX__ +namespace detail { +namespace forward { +namespace avx { +__m256 Sigmoid(const __m256 a); +__m256 Tanh(const __m256 a); +} // namespace avx +} // namespace forward +} // namespace detail + +template <> +void lstm_compute_ctht(float* gates, const float* ct_1, float* ct, + float* ht) { + namespace act = detail::forward::avx; + // gates: W_ch, W_ih, W_fh, W_oh + __m256 c, i, f, o; + c = _mm256_loadu_ps(gates); + i = _mm256_loadu_ps(gates + 8); + f = _mm256_loadu_ps(gates + 16); + o = _mm256_loadu_ps(gates + 24); + + /* C_t = C_t-1 * fgated + cand_gated * igated*/ + c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i)); + i = _mm256_loadu_ps(ct_1); + f = _mm256_mul_ps(i, act::Sigmoid(f)); + f = _mm256_add_ps(c, f); + _mm256_storeu_ps(ct, f); + + /* H_t = act_cell(C_t) * ogated */ + o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o)); + _mm256_storeu_ps(ht, o); +} +#endif + } // namespace math } // namespace operators } // namespace paddle