From a79a77eeb5491b088a4291fe717ebefe481477c7 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 31 Aug 2018 17:24:53 +0800 Subject: [PATCH] refine and clean code --- paddle/fluid/operators/fusion_lstm_op.cc | 130 +++++++++-------------- 1 file changed, 51 insertions(+), 79 deletions(-) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index d9cb75b77d6..1ab73d88db2 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -215,46 +215,53 @@ This operator fuse the X into LSTM, more details can refer to LSTM op. template class FuisonLSTMKernel : public framework::OpKernel { public: +#define INIT_VEC_FUNC \ + std::function act_gate, act_cell, act_cand; \ + auto& act_gate_str = ctx.Attr("gate_activation"); \ + auto& act_cell_str = ctx.Attr("cell_activation"); \ + auto& act_cand_str = ctx.Attr("candidate_activation"); \ + if (platform::jit::MayIUse(platform::jit::avx)) { \ + math::VecActivations act_functor; \ + act_gate = act_functor(act_gate_str); \ + act_cell = act_functor(act_cell_str); \ + act_cand = act_functor(act_cand_str); \ + } else { \ + math::VecActivations act_functor; \ + act_gate = act_functor(act_gate_str); \ + act_cell = act_functor(act_cell_str); \ + act_cand = act_functor(act_cand_str); \ + } + +#define INIT_BASE_INPUT_OUTPUT \ + auto* x = ctx.Input("X"); \ + auto* h0 = ctx.Input("H0"); \ + auto* c0 = ctx.Input("C0"); \ + auto* wx = ctx.Input("WeightX"); \ + auto* wh = ctx.Input("WeightH"); \ + auto* bias = ctx.Input("Bias"); \ + auto* xx = ctx.Output("XX"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + auto* cell_out = ctx.Output("Cell"); \ + bool is_reverse = ctx.Attr("is_reverse"); + +#define INIT_BASE_SIZES \ + auto x_dims = x->dims(); /* T x M*/ \ + auto wh_dims = wh->dims(); /* D x 4D*/ \ + const int M = x_dims[1]; \ + const int D = wh_dims[0]; \ + const int D2 = D * 2; \ + const int D3 = D * 3; \ + const int D4 = wh_dims[1]; + void SeqCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; - auto* x = ctx.Input("X"); - auto* h0 = ctx.Input("H0"); - auto* c0 = ctx.Input("C0"); - auto* wx = ctx.Input("WeightX"); - auto* wh = ctx.Input("WeightH"); - auto* bias = ctx.Input("Bias"); - - auto* xx = ctx.Output("XX"); - auto* hidden_out = ctx.Output("Hidden"); - auto* cell_out = ctx.Output("Cell"); - bool is_reverse = ctx.Attr("is_reverse"); - - std::function act_gate, act_cell, act_cand; - auto& act_gate_str = ctx.Attr("gate_activation"); - auto& act_cell_str = ctx.Attr("cell_activation"); - auto& act_cand_str = ctx.Attr("candidate_activation"); - if (platform::jit::MayIUse(platform::jit::avx)) { - math::VecActivations act_functor; - act_gate = act_functor(act_gate_str); - act_cell = act_functor(act_cell_str); - act_cand = act_functor(act_cand_str); - } else { - math::VecActivations act_functor; - act_gate = act_functor(act_gate_str); - act_cell = act_functor(act_cell_str); - act_cand = act_functor(act_cand_str); - } + INIT_BASE_INPUT_OUTPUT + INIT_BASE_SIZES + INIT_VEC_FUNC auto x_lod = x->lod(); - auto x_dims = x->dims(); // T x M - auto wh_dims = wh->dims(); // D x 4D const int total_T = x_dims[0]; const int N = x_lod[0].size() - 1; // batch size - const int M = x_dims[1]; // x frame size - const int D = wh_dims[0]; - const int D2 = D * 2; - const int D3 = D * 3; - const int D4 = wh_dims[1]; const T* x_data = x->data(); const T* h0_data = h0 ? h0->data() : NULL; @@ -343,52 +350,18 @@ class FuisonLSTMKernel : public framework::OpKernel { void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = platform::CPUDeviceContext; - auto* x = ctx.Input("X"); - auto* wx = ctx.Input("WeightX"); - auto* wh = ctx.Input("WeightH"); - auto* bias = ctx.Input("Bias"); - auto* h0 = ctx.Input("H0"); - auto* c0 = ctx.Input("C0"); - - auto* xx = ctx.Output("XX"); + INIT_BASE_INPUT_OUTPUT + if (x->lod()[0].size() == 2) { // batch size == 1 + SeqCompute(ctx); + } + INIT_BASE_SIZES + INIT_VEC_FUNC + auto* reordered_h0 = ctx.Output("ReorderedH0"); auto* reordered_c0 = ctx.Output("ReorderedC0"); auto* batched_input = ctx.Output("BatchedInput"); auto* batched_c_out = ctx.Output("BatchedCell"); auto* batched_h_out = ctx.Output("BatchedHidden"); - auto* hidden_out = ctx.Output("Hidden"); - auto* cell_out = ctx.Output("Cell"); - bool is_reverse = ctx.Attr("is_reverse"); - - std::function act_gate, act_cell, act_cand; - auto& act_gate_str = ctx.Attr("gate_activation"); - auto& act_cell_str = ctx.Attr("cell_activation"); - auto& act_cand_str = ctx.Attr("candidate_activation"); - if (platform::jit::MayIUse(platform::jit::avx)) { - math::VecActivations act_functor; - act_gate = act_functor(act_gate_str); - act_cell = act_functor(act_cell_str); - act_cand = act_functor(act_cand_str); - } else { - math::VecActivations act_functor; - act_gate = act_functor(act_gate_str); - act_cell = act_functor(act_cell_str); - act_cand = act_functor(act_cand_str); - } - - auto x_dims = x->dims(); // T x M - auto wh_dims = wh->dims(); // D x 4D - - // auto x_lod = x->lod(); - // const int N = x_lod[0].size() - 1; // batch size - // if (N == 1) { - // SeqCompute(ctx); - // } - const int M = x_dims[1]; - const int D = wh_dims[0]; - const int D2 = D * 2; - const int D3 = D * 3; - const int D4 = wh_dims[1]; const T* x_data = x->data(); const T* wx_data = wx->data(); @@ -485,16 +458,12 @@ class FuisonLSTMKernel : public framework::OpKernel { // W_ch, W_ih, W_fh, W_oh act_gate(D3, cur_in_data + D, cur_in_data + D); act_cand(D, cur_in_data, cur_in_data); - // a = forget * prev_cell blas.VMUL(D, cur_in_data + D2, cur_prev_c_data, cur_in_data + D2); - // b = input * tilde blas.VMUL(D, cur_in_data, cur_in_data + D, cur_in_data + D); - // cell out= a+b blas.VADD(D, cur_in_data + D, cur_in_data + D2, cur_c_out_data); - // hidden out= act_state(cellout) * outgate act_cell(D, cur_c_out_data, cur_in_data + D2); blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data); @@ -526,6 +495,9 @@ class FuisonLSTMKernel : public framework::OpKernel { BatchCompute(ctx); } } +#undef INIT_BASE_SIZES +#undef INIT_BASE_INPUT_OUTPUT +#undef INIT_VEC_FUNC }; } // namespace operators -- GitLab