未验证 提交 d74bfefe 编写于 作者: L Leo Chen 提交者: GitHub

polish code of pass and executor (#56886)

* polish code of pass and executor

* update ut
上级 061bb9d5
......@@ -123,9 +123,8 @@ void ProgramInterpreter::RunImpl() {
#endif
}
FetchList ProgramInterpreter::Run(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) {
FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
bool need_fetch) {
SetDeviceId(place_);
CheckCUDAGraphBeforeRun(feed_names);
......@@ -133,10 +132,32 @@ FetchList ProgramInterpreter::Run(
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
bool is_build = is_build_;
Prepare(feed_names, feed_tensors, is_build);
if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
paddle::framework::interpreter::BuildVariableScope(
block_, execution_config_, &var_scope_);
if (is_build) {
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::BuildOpFuncList(
place_,
block_,
execution_config_.skip_gc_vars,
&op_func_nodes,
&var_scope_,
execution_config_,
HasLocalScope(),
static_build_);
SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph
Convert(&op_func_nodes);
UpdateSyncOpNum();
if (static_build_) {
VLOG(4) << "RUN impl";
RunImpl();
}
is_build_ = true;
is_shared_results_build_ = true;
} else {
RunImpl();
}
......@@ -145,8 +166,10 @@ FetchList ProgramInterpreter::Run(
}
// return Fetch Tensors
auto* fetch_var = local_scope_->FindVar(interpreter::kFetchVarName);
if (fetch_var) {
Scope* inner_scope =
HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var && need_fetch) {
auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
......@@ -162,8 +185,9 @@ FetchList ProgramInterpreter::Run(
}
}
FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
bool need_fetch) {
FetchList ProgramInterpreter::Run(
const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors) {
SetDeviceId(place_);
CheckCUDAGraphBeforeRun(feed_names);
......@@ -171,32 +195,10 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
paddle::framework::interpreter::BuildVariableScope(
block_, execution_config_, &var_scope_);
bool is_build = is_build_;
Prepare(feed_names, feed_tensors, is_build);
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::BuildOpFuncList(
place_,
block_,
execution_config_.skip_gc_vars,
&op_func_nodes,
&var_scope_,
execution_config_,
HasLocalScope(),
static_build_);
SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph
Convert(&op_func_nodes);
UpdateSyncOpNum();
if (static_build_) {
VLOG(4) << "RUN impl";
RunImpl();
}
is_build_ = true;
is_shared_results_build_ = true;
} else {
if (is_build) {
RunImpl();
}
......@@ -208,7 +210,7 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
Scope* inner_scope =
HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var && need_fetch) {
if (fetch_var) {
auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
......
......@@ -507,7 +507,7 @@ void BindPassManager(pybind11::module *m) {
},
py::arg("opt_level") = 2)
.def("add_pass",
[](PassManager &self, std::string pass_name) {
[](PassManager &self, const std::string &pass_name) {
self.AddPass(
std::move(ir::PassRegistry::Instance().Get(pass_name)));
})
......
......@@ -26,7 +26,7 @@ namespace {
// Now just a naive implementation.
class DeadCodeEliminationPass : public ir::Pass {
public:
DeadCodeEliminationPass() : ir::Pass("DeadCodeEliminationPass", 0) {}
DeadCodeEliminationPass() : ir::Pass("dead_code_elimination", 0) {}
void Run(ir::Operation *op) override {
auto module_op = op->dyn_cast<ir::ModuleOp>();
......
......@@ -56,8 +56,7 @@ class TestShadowOutputSlice(unittest.TestCase):
pm.run(new_program)
op_names = [op.name() for op in new_program.block().ops]
# print(op_names)
# TODO(zhiqiu): unify the name of pass
self.assertEqual(pm.passes(), ['DeadCodeEliminationPass'])
self.assertEqual(pm.passes(), ['dead_code_elimination'])
self.assertFalse(pm.empty())
self.assertTrue(
'pd.uniform' not in op_names
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册