未验证 提交 85831c32 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] New IR access InterpreterCore:add local scope logic (#55112)

* add local scope

* refine code

* refien code

* refine code

* support local scope for BuildFuncList

* fix bug

* add log

* fix bug

* polish code

* fix bug
上级 902de74c
...@@ -940,6 +940,7 @@ void BuildOpFuncList( ...@@ -940,6 +940,7 @@ void BuildOpFuncList(
::ir::Block* block, ::ir::Block* block,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
framework::Scope* scope, framework::Scope* scope,
framework::Scope* local_scope,
const std::unordered_map<::ir::Value, std::string>& value_2_name_map, const std::unordered_map<::ir::Value, std::string>& value_2_name_map,
const ExecutionConfig& execution_config) { const ExecutionConfig& execution_config) {
vec_func_list->reserve(block->size()); vec_func_list->reserve(block->size());
...@@ -979,6 +980,7 @@ void BuildOpFuncList( ...@@ -979,6 +980,7 @@ void BuildOpFuncList(
false>((*it), false>((*it),
value_2_name_map, value_2_name_map,
scope, scope,
local_scope,
op_yaml_info_parser, op_yaml_info_parser,
&(op_func_node.infer_meta_context_)); &(op_func_node.infer_meta_context_));
...@@ -1004,6 +1006,7 @@ void BuildOpFuncList( ...@@ -1004,6 +1006,7 @@ void BuildOpFuncList(
true>((*it), true>((*it),
value_2_name_map, value_2_name_map,
scope, scope,
local_scope,
op_yaml_info_parser, op_yaml_info_parser,
&(op_func_node.kernel_context_), &(op_func_node.kernel_context_),
&(op_func_node.input_index), &(op_func_node.input_index),
......
...@@ -98,6 +98,7 @@ void BuildOpFuncList( ...@@ -98,6 +98,7 @@ void BuildOpFuncList(
::ir::Block* block, ::ir::Block* block,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
framework::Scope* scope, framework::Scope* scope,
framework::Scope* local_scope,
const std::unordered_map<::ir::Value, std::string>& value_2_name_map, const std::unordered_map<::ir::Value, std::string>& value_2_name_map,
const ExecutionConfig& execution_config); const ExecutionConfig& execution_config);
......
...@@ -86,6 +86,8 @@ class InterpreterBaseImpl { ...@@ -86,6 +86,8 @@ class InterpreterBaseImpl {
virtual void reset_scope(Scope* new_scope) = 0; virtual void reset_scope(Scope* new_scope) = 0;
virtual const Scope* local_scope() const = 0;
virtual const platform::Place& GetPlace() const = 0; virtual const platform::Place& GetPlace() const = 0;
virtual void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) = 0; virtual void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) = 0;
......
...@@ -101,6 +101,10 @@ void InterpreterCore::reset_scope(Scope* new_scope) { ...@@ -101,6 +101,10 @@ void InterpreterCore::reset_scope(Scope* new_scope) {
impl_->reset_scope(new_scope); impl_->reset_scope(new_scope);
} }
const Scope* InterpreterCore::local_scope() const {
return impl_->local_scope();
}
const platform::Place& InterpreterCore::GetPlace() const { const platform::Place& InterpreterCore::GetPlace() const {
return impl_->GetPlace(); return impl_->GetPlace();
} }
......
...@@ -65,6 +65,8 @@ class InterpreterCore { ...@@ -65,6 +65,8 @@ class InterpreterCore {
void reset_scope(Scope* new_scope); void reset_scope(Scope* new_scope);
const Scope* local_scope() const;
const platform::Place& GetPlace() const; const platform::Place& GetPlace() const;
void SetOutputHooks(const std::vector<HookFunc>& hookfuncs); void SetOutputHooks(const std::vector<HookFunc>& hookfuncs);
......
...@@ -49,12 +49,14 @@ NewIRInterpreter::NewIRInterpreter(const platform::Place& place, ...@@ -49,12 +49,14 @@ NewIRInterpreter::NewIRInterpreter(const platform::Place& place,
stream_analyzer_(place), stream_analyzer_(place),
execution_config_(execution_config), execution_config_(execution_config),
var_scope_(scope), var_scope_(scope),
scope_(scope),
ir_program_(std::move(ir_prog)) { ir_program_(std::move(ir_prog)) {
VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_; VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_;
static_build_ = FLAGS_new_executor_static_build && static_build_ = FLAGS_new_executor_static_build &&
!FLAGS_new_executor_use_cuda_graph && !FLAGS_new_executor_use_cuda_graph &&
!execution_config.used_for_control_flow_op; !execution_config.used_for_control_flow_op;
// &&interpreter::BlockCanBeStaticBuilt(block); // &&interpreter::BlockCanBeStaticBuilt(block);
static_build_ = true;
exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught); exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion); completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);
...@@ -62,21 +64,19 @@ NewIRInterpreter::NewIRInterpreter(const platform::Place& place, ...@@ -62,21 +64,19 @@ NewIRInterpreter::NewIRInterpreter(const platform::Place& place,
if (!FLAGS_new_executor_use_local_scope) { if (!FLAGS_new_executor_use_local_scope) {
execution_config_.create_local_scope = false; execution_config_.create_local_scope = false;
} }
execution_config_.AnalyzeThreadPoolConfig(place,
ir_program_->block()->size());
execution_config_.Log(/*log_level=*/8);
if (execution_config_.create_local_scope) { if (execution_config_.create_local_scope) {
auto local_scope = &var_scope_.GetMutableScope()->NewScope(); auto local_scope = &scope_->NewScope();
local_scope_ = local_scope; local_scope_ = local_scope;
VLOG(6) << "new ir interpretercore scope: " << scope_ << "\t"
<< "; local scope: " << local_scope_;
} }
// TODO(zhangbo): delete var_scope
// force use outer scope for now
local_scope_ = scope;
static_build_ = true;
var_scope_.SetLocalScope(local_scope_); var_scope_.SetLocalScope(local_scope_);
execution_config_.AnalyzeThreadPoolConfig(place,
ir_program_->block()->size());
execution_config_.Log(/*log_level=*/8);
instruction_scheduling_priority_less = [this](size_t lhs, size_t rhs) { instruction_scheduling_priority_less = [this](size_t lhs, size_t rhs) {
SchedulingPriority lhs_scheduling_priority = SchedulingPriority lhs_scheduling_priority =
vec_instruction_[lhs].GetSchedulingPriority(); vec_instruction_[lhs].GetSchedulingPriority();
...@@ -185,12 +185,13 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names, ...@@ -185,12 +185,13 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
if (!is_build_) { if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is Running."; LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
::ir::BuildScope( ::ir::BuildScope(
ir_program_->block(), local_scope_, &value_2_var_name_map_); ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
interpreter::BuildOpFuncList(place_, interpreter::BuildOpFuncList(place_,
ir_program_->block(), ir_program_->block(),
&op_func_nodes, &op_func_nodes,
scope_,
local_scope_, local_scope_,
value_2_var_name_map_, value_2_var_name_map_,
execution_config_); execution_config_);
...@@ -212,8 +213,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names, ...@@ -212,8 +213,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
} }
// return Fetch Tensors // return Fetch Tensors
Scope* inner_scope = Scope* inner_scope = HasLocalScope() ? local_scope_ : scope_;
HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName); auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var && need_fetch) { if (fetch_var && need_fetch) {
auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>()); auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
...@@ -287,6 +287,8 @@ void NewIRInterpreter::reset_scope(Scope* new_scope) { ...@@ -287,6 +287,8 @@ void NewIRInterpreter::reset_scope(Scope* new_scope) {
} }
} }
const Scope* NewIRInterpreter::local_scope() const { return local_scope_; }
void NewIRInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) { void NewIRInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) {
async_work_queue_ = reinterpret_cast<NewIRInterpreter*>(src)->GetWorkQueue(); async_work_queue_ = reinterpret_cast<NewIRInterpreter*>(src)->GetWorkQueue();
VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << src VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << src
...@@ -321,8 +323,7 @@ std::shared_ptr<interpreter::AsyncWorkQueue> NewIRInterpreter::GetWorkQueue() { ...@@ -321,8 +323,7 @@ std::shared_ptr<interpreter::AsyncWorkQueue> NewIRInterpreter::GetWorkQueue() {
} }
void NewIRInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) { void NewIRInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) {
Scope* inner_scope = Scope* inner_scope = HasLocalScope() ? local_scope_ : scope_;
HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
VariableValueMap ins_map; VariableValueMap ins_map;
for (auto& var_name_item : instr_node->Inputs()) { for (auto& var_name_item : instr_node->Inputs()) {
std::vector<Variable*> input_vars; std::vector<Variable*> input_vars;
...@@ -349,8 +350,7 @@ void NewIRInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) { ...@@ -349,8 +350,7 @@ void NewIRInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) {
if (instr_node->OpBase()->Type() == "cinn_launch" || if (instr_node->OpBase()->Type() == "cinn_launch" ||
instr_node->OpBase()->Type() == "cinn_instruction_run") { // OP use scope instr_node->OpBase()->Type() == "cinn_instruction_run") { // OP use scope
// in kernel // in kernel
Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope() Scope* local_scope = HasLocalScope() ? local_scope_ : scope_;
: var_scope_.GetMutableScope();
instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope); instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope);
} else { } else {
instr_node->ResetContext(ins_map, outs_map); instr_node->ResetContext(ins_map, outs_map);
...@@ -380,8 +380,7 @@ void NewIRInterpreter::BuildInplace() { ...@@ -380,8 +380,7 @@ void NewIRInterpreter::BuildInplace() {
} }
} }
Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope() Scope* local_scope = HasLocalScope() ? local_scope_ : scope_;
: var_scope_.GetMutableScope();
std::vector<std::vector<size_t>> input_var2op(var_scope_.VarSize()); std::vector<std::vector<size_t>> input_var2op(var_scope_.VarSize());
for (Instruction& instr : vec_instruction_) { for (Instruction& instr : vec_instruction_) {
for (auto& item : instr.Inputs()) { for (auto& item : instr.Inputs()) {
...@@ -799,8 +798,7 @@ void NewIRInterpreter::BuildSkipShareLoDInfo() { ...@@ -799,8 +798,7 @@ void NewIRInterpreter::BuildSkipShareLoDInfo() {
void NewIRInterpreter::RunOperator(const Instruction& instr_node) { void NewIRInterpreter::RunOperator(const Instruction& instr_node) {
auto* op = instr_node.OpBase(); auto* op = instr_node.OpBase();
auto place = instr_node.DeviceContext().GetPlace(); auto place = instr_node.DeviceContext().GetPlace();
Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope() Scope* local_scope = HasLocalScope() ? local_scope_ : scope_;
: var_scope_.GetMutableScope();
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope); VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op); auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
...@@ -1047,7 +1045,7 @@ void NewIRInterpreter::ExecuteInstructionList( ...@@ -1047,7 +1045,7 @@ void NewIRInterpreter::ExecuteInstructionList(
if (cancel) { if (cancel) {
break; break;
} }
VLOG(0) << "deps:\n" << GetDepsString(); VLOG(6) << "deps:\n" << GetDepsString();
times++; times++;
} }
return times; return times;
......
...@@ -60,6 +60,8 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -60,6 +60,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void reset_scope(Scope* new_scope) override; void reset_scope(Scope* new_scope) override;
const Scope* local_scope() const override;
const platform::Place& GetPlace() const override { return place_; } const platform::Place& GetPlace() const override { return place_; }
void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) override { void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) override {
...@@ -143,6 +145,7 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -143,6 +145,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {
ExecutionConfig execution_config_; ExecutionConfig execution_config_;
VariableScope var_scope_; VariableScope var_scope_;
Scope* scope_{nullptr};
Scope* local_scope_{nullptr}; // not owned Scope* local_scope_{nullptr}; // not owned
EventsWaiter main_thread_blocker_; EventsWaiter main_thread_blocker_;
......
...@@ -275,6 +275,7 @@ void ProgramInterpreter::reset_scope(Scope* new_scope) { ...@@ -275,6 +275,7 @@ void ProgramInterpreter::reset_scope(Scope* new_scope) {
} }
} }
const Scope* ProgramInterpreter::local_scope() const { return local_scope_; }
void ProgramInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) { void ProgramInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) {
async_work_queue_ = async_work_queue_ =
reinterpret_cast<ProgramInterpreter*>(src)->GetWorkQueue(); reinterpret_cast<ProgramInterpreter*>(src)->GetWorkQueue();
......
...@@ -62,6 +62,8 @@ class ProgramInterpreter : public InterpreterBaseImpl { ...@@ -62,6 +62,8 @@ class ProgramInterpreter : public InterpreterBaseImpl {
void reset_scope(Scope* new_scope) override; void reset_scope(Scope* new_scope) override;
const Scope* local_scope() const override;
const platform::Place& GetPlace() const override { return place_; } const platform::Place& GetPlace() const override { return place_; }
void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) override { void SetOutputHooks(const std::vector<HookFunc>& hookfuncs) override {
......
...@@ -56,7 +56,7 @@ class PhiKernelAdaptor { ...@@ -56,7 +56,7 @@ class PhiKernelAdaptor {
void run_kernel_prog(ir::Program* program) { void run_kernel_prog(ir::Program* program) {
auto block = program->block(); auto block = program->block();
std::unordered_map<ir::Value, std::string> name_map; std::unordered_map<ir::Value, std::string> name_map;
BuildScope(block, scope_, &name_map); BuildScope(block, scope_, nullptr, &name_map);
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>(); ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
...@@ -87,7 +87,7 @@ class PhiKernelAdaptor { ...@@ -87,7 +87,7 @@ class PhiKernelAdaptor {
phi::MetaTensor, phi::MetaTensor,
phi::MetaTensor, phi::MetaTensor,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>, paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
false>((*it), name_map, scope_, op_yaml_info_parser, &ctx); false>((*it), name_map, scope_, nullptr, op_yaml_info_parser, &ctx);
infer_meta_impl->infer_meta_(&ctx); infer_meta_impl->infer_meta_(&ctx);
...@@ -107,7 +107,7 @@ class PhiKernelAdaptor { ...@@ -107,7 +107,7 @@ class PhiKernelAdaptor {
phi::TensorBase*, phi::TensorBase*,
paddle::small_vector<const phi::TensorBase*>, paddle::small_vector<const phi::TensorBase*>,
true>( true>(
(*it), name_map, scope_, op_yaml_info_parser, &kernel_ctx); (*it), name_map, scope_, nullptr, op_yaml_info_parser, &kernel_ctx);
kernel_fn(&kernel_ctx); kernel_fn(&kernel_ctx);
auto out_value = (*it)->result(0); auto out_value = (*it)->result(0);
......
...@@ -43,114 +43,160 @@ ...@@ -43,114 +43,160 @@
namespace ir { namespace ir {
void BuildScope(ir::Block* block, paddle::framework::Variable* CreateVar(ir::Value value,
paddle::framework::Scope* scope, std::string name,
std::unordered_map<ir::Value, std::string>* name_map) { paddle::framework::Scope* scope,
std::unordered_map<ir::Value, int> map_test; paddle::framework::Scope* local_scope) {
Operation* def_op = value.GetDefiningOp();
int count = name_map->size(); bool is_persisable = false;
for (auto it = block->begin(); it != block->end(); ++it) { if (def_op->attributes().count("is_persisable")) {
size_t input_num = (*it)->num_operands(); is_persisable = def_op->attributes()
auto attr_map = (*it)->attributes(); .at("is_persisable")
std::string op_name = (*it)->name(); .dyn_cast<ir::BoolAttribute>()
if (attr_map.count("op_name")) { .data();
op_name = attr_map.at("op_name").dyn_cast<ir::StrAttribute>().data(); }
} if (is_persisable) {
if (op_name == "pd.fetch") { const paddle::framework::Scope* ancestor_scope = scope;
// fetch is a very special op, with no output while (ancestor_scope->parent()) {
for (size_t i = 0; i < input_num; ++i) { ancestor_scope = ancestor_scope->parent();
auto var = scope->Var("fetch");
auto fetch_list = var->GetMutable<paddle::framework::FetchList>();
int index =
(*it)->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
fetch_list->resize(index + 1);
}
continue;
} }
VLOG(6) << "Create var: " << name << " in scope " << ancestor_scope;
return const_cast<paddle::framework::Scope*>(ancestor_scope)->Var(name);
} else {
VLOG(6) << "Create var: " << name << " in scope " << local_scope;
return local_scope->Var(name);
}
}
if (op_name == "builtin.set_parameter") { void HandleForSpecialOp(ir::Operation* op,
auto param_name = (*it) paddle::framework::Scope* scope,
->attributes() paddle::framework::Scope* local_scope,
.at("parameter_name") std::unordered_map<ir::Value, std::string>* name_map,
.dyn_cast<ir::StrAttribute>() int& count) { // NOLINT
.data(); std::string op_name = op->name();
if (op->attributes().count("op_name")) {
auto in_ptr = (*it)->operand(0); op_name =
// change opreand name to param_name op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
}
auto orig_name = name_map->at(in_ptr); size_t input_num = op->num_operands();
(*name_map)[in_ptr] = param_name;
scope->Rename(orig_name, param_name); if (op_name == "pd.fetch") {
continue; // fetch is a very special op, with no output
VLOG(6) << "Handle for pd.fetch:";
for (size_t i = 0; i < input_num; ++i) {
auto var = scope->Var("fetch");
VLOG(6) << "Create var: fetch in scope " << scope;
auto fetch_list = var->GetMutable<paddle::framework::FetchList>();
int index =
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
fetch_list->resize(index + 1);
} }
}
if (op_name == "builtin.get_parameter") { if (op_name == "pd.feed") {
auto param_name = (*it) VLOG(6) << "Handle for pd.feed:";
->attributes() auto ptr = op->result(0);
.at("parameter_name") std::string name = "inner_var_" + std::to_string(count++);
.dyn_cast<ir::StrAttribute>() name_map->emplace(ptr, name);
.data(); auto var = CreateVar(ptr, name, scope, local_scope);
// TODO(phlrain): need to update here, support StringTensor
auto out_ptr = (*it)->result(0); auto out_tensor = var->GetMutable<phi::DenseTensor>();
auto feed_var = scope->Var("feed");
VLOG(6) << "Create var: feed in scope " << scope;
int index =
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
auto feed_list = feed_var->Get<paddle::framework::FeedList>();
auto& in_tensor = (PADDLE_GET(phi::DenseTensor, feed_list.at(index)));
out_tensor->ShareDataWith(in_tensor);
}
name_map->emplace(out_ptr, param_name); if (op_name == "builtin.combine") {
continue; VLOG(6) << "Handle for builtin.combine:";
auto out_value = op->result(0);
std::string name;
if (name_map->find(out_value) != name_map->end()) {
name = name_map->at(out_value);
} else {
name = "inner_var_" + std::to_string(count++);
name_map->emplace(out_value, name);
} }
if (op_name == "pd.feed") { auto var = CreateVar(out_value, name, scope, local_scope);
auto ptr = (*it)->result(0); auto tensor_array = var->GetMutable<paddle::framework::TensorRefArray>();
std::string name = "inner_var_" + std::to_string(count++);
name_map->emplace(ptr, name);
auto var = scope->Var(name);
// TODO(phlrain): need to update here, support StringTensor
auto out_tensor = var->GetMutable<phi::DenseTensor>();
auto feed_var = scope->Var("feed");
int index =
(*it)->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
auto feed_list = feed_var->Get<paddle::framework::FeedList>();
auto& in_tensor = (PADDLE_GET(phi::DenseTensor, feed_list.at(index)));
out_tensor->ShareDataWith(in_tensor); for (size_t i = 0; i < input_num; ++i) {
auto ptr = op->operand(i);
continue; PADDLE_ENFORCE_EQ(
name_map->count(ptr),
true,
phi::errors::PreconditionNotMet("can not found input of combine op"));
tensor_array->emplace_back(
&(CreateVar(ptr, name_map->at(ptr), scope, local_scope)
->Get<phi::DenseTensor>()));
} }
}
if (op_name == "builtin.combine") { if (op_name == "builtin.set_parameter") {
auto out_value = (*it)->result(0); VLOG(6) << "Handle for builtin.set_parameter:";
auto param_name = op->attributes()
.at("parameter_name")
.dyn_cast<ir::StrAttribute>()
.data();
VLOG(5) << "process builtin combine"; auto in_ptr = op->operand(0);
std::string name; // change opreand name to param_name
if (name_map->find(out_value) != name_map->end()) {
name = name_map->at(out_value);
} else {
name = "inner_var_" + std::to_string(count++);
name_map->emplace(out_value, name);
}
auto var = scope->Var(name); auto orig_name = name_map->at(in_ptr);
auto tensor_array = var->GetMutable<paddle::framework::TensorRefArray>(); (*name_map)[in_ptr] = param_name;
scope->Rename(orig_name, param_name);
}
for (size_t i = 0; i < input_num; ++i) { if (op_name == "builtin.get_parameter") {
auto ptr = (*it)->operand(i); VLOG(6) << "Handle for builtin.get_parameter:";
auto param_name = op->attributes()
.at("parameter_name")
.dyn_cast<ir::StrAttribute>()
.data();
auto out_ptr = op->result(0);
name_map->emplace(out_ptr, param_name);
}
}
void BuildScope(ir::Block* block,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map) {
// NOTE(zhiqiu): if use local_scope (local_scope != nullptr), the persistable
// is created in scope , and other is created in local_scope.
auto inner_local_scope = local_scope != nullptr ? local_scope : scope;
VLOG(6) << "Build: scope [" << scope << "] inner_local_scope ["
<< inner_local_scope << "]";
PADDLE_ENFORCE_EQ(name_map->count(ptr), // int count = name_map->size();
true, int count = name_map->size();
phi::errors::PreconditionNotMet( for (auto it = block->begin(); it != block->end(); ++it) {
"can not found input of combine op")); ir::Operation* op = *it;
tensor_array->emplace_back( auto attr_map = op->attributes();
&(scope->Var(name_map->at(ptr))->Get<phi::DenseTensor>())); std::string op_name = op->name();
} if (attr_map.count("op_name")) {
op_name = attr_map.at("op_name").dyn_cast<ir::StrAttribute>().data();
}
if (op_name == "pd.feed" || op_name == "pd.fetch" ||
op_name == "builtin.combine" || op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter") {
VLOG(6) << "HandleForSpecialOp: " << op_name;
HandleForSpecialOp(op, scope, inner_local_scope, name_map, count);
continue; continue;
} }
// TODO(zhangbo): support builtin.slice size_t input_num = op->num_operands();
if (input_num > 0) { if (input_num > 0) {
for (size_t i = 0; i < input_num; ++i) { for (size_t i = 0; i < input_num; ++i) {
auto ptr = (*it)->operand(i); auto ptr = op->operand(i);
if (ptr) { if (ptr) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
name_map->find(ptr), name_map->find(ptr),
...@@ -163,11 +209,10 @@ void BuildScope(ir::Block* block, ...@@ -163,11 +209,10 @@ void BuildScope(ir::Block* block,
} }
} }
int out_num = (*it)->num_results(); int out_num = op->num_results();
if (out_num > 0) { if (out_num > 0) {
for (int i = 0; i < out_num; ++i) { for (int i = 0; i < out_num; ++i) {
ir::Value ptr = (*it)->result(i); ir::Value ptr = op->result(i);
std::string name; std::string name;
if (name_map->find(ptr) != name_map->end()) { if (name_map->find(ptr) != name_map->end()) {
name = name_map->at(ptr); name = name_map->at(ptr);
...@@ -175,7 +220,7 @@ void BuildScope(ir::Block* block, ...@@ -175,7 +220,7 @@ void BuildScope(ir::Block* block,
name = "inner_var_" + std::to_string(count++); name = "inner_var_" + std::to_string(count++);
name_map->emplace(ptr, name); name_map->emplace(ptr, name);
} }
auto var = scope->Var(name); auto var = CreateVar(ptr, name, scope, inner_local_scope);
// Only support DenseTensor or Vector<DenseTensor> // Only support DenseTensor or Vector<DenseTensor>
if (!ptr.type()) { if (!ptr.type()) {
var->GetMutable<phi::DenseTensor>(); var->GetMutable<phi::DenseTensor>();
...@@ -195,7 +240,7 @@ void BuildScope(ir::Block* block, ...@@ -195,7 +240,7 @@ void BuildScope(ir::Block* block,
"Element of VectorType output only support " "Element of VectorType output only support "
"DenseTensorType")); "DenseTensorType"));
std::string name_i = "inner_var_" + std::to_string(count++); std::string name_i = "inner_var_" + std::to_string(count++);
auto var_i = scope->Var(name_i); auto var_i = CreateVar(ptr, name_i, scope, inner_local_scope);
tensor_array->emplace_back(var_i->GetMutable<phi::DenseTensor>()); tensor_array->emplace_back(var_i->GetMutable<phi::DenseTensor>());
} }
} else { } else {
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "paddle/ir/core/utils.h" #include "paddle/ir/core/utils.h"
#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/meta_tensor.h"
#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
...@@ -39,9 +40,20 @@ ...@@ -39,9 +40,20 @@
#include "glog/logging.h" #include "glog/logging.h"
namespace ir { namespace ir {
paddle::framework::Variable* CreateVar(ir::Value value,
std::string name,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope);
void HandleForSpecialOp(ir::Operation* op,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count); // NOLINT
void BuildScope(ir::Block* block, void BuildScope(ir::Block* block,
paddle::framework::Scope* scope, paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map); std::unordered_map<ir::Value, std::string>* name_map);
template <typename Context, template <typename Context,
...@@ -53,10 +65,15 @@ void BuildPhiContext( ...@@ -53,10 +65,15 @@ void BuildPhiContext(
ir::Operation* op, ir::Operation* op,
const std::unordered_map<ir::Value, std::string>& name_map, const std::unordered_map<ir::Value, std::string>& name_map,
paddle::framework::Scope* scope, paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
const paddle::dialect::OpYamlInfoParser& op_yaml_info, const paddle::dialect::OpYamlInfoParser& op_yaml_info,
Context* ctx, Context* ctx,
std::map<std::string, std::vector<int>>* input_map = nullptr, std::map<std::string, std::vector<int>>* input_map = nullptr,
std::map<std::string, std::vector<int>>* output_map = nullptr) { std::map<std::string, std::vector<int>>* output_map = nullptr) {
paddle::framework::Scope* inner_scope =
local_scope != nullptr ? local_scope : scope;
VLOG(6) << "BuildPhiContext in scope[" << scope << "] inner_scope["
<< inner_scope << "]";
// inputs include input and mutable attributes // inputs include input and mutable attributes
auto attr_map = op->attributes(); auto attr_map = op->attributes();
...@@ -80,11 +97,10 @@ void BuildPhiContext( ...@@ -80,11 +97,10 @@ void BuildPhiContext(
auto in_var_name = name_map.at(ptr); auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name; VLOG(6) << "ctx->EmplaceBackInput: " << t << "\t" << in_var_name;
PADDLE_ENFORCE_NOT_NULL(scope->FindLocalVar(in_var_name), PADDLE_ENFORCE_NOT_NULL(inner_scope->FindLocalVar(in_var_name),
phi::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"can not find var[%s] in scope", in_var_name)); "can not find var[%s] in scope", in_var_name));
auto var = inner_scope->FindVar(in_var_name);
auto var = scope->Var(in_var_name);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
const phi::TensorBase* tensor_in = &(var->Get<phi::DenseTensor>()); const phi::TensorBase* tensor_in = &(var->Get<phi::DenseTensor>());
ctx->EmplaceBackInput(InType(tensor_in)); ctx->EmplaceBackInput(InType(tensor_in));
...@@ -123,12 +139,12 @@ void BuildPhiContext( ...@@ -123,12 +139,12 @@ void BuildPhiContext(
auto& tensor_attr_type = op_yaml_info.TensorAttrTypeName(t); auto& tensor_attr_type = op_yaml_info.TensorAttrTypeName(t);
VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t" << in_var_name; VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t" << in_var_name;
if (tensor_attr_type == "paddle::dialect::IntArrayAttribute") { if (tensor_attr_type == "paddle::dialect::IntArrayAttribute") {
phi::Attribute r1 = phi::Attribute r1 = phi::TensorRef(
phi::TensorRef(&(scope->Var(in_var_name)->Get<phi::DenseTensor>())); &(inner_scope->FindVar(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1); ctx->EmplaceBackAttr(r1);
} else if (tensor_attr_type == "paddle::dialect::ScalarAttribute") { } else if (tensor_attr_type == "paddle::dialect::ScalarAttribute") {
phi::Attribute r1 = phi::Attribute r1 = phi::TensorRef(
phi::TensorRef(&(scope->Var(in_var_name)->Get<phi::DenseTensor>())); &(inner_scope->FindVar(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1); ctx->EmplaceBackAttr(r1);
} else { } else {
...@@ -239,7 +255,7 @@ void BuildPhiContext( ...@@ -239,7 +255,7 @@ void BuildPhiContext(
(op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data() == (op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data() ==
"pd.fetch")) { "pd.fetch")) {
// process fetch op // process fetch op
auto fetch_var = scope->Var("fetch"); auto fetch_var = inner_scope->FindVar("fetch");
auto* fetch_list = fetch_var->GetMutable<paddle::framework::FetchList>(); auto* fetch_list = fetch_var->GetMutable<paddle::framework::FetchList>();
int index = int index =
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data(); op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
...@@ -251,7 +267,7 @@ void BuildPhiContext( ...@@ -251,7 +267,7 @@ void BuildPhiContext(
auto name = name_map.at(out_ptr); auto name = name_map.at(out_ptr);
if (out_ptr.type()) { if (out_ptr.type()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>( ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>(
&(scope->Var(name)->Get<phi::DenseTensor>())))); &(inner_scope->FindVar(name)->Get<phi::DenseTensor>()))));
} else { } else {
phi::DenseTensor* ptr = nullptr; phi::DenseTensor* ptr = nullptr;
OutType out_ptr(ptr); OutType out_ptr(ptr);
......
...@@ -73,7 +73,11 @@ TEST(StandaloneExecutor, run) { ...@@ -73,7 +73,11 @@ TEST(StandaloneExecutor, run) {
test_core.Run({}); test_core.Run({});
auto out_tensor = scope.Var("inner_var_2")->Get<phi::DenseTensor>(); auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_2")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_2")
->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 2.0); bool res0 = simple_cmp(out_tensor.data<float>()[0], 2.0);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 2.0); bool res1 = simple_cmp(out_tensor.data<float>()[1], 2.0);
...@@ -142,7 +146,11 @@ TEST(StandaloneExecutor, run_2) { ...@@ -142,7 +146,11 @@ TEST(StandaloneExecutor, run_2) {
test_core.Run({}); test_core.Run({});
auto out_tensor = scope.Var("inner_var_10")->Get<phi::DenseTensor>(); auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_10")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_10")
->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 1.80721); bool res0 = simple_cmp(out_tensor.data<float>()[0], 1.80721);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 1.70047); bool res1 = simple_cmp(out_tensor.data<float>()[1], 1.70047);
...@@ -213,7 +221,11 @@ TEST(StandaloneExecutor, data_transfer) { ...@@ -213,7 +221,11 @@ TEST(StandaloneExecutor, data_transfer) {
test_core.Run({}); test_core.Run({});
auto out_tensor = scope.Var("inner_var_9")->Get<phi::DenseTensor>(); auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_9")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_9")
->Get<phi::DenseTensor>();
auto& pool = phi::DeviceContextPool::Instance(); auto& pool = phi::DeviceContextPool::Instance();
phi::DenseTensor out; phi::DenseTensor out;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册