未验证 提交 a25f013b 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Sovle bugs (#55991)

* sovle conflict bug

* fix bug
上级 ddfbf135
...@@ -333,7 +333,7 @@ void PhiKernelInstruction::InitInputsOutputsIds( ...@@ -333,7 +333,7 @@ void PhiKernelInstruction::InitInputsOutputsIds(
std::unordered_map<ir::Value, std::vector<int>> outputs; std::unordered_map<ir::Value, std::vector<int>> outputs;
for (size_t i = 0; i < op->num_results(); i++) { for (size_t i = 0; i < op->num_results(); i++) {
ir::Value value = op->result(i); ir::Value value = op->result(i);
if (value) { if (value && value.type()) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
value_2_var_name.find(value), value_2_var_name.find(value),
value_2_var_name.end(), value_2_var_name.end(),
......
...@@ -977,10 +977,7 @@ void BuildOpFuncList( ...@@ -977,10 +977,7 @@ void BuildOpFuncList(
attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().AsString();
op_func_node.phi_op_name_ = op_name; op_func_node.phi_op_name_ = op_name;
if (op_name == "builtin.combine" || op_name == "pd.feed" || if (GetSpecialOpNames().count(op_name)) {
op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter" || op_name == "builtin.slice" ||
op_name == "pd.data" || op_name == "pd.shadow_output") {
VLOG(6) << "skip process " << op_name; VLOG(6) << "skip process " << op_name;
continue; continue;
} }
...@@ -1171,6 +1168,18 @@ void SetDeviceCommContext(::ir::Operation* op, ...@@ -1171,6 +1168,18 @@ void SetDeviceCommContext(::ir::Operation* op,
} }
} }
std::unordered_set<std::string> GetSpecialOpNames() {
return {
"builtin.combine",
"builtin.slice",
"pd.feed",
"builtin.set_parameter",
"builtin.get_parameter",
"pd.data",
"pd.shadow_output",
};
}
} // namespace interpreter } // namespace interpreter
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -124,6 +124,8 @@ void SetDeviceCommContext(framework::OperatorBase* operator_base, ...@@ -124,6 +124,8 @@ void SetDeviceCommContext(framework::OperatorBase* operator_base,
void SetDeviceCommContext(::ir::Operation* op, void SetDeviceCommContext(::ir::Operation* op,
platform::DeviceContext* dev_ctx); platform::DeviceContext* dev_ctx);
std::unordered_set<std::string> GetSpecialOpNames();
} // namespace interpreter } // namespace interpreter
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -219,10 +219,11 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names, ...@@ -219,10 +219,11 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
&value_2_var_name_, &value_2_var_name_,
&variable_2_var_name_, &variable_2_var_name_,
&var_name_2_id_, &var_name_2_id_,
&variable_list_, &variable_list_);
&parameter_values_);
VLOG(4) << DebugValueInfo(); VLOG(4) << DebugValueInfo();
SolvePersisableVarNames();
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
interpreter::BuildOpFuncList(place_, interpreter::BuildOpFuncList(place_,
ir_program_->block(), ir_program_->block(),
...@@ -1595,7 +1596,9 @@ void NewIRInterpreter::AnalyseExecuteOrderForTrace( ...@@ -1595,7 +1596,9 @@ void NewIRInterpreter::AnalyseExecuteOrderForTrace(
std::stringstream ss; std::stringstream ss;
ss << "trace order: "; ss << "trace order: ";
for (size_t idx = 0; idx < trace_execute_order_.size(); idx++) { for (size_t idx = 0; idx < trace_execute_order_.size(); idx++) {
ss << trace_execute_order_[idx] << " -> "; ss << vec_instruction_base_[trace_execute_order_[idx]]->Name() << "["
<< trace_execute_order_[idx] << "]"
<< " -> ";
} }
ss << "end\n"; ss << "end\n";
VLOG(6) << ss.str(); VLOG(6) << ss.str();
...@@ -1616,10 +1619,7 @@ void NewIRInterpreter::BuildInstruction() { ...@@ -1616,10 +1619,7 @@ void NewIRInterpreter::BuildInstruction() {
.at("op_name") .at("op_name")
.dyn_cast<::ir::StrAttribute>() .dyn_cast<::ir::StrAttribute>()
.AsString(); .AsString();
if (op_name == "builtin.combine" || op_name == "pd.feed" || if (interpreter::GetSpecialOpNames().count(op_name)) {
op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter" || op_name == "builtin.slice" ||
op_name == "pd.data" || op_name == "pd.shadow_output") {
VLOG(6) << "skip process " << op_name; VLOG(6) << "skip process " << op_name;
continue; continue;
} }
...@@ -1793,18 +1793,8 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) { ...@@ -1793,18 +1793,8 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) {
VLOG(4) << "GC sync " << GetNameById(var_id); VLOG(4) << "GC sync " << GetNameById(var_id);
// persistable var will be ignore while GC // persistable var will be ignore while GC
::ir::Value value = GetValueByName(GetNameById(var_id)); if (parameter_var_names_.count(GetNameById(var_id))) {
bool is_parameter = false; VLOG(4) << GetNameById(var_id) << " is a parameter, skip gc";
if (value) {
for (auto item : parameter_values_) {
if (item == value) {
is_parameter = true;
break;
}
}
}
if (is_parameter) {
VLOG(4) << "value " << value.impl() << " is a parameter, skip gc";
continue; continue;
} }
...@@ -1851,18 +1841,8 @@ void NewIRInterpreter::CheckGC(InstructionBase* instr) { ...@@ -1851,18 +1841,8 @@ void NewIRInterpreter::CheckGC(InstructionBase* instr) {
<< ", ref:" << refs_[var_id]->DynamicRef(); << ", ref:" << refs_[var_id]->DynamicRef();
bool is_ready = refs_[var_id]->CheckAndDecrease(); bool is_ready = refs_[var_id]->CheckAndDecrease();
// ignore all persistable var while GCphi // ignore all persistable var while GCphi
::ir::Value value = GetValueByName(GetNameById(var_id)); if (parameter_var_names_.count(GetNameById(var_id))) {
bool is_parameter = false; VLOG(4) << GetNameById(var_id) << " is a parameter, skip gc";
if (value) {
for (auto item : parameter_values_) {
if (item == value) {
is_parameter = true;
break;
}
}
}
if (is_parameter) {
VLOG(4) << "value " << value.impl() << " is a parameter, skip gc";
continue; continue;
} }
...@@ -2020,11 +2000,17 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names, ...@@ -2020,11 +2000,17 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
&value_2_var_name_, &value_2_var_name_,
&variable_2_var_name_, &variable_2_var_name_,
&var_name_2_id_, &var_name_2_id_,
&variable_list_, &variable_list_);
&parameter_values_);
VLOG(4) << "Done BuildScope"; VLOG(4) << "Done BuildScope";
VLOG(4) << DebugValueInfo(); VLOG(4) << DebugValueInfo();
SolvePersisableVarNames();
VLOG(4) << "Parameter value include: ";
for (auto parameter : parameter_var_names_) {
VLOG(4) << "Parameter value: " << parameter;
}
BuildInstruction(); BuildInstruction();
VLOG(4) << "Done BuildInstruction"; VLOG(4) << "Done BuildInstruction";
...@@ -2032,15 +2018,9 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names, ...@@ -2032,15 +2018,9 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
VLOG(4) << "Done PreAnalysis"; VLOG(4) << "Done PreAnalysis";
// Run // Run
if (FLAGS_enable_new_ir_in_executor_loop_run) { LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode "
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode " "with for_loop version(First step).";
"with for_loop version."; LoopRunImpl();
LoopRunImpl();
} else {
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode "
"with trace version.";
TraceRunImpl();
}
is_build_ = true; is_build_ = true;
} else { } else {
if (FLAGS_enable_new_ir_in_executor_loop_run) { if (FLAGS_enable_new_ir_in_executor_loop_run) {
...@@ -2177,7 +2157,8 @@ void NewIRInterpreter::TraceRunInstructionList( ...@@ -2177,7 +2157,8 @@ void NewIRInterpreter::TraceRunInstructionList(
auto instr_id = trace_execute_order_[idx]; auto instr_id = trace_execute_order_[idx];
InstructionBase* instr_node = vec_instruction_base_.at(instr_id).get(); InstructionBase* instr_node = vec_instruction_base_.at(instr_id).get();
VLOG(6) << "Run InstructionBase " << instr_id; VLOG(6) << "Run InstructionBase " << instr_node->Name() << "[" << instr_id
<< "]";
RunInstructionBase(instr_node); RunInstructionBase(instr_node);
if (UNLIKELY(exception_holder_.IsCaught())) { if (UNLIKELY(exception_holder_.IsCaught())) {
...@@ -2263,5 +2244,26 @@ void NewIRInterpreter::PreAnalysis() { ...@@ -2263,5 +2244,26 @@ void NewIRInterpreter::PreAnalysis() {
return nullptr; return nullptr;
} }
void NewIRInterpreter::SolvePersisableVarNames() {
VLOG(6) << "SolvePersisableVarNames";
for (auto kv : value_2_var_name_) {
::ir::Value value = kv.first;
std::string var_name = kv.second;
::ir::OpResult result = value.dyn_cast<::ir::OpResult>();
auto* defining_op = value.GetDefiningOp();
if (defining_op->HasAttribute(kAttrIsPersisable)) {
auto is_persisables = defining_op->attribute(kAttrIsPersisable)
.dyn_cast<::ir::ArrayAttribute>()
.AsVector();
if (is_persisables[result.GetResultIndex()]
.dyn_cast<::ir::BoolAttribute>()
.data()) {
VLOG(6) << "parameter_var_names_ include: " << var_name;
parameter_var_names_.insert(var_name);
}
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -235,6 +235,8 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -235,6 +235,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void RecordStreamForGC(InstructionBase* instr); void RecordStreamForGC(InstructionBase* instr);
void SolvePersisableVarNames();
InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less; InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less;
std::unique_ptr<::ir::Program> ir_program_{nullptr}; std::unique_ptr<::ir::Program> ir_program_{nullptr};
...@@ -260,7 +262,7 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -260,7 +262,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {
// Note(zhangbo): set_parameter_op's input and get_parameter_op's output // Note(zhangbo): set_parameter_op's input and get_parameter_op's output
// belongs to a parameter and cannot GC. // belongs to a parameter and cannot GC.
std::vector<::ir::Value> parameter_values_; std::unordered_set<std::string> parameter_var_names_;
}; };
} // namespace framework } // namespace framework
......
...@@ -1503,6 +1503,18 @@ void ProgramInterpreter::AnalyseExecuteOrderForTrace() { ...@@ -1503,6 +1503,18 @@ void ProgramInterpreter::AnalyseExecuteOrderForTrace() {
"trace_order size should be equal to dependecy_count_.")); "trace_order size should be equal to dependecy_count_."));
trace_execute_order_ = trace_order; trace_execute_order_ = trace_order;
std::stringstream ss;
ss << "trace order: ";
for (size_t idx = 0; idx < trace_execute_order_.size(); idx++) {
ss << vec_instruction_[trace_execute_order_[idx]]
.OpFunc()
->operator_base_->Type()
<< "[" << trace_execute_order_[idx] << "]"
<< " -> ";
}
ss << "end\n";
VLOG(6) << ss.str();
} }
} // namespace framework } // namespace framework
......
...@@ -118,6 +118,28 @@ const std::string& OpYamlInfoParser::InplaceName( ...@@ -118,6 +118,28 @@ const std::string& OpYamlInfoParser::InplaceName(
"Can not find inplace input of [%s].", out_name)); "Can not find inplace input of [%s].", out_name));
} }
bool OpYamlInfoParser::HasView(const std::string& out_name) const {
auto& view_info = std::get<3>(op_info_tuple_).view;
for (size_t i = 0; i < view_info.size(); i++) {
if (out_name == view_info[i].first) {
return true;
}
}
return false;
}
const std::string& OpYamlInfoParser::ViewName(
const std::string& out_name) const {
auto& view_info = std::get<3>(op_info_tuple_).view;
for (size_t i = 0; i < view_info.size(); i++) {
if (out_name == view_info[i].first) {
return view_info[i].second;
}
}
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Can not find inplace input of [%s].", out_name));
}
void OpYamlInfoParser::parse() { void OpYamlInfoParser::parse() {
auto input_info = std::get<0>(op_info_tuple_); auto input_info = std::get<0>(op_info_tuple_);
......
...@@ -53,6 +53,10 @@ class OpYamlInfoParser { ...@@ -53,6 +53,10 @@ class OpYamlInfoParser {
const std::string& InplaceName(const std::string& out_name) const; const std::string& InplaceName(const std::string& out_name) const;
bool HasView(const std::string& out_name) const;
const std::string& ViewName(const std::string& out_name) const;
private: private:
void parse(); void parse();
inline const std::vector<OpInputInfo>& InputInfo() const { inline const std::vector<OpInputInfo>& InputInfo() const {
......
...@@ -69,8 +69,7 @@ class PhiKernelAdaptor { ...@@ -69,8 +69,7 @@ class PhiKernelAdaptor {
&value_2_var_name, &value_2_var_name,
&variable_2_var_name, &variable_2_var_name,
&var_name_2_id, &var_name_2_id,
&variable_list, &variable_list);
nullptr);
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>(); ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
......
...@@ -217,8 +217,7 @@ void HandleForSpecialOp( ...@@ -217,8 +217,7 @@ void HandleForSpecialOp(
std::unordered_map<const paddle::framework::Variable*, std::string>* std::unordered_map<const paddle::framework::Variable*, std::string>*
variable_2_var_name, variable_2_var_name,
std::map<std::string, int>* var_name_2_id, std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list, std::vector<paddle::framework::Variable*>* variable_list) {
std::vector<::ir::Value>* parameter_values) {
std::string op_name = op->name(); std::string op_name = op->name();
if (op->attributes().count("op_name")) { if (op->attributes().count("op_name")) {
op_name = op_name =
...@@ -347,10 +346,6 @@ void HandleForSpecialOp( ...@@ -347,10 +346,6 @@ void HandleForSpecialOp(
value_2_var_name, value_2_var_name,
variable_2_var_name, variable_2_var_name,
var_name_2_id); var_name_2_id);
if (parameter_values) {
parameter_values->push_back(value);
}
} }
if (op_name == "pd.shadow_output") { if (op_name == "pd.shadow_output") {
...@@ -390,10 +385,6 @@ void HandleForSpecialOp( ...@@ -390,10 +385,6 @@ void HandleForSpecialOp(
variable_2_var_name, variable_2_var_name,
var_name_2_id, var_name_2_id,
variable_list); variable_list);
if (parameter_values) {
parameter_values->push_back(value);
}
} }
if (op_name == "builtin.slice") { if (op_name == "builtin.slice") {
...@@ -458,6 +449,14 @@ void HandleForInplaceOp( ...@@ -458,6 +449,14 @@ void HandleForInplaceOp(
VLOG(4) << "inplace: " << value_name << " -> " << inplace_name VLOG(4) << "inplace: " << value_name << " -> " << inplace_name
<< " (var: " << var_name << ")"; << " (var: " << var_name << ")";
value_2_var_name->emplace(value, var_name); value_2_var_name->emplace(value, var_name);
} else if (yaml_parser.HasView(value_name)) {
std::string view_name = yaml_parser.ViewName(value_name);
ir::Value view_value =
op->operand_source(yaml_parser.InputName2Id().at(view_name));
std::string var_name = value_2_var_name->at(view_value);
VLOG(4) << "view: " << value_name << " -> " << view_name
<< " (var: " << var_name << ")";
value_2_var_name->emplace(value, var_name);
} else { } else {
BuildValue(value, BuildValue(value,
inner_scope, inner_scope,
...@@ -479,8 +478,7 @@ void BuildScope(const ir::Block& block, ...@@ -479,8 +478,7 @@ void BuildScope(const ir::Block& block,
std::unordered_map<const paddle::framework::Variable*, std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name, std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id, std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list, std::vector<paddle::framework::Variable*>* variable_list) {
std::vector<::ir::Value>* parameter_values) {
VLOG(4) << "***** [before build] scope" VLOG(4) << "***** [before build] scope"
<< "(" << inner_scope << ") ******\n" << "(" << inner_scope << ") ******\n"
<< paddle::framework::GenScopeTreeDebugInfo( << paddle::framework::GenScopeTreeDebugInfo(
...@@ -506,8 +504,7 @@ void BuildScope(const ir::Block& block, ...@@ -506,8 +504,7 @@ void BuildScope(const ir::Block& block,
value_2_var_name, value_2_var_name,
variable_2_var_name, variable_2_var_name,
var_name_2_id, var_name_2_id,
variable_list, variable_list);
parameter_values);
continue; continue;
} }
......
...@@ -49,8 +49,7 @@ void BuildScope(const ir::Block& block, ...@@ -49,8 +49,7 @@ void BuildScope(const ir::Block& block,
std::unordered_map<const paddle::framework::Variable*, std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name, std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id, std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list, std::vector<paddle::framework::Variable*>* variable_list);
std::vector<::ir::Value>* parameter_values);
void BuildRuntimeContext( void BuildRuntimeContext(
ir::Operation* op, ir::Operation* op,
...@@ -288,23 +287,29 @@ void BuildPhiContext(ir::Operation* op, ...@@ -288,23 +287,29 @@ void BuildPhiContext(ir::Operation* op,
// TODO(phlrain): use var type instead of op name // TODO(phlrain): use var type instead of op name
for (size_t i = 0; i < op->num_results(); ++i) { for (size_t i = 0; i < op->num_results(); ++i) {
ir::Value out_ptr = op->result(i); ir::Value out_ptr = op->result(i);
auto name = name_map.at(out_ptr);
VLOG(6) << "ctx->EmplaceBackOutput: " << name;
auto out_type = out_ptr.type(); auto out_type = out_ptr.type();
if (out_type) {
auto name = name_map.at(out_ptr);
VLOG(6) << "ctx->EmplaceBackOutput: " << name;
} else {
VLOG(6) << "ctx->EmplaceBackOutput : an optioanl output";
}
if (!out_type) { if (!out_type) {
phi::DenseTensor* ptr = nullptr; phi::DenseTensor* ptr = nullptr;
OutType out_ptr(ptr); OutType out_ptr(ptr);
ctx->EmplaceBackOutput(out_ptr); ctx->EmplaceBackOutput(out_ptr);
} else if (out_type.isa<paddle::dialect::AllocatedDenseTensorType>()) { } else if (out_type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>( ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>(
&(inner_scope->FindVar(name)->Get<phi::DenseTensor>())))); &(inner_scope->FindVar(name_map.at(out_ptr))
->Get<phi::DenseTensor>()))));
} else if (out_type.isa<paddle::dialect::AllocatedSelectedRowsType>()) { } else if (out_type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::SelectedRows*>( ctx->EmplaceBackOutput(OutType(const_cast<phi::SelectedRows*>(
&(inner_scope->FindVar(name)->Get<phi::SelectedRows>())))); &(inner_scope->FindVar(name_map.at(out_ptr))
->Get<phi::SelectedRows>()))));
} else if (out_type.isa<ir::VectorType>()) { } else if (out_type.isa<ir::VectorType>()) {
OutListType outputs; OutListType outputs;
auto& variable_array = auto& variable_array = scope->FindVar(name_map.at(out_ptr))
scope->FindVar(name)->Get<paddle::framework::VariableRefArray>(); ->Get<paddle::framework::VariableRefArray>();
for (size_t i = 0; i < variable_array.size(); ++i) { for (size_t i = 0; i < variable_array.size(); ++i) {
outputs.emplace_back(OutType(const_cast<phi::DenseTensor*>( outputs.emplace_back(OutType(const_cast<phi::DenseTensor*>(
&(variable_array[i]->Get<phi::DenseTensor>())))); &(variable_array[i]->Get<phi::DenseTensor>()))));
......
...@@ -72,9 +72,7 @@ using AttributeHandlerFn = std::function<ir::Attribute( ...@@ -72,9 +72,7 @@ using AttributeHandlerFn = std::function<ir::Attribute(
constexpr char kTargetDialectPrefix[] = "pd."; constexpr char kTargetDialectPrefix[] = "pd.";
constexpr char kEmptyVarName[] = "@EMPTY@"; constexpr char kEmptyVarName[] = "@EMPTY@";
static const std::unordered_set<std::string> special_non_inplace_ops = { static const std::unordered_set<std::string> special_non_inplace_ops = {};
"batch_norm",
};
static const std::unordered_set<std::string> special_inplace_ops = { static const std::unordered_set<std::string> special_inplace_ops = {
"adagrad", "adagrad",
......
...@@ -77,6 +77,11 @@ void ProgramTranslator::Translate() { ...@@ -77,6 +77,11 @@ void ProgramTranslator::Translate() {
const BlockDesc& block = legacy_program_->Block(block_idx); const BlockDesc& block = legacy_program_->Block(block_idx);
SetStopGradientAttributeForAllValue(block); SetStopGradientAttributeForAllValue(block);
} }
for (size_t block_idx = 0; block_idx < legacy_program_->Size(); block_idx++) {
const BlockDesc& block = legacy_program_->Block(block_idx);
SetIsPersisableAttributeForAllValue(block);
}
} }
inline ir::Operation* InsertGetParamaterOp(ir::IrContext* ctx, inline ir::Operation* InsertGetParamaterOp(ir::IrContext* ctx,
...@@ -268,5 +273,44 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue( ...@@ -268,5 +273,44 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue(
} }
} }
void ProgramTranslator::SetIsPersisableAttributeForAllValue(
const BlockDesc& block) {
// Currently we set is persisable for operation that generated a value
// connected with VarDesc
for (const auto& [var_name, value_info] : param_map_) {
if (no_cast_var_names.count(var_name) != 0) continue;
VLOG(10) << "[op translated][is persisable]" << var_name;
VarDesc* var = block.FindVarRecursive(var_name);
if (var == nullptr) {
continue;
}
ir::OpResult value = value_info.value;
if (!value) {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Value of [%s] can not ber None", var_name));
}
auto* defining_op = value.owner();
PADDLE_ENFORCE_NOT_NULL(
defining_op,
phi::errors::PreconditionNotMet(
"Defining operator of [%s] can not be nullptr", var_name));
VLOG(8) << "[op translated][is persisable]" << var_name
<< " from: " << defining_op->name();
std::vector<ir::Attribute> is_persisable;
if (defining_op->HasAttribute(kAttrIsPersisable)) {
is_persisable = defining_op->attribute(kAttrIsPersisable)
.dyn_cast<ir::ArrayAttribute>()
.AsVector();
} else {
is_persisable = std::vector<ir::Attribute>(
defining_op->num_results(), ir::BoolAttribute::get(ctx_, false));
}
is_persisable[value.GetResultIndex()] =
ir::BoolAttribute::get(ctx_, var->Persistable());
defining_op->set_attribute(kAttrIsPersisable,
ir::ArrayAttribute::get(ctx_, is_persisable));
}
}
} // namespace translator } // namespace translator
} // namespace paddle } // namespace paddle
...@@ -79,6 +79,7 @@ class ProgramTranslator { ...@@ -79,6 +79,7 @@ class ProgramTranslator {
void InsertOperationToSingleBlock(const BlockDesc& block); void InsertOperationToSingleBlock(const BlockDesc& block);
void SetParameterFromSingleBlock(const BlockDesc& block); void SetParameterFromSingleBlock(const BlockDesc& block);
void SetStopGradientAttributeForAllValue(const BlockDesc& block); void SetStopGradientAttributeForAllValue(const BlockDesc& block);
void SetIsPersisableAttributeForAllValue(const BlockDesc& block);
}; };
} // namespace translator } // namespace translator
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/ir/core/type_id.h" #include "paddle/ir/core/type_id.h"
constexpr char kAttrStopGradients[] = "stop_gradient"; constexpr char kAttrStopGradients[] = "stop_gradient";
constexpr char kAttrIsPersisable[] = "is_persisable";
namespace ir { namespace ir {
class AttributeStorage; class AttributeStorage;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册