diff --git a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc index c4ffb2a9de4970abd147ce2fd709977e26eb626b..a56fcd1a523391ce801bb2b8c3e9dfa424abdd54 100644 --- a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc @@ -41,7 +41,7 @@ struct Param { std::string LSTMOUT = "at.lstmout.new"; }; -void PrepareParameters(Graph* graph, const Param& param); +void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op); void FindWhileOp(Graph* graph) { GraphPatternDetector gpd; @@ -98,7 +98,7 @@ void FindWhileOp(Graph* graph) { auto* hidden_init = graph->RetrieveNode(8); auto* lstm_op = graph->CreateOpNode(&op_desc); - PrepareParameters(graph, param); + PrepareParameters(graph, param, lstm_op); IR_NODE_LINK_TO(X, lstm_op); IR_NODE_LINK_TO(cell_init, lstm_op); @@ -133,20 +133,29 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, const LoDTensor& B_output, const LoDTensor& B_cell, LoDTensor* out); -void PrepareParameters(Graph* graph, const Param& param) { +void PrepareParameters(Graph* graph, const Param& param, ir::Node* lstm_op) { // Check parameters PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); auto& scope = graph->Get(kParamScopeAttr); // Create new parameters. + // AddInput scope.Var(param.LSTMWeight)->GetMutable(); scope.Var(param.LSTMBias)->GetMutable(); - scope.Var(param.Hidden)->GetMutable(); - scope.Var(param.Cell)->GetMutable(); - scope.Var(param.AttentionedX)->GetMutable(); - scope.Var(param.AttentionFCOut)->GetMutable(); - scope.Var(param.LSTMX)->GetMutable(); - scope.Var(param.LSTMOUT)->GetMutable(); +// AddOutput +#define IR_NODE(x) \ + VarDesc key_##x(param.x); \ + key_##x.SetPersistable(false); \ + auto* node_##x = graph->CreateVarNode(&key_##x); \ + IR_NODE_LINK_TO(lstm_op, node_##x); + + IR_NODE(Hidden); + IR_NODE(Cell); + IR_NODE(AttentionedX); + IR_NODE(AttentionFCOut); + IR_NODE(LSTMX); + IR_NODE(LSTMOUT); +#undef IR_NODE #define GATE_W(name__) \ auto* W_##name__##_w0 = scope.FindVar(#name__ ".w_0"); \ diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc index 21ceec7927e4a9f5f9e29aeffbf31e473cf0237e..4df09b828a7f94304b7ce03ec2fe5d695a6e11e0 100644 --- a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc @@ -127,35 +127,24 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, embedding_data, k, weightx_data, n, beta, embeddings_data, n); op_desc.SetInput("Embeddings", {embeddings}); - // Create temp variables. - const std::string BatchedInput = patterns::UniqueKey("BatchedInput"); - const std::string BatchedCellPreAct = - patterns::UniqueKey("BatchedCellPreAct"); - const std::string BatchedGate = patterns::UniqueKey("BatchedGate"); - - scope->Var(BatchedInput)->GetMutable(); - scope->Var(BatchedCellPreAct)->GetMutable(); - scope->Var(BatchedGate)->GetMutable(); - op_desc.SetInput("H0", {}); op_desc.SetInput("C0", {}); op_desc.SetOutput("Hidden", {hidden->Name()}); op_desc.SetOutput("Cell", {cell->Name()}); op_desc.SetOutput("XX", {xx->Name()}); - op_desc.SetOutput("BatchedGate", {BatchedGate}); - op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct}); - op_desc.SetOutput("BatchedInput", {BatchedInput}); op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse")); op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes")); // TODO(TJ): get from attr op_desc.SetAttr("use_seq", true); - PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); - auto& scope = graph->Get(kParamScopeAttr); +// Create temp variables. #define OP_SET_OUT(x) \ const std::string x = patterns::UniqueKey(#x); \ - op_desc.SetOutput(#x, {x}); \ - scope.Var(x)->GetMutable() + op_desc.SetOutput(#x, {x}); + + OP_SET_OUT(BatchedGate); + OP_SET_OUT(BatchCellPreAct); + OP_SET_OUT(BatchedInput); OP_SET_OUT(BatchedCell); OP_SET_OUT(BatchedHidden); OP_SET_OUT(ReorderedH0); @@ -163,11 +152,28 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, #undef OP_SET_OUT auto* op = graph->CreateOpNode(&op_desc); + IR_NODE_LINK_TO(input, op); IR_NODE_LINK_TO(weight_x, op); IR_NODE_LINK_TO(weight_h, op); IR_NODE_LINK_TO(bias, op); IR_NODE_LINK_TO(op, hidden); + +#define IR_NODE(x) \ + VarDesc key_##x(x); \ + key_##x.SetPersistable(false); \ + auto* node_##x = graph->CreateVarNode(&key_##x); \ + IR_NODE_LINK_TO(op, node_##x); + + IR_NODE(BatchedGate); + IR_NODE(BatchCellPreAct); + IR_NODE(BatchedInput); + IR_NODE(BatchedCell); + IR_NODE(BatchedHidden); + IR_NODE(ReorderedH0); + IR_NODE(ReorderedC0); +#undef IR_NODE + return op; }; diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index a5a72e875e49a732ae27f2f4e949ef893011a2a4..89fa5a75e954dbb05932f311a8d77e3f97be86f3 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -74,38 +74,25 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, op_desc.SetInput("Bias", {new_bias_var}); } - // Create temp variables. - const std::string BatchedInput = patterns::UniqueKey("BatchedInput"); - const std::string BatchedCellPreAct = - patterns::UniqueKey("BatchedCellPreAct"); - const std::string BatchedGate = patterns::UniqueKey("BatchedGate"); - const std::string CheckedCell = patterns::UniqueKey("CheckedCell"); - - scope->Var(BatchedInput)->GetMutable(); - scope->Var(BatchedCellPreAct)->GetMutable(); - scope->Var(BatchedGate)->GetMutable(); - scope->Var(CheckedCell)->GetMutable(); - op_desc.SetInput("H0", {}); op_desc.SetInput("C0", {}); op_desc.SetOutput("Hidden", {hidden->Name()}); op_desc.SetOutput("Cell", {cell->Name()}); op_desc.SetOutput("XX", {xx->Name()}); - op_desc.SetOutput("BatchedGate", {BatchedGate}); - op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct}); - op_desc.SetOutput("BatchedInput", {BatchedInput}); - op_desc.SetOutput("CheckedCell", {CheckedCell}); op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse")); op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes")); // TODO(TJ): get from attr op_desc.SetAttr("use_seq", true); - PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); - auto& scope = graph->Get(kParamScopeAttr); +// Create temp variables. #define OP_SET_OUT(x) \ const std::string x = patterns::UniqueKey(#x); \ - op_desc.SetOutput(#x, {x}); \ - scope.Var(x)->GetMutable() + op_desc.SetOutput(#x, {x}); + + OP_SET_OUT(BatchedGate); + OP_SET_OUT(BatchedCellPreAct); + OP_SET_OUT(BatchedInput); + OP_SET_OUT(CheckedCell); OP_SET_OUT(BatchedCell); OP_SET_OUT(BatchedHidden); OP_SET_OUT(ReorderedH0); @@ -113,11 +100,29 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, #undef OP_SET_OUT auto* op = graph->CreateOpNode(&op_desc); + IR_NODE_LINK_TO(input, op); IR_NODE_LINK_TO(weight_x, op); IR_NODE_LINK_TO(weight_h, op); IR_NODE_LINK_TO(bias, op); IR_NODE_LINK_TO(op, hidden); + +#define IR_NODE(x) \ + VarDesc key_##x(x); \ + key_##x.SetPersistable(false); \ + auto* node_##x = graph->CreateVarNode(&key_##x); \ + IR_NODE_LINK_TO(op, node_##x); + + IR_NODE(BatchedGate); + IR_NODE(BatchedCellPreAct); + IR_NODE(BatchedInput); + IR_NODE(CheckedCell); + IR_NODE(BatchedCell); + IR_NODE(BatchedHidden); + IR_NODE(ReorderedH0); + IR_NODE(ReorderedC0); +#undef IR_NODE + return op; }; diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc index b230c50167136d2616068078ce619e8362c38fde..bd826709b1d88abefbfdf487603b5c157ca7bd95 100644 --- a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc @@ -214,7 +214,9 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const { op_desc.SetInput("FCWeight", {fc_w->Name()}); op_desc.SetInput("FCBias", {fc_bias->Name()}); const std::string fc_out_tmp = fc_out->Name() + ".tmp"; - param_scope()->Var(fc_out_tmp)->GetMutable(); + VarDesc fc_out_key(fc_out_tmp); + fc_out_key.SetPersistable(false); + auto* fc_out_node = graph->CreateVarNode(&fc_out_key); op_desc.SetOutput("FCOut", {fc_out_tmp}); op_desc.SetOutput("Out", {fc_out->Name()}); op_desc.SetAttr("fc_activation", act->Op()->Type()); @@ -227,6 +229,7 @@ void SeqConcatFcFusePass::ApplyImpl(ir::Graph* graph) const { IR_NODE_LINK_TO(sequence_expand0_in, op_node); IR_NODE_LINK_TO(sequence_expand1_in, op_node); IR_NODE_LINK_TO(op_node, fc_out); + IR_NODE_LINK_TO(op_node, fc_out_node); // Clean nodes. std::unordered_set marked_nodes; diff --git a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc index 556d28a42ae8d664712417add43732cb57f67355..1485a84d001acef8542a9dda5436cfeb57518d69 100644 --- a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc @@ -42,18 +42,19 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { op_desc.SetAttr("contextLength", seqconv->Op()->GetAttr("contextLength")); op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart")); op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride")); - PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); - auto& scope = graph->Get(kParamScopeAttr); const std::string ColMat = patterns::UniqueKey("SeqConvColMat"); op_desc.SetOutput("ColMat", {ColMat}); op_desc.SetOutput("Out", {relu_out->Name()}); - scope.Var(ColMat)->GetMutable(); + VarDesc key(ColMat); + key.SetPersistable(false); + auto* key_col_mat = graph->CreateVarNode(&key); auto* op = graph->CreateOpNode(&op_desc); IR_NODE_LINK_TO(input, op); IR_NODE_LINK_TO(seqconv_weight, op); IR_NODE_LINK_TO(eltadd_bias, op); IR_NODE_LINK_TO(op, relu_out); + IR_NODE_LINK_TO(op, key_col_mat); return op; };