From 209e9c3db14a691bdd9f824fbae7cb8568159373 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 26 Sep 2018 15:30:23 +0800 Subject: [PATCH] refine peephole test=develop --- .../fluid/framework/ir/fc_lstm_fuse_pass.cc | 3 ++ paddle/fluid/operators/fusion_lstm_op.cc | 46 +++++++++++-------- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index aa95d3e9f6c..f5c28648652 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -77,10 +77,12 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, const std::string BatchedCellPreAct = patterns::UniqueKey("BatchedCellPreAct"); const std::string BatchedGate = patterns::UniqueKey("BatchedGate"); + const std::string CheckedCell = patterns::UniqueKey("CheckedCell"); scope->Var(BatchedInput)->GetMutable(); scope->Var(BatchedCellPreAct)->GetMutable(); scope->Var(BatchedGate)->GetMutable(); + scope->Var(CheckedCell)->GetMutable(); op_desc.SetInput("H0", {}); op_desc.SetInput("C0", {}); @@ -90,6 +92,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, op_desc.SetOutput("BatchedGate", {BatchedGate}); op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct}); op_desc.SetOutput("BatchedInput", {BatchedInput}); + op_desc.SetOutput("CheckedCell", {CheckedCell}); op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse")); op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes")); // TODO(TJ): get from attr diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 8ca79d20ec4..23e8edd18d0 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -76,12 +76,18 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(b_dims[0], 1, "The first dimension of Input(Bias) should be 1."); - PADDLE_ENFORCE_EQ( - b_dims[1], (ctx->Attrs().Get("use_peepholes") ? 7 : 4) * frame_size, - "The second dimension of Input(Bias) should be " - "7 * %d if enable peepholes connection or" - "4 * %d if disable peepholes", - frame_size, frame_size); + if (ctx->Attrs().Get("use_peepholes")) { + PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size, + "The second dimension of Input(Bias) should be " + "7 * %d if enable peepholes connection", + frame_size); + ctx->SetOutputDim("CheckedCell", {2, frame_size}); + } else { + PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, + "The second dimension of Input(Bias) should be " + "4 * %d if disable peepholes", + frame_size); + } framework::DDim out_dims({x_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims); @@ -173,6 +179,8 @@ void FusionLSTMOpMaker::Make() { AddOutput("BatchedCell", "(LoDTensor) (T x D).").AsIntermediate(); AddOutput("ReorderedH0", "(LoDTensor) (N x D).").AsIntermediate(); AddOutput("ReorderedC0", "(LoDTensor) (N x D).").AsIntermediate(); + AddOutput("CheckedCell", "(Tensor) (2 x D) only for peephole.") + .AsIntermediate(); AddAttr("use_peepholes", "(bool, defalut: True) " "whether to enable diagonal/peephole connections.") @@ -250,19 +258,19 @@ class FuisonLSTMKernel : public framework::OpKernel { const int D3 = D * 3; \ const int D4 = wh_dims[1]; -#define INIT_BASE_INPUT_DATAS \ - const T* x_data = x->data(); \ - const T* wx_data = wx->data(); \ - const T* wh_data = wh->data(); \ - /* diagonal weight*/ \ - const T* wc_data = bias->data() + D4; \ - /* for peephole only*/ \ - Tensor checked_cell; \ - T* checked_cell_data = nullptr; \ - auto place = ctx.GetPlace(); \ - if (use_peepholes) { \ - /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ - checked_cell_data = checked_cell.mutable_data({2, D}, place); \ +#define INIT_BASE_INPUT_DATAS \ + const T* x_data = x->data(); \ + const T* wx_data = wx->data(); \ + const T* wh_data = wh->data(); \ + /* diagonal weight*/ \ + const T* wc_data = bias->data() + D4; \ + /* for peephole only*/ \ + T* checked_cell_data = nullptr; \ + auto place = ctx.GetPlace(); \ + if (use_peepholes) { \ + /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ + auto* checked_cell = ctx.Output("CheckedCell"); \ + checked_cell_data = checked_cell->mutable_data(place); \ } /// Compute LSTM -- GitLab