未验证 提交 f00a06d8 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine BuildScope in phi_kernel_util (#55423)

* add code

* fix bug

* refine code

* refine code

* fix bug
上级 7f6d222f
...@@ -185,8 +185,12 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names, ...@@ -185,8 +185,12 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
if (!is_build_) { if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is Running."; LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
::ir::BuildScope( ::ir::BuildScope(*ir_program_->block(),
*ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_); InnerScope(),
&value_2_var_name_,
&variable_2_var_name_,
&var_name_2_id_,
&variable_list_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
interpreter::BuildOpFuncList(place_, interpreter::BuildOpFuncList(place_,
...@@ -194,7 +198,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names, ...@@ -194,7 +198,7 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
&op_func_nodes, &op_func_nodes,
scope_, scope_,
local_scope_, local_scope_,
value_2_var_name_map_, value_2_var_name_,
execution_config_); execution_config_);
// SetFeedVarsInplaceSkip(feed_names); // SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph // convert vec func_list to graph
...@@ -237,8 +241,12 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names, ...@@ -237,8 +241,12 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
SetDeviceId(place_); SetDeviceId(place_);
if (!is_build_) { if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is BetaRunning."; LOG_FIRST_N(INFO, 1) << "New Executor is BetaRunning.";
::ir::BuildScope( ::ir::BuildScope(*ir_program_->block(),
*ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_); InnerScope(),
&value_2_var_name_,
&variable_2_var_name_,
&var_name_2_id_,
&variable_list_);
BuildInstruction(); BuildInstruction();
for (size_t instr_id = 0; instr_id < vec_instruction_base_.size(); for (size_t instr_id = 0; instr_id < vec_instruction_base_.size();
++instr_id) { ++instr_id) {
...@@ -1526,13 +1534,8 @@ void NewIRInterpreter::BuildInstruction() { ...@@ -1526,13 +1534,8 @@ void NewIRInterpreter::BuildInstruction() {
++it) { ++it) {
VLOG(0) << "Build Instruction for op: " << op_idx; VLOG(0) << "Build Instruction for op: " << op_idx;
if ((*it)->dialect()->name() == "pd_kernel") { if ((*it)->dialect()->name() == "pd_kernel") {
vec_instruction_base_.emplace_back( vec_instruction_base_.emplace_back(std::make_unique<PhiKernelInstruction>(
std::make_unique<PhiKernelInstruction>(op_idx++, op_idx++, place_, (*it), scope_, local_scope_, value_2_var_name_));
place_,
(*it),
scope_,
local_scope_,
value_2_var_name_map_));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Now only support pd_kernel dialect.")); "Now only support pd_kernel dialect."));
......
...@@ -192,7 +192,11 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -192,7 +192,11 @@ class NewIRInterpreter : public InterpreterBaseImpl {
std::vector<std::unique_ptr<InstructionBase>> vec_instruction_base_; std::vector<std::unique_ptr<InstructionBase>> vec_instruction_base_;
std::unordered_map<::ir::Value, std::string> value_2_var_name_map_; std::unordered_map<::ir::Value, std::string> value_2_var_name_;
std::unordered_map<const paddle::framework::Variable*, std::string>
variable_2_var_name_;
std::map<std::string, int> var_name_2_id_;
std::vector<Variable*> variable_list_;
}; };
} // namespace framework } // namespace framework
......
...@@ -55,8 +55,18 @@ class PhiKernelAdaptor { ...@@ -55,8 +55,18 @@ class PhiKernelAdaptor {
void run_kernel_prog(ir::Program* program) { void run_kernel_prog(ir::Program* program) {
auto block = program->block(); auto block = program->block();
std::unordered_map<ir::Value, std::string> name_map; std::unordered_map<ir::Value, std::string> value_2_var_name;
BuildScope(*block, scope_, nullptr, &name_map); std::unordered_map<const paddle::framework::Variable*, std::string>
variable_2_var_name;
std::map<std::string, int> var_name_2_id;
std::vector<paddle::framework::Variable*> variable_list;
BuildScope(*block,
scope_,
&value_2_var_name,
&variable_2_var_name,
&var_name_2_id,
&variable_list);
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>(); ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
...@@ -88,7 +98,8 @@ class PhiKernelAdaptor { ...@@ -88,7 +98,8 @@ class PhiKernelAdaptor {
phi::MetaTensor, phi::MetaTensor,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>, paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>, paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
false>((*it), name_map, scope_, nullptr, op_yaml_info_parser, &ctx); false>(
(*it), value_2_var_name, scope_, nullptr, op_yaml_info_parser, &ctx);
infer_meta_impl->infer_meta_(&ctx); infer_meta_impl->infer_meta_(&ctx);
...@@ -108,12 +119,16 @@ class PhiKernelAdaptor { ...@@ -108,12 +119,16 @@ class PhiKernelAdaptor {
phi::TensorBase*, phi::TensorBase*,
paddle::small_vector<const phi::TensorBase*>, paddle::small_vector<const phi::TensorBase*>,
paddle::small_vector<phi::TensorBase*>, paddle::small_vector<phi::TensorBase*>,
true>( true>((*it),
(*it), name_map, scope_, nullptr, op_yaml_info_parser, &kernel_ctx); value_2_var_name,
scope_,
nullptr,
op_yaml_info_parser,
&kernel_ctx);
kernel_fn(&kernel_ctx); kernel_fn(&kernel_ctx);
auto out_value = (*it)->result(0); auto out_value = (*it)->result(0);
out_name = name_map[out_value]; out_name = value_2_var_name[out_value];
} }
} }
......
...@@ -46,10 +46,15 @@ namespace ir { ...@@ -46,10 +46,15 @@ namespace ir {
using VariableNameMap = using VariableNameMap =
std::unordered_map<const paddle::framework::Variable*, std::string>; std::unordered_map<const paddle::framework::Variable*, std::string>;
paddle::framework::Variable* CreateVar(ir::Value value, paddle::framework::Variable* CreateVar(
const std::string& name, ir::Value value,
paddle::framework::Scope* scope, paddle::framework::Scope* inner_scope,
paddle::framework::Scope* local_scope) { bool force_persisable,
std::unordered_map<ir::Value, std::string>* value_2_var_name,
std::unordered_map<const paddle::framework::Variable*, std::string>*
variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
Operation* def_op = value.GetDefiningOp(); Operation* def_op = value.GetDefiningOp();
bool is_persisable = false; bool is_persisable = false;
if (def_op->attributes().count("is_persisable")) { if (def_op->attributes().count("is_persisable")) {
...@@ -58,27 +63,41 @@ paddle::framework::Variable* CreateVar(ir::Value value, ...@@ -58,27 +63,41 @@ paddle::framework::Variable* CreateVar(ir::Value value,
.dyn_cast<ir::BoolAttribute>() .dyn_cast<ir::BoolAttribute>()
.data(); .data();
} }
if (is_persisable) {
VLOG(6) << "Create var: " << name << " in scope " << scope->root(); paddle::framework::Variable* var = nullptr;
return const_cast<paddle::framework::Scope*>(scope->root())->Var(name); std::string name = "inner_var_" + std::to_string(variable_2_var_name->size());
if (force_persisable || is_persisable) {
VLOG(6) << "Create var: " << name << " in scope " << inner_scope->root();
var = const_cast<paddle::framework::Scope*>(inner_scope->root())->Var(name);
} else { } else {
VLOG(6) << "Create var: " << name << " in scope " << local_scope; VLOG(6) << "Create var: " << name << " in scope " << inner_scope;
return local_scope->Var(name); var = inner_scope->Var(name);
} }
value_2_var_name->emplace(value, name);
variable_2_var_name->emplace(var, name);
auto id = var_name_2_id->size();
var_name_2_id->emplace(name, id);
variable_list->push_back(var);
PADDLE_ENFORCE_EQ(
variable_list->size(),
var_name_2_id->size(),
paddle::platform::errors::InvalidArgument(
"The size of variable_list and var_name_2_id map should be equal"));
return var;
} }
void CheckInputVars( void CheckInputVars(
ir::Operation* op, ir::Operation* op,
const std::string& op_name, const std::string& op_name,
const std::unordered_map<ir::Value, std::string>& name_map) { const std::unordered_map<ir::Value, std::string>& value_2_var_name) {
size_t input_num = op->num_operands(); size_t input_num = op->num_operands();
if (input_num > 0) { if (input_num > 0) {
for (size_t i = 0; i < input_num; ++i) { for (size_t i = 0; i < input_num; ++i) {
auto value = op->operand(i); auto value = op->operand(i);
if (value) { if (value) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
name_map.find(value), value_2_var_name.find(value),
name_map.end(), value_2_var_name.end(),
phi::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"input should in name map, [%d] 'th input of [%s] op", "input should in name map, [%d] 'th input of [%s] op",
i, i,
...@@ -89,20 +108,25 @@ void CheckInputVars( ...@@ -89,20 +108,25 @@ void CheckInputVars(
} }
void BuildValue(ir::Value value, void BuildValue(ir::Value value,
paddle::framework::Scope* scope, paddle::framework::Scope* inner_scope,
paddle::framework::Scope* local_scope, std::unordered_map<ir::Value, std::string>* value_2_var_name,
std::unordered_map<ir::Value, std::string>* name_map, std::unordered_map<const paddle::framework::Variable*,
VariableNameMap* variable_name_map, std::string>* variable_2_var_name,
int& count) { // NOLINT std::map<std::string, int>* var_name_2_id,
auto inner_local_scope = local_scope != nullptr ? local_scope : scope; std::vector<paddle::framework::Variable*>* variable_list) {
std::string name; paddle::framework::Variable* var = nullptr;
if (name_map->find(value) != name_map->end()) { if (value_2_var_name->find(value) != value_2_var_name->end()) {
name = name_map->at(value); var = inner_scope->FindVar(value_2_var_name->at(value));
} else { } else {
name = "inner_var_" + std::to_string(count++); var = CreateVar(value,
name_map->emplace(value, name); inner_scope,
false,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
} }
auto var = CreateVar(value, name, scope, inner_local_scope);
// Only support DenseTensor or Vector<DenseTensor> // Only support DenseTensor or Vector<DenseTensor>
if (!value.type()) { if (!value.type()) {
var->GetMutable<phi::DenseTensor>(); var->GetMutable<phi::DenseTensor>();
...@@ -120,11 +144,15 @@ void BuildValue(ir::Value value, ...@@ -120,11 +144,15 @@ void BuildValue(ir::Value value,
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"Element of VectorType output only support " "Element of VectorType output only support "
"DenseTensorType")); "DenseTensorType"));
std::string name_i = "inner_var_" + std::to_string(count++); auto var_i = CreateVar(value,
auto var_i = CreateVar(value, name_i, scope, inner_local_scope); inner_scope,
false,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
var_i->GetMutable<phi::DenseTensor>(); var_i->GetMutable<phi::DenseTensor>();
tensor_array->emplace_back(var_i); tensor_array->emplace_back(var_i);
variable_name_map->emplace(var_i, name_i);
} }
} else { } else {
PADDLE_THROW(phi::errors::PreconditionNotMet( PADDLE_THROW(phi::errors::PreconditionNotMet(
...@@ -132,24 +160,25 @@ void BuildValue(ir::Value value, ...@@ -132,24 +160,25 @@ void BuildValue(ir::Value value,
} }
} }
void HandleForSpecialOp(ir::Operation* op, void HandleForSpecialOp(
const VariableNameMap& variable_name_map, ir::Operation* op,
paddle::framework::Scope* scope, paddle::framework::Scope* inner_scope,
paddle::framework::Scope* local_scope, std::unordered_map<ir::Value, std::string>* value_2_var_name,
std::unordered_map<ir::Value, std::string>* name_map, std::unordered_map<const paddle::framework::Variable*, std::string>*
int& count) { // NOLINT variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
std::string op_name = op->name(); std::string op_name = op->name();
if (op->attributes().count("op_name")) { if (op->attributes().count("op_name")) {
op_name = op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data(); op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
} }
size_t input_num = op->num_operands();
if (op_name == "pd.fetch") { if (op_name == "pd.fetch") {
// fetch is a very special op, with no output // fetch is a very special op, with no output
VLOG(6) << "Handle for pd.fetch:"; auto var = const_cast<paddle::framework::Scope*>(inner_scope->root())
auto var = scope->Var("fetch"); ->Var("fetch");
VLOG(6) << "Create var: fetch in scope " << scope; VLOG(6) << "Create var: fetch in scope " << inner_scope->root();
auto fetch_list = var->GetMutable<paddle::framework::FetchList>(); auto fetch_list = var->GetMutable<paddle::framework::FetchList>();
int index = int index =
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data(); op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
...@@ -157,16 +186,20 @@ void HandleForSpecialOp(ir::Operation* op, ...@@ -157,16 +186,20 @@ void HandleForSpecialOp(ir::Operation* op,
} }
if (op_name == "pd.feed") { if (op_name == "pd.feed") {
VLOG(6) << "Handle for pd.feed:";
auto value = op->result(0); auto value = op->result(0);
std::string name = "inner_var_" + std::to_string(count++); auto var = CreateVar(value,
name_map->emplace(value, name); inner_scope,
auto var = CreateVar(value, name, scope, local_scope); false,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
// TODO(phlrain): need to update here, support StringTensor // TODO(phlrain): need to update here, support StringTensor
auto out_tensor = var->GetMutable<phi::DenseTensor>(); auto out_tensor = var->GetMutable<phi::DenseTensor>();
auto feed_var = scope->Var("feed"); auto feed_var =
VLOG(6) << "Create var: feed in scope " << scope; const_cast<paddle::framework::Scope*>(inner_scope->root())->Var("feed");
VLOG(6) << "Create var: feed in scope " << inner_scope->root();
int index = int index =
op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data(); op->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();
auto feed_list = feed_var->Get<paddle::framework::FeedList>(); auto feed_list = feed_var->Get<paddle::framework::FeedList>();
...@@ -176,30 +209,33 @@ void HandleForSpecialOp(ir::Operation* op, ...@@ -176,30 +209,33 @@ void HandleForSpecialOp(ir::Operation* op,
} }
if (op_name == "builtin.combine") { if (op_name == "builtin.combine") {
VLOG(6) << "Handle for builtin.combine:";
auto out_value = op->result(0); auto out_value = op->result(0);
std::string name;
if (name_map->find(out_value) != name_map->end()) { paddle::framework::Variable* var = nullptr;
name = name_map->at(out_value); if (value_2_var_name->find(out_value) != value_2_var_name->end()) {
var = inner_scope->FindVar(value_2_var_name->at(out_value));
} else { } else {
name = "inner_var_" + std::to_string(count++); var = CreateVar(out_value,
name_map->emplace(out_value, name); inner_scope,
false,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
} }
auto var = CreateVar(out_value, name, scope, local_scope);
auto tensor_array = var->GetMutable<paddle::framework::VariableRefArray>(); auto tensor_array = var->GetMutable<paddle::framework::VariableRefArray>();
// clear tensor array // clear tensor array
tensor_array->clear(); tensor_array->clear();
size_t input_num = op->num_operands();
for (size_t i = 0; i < input_num; ++i) { for (size_t i = 0; i < input_num; ++i) {
auto value = op->operand(i); auto value = op->operand(i);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
name_map->count(value), value_2_var_name->count(value),
true, true,
phi::errors::PreconditionNotMet("can not found input of combine op")); phi::errors::PreconditionNotMet("can not found input of combine op"));
tensor_array->emplace_back( tensor_array->emplace_back(
CreateVar(value, name_map->at(value), scope, local_scope)); inner_scope->FindVar(value_2_var_name->at(value)));
} }
} }
...@@ -210,14 +246,15 @@ void HandleForSpecialOp(ir::Operation* op, ...@@ -210,14 +246,15 @@ void HandleForSpecialOp(ir::Operation* op,
.dyn_cast<ir::StrAttribute>() .dyn_cast<ir::StrAttribute>()
.data(); .data();
auto in_ptr = op->operand(0); auto value = op->operand(0);
// change opreand name to param_name // change opreand name to param_name
auto orig_name = value_2_var_name->at(value);
auto orig_name = name_map->at(in_ptr); if (inner_scope->root()->FindVar(param_name) == nullptr) {
if (scope->FindVar(param_name) == nullptr) { const_cast<paddle::framework::Scope*>(inner_scope->root())
scope->Rename(orig_name, param_name); ->Rename(orig_name, param_name);
} }
(*name_map)[in_ptr] = param_name; (*value_2_var_name)[value] = param_name;
} }
if (op_name == "builtin.get_parameter") { if (op_name == "builtin.get_parameter") {
...@@ -226,44 +263,44 @@ void HandleForSpecialOp(ir::Operation* op, ...@@ -226,44 +263,44 @@ void HandleForSpecialOp(ir::Operation* op,
.at("parameter_name") .at("parameter_name")
.dyn_cast<ir::StrAttribute>() .dyn_cast<ir::StrAttribute>()
.data(); .data();
auto out_ptr = op->result(0); auto value = op->result(0);
name_map->emplace(out_ptr, param_name); value_2_var_name->emplace(value, param_name);
} }
if (op_name == "builtin.slice") { if (op_name == "builtin.slice") {
VLOG(6) << "Handle for builtin.slice"; VLOG(6) << "Handle for builtin.slice";
auto out_value = op->result(0); auto out_value = op->result(0);
auto in_value = op->operand(0); auto in_value = op->operand(0);
PADDLE_ENFORCE_EQ(value_2_var_name->count(in_value),
PADDLE_ENFORCE_EQ(name_map->count(in_value),
true, true,
phi::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"input of buildin slice not in name map")); "input of buildin slice not in name map"));
int index = int index =
op->attributes().at("index").dyn_cast<ir::Int32Attribute>().data(); op->attributes().at("index").dyn_cast<ir::Int32Attribute>().data();
auto in_var = scope->FindVar(name_map->at(in_value)); auto in_var = inner_scope->FindVar(value_2_var_name->at(in_value));
auto variable_array = in_var->Get<paddle::framework::VariableRefArray>(); auto variable_array = in_var->Get<paddle::framework::VariableRefArray>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
variable_name_map.count(variable_array[index]), variable_2_var_name->count(variable_array[index]),
true, true,
phi::errors::PreconditionNotMet("[%d] the variable in build slice " phi::errors::PreconditionNotMet("[%d] the variable in build slice "
"input MUST in variable name map", "input MUST in variable name map",
index)); index));
std::string var_name = variable_name_map.at(variable_array[index]); std::string var_name = variable_2_var_name->at(variable_array[index]);
value_2_var_name->emplace(out_value, var_name);
name_map->emplace(out_value, var_name);
} }
} }
void HandleForInplaceOp(ir::Operation* op, void HandleForInplaceOp(
paddle::framework::Scope* scope, ir::Operation* op,
paddle::framework::Scope* local_scope, paddle::framework::Scope* inner_scope,
std::unordered_map<ir::Value, std::string>* name_map, std::unordered_map<ir::Value, std::string>* value_2_var_name,
int& count) { // NOLINT std::unordered_map<const paddle::framework::Variable*, std::string>*
variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
if (op->num_results() < 1) return; if (op->num_results() < 1) return;
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext* ctx = ir::IrContext::Instance();
std::string op_name = op->name(); std::string op_name = op->name();
...@@ -271,12 +308,12 @@ void HandleForInplaceOp(ir::Operation* op, ...@@ -271,12 +308,12 @@ void HandleForInplaceOp(ir::Operation* op,
op_name = op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data(); op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
} }
VLOG(4) << "HandleForInplaceOp: " << op_name;
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
paddle::dialect::OpYamlInfoParser yaml_parser( paddle::dialect::OpYamlInfoParser yaml_parser(
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>() op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>()
->get_op_info_()); ->get_op_info_());
VariableNameMap variable_name_map;
for (size_t i = 0; i < op->num_results(); ++i) { for (size_t i = 0; i < op->num_results(); ++i) {
ir::Value value = op->result(i); ir::Value value = op->result(i);
std::string value_name = yaml_parser.OutputNames()[i]; std::string value_name = yaml_parser.OutputNames()[i];
...@@ -284,35 +321,36 @@ void HandleForInplaceOp(ir::Operation* op, ...@@ -284,35 +321,36 @@ void HandleForInplaceOp(ir::Operation* op,
std::string inplace_name = yaml_parser.InplaceName(value_name); std::string inplace_name = yaml_parser.InplaceName(value_name);
ir::Value inplace_value = ir::Value inplace_value =
op->operand(yaml_parser.InputName2Id().at(inplace_name)); op->operand(yaml_parser.InputName2Id().at(inplace_name));
std::string var_name = name_map->at(inplace_value); std::string var_name = value_2_var_name->at(inplace_value);
VLOG(4) << "inplace: " << value_name << " -> " << inplace_name VLOG(4) << "inplace: " << value_name << " -> " << inplace_name
<< " (var: " << var_name << ")"; << " (var: " << var_name << ")";
name_map->emplace(value, var_name); value_2_var_name->emplace(value, var_name);
} else { } else {
BuildValue( BuildValue(value,
value, scope, local_scope, name_map, &variable_name_map, count); inner_scope,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
} }
} }
} }
// NOTE(zhiqiu): the persistable is created in inner_scope's root, and other is
// created in inner_scope.
void BuildScope(const ir::Block& block, void BuildScope(const ir::Block& block,
paddle::framework::Scope* scope, paddle::framework::Scope* inner_scope,
paddle::framework::Scope* local_scope, std::unordered_map<ir::Value, std::string>* value_2_var_name,
std::unordered_map<ir::Value, std::string>* name_map) { std::unordered_map<const paddle::framework::Variable*,
VLOG(4) << "***** [before build] scope: ******\n" std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
VLOG(4) << "***** [before build] scope"
<< "(" << inner_scope << ") ******\n"
<< paddle::framework::GenScopeTreeDebugInfo( << paddle::framework::GenScopeTreeDebugInfo(
const_cast<paddle::framework::Scope*>(scope->root())); const_cast<paddle::framework::Scope*>(inner_scope->root()));
// NOTE(zhiqiu): if use local_scope (local_scope != nullptr), the persistable
// is created in scope , and other is created in local_scope. // int count = value_2_var_name->size();
auto inner_local_scope = local_scope != nullptr ? local_scope : scope;
VLOG(6) << "Build: scope [" << scope << "] inner_local_scope ["
<< inner_local_scope << "]";
std::unordered_map<const paddle::framework::Variable*, std::string>
variable_name_map;
// int count = name_map->size();
int count = name_map->size();
for (auto it = block.begin(); it != block.end(); ++it) { for (auto it = block.begin(); it != block.end(); ++it) {
ir::Operation* op = *it; ir::Operation* op = *it;
...@@ -321,19 +359,21 @@ void BuildScope(const ir::Block& block, ...@@ -321,19 +359,21 @@ void BuildScope(const ir::Block& block,
op_name = op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data(); op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
} }
VLOG(4) << "build op:" << op_name;
VLOG(4) << "BuildScope for :" << op_name;
if (op_name == "pd.feed" || op_name == "pd.fetch" || if (op_name == "pd.feed" || op_name == "pd.fetch" ||
op_name == "builtin.combine" || op_name == "builtin.set_parameter" || op_name == "builtin.combine" || op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter" || op_name == "builtin.slice") { op_name == "builtin.get_parameter" || op_name == "builtin.slice") {
VLOG(6) << "HandleForSpecialOp: " << op_name; HandleForSpecialOp(op,
HandleForSpecialOp( inner_scope,
op, variable_name_map, scope, inner_local_scope, name_map, count); value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
continue; continue;
} }
CheckInputVars(op, op_name, *name_map); CheckInputVars(op, op_name, *value_2_var_name);
if (op->num_results() < 1) continue; if (op->num_results() < 1) continue;
if (op->attributes().count("is_inplace") != 0 && if (op->attributes().count("is_inplace") != 0 &&
...@@ -341,23 +381,29 @@ void BuildScope(const ir::Block& block, ...@@ -341,23 +381,29 @@ void BuildScope(const ir::Block& block,
.at("is_inplace") .at("is_inplace")
.dyn_cast<ir::BoolAttribute>() .dyn_cast<ir::BoolAttribute>()
.data()) { .data()) {
HandleForInplaceOp(op, scope, inner_local_scope, name_map, count); HandleForInplaceOp(op,
inner_scope,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list);
continue; continue;
} else { } else {
for (size_t i = 0; i < op->num_results(); ++i) { for (size_t i = 0; i < op->num_results(); ++i) {
BuildValue(op->result(i), BuildValue(op->result(i),
scope, inner_scope,
local_scope, value_2_var_name,
name_map, variable_2_var_name,
&variable_name_map, var_name_2_id,
count); variable_list);
} }
} }
} }
VLOG(4) << "***** [after build] scope: ******\n" VLOG(4) << "***** [after build] scope"
<< "(" << inner_scope << ") ******\n"
<< paddle::framework::GenScopeTreeDebugInfo( << paddle::framework::GenScopeTreeDebugInfo(
const_cast<paddle::framework::Scope*>(scope->root())); const_cast<paddle::framework::Scope*>(inner_scope->root()));
} }
} // namespace ir } // namespace ir
...@@ -41,36 +41,13 @@ ...@@ -41,36 +41,13 @@
#include "glog/logging.h" #include "glog/logging.h"
namespace ir { namespace ir {
paddle::framework::Variable* CreateVar(ir::Value value,
const std::string& name,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope);
void BuildValue(ir::Value value,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count); // NOLINT
void HandleForSpecialOp(ir::Operation* op,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count); // NOLINT
void HandleForInplaceOp(ir::Operation* op,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
int& count); // NOLINT
void CheckInputVars(ir::Operation* op,
const std::unordered_map<ir::Value, std::string>& name_map);
void BuildScope(const ir::Block& block, void BuildScope(const ir::Block& block,
paddle::framework::Scope* scope, paddle::framework::Scope* inner_scope,
paddle::framework::Scope* local_scope, std::unordered_map<ir::Value, std::string>* value_2_var_name,
std::unordered_map<ir::Value, std::string>* name_map); std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list);
template <typename Context, template <typename Context,
typename InType, typename InType,
...@@ -322,6 +299,7 @@ void BuildPhiContext( ...@@ -322,6 +299,7 @@ void BuildPhiContext(
for (size_t i = 0; i < op->num_results(); ++i) { for (size_t i = 0; i < op->num_results(); ++i) {
ir::Value out_ptr = op->result(i); ir::Value out_ptr = op->result(i);
auto name = name_map.at(out_ptr); auto name = name_map.at(out_ptr);
VLOG(6) << "ctx->EmplaceBackOutput: " << name;
auto out_type = out_ptr.type(); auto out_type = out_ptr.type();
if (!out_type) { if (!out_type) {
phi::DenseTensor* ptr = nullptr; phi::DenseTensor* ptr = nullptr;
...@@ -329,14 +307,14 @@ void BuildPhiContext( ...@@ -329,14 +307,14 @@ void BuildPhiContext(
ctx->EmplaceBackOutput(out_ptr); ctx->EmplaceBackOutput(out_ptr);
} else if (out_type.isa<paddle::dialect::AllocatedDenseTensorType>()) { } else if (out_type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>( ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>(
&(inner_scope->Var(name)->Get<phi::DenseTensor>())))); &(inner_scope->FindVar(name)->Get<phi::DenseTensor>()))));
} else if (out_type.isa<paddle::dialect::AllocatedSelectedRowsType>()) { } else if (out_type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::SelectedRows*>( ctx->EmplaceBackOutput(OutType(const_cast<phi::SelectedRows*>(
&(inner_scope->Var(name)->Get<phi::SelectedRows>())))); &(inner_scope->FindVar(name)->Get<phi::SelectedRows>()))));
} else if (out_type.isa<ir::VectorType>()) { } else if (out_type.isa<ir::VectorType>()) {
OutListType outputs; OutListType outputs;
auto& variable_array = auto& variable_array =
scope->Var(name)->Get<paddle::framework::VariableRefArray>(); scope->FindVar(name)->Get<paddle::framework::VariableRefArray>();
for (size_t i = 0; i < variable_array.size(); ++i) { for (size_t i = 0; i < variable_array.size(); ++i) {
outputs.emplace_back(OutType(const_cast<phi::DenseTensor*>( outputs.emplace_back(OutType(const_cast<phi::DenseTensor*>(
&(variable_array[i]->Get<phi::DenseTensor>())))); &(variable_array[i]->Get<phi::DenseTensor>()))));
...@@ -360,6 +338,7 @@ void BuildPhiContext( ...@@ -360,6 +338,7 @@ void BuildPhiContext(
} }
} }
} }
VLOG(6) << "Done build phi context";
} }
} // namespace ir } // namespace ir
...@@ -89,5 +89,42 @@ TEST(StandaloneExecutor, run) { ...@@ -89,5 +89,42 @@ TEST(StandaloneExecutor, run) {
EXPECT_EQ(res3, true); EXPECT_EQ(res3, true);
} }
TEST(StandaloneExecutor, run_inplace_sqrt) {
ir::IrContext* ctx = ir::IrContext::Instance();
ir::Program program((ctx));
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Builder builder = ir::Builder(ctx, program.block());
paddle::dialect::FullOp full = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 4.0, phi::DataType::FLOAT32, phi::CPUPlace());
builder.Build<paddle::dialect::Sqrt_Op>(full->result(0));
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);
auto place = platform::CPUPlace();
Scope scope;
InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.BetaRun({});
auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_0")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_0")
->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 2.0);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 2.0);
bool res2 = simple_cmp(out_tensor.data<float>()[2], 2.0);
bool res3 = simple_cmp(out_tensor.data<float>()[3], 2.0);
EXPECT_EQ(scope.kids().size(), 1u);
EXPECT_EQ(scope.kids().front()->Size(), 1u);
EXPECT_EQ(res0, true);
EXPECT_EQ(res1, true);
EXPECT_EQ(res2, true);
EXPECT_EQ(res3, true);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册