diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 5852705b6b8d1c650faeae3dc810aac65353b459..024397067c50fd2d82c6efd07a250fef4d9a2187 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/lod_tensor.h" namespace paddle { namespace framework { @@ -35,7 +37,6 @@ std::unique_ptr FCLstmFusePass::ApplyImpl( auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - auto* id = subgraph.at(gpd.pattern().RetrieveNode("any_node")); marked_nodes.insert(id); }; @@ -73,12 +74,31 @@ std::unique_ptr FCLstmFusePass::ApplyImpl( op_desc.SetOutput("Hidden", {hidden_n->Name()}); op_desc.SetOutput("Cell", {cell_n->Name()}); op_desc.SetOutput("XX", {xx_n->Name()}); - op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"}); - op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"}); + op_desc.SetOutput("BatchedInput", {"blstm_0.tmp_2"}); op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse")); op_desc.SetAttr("use_peepholes", false); + +#define TMP_NAME(x) "at.new.tmp." #x +#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)}) + OP_SET_OUT(BatchedCell); + OP_SET_OUT(BatchedHidden); + OP_SET_OUT(ReorderedH0); + OP_SET_OUT(ReorderedC0); +#undef OP_SET_OUT auto* op = graph->CreateOpNode(&op_desc); + PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); + auto* scope = graph->Get(kParamScopeAttr); + +#define TMP_NEW(x) scope->Var(TMP_NAME(x))->GetMutable() + TMP_NEW(BatchedCell); + TMP_NEW(BatchedHidden); + TMP_NEW(ReorderedH0); + TMP_NEW(ReorderedC0); + +#undef TMP_NEW +#undef TMP_NAME + #define LINK_TO(a, b) \ a->outputs.push_back(b); \ b->inputs.push_back(a); @@ -89,7 +109,6 @@ std::unique_ptr FCLstmFusePass::ApplyImpl( LINK_TO(op, hidden_n); #undef LINK_TO return op; - }; lstm_creator(16, 12, 14, 18, 17, 22, 21, 19); @@ -105,14 +124,16 @@ std::unique_ptr FCLstmFusePass::ApplyImpl( for (auto it = node->inputs.begin(); it != node->inputs.end();) { if (marked_nodes.count(*it)) { it = const_cast(node)->inputs.erase(it); - } else + } else { it++; + } } for (auto it = node->outputs.begin(); it != node->outputs.end();) { if (marked_nodes.count(*it)) { it = const_cast(node)->outputs.erase(it); - } else + } else { it++; + } } } diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 1ab73d88db22a319391772d27bb75fab2783d086..d6439acf272f0be225beb54c2b85e335e79fc4e7 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -22,7 +22,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/platform/cpu_info.h" -DEFINE_bool(seq_mode, false, "Use sequence mode"); +DEFINE_bool(seq_mode, true, "Use sequence mode"); namespace paddle { namespace operators {