diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index a73ea09f1e120356793a6fd198ce54e68d162cc8..8bab37c5830dfdcd5d6ccf1cc049387b496b0d04 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -299,11 +299,21 @@ class AttentionLSTMKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); fc_out->Resize({max_seq_len, 1}); - math::VecActivations act_functor; std::function act_gate, act_cell, act_cand; - act_gate = act_functor(ctx.Attr("gate_activation")); - act_cell = act_functor(ctx.Attr("cell_activation")); - act_cand = act_functor(ctx.Attr("candidate_activation")); + auto& act_gate_str = ctx.Attr("gate_activation"); + auto& act_cell_str = ctx.Attr("cell_activation"); + auto& act_cand_str = ctx.Attr("candidate_activation"); + if (platform::jit::MayIUse(platform::jit::avx)) { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } else { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } const T* x_data = x->data(); const T* h0_data = h0 ? h0->data() : NULL;