未验证 提交 d74bfefe 编写于 作者: L Leo Chen 提交者: GitHub

polish code of pass and executor (#56886)

* polish code of pass and executor

* update ut
上级 061bb9d5
...@@ -123,9 +123,8 @@ void ProgramInterpreter::RunImpl() { ...@@ -123,9 +123,8 @@ void ProgramInterpreter::RunImpl() {
#endif #endif
} }
FetchList ProgramInterpreter::Run( FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
const std::vector<std::string>& feed_names, bool need_fetch) {
const std::vector<phi::DenseTensor>& feed_tensors) {
SetDeviceId(place_); SetDeviceId(place_);
CheckCUDAGraphBeforeRun(feed_names); CheckCUDAGraphBeforeRun(feed_names);
...@@ -133,10 +132,32 @@ FetchList ProgramInterpreter::Run( ...@@ -133,10 +132,32 @@ FetchList ProgramInterpreter::Run(
platform::AttachPointerHashToMKLDNNKey(this, place_); platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif #endif
bool is_build = is_build_; if (!is_build_) {
Prepare(feed_names, feed_tensors, is_build); LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
paddle::framework::interpreter::BuildVariableScope(
block_, execution_config_, &var_scope_);
if (is_build) { std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::BuildOpFuncList(
place_,
block_,
execution_config_.skip_gc_vars,
&op_func_nodes,
&var_scope_,
execution_config_,
HasLocalScope(),
static_build_);
SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph
Convert(&op_func_nodes);
UpdateSyncOpNum();
if (static_build_) {
VLOG(4) << "RUN impl";
RunImpl();
}
is_build_ = true;
is_shared_results_build_ = true;
} else {
RunImpl(); RunImpl();
} }
...@@ -145,8 +166,10 @@ FetchList ProgramInterpreter::Run( ...@@ -145,8 +166,10 @@ FetchList ProgramInterpreter::Run(
} }
// return Fetch Tensors // return Fetch Tensors
auto* fetch_var = local_scope_->FindVar(interpreter::kFetchVarName); Scope* inner_scope =
if (fetch_var) { HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var && need_fetch) {
auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>()); auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) { if (platform::IsCUDAGraphCapturing()) {
...@@ -162,8 +185,9 @@ FetchList ProgramInterpreter::Run( ...@@ -162,8 +185,9 @@ FetchList ProgramInterpreter::Run(
} }
} }
FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names, FetchList ProgramInterpreter::Run(
bool need_fetch) { const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) {
SetDeviceId(place_); SetDeviceId(place_);
CheckCUDAGraphBeforeRun(feed_names); CheckCUDAGraphBeforeRun(feed_names);
...@@ -171,32 +195,10 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names, ...@@ -171,32 +195,10 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
platform::AttachPointerHashToMKLDNNKey(this, place_); platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif #endif
if (!is_build_) { bool is_build = is_build_;
LOG_FIRST_N(INFO, 1) << "New Executor is Running."; Prepare(feed_names, feed_tensors, is_build);
paddle::framework::interpreter::BuildVariableScope(
block_, execution_config_, &var_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes; if (is_build) {
paddle::framework::interpreter::BuildOpFuncList(
place_,
block_,
execution_config_.skip_gc_vars,
&op_func_nodes,
&var_scope_,
execution_config_,
HasLocalScope(),
static_build_);
SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph
Convert(&op_func_nodes);
UpdateSyncOpNum();
if (static_build_) {
VLOG(4) << "RUN impl";
RunImpl();
}
is_build_ = true;
is_shared_results_build_ = true;
} else {
RunImpl(); RunImpl();
} }
...@@ -208,7 +210,7 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names, ...@@ -208,7 +210,7 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
Scope* inner_scope = Scope* inner_scope =
HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope(); HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var && need_fetch) { if (fetch_var) {
auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>()); auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) { if (platform::IsCUDAGraphCapturing()) {
......
...@@ -507,7 +507,7 @@ void BindPassManager(pybind11::module *m) { ...@@ -507,7 +507,7 @@ void BindPassManager(pybind11::module *m) {
}, },
py::arg("opt_level") = 2) py::arg("opt_level") = 2)
.def("add_pass", .def("add_pass",
[](PassManager &self, std::string pass_name) { [](PassManager &self, const std::string &pass_name) {
self.AddPass( self.AddPass(
std::move(ir::PassRegistry::Instance().Get(pass_name))); std::move(ir::PassRegistry::Instance().Get(pass_name)));
}) })
......
...@@ -26,7 +26,7 @@ namespace { ...@@ -26,7 +26,7 @@ namespace {
// Now just a naive implementation. // Now just a naive implementation.
class DeadCodeEliminationPass : public ir::Pass { class DeadCodeEliminationPass : public ir::Pass {
public: public:
DeadCodeEliminationPass() : ir::Pass("DeadCodeEliminationPass", 0) {} DeadCodeEliminationPass() : ir::Pass("dead_code_elimination", 0) {}
void Run(ir::Operation *op) override { void Run(ir::Operation *op) override {
auto module_op = op->dyn_cast<ir::ModuleOp>(); auto module_op = op->dyn_cast<ir::ModuleOp>();
......
...@@ -56,8 +56,7 @@ class TestShadowOutputSlice(unittest.TestCase): ...@@ -56,8 +56,7 @@ class TestShadowOutputSlice(unittest.TestCase):
pm.run(new_program) pm.run(new_program)
op_names = [op.name() for op in new_program.block().ops] op_names = [op.name() for op in new_program.block().ops]
# print(op_names) # print(op_names)
# TODO(zhiqiu): unify the name of pass self.assertEqual(pm.passes(), ['dead_code_elimination'])
self.assertEqual(pm.passes(), ['DeadCodeEliminationPass'])
self.assertFalse(pm.empty()) self.assertFalse(pm.empty())
self.assertTrue( self.assertTrue(
'pd.uniform' not in op_names 'pd.uniform' not in op_names
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册