未验证 提交 a25f013b 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Sovle bugs (#55991)

* sovle conflict bug

* fix bug
上级 ddfbf135
......@@ -333,7 +333,7 @@ void PhiKernelInstruction::InitInputsOutputsIds(
std::unordered_map<ir::Value, std::vector<int>> outputs;
for (size_t i = 0; i < op->num_results(); i++) {
ir::Value value = op->result(i);
if (value) {
if (value && value.type()) {
PADDLE_ENFORCE_NE(
value_2_var_name.find(value),
value_2_var_name.end(),
......
......@@ -977,10 +977,7 @@ void BuildOpFuncList(
attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().AsString();
op_func_node.phi_op_name_ = op_name;
if (op_name == "builtin.combine" || op_name == "pd.feed" ||
op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter" || op_name == "builtin.slice" ||
op_name == "pd.data" || op_name == "pd.shadow_output") {
if (GetSpecialOpNames().count(op_name)) {
VLOG(6) << "skip process " << op_name;
continue;
}
......@@ -1171,6 +1168,18 @@ void SetDeviceCommContext(::ir::Operation* op,
}
}
std::unordered_set<std::string> GetSpecialOpNames() {
return {
"builtin.combine",
"builtin.slice",
"pd.feed",
"builtin.set_parameter",
"builtin.get_parameter",
"pd.data",
"pd.shadow_output",
};
}
} // namespace interpreter
} // namespace framework
} // namespace paddle
......@@ -124,6 +124,8 @@ void SetDeviceCommContext(framework::OperatorBase* operator_base,
void SetDeviceCommContext(::ir::Operation* op,
platform::DeviceContext* dev_ctx);
std::unordered_set<std::string> GetSpecialOpNames();
} // namespace interpreter
} // namespace framework
} // namespace paddle
......@@ -219,10 +219,11 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
&value_2_var_name_,
&variable_2_var_name_,
&var_name_2_id_,
&variable_list_,
&parameter_values_);
&variable_list_);
VLOG(4) << DebugValueInfo();
SolvePersisableVarNames();
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
interpreter::BuildOpFuncList(place_,
ir_program_->block(),
......@@ -1595,7 +1596,9 @@ void NewIRInterpreter::AnalyseExecuteOrderForTrace(
std::stringstream ss;
ss << "trace order: ";
for (size_t idx = 0; idx < trace_execute_order_.size(); idx++) {
ss << trace_execute_order_[idx] << " -> ";
ss << vec_instruction_base_[trace_execute_order_[idx]]->Name() << "["
<< trace_execute_order_[idx] << "]"
<< " -> ";
}
ss << "end\n";
VLOG(6) << ss.str();
......@@ -1616,10 +1619,7 @@ void NewIRInterpreter::BuildInstruction() {
.at("op_name")
.dyn_cast<::ir::StrAttribute>()
.AsString();
if (op_name == "builtin.combine" || op_name == "pd.feed" ||
op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter" || op_name == "builtin.slice" ||
op_name == "pd.data" || op_name == "pd.shadow_output") {
if (interpreter::GetSpecialOpNames().count(op_name)) {
VLOG(6) << "skip process " << op_name;
continue;
}
......@@ -1793,18 +1793,8 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) {
VLOG(4) << "GC sync " << GetNameById(var_id);
// persistable var will be ignore while GC
::ir::Value value = GetValueByName(GetNameById(var_id));
bool is_parameter = false;
if (value) {
for (auto item : parameter_values_) {
if (item == value) {
is_parameter = true;
break;
}
}
}
if (is_parameter) {
VLOG(4) << "value " << value.impl() << " is a parameter, skip gc";
if (parameter_var_names_.count(GetNameById(var_id))) {
VLOG(4) << GetNameById(var_id) << " is a parameter, skip gc";
continue;
}
......@@ -1851,18 +1841,8 @@ void NewIRInterpreter::CheckGC(InstructionBase* instr) {
<< ", ref:" << refs_[var_id]->DynamicRef();
bool is_ready = refs_[var_id]->CheckAndDecrease();
// ignore all persistable var while GCphi
::ir::Value value = GetValueByName(GetNameById(var_id));
bool is_parameter = false;
if (value) {
for (auto item : parameter_values_) {
if (item == value) {
is_parameter = true;
break;
}
}
}
if (is_parameter) {
VLOG(4) << "value " << value.impl() << " is a parameter, skip gc";
if (parameter_var_names_.count(GetNameById(var_id))) {
VLOG(4) << GetNameById(var_id) << " is a parameter, skip gc";
continue;
}
......@@ -2020,11 +2000,17 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
&value_2_var_name_,
&variable_2_var_name_,
&var_name_2_id_,
&variable_list_,
&parameter_values_);
&variable_list_);
VLOG(4) << "Done BuildScope";
VLOG(4) << DebugValueInfo();
SolvePersisableVarNames();
VLOG(4) << "Parameter value include: ";
for (auto parameter : parameter_var_names_) {
VLOG(4) << "Parameter value: " << parameter;
}
BuildInstruction();
VLOG(4) << "Done BuildInstruction";
......@@ -2032,15 +2018,9 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
VLOG(4) << "Done PreAnalysis";
// Run
if (FLAGS_enable_new_ir_in_executor_loop_run) {
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode "
"with for_loop version.";
"with for_loop version(First step).";
LoopRunImpl();
} else {
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode "
"with trace version.";
TraceRunImpl();
}
is_build_ = true;
} else {
if (FLAGS_enable_new_ir_in_executor_loop_run) {
......@@ -2177,7 +2157,8 @@ void NewIRInterpreter::TraceRunInstructionList(
auto instr_id = trace_execute_order_[idx];
InstructionBase* instr_node = vec_instruction_base_.at(instr_id).get();
VLOG(6) << "Run InstructionBase " << instr_id;
VLOG(6) << "Run InstructionBase " << instr_node->Name() << "[" << instr_id
<< "]";
RunInstructionBase(instr_node);
if (UNLIKELY(exception_holder_.IsCaught())) {
......@@ -2263,5 +2244,26 @@ void NewIRInterpreter::PreAnalysis() {
return nullptr;
}
void NewIRInterpreter::SolvePersisableVarNames() {
VLOG(6) << "SolvePersisableVarNames";
for (auto kv : value_2_var_name_) {
::ir::Value value = kv.first;
std::string var_name = kv.second;
::ir::OpResult result = value.dyn_cast<::ir::OpResult>();
auto* defining_op = value.GetDefiningOp();
if (defining_op->HasAttribute(kAttrIsPersisable)) {
auto is_persisables = defining_op->attribute(kAttrIsPersisable)
.dyn_cast<::ir::ArrayAttribute>()
.AsVector();
if (is_persisables[result.GetResultIndex()]
.dyn_cast<::ir::BoolAttribute>()
.data()) {
VLOG(6) << "parameter_var_names_ include: " << var_name;
parameter_var_names_.insert(var_name);
}
}
}
}
} // namespace framework
} // namespace paddle
......@@ -235,6 +235,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void RecordStreamForGC(InstructionBase* instr);
void SolvePersisableVarNames();
InstructionSchedulingPriorityLess ir_instruction_scheduling_priority_less;
std::unique_ptr<::ir::Program> ir_program_{nullptr};
......@@ -260,7 +262,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {
// Note(zhangbo): set_parameter_op's input and get_parameter_op's output
// belongs to a parameter and cannot GC.
std::vector<::ir::Value> parameter_values_;
std::unordered_set<std::string> parameter_var_names_;
};
} // namespace framework
......
......@@ -1503,6 +1503,18 @@ void ProgramInterpreter::AnalyseExecuteOrderForTrace() {
"trace_order size should be equal to dependecy_count_."));
trace_execute_order_ = trace_order;
std::stringstream ss;
ss << "trace order: ";
for (size_t idx = 0; idx < trace_execute_order_.size(); idx++) {
ss << vec_instruction_[trace_execute_order_[idx]]
.OpFunc()
->operator_base_->Type()
<< "[" << trace_execute_order_[idx] << "]"
<< " -> ";
}
ss << "end\n";
VLOG(6) << ss.str();
}
} // namespace framework
......
......@@ -118,6 +118,28 @@ const std::string& OpYamlInfoParser::InplaceName(
"Can not find inplace input of [%s].", out_name));
}
bool OpYamlInfoParser::HasView(const std::string& out_name) const {
auto& view_info = std::get<3>(op_info_tuple_).view;
for (size_t i = 0; i < view_info.size(); i++) {
if (out_name == view_info[i].first) {
return true;
}
}
return false;
}
const std::string& OpYamlInfoParser::ViewName(
const std::string& out_name) const {
auto& view_info = std::get<3>(op_info_tuple_).view;
for (size_t i = 0; i < view_info.size(); i++) {
if (out_name == view_info[i].first) {
return view_info[i].second;
}
}
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Can not find inplace input of [%s].", out_name));
}
void OpYamlInfoParser::parse() {
auto input_info = std::get<0>(op_info_tuple_);
......
......@@ -53,6 +53,10 @@ class OpYamlInfoParser {
const std::string& InplaceName(const std::string& out_name) const;
bool HasView(const std::string& out_name) const;
const std::string& ViewName(const std::string& out_name) const;
private:
void parse();
inline const std::vector<OpInputInfo>& InputInfo() const {
......
......@@ -69,8 +69,7 @@ class PhiKernelAdaptor {
&value_2_var_name,
&variable_2_var_name,
&var_name_2_id,
&variable_list,
nullptr);
&variable_list);
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
......
......@@ -217,8 +217,7 @@ void HandleForSpecialOp(
std::unordered_map<const paddle::framework::Variable*, std::string>*
variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list,
std::vector<::ir::Value>* parameter_values) {
std::vector<paddle::framework::Variable*>* variable_list) {
std::string op_name = op->name();
if (op->attributes().count("op_name")) {
op_name =
......@@ -347,10 +346,6 @@ void HandleForSpecialOp(
value_2_var_name,
variable_2_var_name,
var_name_2_id);
if (parameter_values) {
parameter_values->push_back(value);
}
}
if (op_name == "pd.shadow_output") {
......@@ -390,10 +385,6 @@ void HandleForSpecialOp(
variable_2_var_name,
var_name_2_id,
variable_list);
if (parameter_values) {
parameter_values->push_back(value);
}
}
if (op_name == "builtin.slice") {
......@@ -458,6 +449,14 @@ void HandleForInplaceOp(
VLOG(4) << "inplace: " << value_name << " -> " << inplace_name
<< " (var: " << var_name << ")";
value_2_var_name->emplace(value, var_name);
} else if (yaml_parser.HasView(value_name)) {
std::string view_name = yaml_parser.ViewName(value_name);
ir::Value view_value =
op->operand_source(yaml_parser.InputName2Id().at(view_name));
std::string var_name = value_2_var_name->at(view_value);
VLOG(4) << "view: " << value_name << " -> " << view_name
<< " (var: " << var_name << ")";
value_2_var_name->emplace(value, var_name);
} else {
BuildValue(value,
inner_scope,
......@@ -479,8 +478,7 @@ void BuildScope(const ir::Block& block,
std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list,
std::vector<::ir::Value>* parameter_values) {
std::vector<paddle::framework::Variable*>* variable_list) {
VLOG(4) << "***** [before build] scope"
<< "(" << inner_scope << ") ******\n"
<< paddle::framework::GenScopeTreeDebugInfo(
......@@ -506,8 +504,7 @@ void BuildScope(const ir::Block& block,
value_2_var_name,
variable_2_var_name,
var_name_2_id,
variable_list,
parameter_values);
variable_list);
continue;
}
......
......@@ -49,8 +49,7 @@ void BuildScope(const ir::Block& block,
std::unordered_map<const paddle::framework::Variable*,
std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list,
std::vector<::ir::Value>* parameter_values);
std::vector<paddle::framework::Variable*>* variable_list);
void BuildRuntimeContext(
ir::Operation* op,
......@@ -288,23 +287,29 @@ void BuildPhiContext(ir::Operation* op,
// TODO(phlrain): use var type instead of op name
for (size_t i = 0; i < op->num_results(); ++i) {
ir::Value out_ptr = op->result(i);
auto out_type = out_ptr.type();
if (out_type) {
auto name = name_map.at(out_ptr);
VLOG(6) << "ctx->EmplaceBackOutput: " << name;
auto out_type = out_ptr.type();
} else {
VLOG(6) << "ctx->EmplaceBackOutput : an optioanl output";
}
if (!out_type) {
phi::DenseTensor* ptr = nullptr;
OutType out_ptr(ptr);
ctx->EmplaceBackOutput(out_ptr);
} else if (out_type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>(
&(inner_scope->FindVar(name)->Get<phi::DenseTensor>()))));
&(inner_scope->FindVar(name_map.at(out_ptr))
->Get<phi::DenseTensor>()))));
} else if (out_type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::SelectedRows*>(
&(inner_scope->FindVar(name)->Get<phi::SelectedRows>()))));
&(inner_scope->FindVar(name_map.at(out_ptr))
->Get<phi::SelectedRows>()))));
} else if (out_type.isa<ir::VectorType>()) {
OutListType outputs;
auto& variable_array =
scope->FindVar(name)->Get<paddle::framework::VariableRefArray>();
auto& variable_array = scope->FindVar(name_map.at(out_ptr))
->Get<paddle::framework::VariableRefArray>();
for (size_t i = 0; i < variable_array.size(); ++i) {
outputs.emplace_back(OutType(const_cast<phi::DenseTensor*>(
&(variable_array[i]->Get<phi::DenseTensor>()))));
......
......@@ -72,9 +72,7 @@ using AttributeHandlerFn = std::function<ir::Attribute(
constexpr char kTargetDialectPrefix[] = "pd.";
constexpr char kEmptyVarName[] = "@EMPTY@";
static const std::unordered_set<std::string> special_non_inplace_ops = {
"batch_norm",
};
static const std::unordered_set<std::string> special_non_inplace_ops = {};
static const std::unordered_set<std::string> special_inplace_ops = {
"adagrad",
......
......@@ -77,6 +77,11 @@ void ProgramTranslator::Translate() {
const BlockDesc& block = legacy_program_->Block(block_idx);
SetStopGradientAttributeForAllValue(block);
}
for (size_t block_idx = 0; block_idx < legacy_program_->Size(); block_idx++) {
const BlockDesc& block = legacy_program_->Block(block_idx);
SetIsPersisableAttributeForAllValue(block);
}
}
inline ir::Operation* InsertGetParamaterOp(ir::IrContext* ctx,
......@@ -268,5 +273,44 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue(
}
}
void ProgramTranslator::SetIsPersisableAttributeForAllValue(
const BlockDesc& block) {
// Currently we set is persisable for operation that generated a value
// connected with VarDesc
for (const auto& [var_name, value_info] : param_map_) {
if (no_cast_var_names.count(var_name) != 0) continue;
VLOG(10) << "[op translated][is persisable]" << var_name;
VarDesc* var = block.FindVarRecursive(var_name);
if (var == nullptr) {
continue;
}
ir::OpResult value = value_info.value;
if (!value) {
PADDLE_THROW(phi::errors::PreconditionNotMet(
"Value of [%s] can not ber None", var_name));
}
auto* defining_op = value.owner();
PADDLE_ENFORCE_NOT_NULL(
defining_op,
phi::errors::PreconditionNotMet(
"Defining operator of [%s] can not be nullptr", var_name));
VLOG(8) << "[op translated][is persisable]" << var_name
<< " from: " << defining_op->name();
std::vector<ir::Attribute> is_persisable;
if (defining_op->HasAttribute(kAttrIsPersisable)) {
is_persisable = defining_op->attribute(kAttrIsPersisable)
.dyn_cast<ir::ArrayAttribute>()
.AsVector();
} else {
is_persisable = std::vector<ir::Attribute>(
defining_op->num_results(), ir::BoolAttribute::get(ctx_, false));
}
is_persisable[value.GetResultIndex()] =
ir::BoolAttribute::get(ctx_, var->Persistable());
defining_op->set_attribute(kAttrIsPersisable,
ir::ArrayAttribute::get(ctx_, is_persisable));
}
}
} // namespace translator
} // namespace paddle
......@@ -79,6 +79,7 @@ class ProgramTranslator {
void InsertOperationToSingleBlock(const BlockDesc& block);
void SetParameterFromSingleBlock(const BlockDesc& block);
void SetStopGradientAttributeForAllValue(const BlockDesc& block);
void SetIsPersisableAttributeForAllValue(const BlockDesc& block);
};
} // namespace translator
......
......@@ -18,6 +18,7 @@
#include "paddle/ir/core/type_id.h"
constexpr char kAttrStopGradients[] = "stop_gradient";
constexpr char kAttrIsPersisable[] = "is_persisable";
namespace ir {
class AttributeStorage;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册