From d74bfefe3cd5fbc606351361b05bfe2de975421e Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Sat, 2 Sep 2023 19:54:42 +0800 Subject: [PATCH] polish code of pass and executor (#56886) * polish code of pass and executor * update ut --- .../new_executor/program_interpreter.cc | 74 ++++++++++--------- paddle/fluid/pybind/ir.cc | 2 +- .../transforms/dead_code_elimination_pass.cc | 2 +- test/ir/new_ir/test_pass_manager.py | 3 +- 4 files changed, 41 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index c94a326f698..43e5301476d 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -123,9 +123,8 @@ void ProgramInterpreter::RunImpl() { #endif } -FetchList ProgramInterpreter::Run( - const std::vector& feed_names, - const std::vector& feed_tensors) { +FetchList ProgramInterpreter::Run(const std::vector& feed_names, + bool need_fetch) { SetDeviceId(place_); CheckCUDAGraphBeforeRun(feed_names); @@ -133,10 +132,32 @@ FetchList ProgramInterpreter::Run( platform::AttachPointerHashToMKLDNNKey(this, place_); #endif - bool is_build = is_build_; - Prepare(feed_names, feed_tensors, is_build); + if (!is_build_) { + LOG_FIRST_N(INFO, 1) << "New Executor is Running."; + paddle::framework::interpreter::BuildVariableScope( + block_, execution_config_, &var_scope_); - if (is_build) { + std::vector op_func_nodes; + paddle::framework::interpreter::BuildOpFuncList( + place_, + block_, + execution_config_.skip_gc_vars, + &op_func_nodes, + &var_scope_, + execution_config_, + HasLocalScope(), + static_build_); + SetFeedVarsInplaceSkip(feed_names); + // convert vec func_list to graph + Convert(&op_func_nodes); + UpdateSyncOpNum(); + if (static_build_) { + VLOG(4) << "RUN impl"; + RunImpl(); + } + is_build_ = true; + is_shared_results_build_ = true; + } else { RunImpl(); } @@ -145,8 +166,10 @@ FetchList ProgramInterpreter::Run( } // return Fetch Tensors - auto* fetch_var = local_scope_->FindVar(interpreter::kFetchVarName); - if (fetch_var) { + Scope* inner_scope = + HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope(); + auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); + if (fetch_var && need_fetch) { auto fetch_list = std::move(*fetch_var->GetMutable()); #ifdef PADDLE_WITH_CUDA if (platform::IsCUDAGraphCapturing()) { @@ -162,8 +185,9 @@ FetchList ProgramInterpreter::Run( } } -FetchList ProgramInterpreter::Run(const std::vector& feed_names, - bool need_fetch) { +FetchList ProgramInterpreter::Run( + const std::vector& feed_names, + const std::vector& feed_tensors) { SetDeviceId(place_); CheckCUDAGraphBeforeRun(feed_names); @@ -171,32 +195,10 @@ FetchList ProgramInterpreter::Run(const std::vector& feed_names, platform::AttachPointerHashToMKLDNNKey(this, place_); #endif - if (!is_build_) { - LOG_FIRST_N(INFO, 1) << "New Executor is Running."; - paddle::framework::interpreter::BuildVariableScope( - block_, execution_config_, &var_scope_); + bool is_build = is_build_; + Prepare(feed_names, feed_tensors, is_build); - std::vector op_func_nodes; - paddle::framework::interpreter::BuildOpFuncList( - place_, - block_, - execution_config_.skip_gc_vars, - &op_func_nodes, - &var_scope_, - execution_config_, - HasLocalScope(), - static_build_); - SetFeedVarsInplaceSkip(feed_names); - // convert vec func_list to graph - Convert(&op_func_nodes); - UpdateSyncOpNum(); - if (static_build_) { - VLOG(4) << "RUN impl"; - RunImpl(); - } - is_build_ = true; - is_shared_results_build_ = true; - } else { + if (is_build) { RunImpl(); } @@ -208,7 +210,7 @@ FetchList ProgramInterpreter::Run(const std::vector& feed_names, Scope* inner_scope = HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope(); auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); - if (fetch_var && need_fetch) { + if (fetch_var) { auto fetch_list = std::move(*fetch_var->GetMutable()); #ifdef PADDLE_WITH_CUDA if (platform::IsCUDAGraphCapturing()) { diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 675e6f2acd2..67da2ba77db 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -507,7 +507,7 @@ void BindPassManager(pybind11::module *m) { }, py::arg("opt_level") = 2) .def("add_pass", - [](PassManager &self, std::string pass_name) { + [](PassManager &self, const std::string &pass_name) { self.AddPass( std::move(ir::PassRegistry::Instance().Get(pass_name))); }) diff --git a/paddle/ir/transforms/dead_code_elimination_pass.cc b/paddle/ir/transforms/dead_code_elimination_pass.cc index d56b83b8446..461ab9c6708 100644 --- a/paddle/ir/transforms/dead_code_elimination_pass.cc +++ b/paddle/ir/transforms/dead_code_elimination_pass.cc @@ -26,7 +26,7 @@ namespace { // Now just a naive implementation. class DeadCodeEliminationPass : public ir::Pass { public: - DeadCodeEliminationPass() : ir::Pass("DeadCodeEliminationPass", 0) {} + DeadCodeEliminationPass() : ir::Pass("dead_code_elimination", 0) {} void Run(ir::Operation *op) override { auto module_op = op->dyn_cast(); diff --git a/test/ir/new_ir/test_pass_manager.py b/test/ir/new_ir/test_pass_manager.py index 2f31e945f31..761baaea13b 100644 --- a/test/ir/new_ir/test_pass_manager.py +++ b/test/ir/new_ir/test_pass_manager.py @@ -56,8 +56,7 @@ class TestShadowOutputSlice(unittest.TestCase): pm.run(new_program) op_names = [op.name() for op in new_program.block().ops] # print(op_names) - # TODO(zhiqiu): unify the name of pass - self.assertEqual(pm.passes(), ['DeadCodeEliminationPass']) + self.assertEqual(pm.passes(), ['dead_code_elimination']) self.assertFalse(pm.empty()) self.assertTrue( 'pd.uniform' not in op_names -- GitLab