From 2f3b498949c4bcfec6e4ced49f61745f76e78eef Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 5 Sep 2018 11:30:51 +0800 Subject: [PATCH] refine fusion seq lstm peephole --- .../fluid/framework/ir/fc_lstm_fuse_pass.cc | 1 + paddle/fluid/operators/fusion_lstm_op.cc | 126 ++++++++---------- .../tests/unittests/test_fusion_lstm_op.py | 44 +----- 3 files changed, 58 insertions(+), 113 deletions(-) diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 55153ecc3e..00f5e7fad2 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include #include "paddle/fluid/framework/lod_tensor.h" diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index c473e2593e..f9761d6ec4 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -78,13 +78,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(b_dims[0], 1, "The first dimension of Input(Bias) should be 1."); - - auto use_peepholes = ctx->Attrs().Get("use_peepholes"); - PADDLE_ENFORCE_EQ(b_dims[1], (use_peepholes ? 7 : 4) * frame_size, - "The second dimension of Input(Bias) should be " - "7 * %d if enable peepholes connection or" - "4 * %d if disable peepholes", - frame_size, frame_size); + PADDLE_ENFORCE_EQ( + b_dims[1], (ctx->Attrs().Get("use_peepholes") ? 7 : 4) * frame_size, + "The second dimension of Input(Bias) should be " + "7 * %d if enable peepholes connection or" + "4 * %d if disable peepholes", + frame_size, frame_size); framework::DDim out_dims({x_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims); @@ -231,18 +230,18 @@ class FuisonLSTMKernel : public framework::OpKernel { act_cand = act_functor(act_cand_str); \ } -#define INIT_BASE_INPUT_OUTPUT \ - auto* x = ctx.Input("X"); \ - auto* h0 = ctx.Input("H0"); \ - auto* c0 = ctx.Input("C0"); \ - auto* wx = ctx.Input("WeightX"); \ - auto* wh = ctx.Input("WeightH"); \ - auto* bias = ctx.Input("Bias"); \ - auto* xx = ctx.Output("XX"); \ - auto* hidden_out = ctx.Output("Hidden"); \ - auto* cell_out = ctx.Output("Cell"); \ - bool use_peepholes = ctx.Attr("use_peepholes"); \ - bool is_reverse = ctx.Attr("is_reverse"); +#define INIT_BASE_INPUT_OUTPUT \ + auto* x = ctx.Input("X"); \ + auto* h0 = ctx.Input("H0"); \ + auto* c0 = ctx.Input("C0"); \ + auto* wx = ctx.Input("WeightX"); \ + auto* wh = ctx.Input("WeightH"); \ + auto* bias = ctx.Input("Bias"); \ + auto* xx = ctx.Output("XX"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + auto* cell_out = ctx.Output("Cell"); \ + bool is_reverse = ctx.Attr("is_reverse"); \ + bool use_peepholes = ctx.Attr("use_peepholes"); #define INIT_BASE_SIZES \ auto x_dims = x->dims(); /* T x M*/ \ @@ -261,25 +260,24 @@ class FuisonLSTMKernel : public framework::OpKernel { auto x_lod = x->lod(); const int total_T = x_dims[0]; - const int N = x_lod[0].size() - 1; // batch size - + const int N = x_lod[0].size() - 1; const T* x_data = x->data(); const T* h0_data = h0 ? h0->data() : nullptr; const T* c0_data = c0 ? c0->data() : nullptr; - const T* bias_data = bias->data(); - const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc const T* wx_data = wx->data(); const T* wh_data = wh->data(); - - T* xx_data = xx->mutable_data(ctx.GetPlace()); - T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); - T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); - - // use local variable - framework::DDim check_dims({3, D}); - Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct - auto checked_cell_data = - checked_cell.mutable_data(check_dims, ctx.GetPlace()); + const T* wc_data = bias->data() + D4; // diagonal weight + auto place = ctx.GetPlace(); + T* xx_data = xx->mutable_data(place); + T* hidden_out_data = hidden_out->mutable_data(place); + T* cell_out_data = cell_out->mutable_data(place); + + Tensor checked_cell; + T* checked_cell_data = nullptr; + if (use_peepholes) { + // w_ic * Ct-1, w_fc * Ct-1 // , w_oc * Ct => ih + checked_cell_data = checked_cell.mutable_data({2, D}, place); + } auto blas = math::GetBlas(ctx); math::FCCompute(blas, total_T, D4, M, x_data, wx_data, @@ -306,44 +304,31 @@ class FuisonLSTMKernel : public framework::OpKernel { int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; const T* prev_c_data = nullptr; const T* prev_h_data = nullptr; - int tstart = 0; if (h0_data) { prev_h_data = h0_data + bid * D; prev_c_data = c0_data + bid * D; } else { - // If step == 0 and there is no initialized hidden state, that is to say - // the H0 is zeros. Then W_h * H_t-1 can be skipped - - // ~C_t + // W_ch, W_ih, W_fh, W_oh + act_gate(D, xx_data + D, xx_data + D); act_cand(D, xx_data, xx_data); - if (use_peepholes) { - // I_t, F_t - act_gate(D2, xx_data + D, xx_data + D); - } else { - // I_t, F_t, O_t - act_gate(D3, xx_data + D, xx_data + D); - } - // C_t = I_t * ~C_t + // C_t = input * tilde blas.VMUL(D, xx_data, xx_data + D, cell_out_data); + // H_t = act_state(cellout) * outgate if (use_peepholes) { // + W_oc * C_t for peephole connection - blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2); - blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3); - // O_t - act_gate(D, xx_data + D3, xx_data + D3); + // put result on W_ih + blas.VMUL(D, wc_data + D2, cell_out_data, xx_data + D); + blas.VADD(D, xx_data + D, xx_data + D3, xx_data + D3); } - - // hidden out= act_state(cellout) * outgate + act_gate(D, xx_data + D3, xx_data + D3); act_cell(D, cell_out_data, xx_data + D2); - // H_t = O_t * act_state(C_t) blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev prev_h_data = hidden_out_data; prev_c_data = cell_out_data; - tstart = 1; move_step(); } @@ -353,39 +338,32 @@ class FuisonLSTMKernel : public framework::OpKernel { blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), prev_h_data, D, wh_data, D4, static_cast(1), xx_data, D4); - // ~C_t - act_cand(D, xx_data, xx_data); - + // W_ch, W_ih, W_fh, W_oh if (use_peepholes) { // + W_ic|W_fc * C_t-1 for peephole connection blas.VMUL(D, wc_data, prev_c_data, checked_cell_data); blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D); - blas.VADD(D2, xx_data + D, checked_cell_data, xx_data + D); - // I_t, F_t + blas.VADD(D2, checked_cell_data, xx_data + D, xx_data + D); act_gate(D2, xx_data + D, xx_data + D); } else { - // I_t, F_t, O_t act_gate(D3, xx_data + D, xx_data + D); } - - // F_t * C_t-1 - blas.VMUL(D, xx_data + D2, prev_c_data, xx_data + D2); - // I_t * ~C_t + // a = I_t * act_cand(ch) + act_cand(D, xx_data, xx_data); blas.VMUL(D, xx_data, xx_data + D, xx_data + D); - // C_t = F_t * C_t-1 + I_t * ~C_t + // b = C_t-1 * F_t + blas.VMUL(D, prev_c_data, xx_data + D2, xx_data + D2); + // C_t = a + b blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); + // H_t = act_cell(C_t) * act_gate(O_c += C_t * W_oc) if (use_peepholes) { - // + W_oc * C_t for peephole connection - blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2); - blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3); - // O_t + // put result on W_ih + blas.VMUL(D, wc_data + D2, cell_out_data, xx_data + D); + blas.VADD(D, xx_data + D, xx_data + D3, xx_data + D3); act_gate(D, xx_data + D3, xx_data + D3); } - - // hidden out= act_state(cellout) * outgate act_cell(D, cell_out_data, xx_data + D2); - // H_t = O_t * act_state(C_t) blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev @@ -393,8 +371,8 @@ class FuisonLSTMKernel : public framework::OpKernel { prev_c_data = cell_out_data; move_step(); - } // for each step in batch - } // for each batch + } // for seqlen + } // for batch } void BatchCompute(const framework::ExecutionContext& ctx) const { diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index 4767e9433e..6ffb52185f 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -53,12 +53,11 @@ class TestFusionLSTMOp(OpTest): self.M = 8 self.D = 16 self.has_initial_state = False + self.use_peepholes = False self.is_reverse = False self.act_gate = 'sigmoid' self.act_cell = 'tanh' self.act_cand = 'tanh' - self.use_peepholes = False - self.use_seq = False self.set_conf() T = sum(self.lod[0]) @@ -108,7 +107,6 @@ class TestFusionLSTMOp(OpTest): } self.attrs = { 'use_peepholes': self.use_peepholes, - 'use_seq': self.use_seq, 'is_reverse': self.is_reverse, 'gate_activation': self.act_gate, 'cell_activation': self.act_cell, @@ -178,50 +176,18 @@ class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp): self.is_reverse = True -class TestFusionLSTMOpPoopholesBS1(TestFusionLSTMOp): +class TestFusionLSTMOpPeepholesInitReverse(TestFusionLSTMOp): def set_conf(self): self.use_peepholes = True - self.lod = [[3]] - self.D = 16 - - -class TestFusionLSTMOpSeqInit(TestFusionLSTMOp): - def set_conf(self): - self.use_seq = True - self.has_initial_state = True - - -class TestFusionLSTMOpSeqReverse(TestFusionLSTMOp): - def set_conf(self): - self.use_seq = True - self.is_reverse = True - - -class TestFusionLSTMOpSeqInitReverse(TestFusionLSTMOp): - def set_conf(self): - self.use_seq = True self.has_initial_state = True self.is_reverse = True -class TestFusionLSTMOpSeqPeepholes(TestFusionLSTMOp): - def set_conf(self): - self.use_seq = True - self.use_peepholes = True - - -class TestFusionLSTMOpSeqPeepholesInit(TestFusionLSTMOp): - def set_conf(self): - self.use_seq = True - self.use_peepholes = True - self.has_initial_state = True - - -class TestFusionLSTMOpSeqPeepholesReverse(TestFusionLSTMOp): +class TestFusionLSTMOpPoopholesBS1(TestFusionLSTMOp): def set_conf(self): - self.use_seq = True self.use_peepholes = True - self.is_reverse = True + self.lod = [[2]] + self.D = 8 if __name__ == '__main__': -- GitLab