diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 604c6f183979e8f21c32b51ff6690ab04f06d4e4..97852e292834b22dd1aebb6ff8f5dd2b7cb39dd6 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -230,6 +230,22 @@ class FuisonLSTMKernel : public framework::OpKernel { auto* hidden_out = ctx.Output("Hidden"); auto* cell_out = ctx.Output("Cell"); + std::function act_gate, act_cell, act_cand; + auto& act_gate_str = ctx.Attr("gate_activation"); + auto& act_cell_str = ctx.Attr("cell_activation"); + auto& act_cand_str = ctx.Attr("candidate_activation"); + if (platform::jit::MayIUse(platform::jit::avx)) { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } else { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } + auto x_lod = x->lod(); auto x_dims = x->dims(); // T x M auto wh_dims = wh->dims(); // D x 4D @@ -263,15 +279,12 @@ class FuisonLSTMKernel : public framework::OpKernel { prev_cell_data = c0_data + i * D; } else { // W_ch, W_ih, W_fh, W_oh - // actgate - math::vec_sigmoid(D3, xx_data + D, xx_data + D); - // ch gate - math::vec_tanh(D, xx_data, xx_data); + act_gate(D3, xx_data + D, xx_data + D); + act_cand(D, xx_data, xx_data); // cell out= input*tilde blas.VMUL(D, xx_data, xx_data + D, cell_out_data); // hidden out= act_state(cellout) * outgate - // act state - math::vec_tanh(D, cell_out_data, xx_data + D2); + act_cell(D, cell_out_data, xx_data + D2); blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev @@ -290,10 +303,8 @@ class FuisonLSTMKernel : public framework::OpKernel { D4); // W_ch, W_ih, W_fh, W_oh - // actgate - math::vec_sigmoid(D3, xx_data + D, xx_data + D); - // ch gate - math::vec_tanh(D, xx_data, xx_data); + act_gate(D3, xx_data + D, xx_data + D); + act_cand(D, xx_data, xx_data); // a = forget * prev_cell blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2); @@ -305,8 +316,7 @@ class FuisonLSTMKernel : public framework::OpKernel { blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); // hidden out= act_state(cellout) * outgate - // act state - math::vec_tanh(D, cell_out_data, xx_data + D2); + act_cell(D, cell_out_data, xx_data + D2); blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index 9d8bef677fd16fb6bdc20b929137b4d885f4efd1..d807f0a8b6342bdc166ac7dd898aae9331e8c513 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -45,7 +45,7 @@ def fusion_lstm( class TestLstmOp(OpTest): def set_argument(self): - self.lod = [[2, 3, 2]] + pass def setUp(self): self.op_type = 'fusion_lstm'