未验证 提交 621d3e0b 编写于 作者: W wangchaochaohu 提交者: GitHub

fix the bug of profile update (#22207)

* fix the bug of profile update test=develop
上级 443a713c
...@@ -167,7 +167,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { ...@@ -167,7 +167,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
} }
{ {
platform::RecordEvent record_event(Type() + "_op"); platform::RecordEvent record_event(Type());
RunImpl(scope, place); RunImpl(scope, place);
} }
...@@ -950,7 +950,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -950,7 +950,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
std::vector<std::string> transfered_inplace_vars; std::vector<std::string> transfered_inplace_vars;
Scope* transfer_scope = nullptr; Scope* transfer_scope = nullptr;
{ {
platform::RecordEvent record_event("prepare_data"); platform::RecordEvent record_event("prepare_data_inner_op");
transfer_scope = PrepareData(scope, *kernel_type_, &transfered_inplace_vars, transfer_scope = PrepareData(scope, *kernel_type_, &transfered_inplace_vars,
runtime_ctx); runtime_ctx);
} }
...@@ -963,7 +963,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -963,7 +963,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
if (!all_kernels_must_compute_runtime_shape_) { if (!all_kernels_must_compute_runtime_shape_) {
platform::RecordEvent record_event("infer_shape"); platform::RecordEvent record_event("infer_shape_inner_op");
RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx); RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
} }
...@@ -975,7 +975,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -975,7 +975,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs. // not Scope. Imperative mode only pass inputs and get outputs.
{ {
platform::RecordEvent record_event("compute"); platform::RecordEvent record_event("compute_inner_op");
(*kernel_func_)(ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx, (*kernel_func_)(ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx,
kernel_configs)); kernel_configs));
} }
......
...@@ -372,12 +372,13 @@ void PrintProfiler(const std::vector<std::vector<EventItem>> &events_table, ...@@ -372,12 +372,13 @@ void PrintProfiler(const std::vector<std::vector<EventItem>> &events_table,
std::vector<std::vector<EventItem>> child_table; std::vector<std::vector<EventItem>> child_table;
std::vector<EventItem> table; std::vector<EventItem> table;
bool do_next = false; bool do_next = false;
std::string op_end_str = "_op"; std::string op_end_str = "inner_op";
for (auto it = child_map.begin(); it != child_map.end(); it++) { for (auto it = child_map.begin(); it != child_map.end(); it++) {
if (it->first == event_item.name) { if (it->first == event_item.name) {
table.push_back(it->second); table.push_back(it->second);
do_next = it->second.name.rfind(op_end_str) == if (!do_next)
(it->second.name.length() - op_end_str.length()); do_next = !(it->second.name.rfind(op_end_str) ==
(it->second.name.length() - op_end_str.length()));
} }
} }
child_table.push_back(table); child_table.push_back(table);
...@@ -579,6 +580,7 @@ void ParseEvents(const std::vector<std::vector<Event>> &events, ...@@ -579,6 +580,7 @@ void ParseEvents(const std::vector<std::vector<Event>> &events,
std::vector<EventItem> event_items; std::vector<EventItem> event_items;
std::vector<EventItem> main_event_items; std::vector<EventItem> main_event_items;
std::unordered_map<std::string, int> event_idx; std::unordered_map<std::string, int> event_idx;
std::multimap<std::string, EventItem> sub_child_map;
for (size_t j = 0; j < (*analyze_events)[i].size(); j++) { for (size_t j = 0; j < (*analyze_events)[i].size(); j++) {
Event analyze_event = (*analyze_events)[i][j]; Event analyze_event = (*analyze_events)[i][j];
...@@ -599,7 +601,7 @@ void ParseEvents(const std::vector<std::vector<Event>> &events, ...@@ -599,7 +601,7 @@ void ParseEvents(const std::vector<std::vector<Event>> &events,
(cname[fname.length()] == '/' && (cname[fname.length()] == '/' &&
cname.rfind('/') == fname.length()); cname.rfind('/') == fname.length());
if (condition) { if (condition) {
child_map.insert( sub_child_map.insert(
std::pair<std::string, EventItem>(fname, event_items[k])); std::pair<std::string, EventItem>(fname, event_items[k]));
child_index[k] = 1; child_index[k] = 1;
} }
...@@ -618,9 +620,9 @@ void ParseEvents(const std::vector<std::vector<Event>> &events, ...@@ -618,9 +620,9 @@ void ParseEvents(const std::vector<std::vector<Event>> &events,
item.ave_time = item.total_time / item.calls; item.ave_time = item.total_time / item.calls;
item.ratio = item.total_time / total; item.ratio = item.total_time / total;
} }
for (auto it = child_map.begin(); it != child_map.end(); it++) { for (auto it = sub_child_map.begin(); it != sub_child_map.end(); it++) {
it->second.ratio = it->second.total_time / total; it->second.ratio = it->second.total_time / total;
it->second.ave_time = it->second.ave_time / it->second.calls; it->second.ave_time = it->second.total_time / it->second.calls;
} }
// sort // sort
...@@ -636,6 +638,11 @@ void ParseEvents(const std::vector<std::vector<Event>> &events, ...@@ -636,6 +638,11 @@ void ParseEvents(const std::vector<std::vector<Event>> &events,
<< "\', which will be ignored in profiling report."; << "\', which will be ignored in profiling report.";
++rit; ++rit;
} }
for (auto it = sub_child_map.begin(); it != sub_child_map.end(); it++) {
child_map.insert(
std::pair<std::string, EventItem>(it->first, it->second));
}
} }
// Print report // Print report
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册