diff --git a/paddle/platform/profiler.cc b/paddle/platform/profiler.cc index 239df23128ede11d6f5245fc50015e7a56c5f15b..2562b2b5f07e03b91d506917d8b8163e2a848090 100644 --- a/paddle/platform/profiler.cc +++ b/paddle/platform/profiler.cc @@ -182,6 +182,44 @@ std::vector> DisableProfiler() { void ParseEvents(std::vector>& events, EventSortingKey sorted_by) { if (g_profiler_place == "") return; + + std::string sorted_domain; + std::function sorted_func; + switch (sorted_by) { + case EventSortingKey::kCalls: + sorted_domain = "number of calls"; + sorted_func = [](EventItem& a, EventItem& b) { + return a.calls > b.calls; + }; + break; + case EventSortingKey::kTotal: + sorted_domain = "total time"; + sorted_func = [](EventItem& a, EventItem& b) { + return a.total_time > b.total_time; + }; + break; + case EventSortingKey::kMin: + sorted_domain = "minimum time"; + sorted_func = [](EventItem& a, EventItem& b) { + return a.min_time > b.min_time; + }; + break; + case EventSortingKey::kMax: + sorted_domain = "maximum time"; + sorted_func = [](EventItem& a, EventItem& b) { + return a.max_time > b.max_time; + }; + break; + case EventSortingKey::kAve: + sorted_domain = "average time"; + sorted_func = [](EventItem& a, EventItem& b) { + return a.ave_time > b.ave_time; + }; + break; + default: + sorted_domain = "event end time"; + } + std::vector> events_table; size_t max_name_width = 0; for (size_t i = 0; i < events.size(); i++) { @@ -240,21 +278,7 @@ void ParseEvents(std::vector>& events, } // sort if (sorted_by != EventSortingKey::kDefault) { - std::sort(event_items.begin(), event_items.end(), - [&](EventItem& a, EventItem& b) { - switch (sorted_by) { - case EventSortingKey::kCalls: - return a.calls > b.calls; - case EventSortingKey::kTotal: - return a.total_time > b.total_time; - case EventSortingKey::kMin: - return a.min_time > b.min_time; - case EventSortingKey::kMax: - return a.max_time > b.max_time; - default: - return a.ave_time > b.ave_time; - } - }); + std::sort(event_items.begin(), event_items.end(), sorted_func); } events_table.push_back(event_items); @@ -268,11 +292,11 @@ void ParseEvents(std::vector>& events, } // Print report - PrintProfilingReport(events_table, sorted_by, max_name_width + 4, 12); + PrintProfilingReport(events_table, sorted_domain, max_name_width + 4, 12); } void PrintProfilingReport(std::vector>& events_table, - EventSortingKey sorted_by, const size_t name_width, + std::string& sorted_domain, const size_t name_width, const size_t data_width) { // Output header information std::cout << "\n------------------------->" @@ -280,27 +304,7 @@ void PrintProfilingReport(std::vector>& events_table, << "<-------------------------\n\n"; std::cout << "Place: " << g_profiler_place << std::endl; std::cout << "Time unit: ms" << std::endl; - std::string sort_domain = "event end time"; - switch (sorted_by) { - case EventSortingKey::kCalls: - sort_domain = "number of calls"; - break; - case EventSortingKey::kTotal: - sort_domain = "total time"; - break; - case EventSortingKey::kMin: - sort_domain = "minimum time"; - break; - case EventSortingKey::kMax: - sort_domain = "maximum time"; - break; - case EventSortingKey::kAve: - sort_domain = "average time"; - break; - default: - break; - } - std::cout << "Sorted by " << sort_domain + std::cout << "Sorted by " << sorted_domain << " in descending order in the same thread\n\n"; // Output events table std::cout.setf(std::ios::left); diff --git a/paddle/platform/profiler.h b/paddle/platform/profiler.h index f97a586787780aeea2aeeb970174a8b34544ac95..6df48ef8806e865f473b4317ac0283863c3c6f64 100644 --- a/paddle/platform/profiler.h +++ b/paddle/platform/profiler.h @@ -136,7 +136,7 @@ void ParseEvents(std::vector>&, // Print results void PrintProfilingReport(std::vector>& events_table, - EventSortingKey sorted_by, const size_t name_width, + std::string& sorted_domain, const size_t name_width, const size_t data_width); } // namespace platform } // namespace paddle