未验证 提交 e9f9da14 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine some code for NewIRInterpreter (#55169)

* fix bug

* fix bug

* refien code

* refien code

* fix bug

* refine code
上级 c234f1f2
...@@ -185,7 +185,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names, ...@@ -185,7 +185,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
if (!is_build_) { if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is Running."; LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
::ir::BuildScope( ::ir::BuildScope(
ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_); *ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
interpreter::BuildOpFuncList(place_, interpreter::BuildOpFuncList(place_,
...@@ -213,7 +213,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names, ...@@ -213,7 +213,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
} }
// return Fetch Tensors // return Fetch Tensors
Scope* inner_scope = HasLocalScope() ? local_scope_ : scope_; Scope* inner_scope = InnerScope();
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var && need_fetch) { if (fetch_var && need_fetch) {
auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>()); auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
...@@ -323,7 +323,7 @@ std::shared_ptr<interpreter::AsyncWorkQueue> NewIRInterpreter::GetWorkQueue() { ...@@ -323,7 +323,7 @@ std::shared_ptr<interpreter::AsyncWorkQueue> NewIRInterpreter::GetWorkQueue() {
} }
void NewIRInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) { void NewIRInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) {
Scope* inner_scope = HasLocalScope() ? local_scope_ : scope_; Scope* inner_scope = InnerScope();
VariableValueMap ins_map; VariableValueMap ins_map;
for (auto& var_name_item : instr_node->Inputs()) { for (auto& var_name_item : instr_node->Inputs()) {
std::vector<Variable*> input_vars; std::vector<Variable*> input_vars;
...@@ -350,8 +350,8 @@ void NewIRInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) { ...@@ -350,8 +350,8 @@ void NewIRInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) {
if (instr_node->OpBase()->Type() == "cinn_launch" || if (instr_node->OpBase()->Type() == "cinn_launch" ||
instr_node->OpBase()->Type() == "cinn_instruction_run") { // OP use scope instr_node->OpBase()->Type() == "cinn_instruction_run") { // OP use scope
// in kernel // in kernel
Scope* local_scope = HasLocalScope() ? local_scope_ : scope_; Scope* inner_scope = InnerScope();
instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope); instr_node->ResetContextWithScope(ins_map, outs_map, *inner_scope);
} else { } else {
instr_node->ResetContext(ins_map, outs_map); instr_node->ResetContext(ins_map, outs_map);
} }
...@@ -380,7 +380,7 @@ void NewIRInterpreter::BuildInplace() { ...@@ -380,7 +380,7 @@ void NewIRInterpreter::BuildInplace() {
} }
} }
Scope* local_scope = HasLocalScope() ? local_scope_ : scope_; Scope* local_scope = InnerScope();
std::vector<std::vector<size_t>> input_var2op(var_scope_.VarSize()); std::vector<std::vector<size_t>> input_var2op(var_scope_.VarSize());
for (Instruction& instr : vec_instruction_) { for (Instruction& instr : vec_instruction_) {
for (auto& item : instr.Inputs()) { for (auto& item : instr.Inputs()) {
...@@ -798,7 +798,7 @@ void NewIRInterpreter::BuildSkipShareLoDInfo() { ...@@ -798,7 +798,7 @@ void NewIRInterpreter::BuildSkipShareLoDInfo() {
void NewIRInterpreter::RunOperator(const Instruction& instr_node) { void NewIRInterpreter::RunOperator(const Instruction& instr_node) {
auto* op = instr_node.OpBase(); auto* op = instr_node.OpBase();
auto place = instr_node.DeviceContext().GetPlace(); auto place = instr_node.DeviceContext().GetPlace();
Scope* local_scope = HasLocalScope() ? local_scope_ : scope_; Scope* local_scope = InnerScope();
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope); VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op); auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
...@@ -1334,6 +1334,10 @@ void NewIRInterpreter::SetFeedVarsInplaceSkip( ...@@ -1334,6 +1334,10 @@ void NewIRInterpreter::SetFeedVarsInplaceSkip(
bool NewIRInterpreter::HasLocalScope() const { return local_scope_ != nullptr; } bool NewIRInterpreter::HasLocalScope() const { return local_scope_ != nullptr; }
Scope* NewIRInterpreter::InnerScope() {
return local_scope_ != nullptr ? local_scope_ : scope_;
}
// Note(zhangbo): // Note(zhangbo):
// (1) What is "Trace"? // (1) What is "Trace"?
// The OP execute scheduling rule adopted by Interpretercore by default is a // The OP execute scheduling rule adopted by Interpretercore by default is a
......
...@@ -116,6 +116,8 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -116,6 +116,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {
// scope // scope
bool HasLocalScope() const; bool HasLocalScope() const;
Scope* InnerScope();
// For log and debug // For log and debug
std::string GetDepsString() const; std::string GetDepsString() const;
......
...@@ -101,6 +101,14 @@ const Scope* Scope::FindScope(const std::string& name) const { ...@@ -101,6 +101,14 @@ const Scope* Scope::FindScope(const std::string& name) const {
return FindScopeInternal(name); return FindScopeInternal(name);
} }
const Scope* Scope::root() const {
const Scope* root_scope = this;
while (root_scope->parent()) {
root_scope = root_scope->parent();
}
return root_scope;
}
void Scope::DropKids() { void Scope::DropKids() {
{ {
SCOPE_KIDS_WRITER_LOCK SCOPE_KIDS_WRITER_LOCK
......
...@@ -89,6 +89,8 @@ class Scope { ...@@ -89,6 +89,8 @@ class Scope {
const Scope* parent() const { return parent_; } const Scope* parent() const { return parent_; }
const Scope* root() const;
/// Find the scope or an ancestor scope that contains the given variable. /// Find the scope or an ancestor scope that contains the given variable.
const Scope* FindScope(const Variable* var) const; const Scope* FindScope(const Variable* var) const;
......
...@@ -56,7 +56,7 @@ class PhiKernelAdaptor { ...@@ -56,7 +56,7 @@ class PhiKernelAdaptor {
void run_kernel_prog(ir::Program* program) { void run_kernel_prog(ir::Program* program) {
auto block = program->block(); auto block = program->block();
std::unordered_map<ir::Value, std::string> name_map; std::unordered_map<ir::Value, std::string> name_map;
BuildScope(block, scope_, nullptr, &name_map); BuildScope(*block, scope_, nullptr, &name_map);
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>(); ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
......
...@@ -44,7 +44,7 @@ ...@@ -44,7 +44,7 @@
namespace ir { namespace ir {
paddle::framework::Variable* CreateVar(ir::Value value, paddle::framework::Variable* CreateVar(ir::Value value,
std::string name, const std::string& name,
paddle::framework::Scope* scope, paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope) { paddle::framework::Scope* local_scope) {
Operation* def_op = value.GetDefiningOp(); Operation* def_op = value.GetDefiningOp();
...@@ -56,12 +56,8 @@ paddle::framework::Variable* CreateVar(ir::Value value, ...@@ -56,12 +56,8 @@ paddle::framework::Variable* CreateVar(ir::Value value,
.data(); .data();
} }
if (is_persisable) { if (is_persisable) {
const paddle::framework::Scope* ancestor_scope = scope; VLOG(6) << "Create var: " << name << " in scope " << scope->root();
while (ancestor_scope->parent()) { return const_cast<paddle::framework::Scope*>(scope->root())->Var(name);
ancestor_scope = ancestor_scope->parent();
}
VLOG(6) << "Create var: " << name << " in scope " << ancestor_scope;
return const_cast<paddle::framework::Scope*>(ancestor_scope)->Var(name);
} else { } else {
VLOG(6) << "Create var: " << name << " in scope " << local_scope; VLOG(6) << "Create var: " << name << " in scope " << local_scope;
return local_scope->Var(name); return local_scope->Var(name);
...@@ -164,10 +160,13 @@ void HandleForSpecialOp(ir::Operation* op, ...@@ -164,10 +160,13 @@ void HandleForSpecialOp(ir::Operation* op,
} }
} }
void BuildScope(ir::Block* block, void BuildScope(const ir::Block& block,
paddle::framework::Scope* scope, paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope, paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map) { std::unordered_map<ir::Value, std::string>* name_map) {
VLOG(4) << "***** [before build] scope: ******\n"
<< paddle::framework::GenScopeTreeDebugInfo(
const_cast<paddle::framework::Scope*>(scope->root()));
// NOTE(zhiqiu): if use local_scope (local_scope != nullptr), the persistable // NOTE(zhiqiu): if use local_scope (local_scope != nullptr), the persistable
// is created in scope , and other is created in local_scope. // is created in scope , and other is created in local_scope.
auto inner_local_scope = local_scope != nullptr ? local_scope : scope; auto inner_local_scope = local_scope != nullptr ? local_scope : scope;
...@@ -176,7 +175,7 @@ void BuildScope(ir::Block* block, ...@@ -176,7 +175,7 @@ void BuildScope(ir::Block* block,
// int count = name_map->size(); // int count = name_map->size();
int count = name_map->size(); int count = name_map->size();
for (auto it = block->begin(); it != block->end(); ++it) { for (auto it = block.begin(); it != block.end(); ++it) {
ir::Operation* op = *it; ir::Operation* op = *it;
auto attr_map = op->attributes(); auto attr_map = op->attributes();
...@@ -250,6 +249,9 @@ void BuildScope(ir::Block* block, ...@@ -250,6 +249,9 @@ void BuildScope(ir::Block* block,
} }
} }
} }
VLOG(4) << "***** [after build] scope: ******\n"
<< paddle::framework::GenScopeTreeDebugInfo(
const_cast<paddle::framework::Scope*>(scope->root()));
} }
} // namespace ir } // namespace ir
...@@ -41,7 +41,7 @@ ...@@ -41,7 +41,7 @@
namespace ir { namespace ir {
paddle::framework::Variable* CreateVar(ir::Value value, paddle::framework::Variable* CreateVar(ir::Value value,
std::string name, const std::string& name,
paddle::framework::Scope* scope, paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope); paddle::framework::Scope* local_scope);
...@@ -51,7 +51,7 @@ void HandleForSpecialOp(ir::Operation* op, ...@@ -51,7 +51,7 @@ void HandleForSpecialOp(ir::Operation* op,
std::unordered_map<ir::Value, std::string>* name_map, std::unordered_map<ir::Value, std::string>* name_map,
int& count); // NOLINT int& count); // NOLINT
void BuildScope(ir::Block* block, void BuildScope(const ir::Block& block,
paddle::framework::Scope* scope, paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope, paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map); std::unordered_map<ir::Value, std::string>* name_map);
......
...@@ -40,6 +40,8 @@ class IR_API Block { ...@@ -40,6 +40,8 @@ class IR_API Block {
bool empty() const { return ops_.empty(); } bool empty() const { return ops_.empty(); }
size_t size() const { return ops_.size(); } size_t size() const { return ops_.size(); }
const_iterator begin() const { return ops_.begin(); }
const_iterator end() const { return ops_.end(); }
iterator begin() { return ops_.begin(); } iterator begin() { return ops_.begin(); }
iterator end() { return ops_.end(); } iterator end() { return ops_.end(); }
reverse_iterator rbegin() { return ops_.rbegin(); } reverse_iterator rbegin() { return ops_.rbegin(); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册