diff --git a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc index d687428a1cb2c7ee0a56093307fd9364437e40d1..11c2a3814a013372ef211fa20aad72df99bea94d 100644 --- a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc @@ -333,7 +333,7 @@ void PhiKernelInstruction::InitInputsOutputsIds( std::unordered_map> outputs; for (size_t i = 0; i < op->num_results(); i++) { ir::Value value = op->result(i); - if (value) { + if (value && value.type()) { PADDLE_ENFORCE_NE( value_2_var_name.find(value), value_2_var_name.end(), diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 8aa91b434c1db1719b8693529fe4ef3a1fb9843f..b620ca4fc6395b20058aea6e3aa1906b6afc440a 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -977,10 +977,7 @@ void BuildOpFuncList( attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); op_func_node.phi_op_name_ = op_name; - if (op_name == "builtin.combine" || op_name == "pd.feed" || - op_name == "builtin.set_parameter" || - op_name == "builtin.get_parameter" || op_name == "builtin.slice" || - op_name == "pd.data" || op_name == "pd.shadow_output") { + if (GetSpecialOpNames().count(op_name)) { VLOG(6) << "skip process " << op_name; continue; } @@ -1171,6 +1168,18 @@ void SetDeviceCommContext(::ir::Operation* op, } } +std::unordered_set GetSpecialOpNames() { + return { + "builtin.combine", + "builtin.slice", + "pd.feed", + "builtin.set_parameter", + "builtin.get_parameter", + "pd.data", + "pd.shadow_output", + }; +} + } // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h index b37e46d5206e661924645b0c363e2410c5350f74..186f9459fbac7a68586490ca80d54d07bf33bacc 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h @@ -124,6 +124,8 @@ void SetDeviceCommContext(framework::OperatorBase* operator_base, void SetDeviceCommContext(::ir::Operation* op, platform::DeviceContext* dev_ctx); + +std::unordered_set GetSpecialOpNames(); } // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index 28a9b8e75d6e6656a629ae3072528255ca01ae7f..17748c0c8b6e80c11c0dddf33b4cf278ed307081 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -219,10 +219,11 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, &value_2_var_name_, &variable_2_var_name_, &var_name_2_id_, - &variable_list_, - ¶meter_values_); + &variable_list_); VLOG(4) << DebugValueInfo(); + SolvePersisableVarNames(); + std::vector op_func_nodes; interpreter::BuildOpFuncList(place_, ir_program_->block(), @@ -1595,7 +1596,9 @@ void NewIRInterpreter::AnalyseExecuteOrderForTrace( std::stringstream ss; ss << "trace order: "; for (size_t idx = 0; idx < trace_execute_order_.size(); idx++) { - ss << trace_execute_order_[idx] << " -> "; + ss << vec_instruction_base_[trace_execute_order_[idx]]->Name() << "[" + << trace_execute_order_[idx] << "]" + << " -> "; } ss << "end\n"; VLOG(6) << ss.str(); @@ -1616,10 +1619,7 @@ void NewIRInterpreter::BuildInstruction() { .at("op_name") .dyn_cast<::ir::StrAttribute>() .AsString(); - if (op_name == "builtin.combine" || op_name == "pd.feed" || - op_name == "builtin.set_parameter" || - op_name == "builtin.get_parameter" || op_name == "builtin.slice" || - op_name == "pd.data" || op_name == "pd.shadow_output") { + if (interpreter::GetSpecialOpNames().count(op_name)) { VLOG(6) << "skip process " << op_name; continue; } @@ -1793,18 +1793,8 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) { VLOG(4) << "GC sync " << GetNameById(var_id); // persistable var will be ignore while GC - ::ir::Value value = GetValueByName(GetNameById(var_id)); - bool is_parameter = false; - if (value) { - for (auto item : parameter_values_) { - if (item == value) { - is_parameter = true; - break; - } - } - } - if (is_parameter) { - VLOG(4) << "value " << value.impl() << " is a parameter, skip gc"; + if (parameter_var_names_.count(GetNameById(var_id))) { + VLOG(4) << GetNameById(var_id) << " is a parameter, skip gc"; continue; } @@ -1851,18 +1841,8 @@ void NewIRInterpreter::CheckGC(InstructionBase* instr) { << ", ref:" << refs_[var_id]->DynamicRef(); bool is_ready = refs_[var_id]->CheckAndDecrease(); // ignore all persistable var while GCphi - ::ir::Value value = GetValueByName(GetNameById(var_id)); - bool is_parameter = false; - if (value) { - for (auto item : parameter_values_) { - if (item == value) { - is_parameter = true; - break; - } - } - } - if (is_parameter) { - VLOG(4) << "value " << value.impl() << " is a parameter, skip gc"; + if (parameter_var_names_.count(GetNameById(var_id))) { + VLOG(4) << GetNameById(var_id) << " is a parameter, skip gc"; continue; } @@ -2020,11 +2000,17 @@ FetchList NewIRInterpreter::BetaRun(const std::vector& feed_names, &value_2_var_name_, &variable_2_var_name_, &var_name_2_id_, - &variable_list_, - ¶meter_values_); + &variable_list_); VLOG(4) << "Done BuildScope"; VLOG(4) << DebugValueInfo(); + SolvePersisableVarNames(); + + VLOG(4) << "Parameter value include: "; + for (auto parameter : parameter_var_names_) { + VLOG(4) << "Parameter value: " << parameter; + } + BuildInstruction(); VLOG(4) << "Done BuildInstruction"; @@ -2032,15 +2018,9 @@ FetchList NewIRInterpreter::BetaRun(const std::vector& feed_names, VLOG(4) << "Done PreAnalysis"; // Run - if (FLAGS_enable_new_ir_in_executor_loop_run) { - LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode " - "with for_loop version."; - LoopRunImpl(); - } else { - LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode " - "with trace version."; - TraceRunImpl(); - } + LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode " + "with for_loop version(First step)."; + LoopRunImpl(); is_build_ = true; } else { if (FLAGS_enable_new_ir_in_executor_loop_run) { @@ -2177,7 +2157,8 @@ void NewIRInterpreter::TraceRunInstructionList( auto instr_id = trace_execute_order_[idx]; InstructionBase* instr_node = vec_instruction_base_.at(instr_id).get(); - VLOG(6) << "Run InstructionBase " << instr_id; + VLOG(6) << "Run InstructionBase " << instr_node->Name() << "[" << instr_id + << "]"; RunInstructionBase(instr_node); if (UNLIKELY(exception_holder_.IsCaught())) { @@ -2263,5 +2244,26 @@ void NewIRInterpreter::PreAnalysis() { return nullptr; } +void NewIRInterpreter::SolvePersisableVarNames() { + VLOG(6) << "SolvePersisableVarNames"; + for (auto kv : value_2_var_name_) { + ::ir::Value value = kv.first; + std::string var_name = kv.second; + ::ir::OpResult result = value.dyn_cast<::ir::OpResult>(); + auto* defining_op = value.GetDefiningOp(); + if (defining_op->HasAttribute(kAttrIsPersisable)) { + auto is_persisables = defining_op->attribute(kAttrIsPersisable) + .dyn_cast<::ir::ArrayAttribute>() + .AsVector(); + if (is_persisables[result.GetResultIndex()] + .dyn_cast<::ir::BoolAttribute>() + .data()) { + VLOG(6) << "parameter_var_names_ include: " << var_name; + parameter_var_names_.insert(var_name); + } + } + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.h b/paddle/fluid/framework/new_executor/new_ir_interpreter.h index 1388a0276e4655184bd9ebefe5b25198aa6659c5..6c9f975a0ef414963fb429832eb7e7da0e4fcaf0 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.h +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.h @@ -235,6 +235,8 @@ class NewIRInterpreter : public InterpreterBaseImpl { void RecordStreamForGC(InstructionBase* instr); + void SolvePersisableVarNames(); + InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less; std::unique_ptr<::ir::Program> ir_program_{nullptr}; @@ -260,7 +262,7 @@ class NewIRInterpreter : public InterpreterBaseImpl { // Note(zhangbo): set_parameter_op's input and get_parameter_op's output // belongs to a parameter and cannot GC. - std::vector<::ir::Value> parameter_values_; + std::unordered_set parameter_var_names_; }; } // namespace framework diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index 9156c46dc6dc27f442989c257f83f146df7e8da3..9833d37cca5d37b20e4f2f718f7d988d3fd3ffcf 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -1503,6 +1503,18 @@ void ProgramInterpreter::AnalyseExecuteOrderForTrace() { "trace_order size should be equal to dependecy_count_.")); trace_execute_order_ = trace_order; + + std::stringstream ss; + ss << "trace order: "; + for (size_t idx = 0; idx < trace_execute_order_.size(); idx++) { + ss << vec_instruction_[trace_execute_order_[idx]] + .OpFunc() + ->operator_base_->Type() + << "[" << trace_execute_order_[idx] << "]" + << " -> "; + } + ss << "end\n"; + VLOG(6) << ss.str(); } } // namespace framework diff --git a/paddle/fluid/ir/interface/op_yaml_info_parser.cc b/paddle/fluid/ir/interface/op_yaml_info_parser.cc index 11c99d3b3fc2739f075450d3bc1763f366ce6a3c..44453c160aedef277e89b09889efb3a4e47e6df5 100644 --- a/paddle/fluid/ir/interface/op_yaml_info_parser.cc +++ b/paddle/fluid/ir/interface/op_yaml_info_parser.cc @@ -118,6 +118,28 @@ const std::string& OpYamlInfoParser::InplaceName( "Can not find inplace input of [%s].", out_name)); } +bool OpYamlInfoParser::HasView(const std::string& out_name) const { + auto& view_info = std::get<3>(op_info_tuple_).view; + for (size_t i = 0; i < view_info.size(); i++) { + if (out_name == view_info[i].first) { + return true; + } + } + return false; +} + +const std::string& OpYamlInfoParser::ViewName( + const std::string& out_name) const { + auto& view_info = std::get<3>(op_info_tuple_).view; + for (size_t i = 0; i < view_info.size(); i++) { + if (out_name == view_info[i].first) { + return view_info[i].second; + } + } + PADDLE_THROW(phi::errors::PreconditionNotMet( + "Can not find inplace input of [%s].", out_name)); +} + void OpYamlInfoParser::parse() { auto input_info = std::get<0>(op_info_tuple_); diff --git a/paddle/fluid/ir/interface/op_yaml_info_parser.h b/paddle/fluid/ir/interface/op_yaml_info_parser.h index 356decadcf677fffe0c5383967822c619d46286a..8aae9ef10ee2499d3bccbead1511a9d9bf2ef4ad 100644 --- a/paddle/fluid/ir/interface/op_yaml_info_parser.h +++ b/paddle/fluid/ir/interface/op_yaml_info_parser.h @@ -53,6 +53,10 @@ class OpYamlInfoParser { const std::string& InplaceName(const std::string& out_name) const; + bool HasView(const std::string& out_name) const; + + const std::string& ViewName(const std::string& out_name) const; + private: void parse(); inline const std::vector& InputInfo() const { diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h index 980013e11973d7bf31e54ab095dc527dd8591a4b..24066abecc04344cbaefb714c08bdbf4efa6a46a 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h @@ -69,8 +69,7 @@ class PhiKernelAdaptor { &value_2_var_name, &variable_2_var_name, &var_name_2_id, - &variable_list, - nullptr); + &variable_list); ir::IrContext* ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index a81485fdeddbdbcc750f4db59e1077535257fbb3..797a98bc10e75845fd405f66ec36c65eaef558cb 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -217,8 +217,7 @@ void HandleForSpecialOp( std::unordered_map* variable_2_var_name, std::map* var_name_2_id, - std::vector* variable_list, - std::vector<::ir::Value>* parameter_values) { + std::vector* variable_list) { std::string op_name = op->name(); if (op->attributes().count("op_name")) { op_name = @@ -347,10 +346,6 @@ void HandleForSpecialOp( value_2_var_name, variable_2_var_name, var_name_2_id); - - if (parameter_values) { - parameter_values->push_back(value); - } } if (op_name == "pd.shadow_output") { @@ -390,10 +385,6 @@ void HandleForSpecialOp( variable_2_var_name, var_name_2_id, variable_list); - - if (parameter_values) { - parameter_values->push_back(value); - } } if (op_name == "builtin.slice") { @@ -458,6 +449,14 @@ void HandleForInplaceOp( VLOG(4) << "inplace: " << value_name << " -> " << inplace_name << " (var: " << var_name << ")"; value_2_var_name->emplace(value, var_name); + } else if (yaml_parser.HasView(value_name)) { + std::string view_name = yaml_parser.ViewName(value_name); + ir::Value view_value = + op->operand_source(yaml_parser.InputName2Id().at(view_name)); + std::string var_name = value_2_var_name->at(view_value); + VLOG(4) << "view: " << value_name << " -> " << view_name + << " (var: " << var_name << ")"; + value_2_var_name->emplace(value, var_name); } else { BuildValue(value, inner_scope, @@ -479,8 +478,7 @@ void BuildScope(const ir::Block& block, std::unordered_map* variable_2_var_name, std::map* var_name_2_id, - std::vector* variable_list, - std::vector<::ir::Value>* parameter_values) { + std::vector* variable_list) { VLOG(4) << "***** [before build] scope" << "(" << inner_scope << ") ******\n" << paddle::framework::GenScopeTreeDebugInfo( @@ -506,8 +504,7 @@ void BuildScope(const ir::Block& block, value_2_var_name, variable_2_var_name, var_name_2_id, - variable_list, - parameter_values); + variable_list); continue; } diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h index f14cd35621b5a55451b57f01e21a17f0bffe13ae..e39c452c788cbe9aeb88a52362ab5b74851adb80 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h @@ -49,8 +49,7 @@ void BuildScope(const ir::Block& block, std::unordered_map* variable_2_var_name, std::map* var_name_2_id, - std::vector* variable_list, - std::vector<::ir::Value>* parameter_values); + std::vector* variable_list); void BuildRuntimeContext( ir::Operation* op, @@ -288,23 +287,29 @@ void BuildPhiContext(ir::Operation* op, // TODO(phlrain): use var type instead of op name for (size_t i = 0; i < op->num_results(); ++i) { ir::Value out_ptr = op->result(i); - auto name = name_map.at(out_ptr); - VLOG(6) << "ctx->EmplaceBackOutput: " << name; auto out_type = out_ptr.type(); + if (out_type) { + auto name = name_map.at(out_ptr); + VLOG(6) << "ctx->EmplaceBackOutput: " << name; + } else { + VLOG(6) << "ctx->EmplaceBackOutput : an optioanl output"; + } if (!out_type) { phi::DenseTensor* ptr = nullptr; OutType out_ptr(ptr); ctx->EmplaceBackOutput(out_ptr); } else if (out_type.isa()) { ctx->EmplaceBackOutput(OutType(const_cast( - &(inner_scope->FindVar(name)->Get())))); + &(inner_scope->FindVar(name_map.at(out_ptr)) + ->Get())))); } else if (out_type.isa()) { ctx->EmplaceBackOutput(OutType(const_cast( - &(inner_scope->FindVar(name)->Get())))); + &(inner_scope->FindVar(name_map.at(out_ptr)) + ->Get())))); } else if (out_type.isa()) { OutListType outputs; - auto& variable_array = - scope->FindVar(name)->Get(); + auto& variable_array = scope->FindVar(name_map.at(out_ptr)) + ->Get(); for (size_t i = 0; i < variable_array.size(); ++i) { outputs.emplace_back(OutType(const_cast( &(variable_array[i]->Get())))); diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 7c30a1399fb8e8a3ba4f5137df22ad062dc4845d..1168a6af4d8b1b673ece06be5c32e1adc0b64ae8 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -72,9 +72,7 @@ using AttributeHandlerFn = std::function special_non_inplace_ops = { - "batch_norm", -}; +static const std::unordered_set special_non_inplace_ops = {}; static const std::unordered_set special_inplace_ops = { "adagrad", diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 357ca28d6f5adf1ceff10a76c32645ba3ab37505..342d5f1aa0ac070af733baf5d2b46f437f2d216c 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -77,6 +77,11 @@ void ProgramTranslator::Translate() { const BlockDesc& block = legacy_program_->Block(block_idx); SetStopGradientAttributeForAllValue(block); } + + for (size_t block_idx = 0; block_idx < legacy_program_->Size(); block_idx++) { + const BlockDesc& block = legacy_program_->Block(block_idx); + SetIsPersisableAttributeForAllValue(block); + } } inline ir::Operation* InsertGetParamaterOp(ir::IrContext* ctx, @@ -268,5 +273,44 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue( } } +void ProgramTranslator::SetIsPersisableAttributeForAllValue( + const BlockDesc& block) { + // Currently we set is persisable for operation that generated a value + // connected with VarDesc + for (const auto& [var_name, value_info] : param_map_) { + if (no_cast_var_names.count(var_name) != 0) continue; + VLOG(10) << "[op translated][is persisable]" << var_name; + VarDesc* var = block.FindVarRecursive(var_name); + if (var == nullptr) { + continue; + } + ir::OpResult value = value_info.value; + if (!value) { + PADDLE_THROW(phi::errors::PreconditionNotMet( + "Value of [%s] can not ber None", var_name)); + } + auto* defining_op = value.owner(); + PADDLE_ENFORCE_NOT_NULL( + defining_op, + phi::errors::PreconditionNotMet( + "Defining operator of [%s] can not be nullptr", var_name)); + VLOG(8) << "[op translated][is persisable]" << var_name + << " from: " << defining_op->name(); + std::vector is_persisable; + if (defining_op->HasAttribute(kAttrIsPersisable)) { + is_persisable = defining_op->attribute(kAttrIsPersisable) + .dyn_cast() + .AsVector(); + } else { + is_persisable = std::vector( + defining_op->num_results(), ir::BoolAttribute::get(ctx_, false)); + } + is_persisable[value.GetResultIndex()] = + ir::BoolAttribute::get(ctx_, var->Persistable()); + defining_op->set_attribute(kAttrIsPersisable, + ir::ArrayAttribute::get(ctx_, is_persisable)); + } +} + } // namespace translator } // namespace paddle diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index ce34fec141912e6acc60afb020c7418bb24ffab6..88901376ae3cb08bfbdfed875cedb68b73c27aa7 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -79,6 +79,7 @@ class ProgramTranslator { void InsertOperationToSingleBlock(const BlockDesc& block); void SetParameterFromSingleBlock(const BlockDesc& block); void SetStopGradientAttributeForAllValue(const BlockDesc& block); + void SetIsPersisableAttributeForAllValue(const BlockDesc& block); }; } // namespace translator diff --git a/paddle/ir/core/attribute.h b/paddle/ir/core/attribute.h index 3b96018293db79f8132484bcb72b5980092b77f9..4315e13b0fcaddcdaa81c5bcbe0e36cd5ab89276 100644 --- a/paddle/ir/core/attribute.h +++ b/paddle/ir/core/attribute.h @@ -18,6 +18,7 @@ #include "paddle/ir/core/type_id.h" constexpr char kAttrStopGradients[] = "stop_gradient"; +constexpr char kAttrIsPersisable[] = "is_persisable"; namespace ir { class AttributeStorage;