From 4a1b7fec967837ca283210aa669da3513c9f1aa1 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Thu, 16 May 2019 10:11:03 +0800 Subject: [PATCH] Add setting Scope function for the graph class (#17417) * add set_not_owned function for graph * add scope set. test=develop * add scope_ptr enforce not null before setting.test=develop --- .../framework/ir/attention_lstm_fuse_pass.cc | 40 +++++++++---------- .../ir/embedding_fc_lstm_fuse_pass.cc | 4 +- paddle/fluid/framework/ir/fc_gru_fuse_pass.cc | 11 +++-- .../fluid/framework/ir/fc_lstm_fuse_pass.cc | 4 +- paddle/fluid/framework/ir/fuse_pass_base.cc | 6 ++- .../conv_bias_mkldnn_fuse_pass_tester.cc | 2 +- .../ir/mkldnn/cpu_quantize_pass_tester.cc | 2 +- .../mkldnn/cpu_quantize_squash_pass_tester.cc | 2 +- .../ir/seqconv_eltadd_relu_fuse_pass.cc | 4 +- .../inference/analysis/ir_pass_manager.cc | 6 +-- .../analysis/passes/ir_graph_build_pass.cc | 8 ++-- .../fluid/inference/api/mkldnn_quantizer.cc | 5 ++- paddle/fluid/pybind/ir.cc | 9 +++++ 13 files changed, 58 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc index 5a82d7927f4..c4ffb2a9de4 100644 --- a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc @@ -136,22 +136,22 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, void PrepareParameters(Graph* graph, const Param& param) { // Check parameters PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); - auto* scope = graph->Get(kParamScopeAttr); + auto& scope = graph->Get(kParamScopeAttr); // Create new parameters. - scope->Var(param.LSTMWeight)->GetMutable(); - scope->Var(param.LSTMBias)->GetMutable(); - scope->Var(param.Hidden)->GetMutable(); - scope->Var(param.Cell)->GetMutable(); - scope->Var(param.AttentionedX)->GetMutable(); - scope->Var(param.AttentionFCOut)->GetMutable(); - scope->Var(param.LSTMX)->GetMutable(); - scope->Var(param.LSTMOUT)->GetMutable(); + scope.Var(param.LSTMWeight)->GetMutable(); + scope.Var(param.LSTMBias)->GetMutable(); + scope.Var(param.Hidden)->GetMutable(); + scope.Var(param.Cell)->GetMutable(); + scope.Var(param.AttentionedX)->GetMutable(); + scope.Var(param.AttentionFCOut)->GetMutable(); + scope.Var(param.LSTMX)->GetMutable(); + scope.Var(param.LSTMOUT)->GetMutable(); #define GATE_W(name__) \ - auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \ - auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \ - auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \ + auto* W_##name__##_w0 = scope.FindVar(#name__ ".w_0"); \ + auto* W_##name__##_w1 = scope.FindVar(#name__ ".w_1"); \ + auto* W_##name__##_b0 = scope.FindVar(#name__ ".b_0"); \ CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \ VLOG(4) << #name__ "_w0" \ << " shape: " << W_##name__##_w0->Get().dims(); \ @@ -169,26 +169,26 @@ void PrepareParameters(Graph* graph, const Param& param) { GATE_W(c); #undef GATE_W - auto* attention_fc_w = scope->FindVar("attention_fc.w_0"); - auto* attention_fc_b = scope->FindVar("attention_fc.b_0"); - auto* attention_output_w = scope->FindVar("attention_output.w_0"); - auto* attention_output_b = scope->FindVar("attention_output.b_0"); + auto* attention_fc_w = scope.FindVar("attention_fc.w_0"); + auto* attention_fc_b = scope.FindVar("attention_fc.b_0"); + auto* attention_output_w = scope.FindVar("attention_output.w_0"); + auto* attention_output_b = scope.FindVar("attention_output.b_0"); CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w, attention_output_b); - auto* lstm_weight = scope->Var(param.LSTMWeight); + auto* lstm_weight = scope.Var(param.LSTMWeight); auto* lstm_weight_t = lstm_weight->GetMutable(); - auto* lstm_bias = scope->Var(param.LSTMBias); + auto* lstm_bias = scope.Var(param.LSTMBias); auto* lstm_bias_t = lstm_bias->GetMutable(); // reshape attention_bias auto* attention_bias_t = - scope->FindVar(param.AttentionBias)->GetMutable(); + scope.FindVar(param.AttentionBias)->GetMutable(); PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1); attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]})); auto* attention_scalar_bias_t = - scope->FindVar(param.AttentionScalarBias)->GetMutable(); + scope.FindVar(param.AttentionScalarBias)->GetMutable(); attention_scalar_bias_t->Resize( make_ddim({1, attention_scalar_bias_t->dims()[0]})); diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc index 3a6bbe65b36..6462e7bf4c0 100644 --- a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc @@ -151,11 +151,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, op_desc.SetAttr("use_seq", true); PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); - auto* scope = graph->Get(kParamScopeAttr); + auto& scope = graph->Get(kParamScopeAttr); #define OP_SET_OUT(x) \ const std::string x = patterns::UniqueKey(#x); \ op_desc.SetOutput(#x, {x}); \ - scope->Var(x)->GetMutable() + scope.Var(x)->GetMutable() OP_SET_OUT(BatchedCell); OP_SET_OUT(BatchedHidden); OP_SET_OUT(ReorderedH0); diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc index 5f660c6d366..10cbe319ac8 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc @@ -69,16 +69,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, auto* op = graph->CreateOpNode(&op_desc); PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); - auto* scope = graph->Get(kParamScopeAttr); - PADDLE_ENFORCE(scope); + auto& scope = graph->Get(kParamScopeAttr); if (with_fc_bias) { // Fusion GRU bias = fcbias + grubias - auto* fusion_bias_var = scope->Var(NEW_NAME(bias) + bias->Name()); + auto* fusion_bias_var = scope.Var(NEW_NAME(bias) + bias->Name()); auto* out_bias_tensor = fusion_bias_var->GetMutable(); PADDLE_ENFORCE(fusion_bias_var); - auto* gru_bias_var = scope->FindVar(bias->Name()); - auto* fc_bias_var = scope->FindVar(fc_bias->Name()); + auto* gru_bias_var = scope.FindVar(bias->Name()); + auto* fc_bias_var = scope.FindVar(fc_bias->Name()); PADDLE_ENFORCE(gru_bias_var); PADDLE_ENFORCE(fc_bias_var); const auto& gru_bias_tenosr = gru_bias_var->Get(); @@ -94,7 +93,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, #undef GET_NODE #define NEW_IMTERMEDIATE_OUT(key) \ - scope->Var(NEW_NAME(key))->GetMutable() + scope.Var(NEW_NAME(key))->GetMutable() NEW_IMTERMEDIATE_OUT(ReorderedH0); NEW_IMTERMEDIATE_OUT(XX); NEW_IMTERMEDIATE_OUT(BatchedInput); diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index babeba96149..6858a98be39 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -100,11 +100,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, op_desc.SetAttr("use_seq", true); PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); - auto* scope = graph->Get(kParamScopeAttr); + auto& scope = graph->Get(kParamScopeAttr); #define OP_SET_OUT(x) \ const std::string x = patterns::UniqueKey(#x); \ op_desc.SetOutput(#x, {x}); \ - scope->Var(x)->GetMutable() + scope.Var(x)->GetMutable() OP_SET_OUT(BatchedCell); OP_SET_OUT(BatchedHidden); OP_SET_OUT(ReorderedH0); diff --git a/paddle/fluid/framework/ir/fuse_pass_base.cc b/paddle/fluid/framework/ir/fuse_pass_base.cc index d70010089e4..5e2523607d6 100644 --- a/paddle/fluid/framework/ir/fuse_pass_base.cc +++ b/paddle/fluid/framework/ir/fuse_pass_base.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include namespace paddle { namespace framework { @@ -25,7 +26,8 @@ void FusePassBase::Init(const std::string& repr, Graph* graph) const { Scope* FusePassBase::param_scope() const { PADDLE_ENFORCE(graph_->Has(kParamScopeAttr)); - return graph_->Get(kParamScopeAttr); + auto& scope = graph_->Get(kParamScopeAttr); + return &scope; } void FusePassBase::AddStatis(int count_of_fused) const { @@ -55,7 +57,7 @@ FuseOptions FusePassBase::FindFuseOption(const Node& node1, #else return FUSE_NATIVE; #endif -}; +} } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc index ff7f9190fde..9f618176747 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc @@ -97,7 +97,7 @@ void MainTest(bool convWithExistingBias) { InitTensorHolder(&scope, place, "conv_bias"); InitTensorHolder(&scope, place, "eltwise_bias"); } - graph->Set(kParamScopeAttr, new framework::Scope*(&scope)); + graph->SetNotOwned(kParamScopeAttr, &scope); auto pass = PassRegistry::Instance().Get("conv_bias_mkldnn_fuse_pass"); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc index 8716a412e4d..7d9d0ead0fe 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -132,7 +132,7 @@ void MainTest(const ProgramDesc& prog, int conv_count, int pool_count, (*scales)[v] = std::make_pair(false, std::move(tensor)); } - graph->Set(kParamScopeAttr, new framework::Scope*(&scope)); + graph->SetNotOwned(kParamScopeAttr, &scope); auto pass = PassRegistry::Instance().Get("cpu_quantize_pass"); pass->Set("quant_var_scales", scales); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc index fda337066f4..94cb42633f4 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc @@ -119,7 +119,7 @@ void MainTest(const ProgramDesc& prog, int removed_nodes_num) { InitTensorHolder(&scope, place, v.c_str()); } - graph->Set(kParamScopeAttr, new framework::Scope*(&scope)); + graph->SetNotOwned(kParamScopeAttr, &scope); auto pass = PassRegistry::Instance().Get("cpu_quantize_squash_pass"); diff --git a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc index 3fd368741fb..556d28a42ae 100644 --- a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc @@ -43,11 +43,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart")); op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride")); PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); - auto* scope = graph->Get(kParamScopeAttr); + auto& scope = graph->Get(kParamScopeAttr); const std::string ColMat = patterns::UniqueKey("SeqConvColMat"); op_desc.SetOutput("ColMat", {ColMat}); op_desc.SetOutput("Out", {relu_out->Name()}); - scope->Var(ColMat)->GetMutable(); + scope.Var(ColMat)->GetMutable(); auto* op = graph->CreateOpNode(&op_desc); IR_NODE_LINK_TO(input, op); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 4714c30507c..371118ffaf2 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -38,9 +38,9 @@ IRPassManager::IRPassManager(Argument *argument) { ARGUMENT_CHECK_FIELD(argument, main_program); graph_ = std::unique_ptr(new Graph(argument->main_program())); if (argument->Has("scope")) { - graph_->Set(framework::ir::kParamScopeAttr, - new framework::Scope *( - const_cast(&argument->scope()))); + auto *scope_ptr = argument->scope_ptr(); + PADDLE_ENFORCE(scope_ptr); + graph_->SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr); } ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes); diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc index c6e923c0048..970ecdbbeb0 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc @@ -13,9 +13,10 @@ // limitations under the License. #include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h" -#include +#include #include #include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/inference/io.h" #include "paddle/fluid/platform/enforce.h" @@ -56,8 +57,9 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { auto graph = std::unique_ptr(new Graph(argument->main_program())); argument->SetMainGraph(graph.release()); - argument->main_graph().Set(framework::ir::kParamScopeAttr, - new framework::Scope *(argument->scope_ptr())); + auto *scope_ptr = argument->scope_ptr(); + PADDLE_ENFORCE(scope_ptr); + argument->main_graph().SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr); } std::unique_ptr IrGraphBuildPass::LoadModel( diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index de75e884f53..0765d300f45 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -353,8 +353,9 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { arg.SetMainProgramNotOwned(predictor_.inference_program_.get()); auto graph = std::unique_ptr(new Graph(arg.main_program())); arg.SetMainGraph(graph.release()); - arg.main_graph().Set(framework::ir::kParamScopeAttr, - new framework::Scope*(arg.scope_ptr())); + auto* scope_ptr = arg.scope_ptr(); + PADDLE_ENFORCE(scope_ptr); + arg.main_graph().SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr); auto* builder = predictor_.config_.pass_builder(); builder->SetPasses({ diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 798e488f5b0..abc10765e4a 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -24,6 +24,7 @@ #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/var_desc.h" #include "pybind11/stl.h" @@ -37,6 +38,7 @@ using paddle::framework::ir::TopologySortOperations; using paddle::framework::ir::BuildOperationAdjList; using paddle::framework::OpDesc; using paddle::framework::ProgramDesc; +using paddle::framework::Scope; using paddle::framework::VarDesc; using pybind11::return_value_policy; @@ -57,12 +59,15 @@ void BindGraph(py::module *m) { .def(py::init()) .def("clone", &Graph::Clone) .def("has", &Graph::Has) + .def("get_bool", &Graph::Get) .def("get_int", &Graph::Get) .def("get_float", &Graph::Get) .def("get_double", &Graph::Get) .def("get_string", &Graph::Get) .def("get_marked_nodes", &Graph::Get>, return_value_policy::reference) + .def("set", [](Graph &self, const std::string &attr_name, + bool attr) { return self.Set(attr_name, new bool(attr)); }) .def("set", [](Graph &self, const std::string &attr_name, int attr) { return self.Set(attr_name, new int(attr)); }) .def("set", @@ -90,6 +95,10 @@ void BindGraph(py::module *m) { return self.Set(attr_name, new std::unordered_set(attr)); }) + .def("set_not_owned", + [](Graph &self, const std::string &attr_name, Scope &attr) { + self.SetNotOwned(attr_name, &attr); + }) .def("erase", &Graph::Erase) .def("nodes", &Graph::Nodes, return_value_policy::reference) .def("create_var_node", -- GitLab