From 85831c32ed7175c4f24de5df3240dcbaa810a088 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Wed, 5 Jul 2023 17:23:16 +0800 Subject: [PATCH] =?UTF-8?q?[IR]=20New=20IR=20access=20InterpreterCore?= =?UTF-8?q?=EF=BC=9Aadd=20local=20scope=20logic=20(#55112)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add local scope * refine code * refien code * refine code * support local scope for BuildFuncList * fix bug * add log * fix bug * polish code * fix bug --- .../interpreter/interpreter_util.cc | 3 + .../interpreter/interpreter_util.h | 1 + .../new_executor/interpreter_base_impl.h | 2 + .../framework/new_executor/interpretercore.cc | 4 + .../framework/new_executor/interpretercore.h | 2 + .../new_executor/new_ir_interpreter.cc | 42 ++-- .../new_executor/new_ir_interpreter.h | 3 + .../new_executor/program_interpreter.cc | 1 + .../new_executor/program_interpreter.h | 2 + .../phi_kernel_adaptor/phi_kernel_adaptor.h | 6 +- .../ir/phi_kernel_adaptor/phi_kernel_util.cc | 227 +++++++++++------- .../ir/phi_kernel_adaptor/phi_kernel_util.h | 34 ++- .../standalone_executor_new_ir_test.cc | 18 +- 13 files changed, 217 insertions(+), 128 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index a8dbbedd038..13896b66f3c 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -940,6 +940,7 @@ void BuildOpFuncList( ::ir::Block* block, std::vector* vec_func_list, framework::Scope* scope, + framework::Scope* local_scope, const std::unordered_map<::ir::Value, std::string>& value_2_name_map, const ExecutionConfig& execution_config) { vec_func_list->reserve(block->size()); @@ -979,6 +980,7 @@ void BuildOpFuncList( false>((*it), value_2_name_map, scope, + local_scope, op_yaml_info_parser, &(op_func_node.infer_meta_context_)); @@ -1004,6 +1006,7 @@ void BuildOpFuncList( true>((*it), value_2_name_map, scope, + local_scope, op_yaml_info_parser, &(op_func_node.kernel_context_), &(op_func_node.input_index), diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h index f553609b04d..eb87c8bcb4c 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h @@ -98,6 +98,7 @@ void BuildOpFuncList( ::ir::Block* block, std::vector* vec_func_list, framework::Scope* scope, + framework::Scope* local_scope, const std::unordered_map<::ir::Value, std::string>& value_2_name_map, const ExecutionConfig& execution_config); diff --git a/paddle/fluid/framework/new_executor/interpreter_base_impl.h b/paddle/fluid/framework/new_executor/interpreter_base_impl.h index 8fc0a137d03..1ae7e5e59ce 100644 --- a/paddle/fluid/framework/new_executor/interpreter_base_impl.h +++ b/paddle/fluid/framework/new_executor/interpreter_base_impl.h @@ -86,6 +86,8 @@ class InterpreterBaseImpl { virtual void reset_scope(Scope* new_scope) = 0; + virtual const Scope* local_scope() const = 0; + virtual const platform::Place& GetPlace() const = 0; virtual void SetOutputHooks(const std::vector& hookfuncs) = 0; diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index ab68a477954..2a240158bcd 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -101,6 +101,10 @@ void InterpreterCore::reset_scope(Scope* new_scope) { impl_->reset_scope(new_scope); } +const Scope* InterpreterCore::local_scope() const { + return impl_->local_scope(); +} + const platform::Place& InterpreterCore::GetPlace() const { return impl_->GetPlace(); } diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 69e0e910237..8f719c595d2 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -65,6 +65,8 @@ class InterpreterCore { void reset_scope(Scope* new_scope); + const Scope* local_scope() const; + const platform::Place& GetPlace() const; void SetOutputHooks(const std::vector& hookfuncs); diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index 31819fc4a42..0e8d9d84199 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -49,12 +49,14 @@ NewIRInterpreter::NewIRInterpreter(const platform::Place& place, stream_analyzer_(place), execution_config_(execution_config), var_scope_(scope), + scope_(scope), ir_program_(std::move(ir_prog)) { VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_; static_build_ = FLAGS_new_executor_static_build && !FLAGS_new_executor_use_cuda_graph && !execution_config.used_for_control_flow_op; // &&interpreter::BlockCanBeStaticBuilt(block); + static_build_ = true; exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught); completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion); @@ -62,21 +64,19 @@ NewIRInterpreter::NewIRInterpreter(const platform::Place& place, if (!FLAGS_new_executor_use_local_scope) { execution_config_.create_local_scope = false; } - execution_config_.AnalyzeThreadPoolConfig(place, - ir_program_->block()->size()); - execution_config_.Log(/*log_level=*/8); - if (execution_config_.create_local_scope) { - auto local_scope = &var_scope_.GetMutableScope()->NewScope(); + auto local_scope = &scope_->NewScope(); local_scope_ = local_scope; + VLOG(6) << "new ir interpretercore scope: " << scope_ << "\t" + << "; local scope: " << local_scope_; } - - // force use outer scope for now - local_scope_ = scope; - static_build_ = true; - + // TODO(zhangbo): delete var_scope var_scope_.SetLocalScope(local_scope_); + execution_config_.AnalyzeThreadPoolConfig(place, + ir_program_->block()->size()); + execution_config_.Log(/*log_level=*/8); + instruction_scheduling_priority_less = [this](size_t lhs, size_t rhs) { SchedulingPriority lhs_scheduling_priority = vec_instruction_[lhs].GetSchedulingPriority(); @@ -185,12 +185,13 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, if (!is_build_) { LOG_FIRST_N(INFO, 1) << "New Executor is Running."; ::ir::BuildScope( - ir_program_->block(), local_scope_, &value_2_var_name_map_); + ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_); std::vector op_func_nodes; interpreter::BuildOpFuncList(place_, ir_program_->block(), &op_func_nodes, + scope_, local_scope_, value_2_var_name_map_, execution_config_); @@ -212,8 +213,7 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, } // return Fetch Tensors - Scope* inner_scope = - HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope(); + Scope* inner_scope = HasLocalScope() ? local_scope_ : scope_; auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); if (fetch_var && need_fetch) { auto fetch_list = std::move(*fetch_var->GetMutable()); @@ -287,6 +287,8 @@ void NewIRInterpreter::reset_scope(Scope* new_scope) { } } +const Scope* NewIRInterpreter::local_scope() const { return local_scope_; } + void NewIRInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) { async_work_queue_ = reinterpret_cast(src)->GetWorkQueue(); VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << src @@ -321,8 +323,7 @@ std::shared_ptr NewIRInterpreter::GetWorkQueue() { } void NewIRInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) { - Scope* inner_scope = - HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope(); + Scope* inner_scope = HasLocalScope() ? local_scope_ : scope_; VariableValueMap ins_map; for (auto& var_name_item : instr_node->Inputs()) { std::vector input_vars; @@ -349,8 +350,7 @@ void NewIRInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) { if (instr_node->OpBase()->Type() == "cinn_launch" || instr_node->OpBase()->Type() == "cinn_instruction_run") { // OP use scope // in kernel - Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope() - : var_scope_.GetMutableScope(); + Scope* local_scope = HasLocalScope() ? local_scope_ : scope_; instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope); } else { instr_node->ResetContext(ins_map, outs_map); @@ -380,8 +380,7 @@ void NewIRInterpreter::BuildInplace() { } } - Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope() - : var_scope_.GetMutableScope(); + Scope* local_scope = HasLocalScope() ? local_scope_ : scope_; std::vector> input_var2op(var_scope_.VarSize()); for (Instruction& instr : vec_instruction_) { for (auto& item : instr.Inputs()) { @@ -799,8 +798,7 @@ void NewIRInterpreter::BuildSkipShareLoDInfo() { void NewIRInterpreter::RunOperator(const Instruction& instr_node) { auto* op = instr_node.OpBase(); auto place = instr_node.DeviceContext().GetPlace(); - Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope() - : var_scope_.GetMutableScope(); + Scope* local_scope = HasLocalScope() ? local_scope_ : scope_; VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope); auto op_with_kernel = dynamic_cast(op); @@ -1047,7 +1045,7 @@ void NewIRInterpreter::ExecuteInstructionList( if (cancel) { break; } - VLOG(0) << "deps:\n" << GetDepsString(); + VLOG(6) << "deps:\n" << GetDepsString(); times++; } return times; diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.h b/paddle/fluid/framework/new_executor/new_ir_interpreter.h index 59eb155c123..dbd47691d1b 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.h +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.h @@ -60,6 +60,8 @@ class NewIRInterpreter : public InterpreterBaseImpl { void reset_scope(Scope* new_scope) override; + const Scope* local_scope() const override; + const platform::Place& GetPlace() const override { return place_; } void SetOutputHooks(const std::vector& hookfuncs) override { @@ -143,6 +145,7 @@ class NewIRInterpreter : public InterpreterBaseImpl { ExecutionConfig execution_config_; VariableScope var_scope_; + Scope* scope_{nullptr}; Scope* local_scope_{nullptr}; // not owned EventsWaiter main_thread_blocker_; diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index 3156575ebd3..b6c54192a69 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -275,6 +275,7 @@ void ProgramInterpreter::reset_scope(Scope* new_scope) { } } +const Scope* ProgramInterpreter::local_scope() const { return local_scope_; } void ProgramInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) { async_work_queue_ = reinterpret_cast(src)->GetWorkQueue(); diff --git a/paddle/fluid/framework/new_executor/program_interpreter.h b/paddle/fluid/framework/new_executor/program_interpreter.h index ca66691608d..a21cf26072c 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.h +++ b/paddle/fluid/framework/new_executor/program_interpreter.h @@ -62,6 +62,8 @@ class ProgramInterpreter : public InterpreterBaseImpl { void reset_scope(Scope* new_scope) override; + const Scope* local_scope() const override; + const platform::Place& GetPlace() const override { return place_; } void SetOutputHooks(const std::vector& hookfuncs) override { diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h index 8cc1b667967..6a67c09972a 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h @@ -56,7 +56,7 @@ class PhiKernelAdaptor { void run_kernel_prog(ir::Program* program) { auto block = program->block(); std::unordered_map name_map; - BuildScope(block, scope_, &name_map); + BuildScope(block, scope_, nullptr, &name_map); ir::IrContext* ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); @@ -87,7 +87,7 @@ class PhiKernelAdaptor { phi::MetaTensor, phi::MetaTensor, paddle::small_vector, - false>((*it), name_map, scope_, op_yaml_info_parser, &ctx); + false>((*it), name_map, scope_, nullptr, op_yaml_info_parser, &ctx); infer_meta_impl->infer_meta_(&ctx); @@ -107,7 +107,7 @@ class PhiKernelAdaptor { phi::TensorBase*, paddle::small_vector, true>( - (*it), name_map, scope_, op_yaml_info_parser, &kernel_ctx); + (*it), name_map, scope_, nullptr, op_yaml_info_parser, &kernel_ctx); kernel_fn(&kernel_ctx); auto out_value = (*it)->result(0); diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index 98e4487da46..6d40ad8bca5 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -43,114 +43,160 @@ namespace ir { -void BuildScope(ir::Block* block, - paddle::framework::Scope* scope, - std::unordered_map* name_map) { - std::unordered_map map_test; - - int count = name_map->size(); - for (auto it = block->begin(); it != block->end(); ++it) { - size_t input_num = (*it)->num_operands(); - auto attr_map = (*it)->attributes(); - std::string op_name = (*it)->name(); - if (attr_map.count("op_name")) { - op_name = attr_map.at("op_name").dyn_cast().data(); - } - if (op_name == "pd.fetch") { - // fetch is a very special op, with no output - for (size_t i = 0; i < input_num; ++i) { - auto var = scope->Var("fetch"); - auto fetch_list = var->GetMutable(); - int index = - (*it)->attributes().at("col").dyn_cast().data(); - fetch_list->resize(index + 1); - } - continue; +paddle::framework::Variable* CreateVar(ir::Value value, + std::string name, + paddle::framework::Scope* scope, + paddle::framework::Scope* local_scope) { + Operation* def_op = value.GetDefiningOp(); + bool is_persisable = false; + if (def_op->attributes().count("is_persisable")) { + is_persisable = def_op->attributes() + .at("is_persisable") + .dyn_cast() + .data(); + } + if (is_persisable) { + const paddle::framework::Scope* ancestor_scope = scope; + while (ancestor_scope->parent()) { + ancestor_scope = ancestor_scope->parent(); } + VLOG(6) << "Create var: " << name << " in scope " << ancestor_scope; + return const_cast(ancestor_scope)->Var(name); + } else { + VLOG(6) << "Create var: " << name << " in scope " << local_scope; + return local_scope->Var(name); + } +} - if (op_name == "builtin.set_parameter") { - auto param_name = (*it) - ->attributes() - .at("parameter_name") - .dyn_cast() - .data(); - - auto in_ptr = (*it)->operand(0); - // change opreand name to param_name - - auto orig_name = name_map->at(in_ptr); - (*name_map)[in_ptr] = param_name; - scope->Rename(orig_name, param_name); - continue; +void HandleForSpecialOp(ir::Operation* op, + paddle::framework::Scope* scope, + paddle::framework::Scope* local_scope, + std::unordered_map* name_map, + int& count) { // NOLINT + std::string op_name = op->name(); + if (op->attributes().count("op_name")) { + op_name = + op->attributes().at("op_name").dyn_cast().data(); + } + size_t input_num = op->num_operands(); + + if (op_name == "pd.fetch") { + // fetch is a very special op, with no output + VLOG(6) << "Handle for pd.fetch:"; + for (size_t i = 0; i < input_num; ++i) { + auto var = scope->Var("fetch"); + VLOG(6) << "Create var: fetch in scope " << scope; + auto fetch_list = var->GetMutable(); + int index = + op->attributes().at("col").dyn_cast().data(); + fetch_list->resize(index + 1); } + } - if (op_name == "builtin.get_parameter") { - auto param_name = (*it) - ->attributes() - .at("parameter_name") - .dyn_cast() - .data(); - - auto out_ptr = (*it)->result(0); + if (op_name == "pd.feed") { + VLOG(6) << "Handle for pd.feed:"; + auto ptr = op->result(0); + std::string name = "inner_var_" + std::to_string(count++); + name_map->emplace(ptr, name); + auto var = CreateVar(ptr, name, scope, local_scope); + // TODO(phlrain): need to update here, support StringTensor + auto out_tensor = var->GetMutable(); + + auto feed_var = scope->Var("feed"); + VLOG(6) << "Create var: feed in scope " << scope; + int index = + op->attributes().at("col").dyn_cast().data(); + auto feed_list = feed_var->Get(); + auto& in_tensor = (PADDLE_GET(phi::DenseTensor, feed_list.at(index))); + out_tensor->ShareDataWith(in_tensor); + } - name_map->emplace(out_ptr, param_name); - continue; + if (op_name == "builtin.combine") { + VLOG(6) << "Handle for builtin.combine:"; + auto out_value = op->result(0); + std::string name; + if (name_map->find(out_value) != name_map->end()) { + name = name_map->at(out_value); + } else { + name = "inner_var_" + std::to_string(count++); + name_map->emplace(out_value, name); } - if (op_name == "pd.feed") { - auto ptr = (*it)->result(0); - std::string name = "inner_var_" + std::to_string(count++); - name_map->emplace(ptr, name); - auto var = scope->Var(name); - // TODO(phlrain): need to update here, support StringTensor - auto out_tensor = var->GetMutable(); - - auto feed_var = scope->Var("feed"); - int index = - (*it)->attributes().at("col").dyn_cast().data(); - auto feed_list = feed_var->Get(); - auto& in_tensor = (PADDLE_GET(phi::DenseTensor, feed_list.at(index))); + auto var = CreateVar(out_value, name, scope, local_scope); + auto tensor_array = var->GetMutable(); - out_tensor->ShareDataWith(in_tensor); + for (size_t i = 0; i < input_num; ++i) { + auto ptr = op->operand(i); - continue; + PADDLE_ENFORCE_EQ( + name_map->count(ptr), + true, + phi::errors::PreconditionNotMet("can not found input of combine op")); + tensor_array->emplace_back( + &(CreateVar(ptr, name_map->at(ptr), scope, local_scope) + ->Get())); } + } - if (op_name == "builtin.combine") { - auto out_value = (*it)->result(0); + if (op_name == "builtin.set_parameter") { + VLOG(6) << "Handle for builtin.set_parameter:"; + auto param_name = op->attributes() + .at("parameter_name") + .dyn_cast() + .data(); - VLOG(5) << "process builtin combine"; - std::string name; - if (name_map->find(out_value) != name_map->end()) { - name = name_map->at(out_value); - } else { - name = "inner_var_" + std::to_string(count++); - name_map->emplace(out_value, name); - } + auto in_ptr = op->operand(0); + // change opreand name to param_name - auto var = scope->Var(name); - auto tensor_array = var->GetMutable(); + auto orig_name = name_map->at(in_ptr); + (*name_map)[in_ptr] = param_name; + scope->Rename(orig_name, param_name); + } - for (size_t i = 0; i < input_num; ++i) { - auto ptr = (*it)->operand(i); + if (op_name == "builtin.get_parameter") { + VLOG(6) << "Handle for builtin.get_parameter:"; + auto param_name = op->attributes() + .at("parameter_name") + .dyn_cast() + .data(); + auto out_ptr = op->result(0); + name_map->emplace(out_ptr, param_name); + } +} + +void BuildScope(ir::Block* block, + paddle::framework::Scope* scope, + paddle::framework::Scope* local_scope, + std::unordered_map* name_map) { + // NOTE(zhiqiu): if use local_scope (local_scope != nullptr), the persistable + // is created in scope , and other is created in local_scope. + auto inner_local_scope = local_scope != nullptr ? local_scope : scope; + VLOG(6) << "Build: scope [" << scope << "] inner_local_scope [" + << inner_local_scope << "]"; - PADDLE_ENFORCE_EQ(name_map->count(ptr), - true, - phi::errors::PreconditionNotMet( - "can not found input of combine op")); + // int count = name_map->size(); + int count = name_map->size(); + for (auto it = block->begin(); it != block->end(); ++it) { + ir::Operation* op = *it; - tensor_array->emplace_back( - &(scope->Var(name_map->at(ptr))->Get())); - } + auto attr_map = op->attributes(); + std::string op_name = op->name(); + if (attr_map.count("op_name")) { + op_name = attr_map.at("op_name").dyn_cast().data(); + } + if (op_name == "pd.feed" || op_name == "pd.fetch" || + op_name == "builtin.combine" || op_name == "builtin.set_parameter" || + op_name == "builtin.get_parameter") { + VLOG(6) << "HandleForSpecialOp: " << op_name; + HandleForSpecialOp(op, scope, inner_local_scope, name_map, count); continue; } - // TODO(zhangbo): support builtin.slice - + size_t input_num = op->num_operands(); if (input_num > 0) { for (size_t i = 0; i < input_num; ++i) { - auto ptr = (*it)->operand(i); + auto ptr = op->operand(i); if (ptr) { PADDLE_ENFORCE_NE( name_map->find(ptr), @@ -163,11 +209,10 @@ void BuildScope(ir::Block* block, } } - int out_num = (*it)->num_results(); - + int out_num = op->num_results(); if (out_num > 0) { for (int i = 0; i < out_num; ++i) { - ir::Value ptr = (*it)->result(i); + ir::Value ptr = op->result(i); std::string name; if (name_map->find(ptr) != name_map->end()) { name = name_map->at(ptr); @@ -175,7 +220,7 @@ void BuildScope(ir::Block* block, name = "inner_var_" + std::to_string(count++); name_map->emplace(ptr, name); } - auto var = scope->Var(name); + auto var = CreateVar(ptr, name, scope, inner_local_scope); // Only support DenseTensor or Vector if (!ptr.type()) { var->GetMutable(); @@ -195,7 +240,7 @@ void BuildScope(ir::Block* block, "Element of VectorType output only support " "DenseTensorType")); std::string name_i = "inner_var_" + std::to_string(count++); - auto var_i = scope->Var(name_i); + auto var_i = CreateVar(ptr, name_i, scope, inner_local_scope); tensor_array->emplace_back(var_i->GetMutable()); } } else { diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h index 7ecf94fe2fe..863f316a833 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h @@ -25,6 +25,7 @@ #include "paddle/ir/core/utils.h" #include "paddle/phi/core/meta_tensor.h" +#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/variable.h" @@ -39,9 +40,20 @@ #include "glog/logging.h" namespace ir { +paddle::framework::Variable* CreateVar(ir::Value value, + std::string name, + paddle::framework::Scope* scope, + paddle::framework::Scope* local_scope); + +void HandleForSpecialOp(ir::Operation* op, + paddle::framework::Scope* scope, + paddle::framework::Scope* local_scope, + std::unordered_map* name_map, + int& count); // NOLINT void BuildScope(ir::Block* block, paddle::framework::Scope* scope, + paddle::framework::Scope* local_scope, std::unordered_map* name_map); template & name_map, paddle::framework::Scope* scope, + paddle::framework::Scope* local_scope, const paddle::dialect::OpYamlInfoParser& op_yaml_info, Context* ctx, std::map>* input_map = nullptr, std::map>* output_map = nullptr) { + paddle::framework::Scope* inner_scope = + local_scope != nullptr ? local_scope : scope; + VLOG(6) << "BuildPhiContext in scope[" << scope << "] inner_scope[" + << inner_scope << "]"; // inputs include input and mutable attributes auto attr_map = op->attributes(); @@ -80,11 +97,10 @@ void BuildPhiContext( auto in_var_name = name_map.at(ptr); VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name; - PADDLE_ENFORCE_NOT_NULL(scope->FindLocalVar(in_var_name), + PADDLE_ENFORCE_NOT_NULL(inner_scope->FindLocalVar(in_var_name), phi::errors::PreconditionNotMet( "can not find var[%s] in scope", in_var_name)); - - auto var = scope->Var(in_var_name); + auto var = inner_scope->FindVar(in_var_name); if (var->IsType()) { const phi::TensorBase* tensor_in = &(var->Get()); ctx->EmplaceBackInput(InType(tensor_in)); @@ -123,12 +139,12 @@ void BuildPhiContext( auto& tensor_attr_type = op_yaml_info.TensorAttrTypeName(t); VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t" << in_var_name; if (tensor_attr_type == "paddle::dialect::IntArrayAttribute") { - phi::Attribute r1 = - phi::TensorRef(&(scope->Var(in_var_name)->Get())); + phi::Attribute r1 = phi::TensorRef( + &(inner_scope->FindVar(in_var_name)->Get())); ctx->EmplaceBackAttr(r1); } else if (tensor_attr_type == "paddle::dialect::ScalarAttribute") { - phi::Attribute r1 = - phi::TensorRef(&(scope->Var(in_var_name)->Get())); + phi::Attribute r1 = phi::TensorRef( + &(inner_scope->FindVar(in_var_name)->Get())); ctx->EmplaceBackAttr(r1); } else { @@ -239,7 +255,7 @@ void BuildPhiContext( (op->attributes().at("op_name").dyn_cast().data() == "pd.fetch")) { // process fetch op - auto fetch_var = scope->Var("fetch"); + auto fetch_var = inner_scope->FindVar("fetch"); auto* fetch_list = fetch_var->GetMutable(); int index = op->attributes().at("col").dyn_cast().data(); @@ -251,7 +267,7 @@ void BuildPhiContext( auto name = name_map.at(out_ptr); if (out_ptr.type()) { ctx->EmplaceBackOutput(OutType(const_cast( - &(scope->Var(name)->Get())))); + &(inner_scope->FindVar(name)->Get())))); } else { phi::DenseTensor* ptr = nullptr; OutType out_ptr(ptr); diff --git a/test/cpp/new_executor/standalone_executor_new_ir_test.cc b/test/cpp/new_executor/standalone_executor_new_ir_test.cc index c08c590b773..8501980b251 100644 --- a/test/cpp/new_executor/standalone_executor_new_ir_test.cc +++ b/test/cpp/new_executor/standalone_executor_new_ir_test.cc @@ -73,7 +73,11 @@ TEST(StandaloneExecutor, run) { test_core.Run({}); - auto out_tensor = scope.Var("inner_var_2")->Get(); + auto out_tensor = test_core.local_scope() == nullptr + ? scope.FindVar("inner_var_2")->Get() + : test_core.local_scope() + ->FindVar("inner_var_2") + ->Get(); bool res0 = simple_cmp(out_tensor.data()[0], 2.0); bool res1 = simple_cmp(out_tensor.data()[1], 2.0); @@ -142,7 +146,11 @@ TEST(StandaloneExecutor, run_2) { test_core.Run({}); - auto out_tensor = scope.Var("inner_var_10")->Get(); + auto out_tensor = test_core.local_scope() == nullptr + ? scope.FindVar("inner_var_10")->Get() + : test_core.local_scope() + ->FindVar("inner_var_10") + ->Get(); bool res0 = simple_cmp(out_tensor.data()[0], 1.80721); bool res1 = simple_cmp(out_tensor.data()[1], 1.70047); @@ -213,7 +221,11 @@ TEST(StandaloneExecutor, data_transfer) { test_core.Run({}); - auto out_tensor = scope.Var("inner_var_9")->Get(); + auto out_tensor = test_core.local_scope() == nullptr + ? scope.FindVar("inner_var_9")->Get() + : test_core.local_scope() + ->FindVar("inner_var_9") + ->Get(); auto& pool = phi::DeviceContextPool::Instance(); phi::DenseTensor out; -- GitLab