未验证 提交 e9f9da14 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine some code for NewIRInterpreter (#55169)

* fix bug

* fix bug

* refien code

* refien code

* fix bug

* refine code
上级 c234f1f2
......@@ -185,7 +185,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
::ir::BuildScope(
ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_);
*ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
interpreter::BuildOpFuncList(place_,
......@@ -213,7 +213,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
}
// return Fetch Tensors
Scope* inner_scope = HasLocalScope() ? local_scope_ : scope_;
Scope* inner_scope = InnerScope();
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var && need_fetch) {
auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
......@@ -323,7 +323,7 @@ std::shared_ptr<interpreter::AsyncWorkQueue> NewIRInterpreter::GetWorkQueue() {
}
void NewIRInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) {
Scope* inner_scope = HasLocalScope() ? local_scope_ : scope_;
Scope* inner_scope = InnerScope();
VariableValueMap ins_map;
for (auto& var_name_item : instr_node->Inputs()) {
std::vector<Variable*> input_vars;
......@@ -350,8 +350,8 @@ void NewIRInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) {
if (instr_node->OpBase()->Type() == "cinn_launch" ||
instr_node->OpBase()->Type() == "cinn_instruction_run") { // OP use scope
// in kernel
Scope* local_scope = HasLocalScope() ? local_scope_ : scope_;
instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope);
Scope* inner_scope = InnerScope();
instr_node->ResetContextWithScope(ins_map, outs_map, *inner_scope);
} else {
instr_node->ResetContext(ins_map, outs_map);
}
......@@ -380,7 +380,7 @@ void NewIRInterpreter::BuildInplace() {
}
}
Scope* local_scope = HasLocalScope() ? local_scope_ : scope_;
Scope* local_scope = InnerScope();
std::vector<std::vector<size_t>> input_var2op(var_scope_.VarSize());
for (Instruction& instr : vec_instruction_) {
for (auto& item : instr.Inputs()) {
......@@ -798,7 +798,7 @@ void NewIRInterpreter::BuildSkipShareLoDInfo() {
void NewIRInterpreter::RunOperator(const Instruction& instr_node) {
auto* op = instr_node.OpBase();
auto place = instr_node.DeviceContext().GetPlace();
Scope* local_scope = HasLocalScope() ? local_scope_ : scope_;
Scope* local_scope = InnerScope();
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
......@@ -1334,6 +1334,10 @@ void NewIRInterpreter::SetFeedVarsInplaceSkip(
bool NewIRInterpreter::HasLocalScope() const { return local_scope_ != nullptr; }
Scope* NewIRInterpreter::InnerScope() {
return local_scope_ != nullptr ? local_scope_ : scope_;
}
// Note(zhangbo):
// (1) What is "Trace"?
// The OP execute scheduling rule adopted by Interpretercore by default is a
......
......@@ -116,6 +116,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {
// scope
bool HasLocalScope() const;
Scope* InnerScope();
// For log and debug
std::string GetDepsString() const;
......
......@@ -101,6 +101,14 @@ const Scope* Scope::FindScope(const std::string& name) const {
return FindScopeInternal(name);
}
const Scope* Scope::root() const {
const Scope* root_scope = this;
while (root_scope->parent()) {
root_scope = root_scope->parent();
}
return root_scope;
}
void Scope::DropKids() {
{
SCOPE_KIDS_WRITER_LOCK
......
......@@ -89,6 +89,8 @@ class Scope {
const Scope* parent() const { return parent_; }
const Scope* root() const;
/// Find the scope or an ancestor scope that contains the given variable.
const Scope* FindScope(const Variable* var) const;
......
......@@ -56,7 +56,7 @@ class PhiKernelAdaptor {
void run_kernel_prog(ir::Program* program) {
auto block = program->block();
std::unordered_map<ir::Value, std::string> name_map;
BuildScope(block, scope_, nullptr, &name_map);
BuildScope(*block, scope_, nullptr, &name_map);
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
......
......@@ -44,7 +44,7 @@
namespace ir {
paddle::framework::Variable* CreateVar(ir::Value value,
std::string name,
const std::string& name,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope) {
Operation* def_op = value.GetDefiningOp();
......@@ -56,12 +56,8 @@ paddle::framework::Variable* CreateVar(ir::Value value,
.data();
}
if (is_persisable) {
const paddle::framework::Scope* ancestor_scope = scope;
while (ancestor_scope->parent()) {
ancestor_scope = ancestor_scope->parent();
}
VLOG(6) << "Create var: " << name << " in scope " << ancestor_scope;
return const_cast<paddle::framework::Scope*>(ancestor_scope)->Var(name);
VLOG(6) << "Create var: " << name << " in scope " << scope->root();
return const_cast<paddle::framework::Scope*>(scope->root())->Var(name);
} else {
VLOG(6) << "Create var: " << name << " in scope " << local_scope;
return local_scope->Var(name);
......@@ -164,10 +160,13 @@ void HandleForSpecialOp(ir::Operation* op,
}
}
void BuildScope(ir::Block* block,
void BuildScope(const ir::Block& block,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map) {
VLOG(4) << "***** [before build] scope: ******\n"
<< paddle::framework::GenScopeTreeDebugInfo(
const_cast<paddle::framework::Scope*>(scope->root()));
// NOTE(zhiqiu): if use local_scope (local_scope != nullptr), the persistable
// is created in scope , and other is created in local_scope.
auto inner_local_scope = local_scope != nullptr ? local_scope : scope;
......@@ -176,7 +175,7 @@ void BuildScope(ir::Block* block,
// int count = name_map->size();
int count = name_map->size();
for (auto it = block->begin(); it != block->end(); ++it) {
for (auto it = block.begin(); it != block.end(); ++it) {
ir::Operation* op = *it;
auto attr_map = op->attributes();
......@@ -250,6 +249,9 @@ void BuildScope(ir::Block* block,
}
}
}
VLOG(4) << "***** [after build] scope: ******\n"
<< paddle::framework::GenScopeTreeDebugInfo(
const_cast<paddle::framework::Scope*>(scope->root()));
}
} // namespace ir
......@@ -41,7 +41,7 @@
namespace ir {
paddle::framework::Variable* CreateVar(ir::Value value,
std::string name,
const std::string& name,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope);
......@@ -51,7 +51,7 @@ void HandleForSpecialOp(ir::Operation* op,
std::unordered_map<ir::Value, std::string>* name_map,
int& count); // NOLINT
void BuildScope(ir::Block* block,
void BuildScope(const ir::Block& block,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map);
......
......@@ -40,6 +40,8 @@ class IR_API Block {
bool empty() const { return ops_.empty(); }
size_t size() const { return ops_.size(); }
const_iterator begin() const { return ops_.begin(); }
const_iterator end() const { return ops_.end(); }
iterator begin() { return ops_.begin(); }
iterator end() { return ops_.end(); }
reverse_iterator rbegin() { return ops_.rbegin(); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册