diff --git a/paddle/fluid/platform/device_tracer.cc b/paddle/fluid/platform/device_tracer.cc index 609bc4245e99751492254a95b9f7db9cf95a3572..76b9e32bb5207edad075c3d8670dde5ec3429b96 100644 --- a/paddle/fluid/platform/device_tracer.cc +++ b/paddle/fluid/platform/device_tracer.cc @@ -400,7 +400,7 @@ class DeviceTracerImpl : public DeviceTracer { } else if (ret != CUPTI_SUCCESS) { fprintf(stderr, "Failed to create CUPTI subscriber.\n"); } - const std::vector cbids { + const std::vector runtime_cbids { CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020, CUPTI_RUNTIME_TRACE_CBID_cudaSetupArgument_v3020, CUPTI_RUNTIME_TRACE_CBID_cudaMemcpyAsync_v3020, @@ -414,9 +414,15 @@ class DeviceTracerImpl : public DeviceTracer { CUPTI_RUNTIME_TRACE_CBID_cudaLaunchCooperativeKernelMultiDevice_v9000 #endif }; - for (auto cbid : cbids) + const std::vector driver_cbids{CUPTI_DRIVER_TRACE_CBID_cuLaunch, + CUPTI_DRIVER_TRACE_CBID_cuLaunchGrid, + CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel}; + for (auto cbid : runtime_cbids) CUPTI_CALL(dynload::cuptiEnableCallback( 1, subscriber_, CUPTI_CB_DOMAIN_RUNTIME_API, cbid)); + for (auto cbid : driver_cbids) + CUPTI_CALL(dynload::cuptiEnableCallback( + 1, subscriber_, CUPTI_CB_DOMAIN_DRIVER_API, cbid)); CUPTI_CALL(dynload::cuptiGetTimestamp(&start_ns_)); #endif // PADDLE_WITH_CUPTI enabled_ = true; diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index dabd2af6af806dc04748182729496c7390d54791..a3f9622ec36e8d9b0f7130166c18de05e1ccc176 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -702,8 +702,7 @@ void ParseMemEvents(const std::vector> &events) { } void DealWithShowName() { - std::unordered_map prefix_name; - std::vector op_out_name; + std::unordered_map> profiler_name_info; for (auto it = g_all_event_lists.begin(); it != g_all_event_lists.end(); ++it) { for (auto &block : (*it)->event_blocks) { @@ -714,20 +713,25 @@ void DealWithShowName() { std::string prefix_str = event_name.substr(0, start); while (start != std::string::npos && end != std::string::npos) { auto search_str = event_name.substr(start, end - start + 1); - auto it = find(op_out_name.begin(), op_out_name.end(), search_str); - std::string replace_str; - bool prefix_find = true; - if (prefix_name.find(prefix_str) == prefix_name.end()) { - prefix_find = false; - prefix_name[prefix_str] = 0; + std::string replace_str = ""; + int replace_index = 0; + + auto it = profiler_name_info.find(prefix_str); + if (it == profiler_name_info.end()) { + std::vector op_name_vector{search_str}; + profiler_name_info[prefix_str] = op_name_vector; + } else { + auto op_name_vector = it->second; + auto iter = + find(op_name_vector.begin(), op_name_vector.end(), search_str); + if (iter == op_name_vector.end()) { + replace_index = it->second.size(); + it->second.push_back(search_str); + } else { + replace_index = it->second.size() - 1; + } } - - if (it == op_out_name.end()) { - if (prefix_find) - prefix_name[prefix_str] = prefix_name[prefix_str] + 1; - op_out_name.push_back(search_str); - } - replace_str = std::to_string(prefix_name[prefix_str]); + replace_str = std::to_string(replace_index); event_name.replace(start, end - start + 1, replace_str); start = start + 1; start = event_name.find('%', start); @@ -792,8 +796,7 @@ std::string OpName(const framework::VariableNameMap &name_map, std::string ret = type_name + "%"; for (auto it = name_map.begin(); it != name_map.end(); it++) { auto name_outputs = it->second; - if (!name_outputs.empty() && - type_name.length() < name_outputs[0].length()) { + if (!name_outputs.empty()) { ret = ret + name_outputs[0]; break; }