未验证 提交 611e71d0 编写于 作者: Z zhangbo9674 提交者: GitHub

refine code (#56020)

上级 edd5e9a8
......@@ -742,69 +742,6 @@ void NewIRInterpreter::Convert(
}
}
// calculate last_live_ops_
// for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
// Instruction& instr = vec_instruction_[op_idx];
// OpInOutInfo info;
// info.Build(instr.OpBase());
// std::set<size_t> gc_check_vars;
// const std::map<std::string, std::vector<int>>& ins = instr.Inputs();
// const std::map<std::string, std::vector<int>>& outs = instr.Outputs();
// std::multimap<std::string, std::vector<int>> ins_and_outs{ins.begin(),
// ins.end()};
// ins_and_outs.insert(outs.begin(), outs.end());
// for (auto& item : ins_and_outs) {
// for (auto id : item.second) {
// if (id == kEmptyVarIndex) {
// continue;
// }
// auto* var_desc = var_scope_.VarDesc(id);
// // skip no_need_buffer input vars
// if (var_desc && ins.count(item.first) &&
// !info.IsInArgBufferNeeded(var_desc->Name())) {
// continue;
// }
// // skip when this var is not in block and not a data_transferred
// var,
// // which means this var is managed by other block
// const auto& var_name = var_scope_.GetNameById(id);
// bool not_owned = !block_.HasVar(var_name);
// const auto& transferred_vars = var_scope_.DataTransferAddedVars();
// bool not_transferred =
// std::all_of(transferred_vars.begin(),
// transferred_vars.end(),
// [&](const std::pair<std::string, int>& elem) {
// return elem.first != var_name;
// });
// if (not_owned && not_transferred) {
// VLOG(10) << "[gc_check_inputs] skip gc: " << var_name;
// continue;
// }
// gc_check_vars.insert(id);
// }
// }
// for (auto var_id : gc_check_vars) {
// Scope* inner_scope =
// HasLocalScope() ? local_scope_ : var_scope_.GetMutableScope();
// paddle::framework::Variable* var =
// inner_scope->FindVar(var_scope_.GetNameById(var_id));
// if (var->IsType<phi::DenseTensor>() ||
// var->IsType<phi::SelectedRows>() ||
// var->IsType<LoDTensorArray>()) {
// last_live_ops_[var_id].insert(op_idx);
// } else {
// VLOG(4) << "not clear " << var_scope_.GetNameById(var_id) << "
// after "
// << instr.OpBase()->Type() << " because its type is "
// << framework::ToTypeName(var->Type());
// }
// }
// }
// clear the last_live_ops list for all vars in skip_gc_vars
for (const std::string& skip_gc_var : execution_config_.skip_gc_vars) {
int var_id = var_scope_.GetIdByName(skip_gc_var);
......@@ -846,23 +783,6 @@ void NewIRInterpreter::Convert(
vec_meta_info[i].var_ref_count_ = last_live_ops_[i].size();
}
// for (size_t i = 0; i < vec_instruction_.size(); ++i) {
// BuildAndCacheInstructionCtx(&vec_instruction_[i]);
// }
// bool inplaced = false;
// for (const Instruction& inst : vec_instruction_) {
// if (inst.OpBase()->Type() == "share_buffer" ||
// inst.OpBase()->Type() == "share_data") {
// VLOG(4) << "Already inplaced, skip inplace now.";
// inplaced = true;
// }
// }
// if (FLAGS_new_executor_use_inplace && !inplaced) {
// BuildInplace();
// }
for (auto& dep : dependecy_count_) {
deps_.emplace_back(std::make_shared<interpreter::OpDepInfo>(dep));
}
......@@ -1593,15 +1513,17 @@ void NewIRInterpreter::AnalyseExecuteOrderForTrace(
trace_execute_order_ = trace_order;
std::stringstream ss;
ss << "trace order: ";
for (size_t idx = 0; idx < trace_execute_order_.size(); idx++) {
ss << vec_instruction_base_[trace_execute_order_[idx]]->Name() << "["
<< trace_execute_order_[idx] << "]"
<< " -> ";
if (VLOG_IS_ON(6)) {
std::stringstream ss;
ss << "trace order: ";
for (size_t idx = 0; idx < trace_execute_order_.size(); idx++) {
ss << vec_instruction_base_[trace_execute_order_[idx]]->Name() << "["
<< trace_execute_order_[idx] << "]"
<< " -> ";
}
ss << "end\n";
VLOG(6) << ss.str();
}
ss << "end\n";
VLOG(6) << ss.str();
}
/// ======================== ///
......@@ -1921,9 +1843,6 @@ void NewIRInterpreter::CalculateLastLiveOps() {
VLOG(4) << "var_ref_count_.size() : " << var_ref_count_.size();
for (size_t i = 0; i < last_live_ops_.size(); ++i) {
std::set<size_t> minumum_last_live_ops;
for (auto val : last_live_ops_[i]) {
VLOG(4) << "last_live_ops_: " << val;
}
for (size_t item : last_live_ops_[i]) {
bool not_before_any = true;
// find the op that is not executed before any
......@@ -1945,7 +1864,6 @@ void NewIRInterpreter::CalculateLastLiveOps() {
last_live_ops_[i] = minumum_last_live_ops;
var_ref_count_[i] = last_live_ops_[i].size();
}
VLOG(4) << "calculate last_live_ops_ 2";
for (auto& dep : dependecy_count_) {
deps_.emplace_back(std::make_shared<interpreter::OpDepInfo>(dep));
......@@ -1954,7 +1872,6 @@ void NewIRInterpreter::CalculateLastLiveOps() {
refs_.emplace_back(std::make_shared<interpreter::VarRefInfo>(
var_ref_count_[i], variable_list_[i]));
}
VLOG(4) << "calculate last_live_ops_ 3";
}
void NewIRInterpreter::ConstructEventForJitInput() {
......@@ -2018,9 +1935,16 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
VLOG(4) << "Done PreAnalysis";
// Run
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode "
"with for_loop version(First step).";
LoopRunImpl();
if (FLAGS_enable_new_ir_in_executor_loop_run) {
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode "
"with for_loop version.";
LoopRunImpl();
} else {
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode "
"with trace version.";
TraceRunImpl();
}
is_build_ = true;
} else {
if (FLAGS_enable_new_ir_in_executor_loop_run) {
......@@ -2248,7 +2172,7 @@ void NewIRInterpreter::SolvePersisableVarNames() {
VLOG(6) << "SolvePersisableVarNames";
for (auto kv : value_2_var_name_) {
::ir::Value value = kv.first;
std::string var_name = kv.second;
const std::string& var_name = kv.second;
::ir::OpResult result = value.dyn_cast<::ir::OpResult>();
auto* defining_op = value.GetDefiningOp();
if (defining_op->HasAttribute(kAttrIsPersisable)) {
......
......@@ -1504,17 +1504,19 @@ void ProgramInterpreter::AnalyseExecuteOrderForTrace() {
trace_execute_order_ = trace_order;
std::stringstream ss;
ss << "trace order: ";
for (size_t idx = 0; idx < trace_execute_order_.size(); idx++) {
ss << vec_instruction_[trace_execute_order_[idx]]
.OpFunc()
->operator_base_->Type()
<< "[" << trace_execute_order_[idx] << "]"
<< " -> ";
}
ss << "end\n";
VLOG(6) << ss.str();
if (VLOG_IS_ON(6)) {
std::stringstream ss;
ss << "trace order: ";
for (size_t idx = 0; idx < trace_execute_order_.size(); idx++) {
ss << vec_instruction_[trace_execute_order_[idx]]
.OpFunc()
->operator_base_->Type()
<< "[" << trace_execute_order_[idx] << "]"
<< " -> ";
}
ss << "end\n";
VLOG(6) << ss.str();
}
}
} // namespace framework
......
......@@ -442,7 +442,7 @@ void HandleForInplaceOp(
}
std::string value_name = yaml_parser.OutputNames()[i];
if (yaml_parser.HasInplace(value_name)) {
std::string inplace_name = yaml_parser.InplaceName(value_name);
const std::string& inplace_name = yaml_parser.InplaceName(value_name);
ir::Value inplace_value =
op->operand_source(yaml_parser.InputName2Id().at(inplace_name));
std::string var_name = value_2_var_name->at(inplace_value);
......@@ -450,10 +450,10 @@ void HandleForInplaceOp(
<< " (var: " << var_name << ")";
value_2_var_name->emplace(value, var_name);
} else if (yaml_parser.HasView(value_name)) {
std::string view_name = yaml_parser.ViewName(value_name);
const std::string& view_name = yaml_parser.ViewName(value_name);
ir::Value view_value =
op->operand_source(yaml_parser.InputName2Id().at(view_name));
std::string var_name = value_2_var_name->at(view_value);
const std::string& var_name = value_2_var_name->at(view_value);
VLOG(4) << "view: " << value_name << " -> " << view_name
<< " (var: " << var_name << ")";
value_2_var_name->emplace(value, var_name);
......
......@@ -289,7 +289,7 @@ void BuildPhiContext(ir::Operation* op,
ir::Value out_ptr = op->result(i);
auto out_type = out_ptr.type();
if (out_type) {
auto name = name_map.at(out_ptr);
auto& name = name_map.at(out_ptr);
VLOG(6) << "ctx->EmplaceBackOutput: " << name;
} else {
VLOG(6) << "ctx->EmplaceBackOutput : an optioanl output";
......
......@@ -72,9 +72,9 @@ using AttributeHandlerFn = std::function<ir::Attribute(
constexpr char kTargetDialectPrefix[] = "pd."; // NOLINT
constexpr char kEmptyVarName[] = "@EMPTY@"; // NOLINT
static const std::unordered_set<std::string> special_non_inplace_ops = {};
static const std::unordered_set<std::string> SpecialNonInplaceOps = {};
static const std::unordered_set<std::string> special_inplace_ops = {
static const std::unordered_set<std::string> SpecialInplaceOps = {
"adagrad",
"adam",
"adamax",
......@@ -82,10 +82,10 @@ static const std::unordered_set<std::string> special_inplace_ops = {
};
inline bool IsInplace(const OpDesc& op_desc) {
if (special_non_inplace_ops.count(op_desc.Type())) {
if (SpecialNonInplaceOps.count(op_desc.Type())) {
return false;
}
if (special_inplace_ops.count(op_desc.Type())) {
if (SpecialInplaceOps.count(op_desc.Type())) {
return true;
}
bool inplace = false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册