From 4b28fab8c94863d5ff24ce4c59ff31bb5d06b4ee Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Sun, 26 Aug 2018 18:24:00 +0800 Subject: [PATCH] enable more acts --- paddle/fluid/operators/fusion_lstm_op.cc | 34 ++++++++++++------- .../tests/unittests/test_fusion_lstm_op.py | 2 +- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 604c6f1839..97852e2928 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -230,6 +230,22 @@ class FuisonLSTMKernel : public framework::OpKernel { auto* hidden_out = ctx.Output("Hidden"); auto* cell_out = ctx.Output("Cell"); + std::function act_gate, act_cell, act_cand; + auto& act_gate_str = ctx.Attr("gate_activation"); + auto& act_cell_str = ctx.Attr("cell_activation"); + auto& act_cand_str = ctx.Attr("candidate_activation"); + if (platform::jit::MayIUse(platform::jit::avx)) { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } else { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } + auto x_lod = x->lod(); auto x_dims = x->dims(); // T x M auto wh_dims = wh->dims(); // D x 4D @@ -263,15 +279,12 @@ class FuisonLSTMKernel : public framework::OpKernel { prev_cell_data = c0_data + i * D; } else { // W_ch, W_ih, W_fh, W_oh - // actgate - math::vec_sigmoid(D3, xx_data + D, xx_data + D); - // ch gate - math::vec_tanh(D, xx_data, xx_data); + act_gate(D3, xx_data + D, xx_data + D); + act_cand(D, xx_data, xx_data); // cell out= input*tilde blas.VMUL(D, xx_data, xx_data + D, cell_out_data); // hidden out= act_state(cellout) * outgate - // act state - math::vec_tanh(D, cell_out_data, xx_data + D2); + act_cell(D, cell_out_data, xx_data + D2); blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev @@ -290,10 +303,8 @@ class FuisonLSTMKernel : public framework::OpKernel { D4); // W_ch, W_ih, W_fh, W_oh - // actgate - math::vec_sigmoid(D3, xx_data + D, xx_data + D); - // ch gate - math::vec_tanh(D, xx_data, xx_data); + act_gate(D3, xx_data + D, xx_data + D); + act_cand(D, xx_data, xx_data); // a = forget * prev_cell blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2); @@ -305,8 +316,7 @@ class FuisonLSTMKernel : public framework::OpKernel { blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); // hidden out= act_state(cellout) * outgate - // act state - math::vec_tanh(D, cell_out_data, xx_data + D2); + act_cell(D, cell_out_data, xx_data + D2); blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index 9d8bef677f..d807f0a8b6 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -45,7 +45,7 @@ def fusion_lstm( class TestLstmOp(OpTest): def set_argument(self): - self.lod = [[2, 3, 2]] + pass def setUp(self): self.op_type = 'fusion_lstm' -- GitLab