From f38c2e5c772d24ab83b67034efbcc0de693d6198 Mon Sep 17 00:00:00 2001 From: chenjian Date: Fri, 11 Feb 2022 10:31:02 +0800 Subject: [PATCH] Add profiler node tree implementation (#39316) * add event node implementation * modify profiler.stop interface * fix according to review * fix file mode * modify class method name in event_node.cc * modify LLONG_MAX to ULLONG_MAX * fix ci error * fix ci error --- paddle/fluid/platform/profiler/CMakeLists.txt | 5 +- .../platform/profiler/chrometracing_logger.cc | 371 ++++++++++++++++++ .../platform/profiler/chrometracing_logger.h | 45 +++ paddle/fluid/platform/profiler/event_node.cc | 306 +++++++++++++++ paddle/fluid/platform/profiler/event_node.h | 58 +-- paddle/fluid/platform/profiler/event_python.h | 2 +- .../fluid/platform/profiler/output_logger.h | 2 +- paddle/fluid/platform/profiler/profiler.cc | 8 +- paddle/fluid/platform/profiler/profiler.h | 4 +- .../fluid/platform/profiler/profiler_test.cc | 9 +- .../platform/profiler/test_event_node.cc | 203 ++++++++++ paddle/fluid/platform/profiler/trace_event.h | 2 +- 12 files changed, 974 insertions(+), 41 deletions(-) create mode 100644 paddle/fluid/platform/profiler/chrometracing_logger.cc create mode 100644 paddle/fluid/platform/profiler/chrometracing_logger.h create mode 100644 paddle/fluid/platform/profiler/event_node.cc mode change 100755 => 100644 paddle/fluid/platform/profiler/event_node.h create mode 100644 paddle/fluid/platform/profiler/test_event_node.cc diff --git a/paddle/fluid/platform/profiler/CMakeLists.txt b/paddle/fluid/platform/profiler/CMakeLists.txt index e25e4f3f56c..ce062175a53 100644 --- a/paddle/fluid/platform/profiler/CMakeLists.txt +++ b/paddle/fluid/platform/profiler/CMakeLists.txt @@ -1,3 +1,6 @@ cc_library(host_tracer SRCS host_tracer.cc DEPS enforce) cc_library(new_profiler SRCS profiler.cc DEPS host_tracer) -cc_test(new_profiler_test SRCS profiler_test.cc DEPS new_profiler) +cc_library(chrometracinglogger SRCS chrometracing_logger.cc) +cc_library(event_node SRCS event_node.cc) +cc_test(test_event_node SRCS test_event_node.cc DEPS event_node chrometracinglogger) +cc_test(new_profiler_test SRCS profiler_test.cc DEPS new_profiler event_node) diff --git a/paddle/fluid/platform/profiler/chrometracing_logger.cc b/paddle/fluid/platform/profiler/chrometracing_logger.cc new file mode 100644 index 00000000000..7b207ea7b20 --- /dev/null +++ b/paddle/fluid/platform/profiler/chrometracing_logger.cc @@ -0,0 +1,371 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "glog/logging.h" + +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/os_info.h" +#include "paddle/fluid/platform/profiler/chrometracing_logger.h" +#include "paddle/fluid/platform/profiler/event_node.h" + +namespace paddle { +namespace platform { + +static const char* kSchemaVersion = "1.0.0"; +static const char* kDefaultFilename = "pid_%s_time_%s.paddle_trace.json"; +static uint32_t num_span = 0; + +static int64_t nsToUs(int64_t ns) { return ns / 1000; } + +template +std::string string_format(const std::string& format, Args... args) { + int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + + 1; // Extra space for '\0' + PADDLE_ENFORCE_GE(size_s, 0, platform::errors::Fatal( + "Error during profiler data formatting.")); + auto size = static_cast(size_s); + auto buf = std::make_unique(size); + std::snprintf(buf.get(), size, format.c_str(), args...); + return std::string(buf.get(), size - 1); // exclude the '\0' +} + +std::string GetStringFormatLocalTime() { + std::time_t rawtime; + std::tm* timeinfo; + char buf[100]; + std::time(&rawtime); + timeinfo = std::localtime(&rawtime); + std::strftime(buf, 100, "%F-%X", timeinfo); + return std::string(buf); +} + +static std::string DefaultFileName() { + auto pid = GetProcessId(); + return string_format(std::string(kDefaultFilename), pid, + GetStringFormatLocalTime().c_str()); +} + +const char* ChromeTracingLogger::categary_name_[] = { + "operator", "dataloader", "profile_step", "cuda_runtime", "kernel", + "memcpy", "memset", "user_defined", "others"}; + +void ChromeTracingLogger::OpenFile() { + output_file_stream_.open(filename_, + std::ofstream::out | std::ofstream::trunc); + if (!output_file_stream_) { + VLOG(2) << "Unable to open file for writing profiling data." << std::endl; + } else { + VLOG(0) << "writing profiling data to " << filename_ << std::endl; + } +} + +ChromeTracingLogger::ChromeTracingLogger(const std::string& filename) { + filename_ = filename.empty() ? DefaultFileName() : filename; + OpenFile(); + StartLog(); +} + +ChromeTracingLogger::ChromeTracingLogger(const char* filename_cstr) { + std::string filename(filename_cstr); + filename_ = filename.empty() ? DefaultFileName() : filename; + OpenFile(); + StartLog(); +} + +ChromeTracingLogger::~ChromeTracingLogger() { + EndLog(); + output_file_stream_.close(); +} + +void ChromeTracingLogger::LogNodeTrees(const NodeTrees& node_trees) { + // log all nodes except root node, root node is a helper node. + const std::map> + thread2host_event_nodes = node_trees.Traverse(true); + for (auto it = thread2host_event_nodes.begin(); + it != thread2host_event_nodes.end(); ++it) { + for (auto hostnode = it->second.begin(); hostnode != it->second.end(); + ++hostnode) { + if (hostnode != it->second.begin()) { // skip root node + (*hostnode)->LogMe(this); + } + for (auto runtimenode = (*hostnode)->GetRuntimeTraceEventNodes().begin(); + runtimenode != (*hostnode)->GetRuntimeTraceEventNodes().end(); + ++runtimenode) { + (*runtimenode)->LogMe(this); + for (auto devicenode = + (*runtimenode)->GetDeviceTraceEventNodes().begin(); + devicenode != (*runtimenode)->GetDeviceTraceEventNodes().end(); + ++devicenode) { + (*devicenode)->LogMe(this); + } + } + } + } +} + +void ChromeTracingLogger::LogHostTraceEventNode( + const HostTraceEventNode& host_node) { + if (!output_file_stream_) { + return; + } + output_file_stream_ << string_format( + std::string( + R"JSON( + { + "name": "%s", "pid": %lld, "tid": %lld, + "ts": %lld, "dur": %lld, + "ph": "X", "cat": "%s", + "args": { + + } + }, + )JSON"), + host_node.Name().c_str(), host_node.ProcessId(), host_node.ThreadId(), + nsToUs(host_node.StartNs()), nsToUs(host_node.Duration()), + categary_name_[static_cast(host_node.Type())]); +} + +void ChromeTracingLogger::LogRuntimeTraceEventNode( + const CudaRuntimeTraceEventNode& runtime_node) { + if (!output_file_stream_) { + return; + } + output_file_stream_ << string_format( + std::string( + R"JSON( + { + "name": "%s", "pid": %lld, "tid": %lld, + "ts": %lld, "dur": %lld, + "ph": "X", "cat": "%s", + "args": { + "correlation id": %d + } + }, + )JSON"), + runtime_node.Name().c_str(), runtime_node.ProcessId(), + runtime_node.ThreadId(), nsToUs(runtime_node.StartNs()), + nsToUs(runtime_node.Duration()), + categary_name_[static_cast(runtime_node.Type())], + runtime_node.CorrelationId()); +} + +void ChromeTracingLogger::LogDeviceTraceEventNode( + const DeviceTraceEventNode& device_node) { + if (!output_file_stream_) { + return; + } + switch (device_node.Type()) { + case TracerEventType::Kernel: + HandleTypeKernel(device_node); + break; + case TracerEventType::Memcpy: + HandleTypeMemcpy(device_node); + break; + case TracerEventType::Memset: + HandleTypeMemset(device_node); + default: + break; + } +} + +void ChromeTracingLogger::HandleTypeKernel( + const DeviceTraceEventNode& device_node) { + KernelEventInfo kernel_info = device_node.KernelInfo(); + float blocks_per_sm = 0.0; + float warps_per_sm = 0.0; + float occupancy = 0.0; +#if defined(PADDLE_WITH_CUDA) + constexpr int threads_per_warp = 32; + const gpuDeviceProp& device_property = + GetDeviceProperties(device_node.DeviceId()); + blocks_per_sm = + (kernel_info.grid_x * kernel_info.grid_y * kernel_info.grid_z) / + device_property.multiProcessorCount; + warps_per_sm = blocks_per_sm * (kernel_info.block_x * kernel_info.block_y * + kernel_info.block_z) / + threads_per_warp; +#endif + + output_file_stream_ << string_format( + std::string( + R"JSON( + { + "name": "%s", "pid": %lld, "tid": %lld, + "ts": %lld, "dur": %lld, + "ph": "X", "cat": "%s", + "args": { + "device": %d, "context": %d, + "stream": %d, "correlation id": %d, + "registers per thread": %d, + "shared memory": %f, + "blocks per SM": %f, + "warps per SM": %f, + "grid": [%d, %d, %d], + "block": [%d, %d, %d], + "est. achieved occupancy %": %f + } + }, + )JSON"), + device_node.Name().c_str(), device_node.DeviceId(), + device_node.StreamId(), nsToUs(device_node.StartNs()), + nsToUs(device_node.Duration()), + categary_name_[static_cast(device_node.Type())], + device_node.DeviceId(), device_node.ContextId(), device_node.StreamId(), + device_node.CorrelationId(), kernel_info.registers_per_thread, + kernel_info.static_shared_memory + kernel_info.dynamic_shared_memory, + blocks_per_sm, warps_per_sm, kernel_info.grid_x, kernel_info.grid_y, + kernel_info.grid_z, kernel_info.block_x, kernel_info.block_y, + kernel_info.block_z, occupancy); +} + +void ChromeTracingLogger::HandleTypeMemcpy( + const DeviceTraceEventNode& device_node) { + MemcpyEventInfo memcpy_info = device_node.MemcpyInfo(); + float memory_bandwidth = 0; + if (device_node.Duration() > 0) { + memory_bandwidth = memcpy_info.num_bytes * 1.0 / device_node.Duration(); + } + output_file_stream_ << string_format( + std::string( + R"JSON( + { + "name": "%s", "pid": %lld, "tid": %lld, + "ts": %lld, "dur": %lld, + "ph": "X", "cat": "%s", + "args": { + "stream": %d, "correlation id": %d, + "bytes": %d, "memory bandwidth (GB/s)": %f + } + }, + )JSON"), + device_node.Name().c_str(), device_node.DeviceId(), + device_node.StreamId(), nsToUs(device_node.StartNs()), + nsToUs(device_node.Duration()), + categary_name_[static_cast(device_node.Type())], + device_node.StreamId(), device_node.CorrelationId(), + memcpy_info.num_bytes, memory_bandwidth); +} + +void ChromeTracingLogger::HandleTypeMemset( + const DeviceTraceEventNode& device_node) { + MemsetEventInfo memset_info = device_node.MemsetInfo(); + output_file_stream_ << string_format( + std::string( + R"JSON( + { + "name": "%s", "pid": %lld, "tid": %lld, + "ts": %lld, "dur": %lld, + "ph": "X", "cat": "%s", + "args": { + "device": %d, "context": %d, + "stream": %d, "correlation id": %d, + "bytes": %d, "value": %d + } + }, + )JSON"), + device_node.Name().c_str(), device_node.DeviceId(), + device_node.StreamId(), nsToUs(device_node.StartNs()), + nsToUs(device_node.Duration()), + categary_name_[static_cast(device_node.Type())], + device_node.DeviceId(), device_node.ContextId(), device_node.StreamId(), + device_node.CorrelationId(), memset_info.num_bytes, memset_info.value); +} + +void ChromeTracingLogger::StartLog() { + output_file_stream_ << string_format(std::string( + R"JSON( + { + "schemaVersion": "%s", + "displayTimeUnit": "us", + "SpanNumber": "%d", + )JSON"), + kSchemaVersion, num_span); +// add device property information +#if defined(PADDLE_WITH_CUDA) + output_file_stream_ << std::string(R"JSON( + "deviceProperties": [ + )JSON"); + std::vector device_ids = GetSelectedDevices(); + for (auto index = 0u; index < device_ids.size() - 1; index++) { + const gpuDeviceProp& device_property = + GetDeviceProperties(device_ids[index]); + output_file_stream_ << string_format( + std::string( + R"JSON( + { + "id": %d, "name": "%s", "totalGlobalMem": %u, + "computeMajor": %d, "computeMinor": %d, + "maxThreadsPerBlock": %d, "maxThreadsPerMultiprocessor": %d, + "regsPerBlock": %d, "regsPerMultiprocessor": %d, "warpSize": %d, + "sharedMemPerBlock": %d, "sharedMemPerMultiprocessor": %d, + "smCount": %d, "sharedMemPerBlockOptin": %d + }, + )JSON"), + device_ids[index], device_property.name, device_property.totalGlobalMem, + device_property.major, device_property.minor, + device_property.maxThreadsPerBlock, + device_property.maxThreadsPerMultiProcessor, + device_property.regsPerBlock, device_property.regsPerMultiprocessor, + device_property.warpSize, device_property.sharedMemPerBlock, + device_property.sharedMemPerMultiprocessor, + device_property.multiProcessorCount, + device_property.sharedMemPerBlockOptin); + } + if (device_ids.size() > 0) { + const gpuDeviceProp& device_property = + GetDeviceProperties(device_ids[device_ids.size() - 1]); + output_file_stream_ << string_format( + std::string( + R"JSON( + { + "id": %d, "name": "%s", "totalGlobalMem": %u, + "computeMajor": %d, "computeMinor": %d, + "maxThreadsPerBlock": %d, "maxThreadsPerMultiprocessor": %d, + "regsPerBlock": %d, "regsPerMultiprocessor": %d, "warpSize": %d, + "sharedMemPerBlock": %d, "sharedMemPerMultiprocessor": %d, + "smCount": %d, "sharedMemPerBlockOptin": %d + }], + )JSON"), + device_ids[device_ids.size() - 1], device_property.name, + device_property.totalGlobalMem, device_property.major, + device_property.minor, device_property.maxThreadsPerBlock, + device_property.maxThreadsPerMultiProcessor, + device_property.regsPerBlock, device_property.regsPerMultiprocessor, + device_property.warpSize, device_property.sharedMemPerBlock, + device_property.sharedMemPerMultiprocessor, + device_property.multiProcessorCount, + device_property.sharedMemPerBlockOptin); + } +#endif + + output_file_stream_ << std::string( + R"JSON( + "traceEvents": [ + )JSON"); +} + +void ChromeTracingLogger::EndLog() { + output_file_stream_ << std::string( + R"JSON( + {} + ] + } + )JSON"); +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/profiler/chrometracing_logger.h b/paddle/fluid/platform/profiler/chrometracing_logger.h new file mode 100644 index 00000000000..06734418609 --- /dev/null +++ b/paddle/fluid/platform/profiler/chrometracing_logger.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include "paddle/fluid/platform/profiler/output_logger.h" + +namespace paddle { +namespace platform { + +class ChromeTracingLogger : public BaseLogger { + public: + explicit ChromeTracingLogger(const std::string& filename); + explicit ChromeTracingLogger(const char* filename); + ~ChromeTracingLogger(); + std::string filename() { return filename_; } + void LogDeviceTraceEventNode(const DeviceTraceEventNode&) override; + void LogHostTraceEventNode(const HostTraceEventNode&) override; + void LogRuntimeTraceEventNode(const CudaRuntimeTraceEventNode&) override; + void LogNodeTrees(const NodeTrees&) override; + + private: + void OpenFile(); + void HandleTypeKernel(const DeviceTraceEventNode&); + void HandleTypeMemset(const DeviceTraceEventNode&); + void HandleTypeMemcpy(const DeviceTraceEventNode&); + void StartLog(); + void EndLog(); + std::string filename_; + std::ofstream output_file_stream_; + static const char* categary_name_[]; +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/profiler/event_node.cc b/paddle/fluid/platform/profiler/event_node.cc new file mode 100644 index 00000000000..6c8be1811d7 --- /dev/null +++ b/paddle/fluid/platform/profiler/event_node.cc @@ -0,0 +1,306 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/profiler/event_node.h" + +#include +#include +#include +#include +#include + +namespace paddle { +namespace platform { + +HostTraceEventNode::~HostTraceEventNode() { + // delete all runtime nodes and recursive delete children + for (auto it = runtime_node_ptrs_.begin(); it != runtime_node_ptrs_.end(); + ++it) { + delete *it; + } + for (auto it = children_.begin(); it != children_.end(); ++it) { + delete *it; + } +} + +CudaRuntimeTraceEventNode::~CudaRuntimeTraceEventNode() { + // delete all device nodes + for (auto it = device_node_ptrs_.begin(); it != device_node_ptrs_.end(); + ++it) { + delete *it; + } +} + +NodeTrees::~NodeTrees() { + // delete all root nodes + for (auto it = thread_event_trees_map_.begin(); + it != thread_event_trees_map_.end(); ++it) { + delete it->second; + } +} + +void NodeTrees::BuildTrees( + const std::vector& host_event_nodes, + std::vector& runtime_event_nodes, + const std::vector& device_event_nodes) { + // seperate Host Event Nodes into different threads + std::map> + thread2host_event_nodes; // used to store HostTraceEventNodes per thread + std::map> + thread2runtime_event_nodes; // used to store CudaRuntimeTraceEventNode + // per + // thread + std::map + correlation_id2runtime_event_node; // used to store the relation between + // correlation id and runtime node + // construct thread2host_event_nodes + for (auto it = host_event_nodes.begin(); it != host_event_nodes.end(); ++it) { + thread2host_event_nodes[(*it)->ThreadId()].push_back(*it); + } + // construct thread2runtime_event_nodes and + // correlation_id2runtime_event_node + for (auto it = runtime_event_nodes.begin(); it != runtime_event_nodes.end(); + ++it) { + thread2runtime_event_nodes[(*it)->ThreadId()].push_back(*it); + correlation_id2runtime_event_node[(*it)->CorrelationId()] = *it; + } + // associate CudaRuntimeTraceEventNode and DeviceTraceEventNode + // construct correlation_id2device_event_nodes + for (auto it = device_event_nodes.begin(); it != device_event_nodes.end(); + ++it) { + auto dst_iter = + correlation_id2runtime_event_node.find((*it)->CorrelationId()); + PADDLE_ENFORCE_NE( + dst_iter, correlation_id2runtime_event_node.end(), + platform::errors::NotFound("Unknown device events, " + "no corresponding cuda runtime events")); + dst_iter->second->AddDeviceTraceEventNode(*it); + } + // sort host event nodes and runtime event nodes according to start_ns and + // end_ns + // the smaller start_ns is, the further ahead position is. + // when start_ns of two nodes are equal, the one with bigger end_ns should be + // ahead. + for (auto it = thread2host_event_nodes.begin(); + it != thread2host_event_nodes.end(); ++it) { + std::sort(it->second.begin(), it->second.end(), + [](HostTraceEventNode* node1, HostTraceEventNode* node2) { + if (node1->StartNs() < node2->StartNs()) { + return true; + } + if ((node1->StartNs() == node2->StartNs()) && + (node1->EndNs() > node2->EndNs())) { + return true; + } + return false; + }); + } + for (auto it = thread2runtime_event_nodes.begin(); + it != thread2runtime_event_nodes.end(); ++it) { + std::sort( + it->second.begin(), it->second.end(), + [](CudaRuntimeTraceEventNode* node1, CudaRuntimeTraceEventNode* node2) { + if (node1->StartNs() < node2->StartNs()) { + return true; + } + if ((node1->StartNs() == node2->StartNs()) && + (node1->EndNs() > node2->EndNs())) { + return true; + } + return false; + }); + } + + // construct trees + std::set thread_set; + for (auto it = thread2host_event_nodes.begin(); + it != thread2host_event_nodes.end(); ++it) { + thread_set.insert(it->first); + } + + for (auto it = thread2runtime_event_nodes.begin(); + it != thread2runtime_event_nodes.end(); ++it) { + thread_set.insert(it->first); + } + + for (auto it = thread_set.begin(); it != thread_set.end(); ++it) { + thread_event_trees_map_[*it] = BuildTreeRelationship( + thread2host_event_nodes[*it], thread2runtime_event_nodes[*it]); + } +} + +HostTraceEventNode* NodeTrees::BuildTreeRelationship( + std::vector host_event_nodes, + std::vector runtime_event_nodes) { + // a stack used for analyse relationship + auto node_stack = std::vector(); + // root node, top level + auto root_node = new HostTraceEventNode( + HostTraceEvent(std::string("root node"), TracerEventType::UserDefined, 0, + ULLONG_MAX, 0, 0)); + // push root node into node_stack + node_stack.push_back(root_node); + // handle host_event_nodes + for (auto it = host_event_nodes.begin(); it != host_event_nodes.end(); ++it) { + while (true) { + auto stack_top_node = node_stack.back(); + if ((*it)->StartNs() < stack_top_node->EndNs()) { + // current node is the child of stack_top_node + PADDLE_ENFORCE_LE( + (*it)->EndNs(), stack_top_node->EndNs(), + platform::errors::Fatal( + "should not have time range intersection within one thread")); + stack_top_node->AddChild(*it); + node_stack.push_back(*it); + break; + } else { + node_stack.pop_back(); + // insert runtime node + // select runtime nodes which time range within stack_top_node + std::vector::iterator firstposition; + std::vector::iterator lastposition = + runtime_event_nodes.end(); + bool hasenter = false; + for (auto runtimenode = runtime_event_nodes.begin(); + runtimenode != runtime_event_nodes.end(); ++runtimenode) { + if (((*runtimenode)->StartNs() >= stack_top_node->StartNs()) && + ((*runtimenode)->EndNs() <= stack_top_node->EndNs())) { + if (!hasenter) { + firstposition = runtimenode; + hasenter = true; + } + stack_top_node->AddCudaRuntimeNode(*runtimenode); + } else { + // from this runtime node, not within stack_top_node, erase the + // nodes from runtime_event_nodes + if ((*runtimenode)->StartNs() > stack_top_node->EndNs()) { + lastposition = runtimenode; + break; + } + } + } + if (hasenter) { + runtime_event_nodes.erase(firstposition, lastposition); + } + } + } + } + // to insert left runtimenode into host_event_nodes + while (!node_stack.empty()) { + auto stack_top_node = node_stack.back(); + // insert runtime node + // select runtime nodes which time range within stack_top_node + std::vector::iterator firstposition; + std::vector::iterator lastposition = + runtime_event_nodes.end(); + bool hasenter = false; + for (auto runtimenode = runtime_event_nodes.begin(); + runtimenode != runtime_event_nodes.end(); ++runtimenode) { + if (((*runtimenode)->StartNs() >= stack_top_node->StartNs()) && + ((*runtimenode)->EndNs() <= stack_top_node->EndNs())) { + if (!hasenter) { + firstposition = runtimenode; + hasenter = true; + } + stack_top_node->AddCudaRuntimeNode(*runtimenode); + } else { + // from this runtime node, not within stack_top_node, erase the + // nodes from runtime_event_nodes + if ((*runtimenode)->StartNs() > stack_top_node->EndNs()) { + lastposition = runtimenode; + break; + } + } + } + if (hasenter) { + runtime_event_nodes.erase(firstposition, lastposition); + } + node_stack.pop_back(); + } + return root_node; +} + +std::map> NodeTrees::Traverse( + bool bfs) const { + // traverse the tree, provide two methods: bfs(breadth first search) or + // dfs(depth first search) + std::map> thread2host_event_nodes; + if (bfs == true) { + for (auto it = thread_event_trees_map_.begin(); + it != thread_event_trees_map_.end(); ++it) { + auto deque = std::deque(); + uint64_t thread_id = it->first; + auto root_node = it->second; + deque.push_back(root_node); + while (!deque.empty()) { + auto current_node = deque.front(); + deque.pop_front(); + thread2host_event_nodes[thread_id].push_back(current_node); + for (auto child = current_node->GetChildren().begin(); + child != current_node->GetChildren().end(); ++child) { + deque.push_back(*child); + } + } + } + + } else { + for (auto it = thread_event_trees_map_.begin(); + it != thread_event_trees_map_.end(); ++it) { + auto stack = std::stack(); + uint64_t thread_id = it->first; + auto root_node = it->second; + stack.push(root_node); + while (!stack.empty()) { + auto current_node = stack.top(); + stack.pop(); + thread2host_event_nodes[thread_id].push_back(current_node); + for (auto child = current_node->GetChildren().begin(); + child != current_node->GetChildren().end(); ++child) { + stack.push(*child); + } + } + } + } + return thread2host_event_nodes; +} + +void NodeTrees::LogMe(BaseLogger* logger) { logger->LogNodeTrees(*this); } + +void NodeTrees::HandleTrees( + std::function host_event_node_handle, + std::function runtime_event_node_handle, + std::function device_event_node_handle) { + // using different user-defined function to handle different nodes + const std::map> + thread2host_event_nodes = Traverse(true); + for (auto it = thread2host_event_nodes.begin(); + it != thread2host_event_nodes.end(); ++it) { + for (auto hostnode = it->second.begin(); hostnode != it->second.end(); + ++hostnode) { + if (hostnode != it->second.begin()) { // skip root node + host_event_node_handle(*hostnode); + } + for (auto runtimenode = (*hostnode)->GetRuntimeTraceEventNodes().begin(); + runtimenode != (*hostnode)->GetRuntimeTraceEventNodes().end(); + ++runtimenode) { + runtime_event_node_handle(*runtimenode); + for (auto devicenode = + (*runtimenode)->GetDeviceTraceEventNodes().begin(); + devicenode != (*runtimenode)->GetDeviceTraceEventNodes().end(); + ++devicenode) { + device_event_node_handle(*devicenode); + } + } + } + } +} +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/profiler/event_node.h b/paddle/fluid/platform/profiler/event_node.h old mode 100755 new mode 100644 index 05190bc4666..dd8dfd32df4 --- a/paddle/fluid/platform/profiler/event_node.h +++ b/paddle/fluid/platform/profiler/event_node.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,18 +35,18 @@ class DeviceTraceEventNode { // destructor ~DeviceTraceEventNode() {} // getter - std::string name() const { return device_event_.name; } - TracerEventType type() const { return device_event_.type; } - uint64_t start_ns() const { return device_event_.start_ns; } - uint64_t end_ns() const { return device_event_.end_ns; } - uint64_t device_id() const { return device_event_.device_id; } - uint64_t context_id() const { return device_event_.context_id; } - uint64_t stream_id() const { return device_event_.stream_id; } - uint64_t duration() const { + std::string Name() const { return device_event_.name; } + TracerEventType Type() const { return device_event_.type; } + uint64_t StartNs() const { return device_event_.start_ns; } + uint64_t EndNs() const { return device_event_.end_ns; } + uint64_t DeviceId() const { return device_event_.device_id; } + uint64_t ContextId() const { return device_event_.context_id; } + uint64_t StreamId() const { return device_event_.stream_id; } + uint64_t Duration() const { return device_event_.end_ns - device_event_.start_ns; } - uint32_t correlation_id() const { return device_event_.correlation_id; } - KernelEventInfo kernel_info() const { + uint32_t CorrelationId() const { return device_event_.correlation_id; } + KernelEventInfo KernelInfo() const { PADDLE_ENFORCE_EQ( device_event_.type, TracerEventType::Kernel, platform::errors::Unavailable( @@ -54,7 +54,7 @@ class DeviceTraceEventNode { "TracerEventType in node must be TracerEventType::Kernel.")); return device_event_.kernel_info; } - MemcpyEventInfo memcpy_info() const { + MemcpyEventInfo MemcpyInfo() const { PADDLE_ENFORCE_EQ( device_event_.type, TracerEventType::Memcpy, platform::errors::Unavailable( @@ -62,7 +62,7 @@ class DeviceTraceEventNode { "TracerEventType in node must be TracerEventType::Memcpy.")); return device_event_.memcpy_info; } - MemsetEventInfo memset_info() const { + MemsetEventInfo MemsetInfo() const { PADDLE_ENFORCE_EQ( device_event_.type, TracerEventType::Memset, platform::errors::Unavailable( @@ -87,17 +87,17 @@ class CudaRuntimeTraceEventNode { // destructor ~CudaRuntimeTraceEventNode(); // getter - std::string name() const { return runtime_event_.name; } - TracerEventType type() const { return runtime_event_.type; } - uint64_t start_ns() const { return runtime_event_.start_ns; } - uint64_t end_ns() const { return runtime_event_.end_ns; } - uint64_t process_id() const { return runtime_event_.process_id; } - uint64_t thread_id() const { return runtime_event_.thread_id; } - uint64_t duration() const { + std::string Name() const { return runtime_event_.name; } + TracerEventType Type() const { return runtime_event_.type; } + uint64_t StartNs() const { return runtime_event_.start_ns; } + uint64_t EndNs() const { return runtime_event_.end_ns; } + uint64_t ProcessId() const { return runtime_event_.process_id; } + uint64_t ThreadId() const { return runtime_event_.thread_id; } + uint64_t Duration() const { return runtime_event_.end_ns - runtime_event_.start_ns; } - uint32_t correlation_id() const { return runtime_event_.correlation_id; } - uint32_t callback_id() const { return runtime_event_.callback_id; } + uint32_t CorrelationId() const { return runtime_event_.correlation_id; } + uint32_t CallbackId() const { return runtime_event_.callback_id; } // member function void AddDeviceTraceEventNode(DeviceTraceEventNode* node) { device_node_ptrs_.push_back(node); @@ -124,13 +124,13 @@ class HostTraceEventNode { ~HostTraceEventNode(); // getter - std::string name() const { return host_event_.name; } - TracerEventType type() const { return host_event_.type; } - uint64_t start_ns() const { return host_event_.start_ns; } - uint64_t end_ns() const { return host_event_.end_ns; } - uint64_t process_id() const { return host_event_.process_id; } - uint64_t thread_id() const { return host_event_.thread_id; } - uint64_t duration() const { + std::string Name() const { return host_event_.name; } + TracerEventType Type() const { return host_event_.type; } + uint64_t StartNs() const { return host_event_.start_ns; } + uint64_t EndNs() const { return host_event_.end_ns; } + uint64_t ProcessId() const { return host_event_.process_id; } + uint64_t ThreadId() const { return host_event_.thread_id; } + uint64_t Duration() const { return host_event_.end_ns - host_event_.start_ns; } diff --git a/paddle/fluid/platform/profiler/event_python.h b/paddle/fluid/platform/profiler/event_python.h index 2241cf9e49e..b0d8eaa2427 100755 --- a/paddle/fluid/platform/profiler/event_python.h +++ b/paddle/fluid/platform/profiler/event_python.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/fluid/platform/profiler/output_logger.h b/paddle/fluid/platform/profiler/output_logger.h index 6901ed0c444..ff4effad5ec 100755 --- a/paddle/fluid/platform/profiler/output_logger.h +++ b/paddle/fluid/platform/profiler/output_logger.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/fluid/platform/profiler/profiler.cc b/paddle/fluid/platform/profiler/profiler.cc index e9f0eb98d53..96fa157f399 100644 --- a/paddle/fluid/platform/profiler/profiler.cc +++ b/paddle/fluid/platform/profiler/profiler.cc @@ -25,6 +25,7 @@ limitations under the License. */ #endif #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/profiler/host_tracer.h" +#include "paddle/fluid/platform/profiler/trace_event_collector.h" namespace paddle { namespace platform { @@ -62,14 +63,17 @@ void Profiler::Start() { } } -TraceEventCollector Profiler::Stop() { +std::unique_ptr Profiler::Stop() { SynchronizeAllDevice(); TraceEventCollector collector; for (auto& tracer : tracers_) { tracer.Get().StopTracing(); tracer.Get().CollectTraceData(&collector); } - return collector; + std::unique_ptr tree(new NodeTrees(collector.HostEvents(), + collector.RuntimeEvents(), + collector.DeviceEvents())); + return tree; } } // namespace platform diff --git a/paddle/fluid/platform/profiler/profiler.h b/paddle/fluid/platform/profiler/profiler.h index 1324d81f959..33fc844b0f3 100644 --- a/paddle/fluid/platform/profiler/profiler.h +++ b/paddle/fluid/platform/profiler/profiler.h @@ -20,7 +20,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/platform/macros.h" -#include "paddle/fluid/platform/profiler/trace_event_collector.h" +#include "paddle/fluid/platform/profiler/event_node.h" #include "paddle/fluid/platform/profiler/tracer_base.h" namespace paddle { @@ -38,7 +38,7 @@ class Profiler { void Start(); - TraceEventCollector Stop(); + std::unique_ptr Stop(); ~Profiler(); diff --git a/paddle/fluid/platform/profiler/profiler_test.cc b/paddle/fluid/platform/profiler/profiler_test.cc index 414987d2f10..6bd3ed9d809 100644 --- a/paddle/fluid/platform/profiler/profiler_test.cc +++ b/paddle/fluid/platform/profiler/profiler_test.cc @@ -42,11 +42,12 @@ TEST(ProfilerTest, TestHostTracer) { RecordInstantEvent("TestTraceLevel_record2", TracerEventType::UserDefined, 3); } - auto collector = profiler->Stop(); + auto nodetree = profiler->Stop(); std::set host_events; - for (const auto evt : collector.HostEvents()) { - host_events.insert(evt.name); - } + for (const auto pair : nodetree->Traverse(true)) + for (const auto evt : pair.second) { + host_events.insert(evt->Name()); + } EXPECT_EQ(host_events.count("TestTraceLevel_record1"), 1u); EXPECT_EQ(host_events.count("TestTraceLevel_record2"), 0u); } diff --git a/paddle/fluid/platform/profiler/test_event_node.cc b/paddle/fluid/platform/profiler/test_event_node.cc new file mode 100644 index 00000000000..b8d1306ad07 --- /dev/null +++ b/paddle/fluid/platform/profiler/test_event_node.cc @@ -0,0 +1,203 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" + +#include "paddle/fluid/platform/profiler/chrometracing_logger.h" +#include "paddle/fluid/platform/profiler/event_node.h" + +using paddle::platform::ChromeTracingLogger; +using paddle::platform::NodeTrees; +using paddle::platform::HostTraceEventNode; +using paddle::platform::CudaRuntimeTraceEventNode; +using paddle::platform::DeviceTraceEventNode; +using paddle::platform::HostTraceEvent; +using paddle::platform::RuntimeTraceEvent; +using paddle::platform::DeviceTraceEvent; +using paddle::platform::TracerEventType; +using paddle::platform::KernelEventInfo; +using paddle::platform::MemcpyEventInfo; +using paddle::platform::MemsetEventInfo; +TEST(NodeTreesTest, LogMe_case0) { + std::list host_events; + std::list runtime_events; + std::list device_events; + host_events.push_back(HostTraceEvent(std::string("dataloader#1"), + TracerEventType::Dataloader, 1000, 10000, + 10, 10)); + host_events.push_back(HostTraceEvent( + std::string("op1"), TracerEventType::Operator, 11000, 20000, 10, 10)); + host_events.push_back(HostTraceEvent( + std::string("op2"), TracerEventType::Operator, 21000, 30000, 10, 10)); + host_events.push_back(HostTraceEvent( + std::string("op3"), TracerEventType::Operator, 31000, 40000, 10, 11)); + runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch1"), 15000, + 17000, 10, 10, 1, 0)); + runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch2"), 25000, + 35000, 10, 10, 2, 0)); + runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch3"), 33000, + 37000, 10, 11, 3, 0)); + runtime_events.push_back(RuntimeTraceEvent(std::string("cudaMemcpy1"), 18000, + 19000, 10, 10, 4, 0)); + runtime_events.push_back(RuntimeTraceEvent(std::string("cudaMemset1"), 38000, + 39000, 10, 11, 5, 0)); + device_events.push_back( + DeviceTraceEvent(std::string("kernel1"), TracerEventType::Kernel, 40000, + 55000, 0, 10, 10, 1, KernelEventInfo())); + device_events.push_back( + DeviceTraceEvent(std::string("kernel2"), TracerEventType::Kernel, 70000, + 95000, 0, 10, 10, 2, KernelEventInfo())); + device_events.push_back( + DeviceTraceEvent(std::string("kernel3"), TracerEventType::Kernel, 60000, + 65000, 0, 10, 11, 3, KernelEventInfo())); + device_events.push_back( + DeviceTraceEvent(std::string("memcpy1"), TracerEventType::Memcpy, 56000, + 59000, 0, 10, 10, 4, MemcpyEventInfo())); + device_events.push_back( + DeviceTraceEvent(std::string("memset1"), TracerEventType::Memset, 66000, + 69000, 0, 10, 11, 5, MemsetEventInfo())); + ChromeTracingLogger logger("test_nodetrees_logme_case0.json"); + NodeTrees tree(host_events, runtime_events, device_events); + std::map> nodes = + tree.Traverse(true); + EXPECT_EQ(nodes[10].size(), 4u); + EXPECT_EQ(nodes[11].size(), 2u); + std::vector thread1_nodes = nodes[10]; + std::vector thread2_nodes = nodes[11]; + for (auto it = thread1_nodes.begin(); it != thread1_nodes.end(); it++) { + if ((*it)->Name() == "root node") { + EXPECT_EQ((*it)->GetChildren().size(), 3u); + } + if ((*it)->Name() == "op1") { + EXPECT_EQ((*it)->GetChildren().size(), 0u); + EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u); + } + } + for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) { + if ((*it)->Name() == "op3") { + EXPECT_EQ((*it)->GetChildren().size(), 0u); + EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u); + } + } + tree.LogMe(&logger); +} + +TEST(NodeTreesTest, LogMe_case1) { + std::list host_events; + std::list runtime_events; + std::list device_events; + runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch1"), 15000, + 17000, 10, 10, 1, 0)); + runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch2"), 25000, + 35000, 10, 10, 2, 0)); + runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch3"), 33000, + 37000, 10, 11, 3, 0)); + runtime_events.push_back(RuntimeTraceEvent(std::string("cudaMemcpy1"), 18000, + 19000, 10, 10, 4, 0)); + runtime_events.push_back(RuntimeTraceEvent(std::string("cudaMemset1"), 38000, + 39000, 10, 11, 5, 0)); + device_events.push_back( + DeviceTraceEvent(std::string("kernel1"), TracerEventType::Kernel, 40000, + 55000, 0, 10, 10, 1, KernelEventInfo())); + device_events.push_back( + DeviceTraceEvent(std::string("kernel2"), TracerEventType::Kernel, 70000, + 95000, 0, 10, 10, 2, KernelEventInfo())); + device_events.push_back( + DeviceTraceEvent(std::string("kernel3"), TracerEventType::Kernel, 60000, + 65000, 0, 10, 11, 3, KernelEventInfo())); + device_events.push_back( + DeviceTraceEvent(std::string("memcpy1"), TracerEventType::Memcpy, 56000, + 59000, 0, 10, 10, 4, MemcpyEventInfo())); + device_events.push_back( + DeviceTraceEvent(std::string("memset1"), TracerEventType::Memset, 66000, + 69000, 0, 10, 11, 5, MemsetEventInfo())); + ChromeTracingLogger logger("test_nodetrees_logme_case1.json"); + NodeTrees tree(host_events, runtime_events, device_events); + std::map> nodes = + tree.Traverse(true); + EXPECT_EQ(nodes[10].size(), 1u); + EXPECT_EQ(nodes[11].size(), 1u); + std::vector thread1_nodes = nodes[10]; + std::vector thread2_nodes = nodes[11]; + for (auto it = thread1_nodes.begin(); it != thread1_nodes.end(); it++) { + if ((*it)->Name() == "root node") { + EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 3u); + } + } + for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) { + if ((*it)->Name() == "root node") { + EXPECT_EQ((*it)->GetChildren().size(), 0u); + EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u); + } + } + tree.LogMe(&logger); +} + +TEST(NodeTreesTest, HandleTrees_case0) { + std::list host_events; + std::list runtime_events; + std::list device_events; + host_events.push_back(HostTraceEvent( + std::string("op1"), TracerEventType::Operator, 10000, 100000, 10, 10)); + host_events.push_back(HostTraceEvent( + std::string("op2"), TracerEventType::Operator, 30000, 70000, 10, 10)); + host_events.push_back(HostTraceEvent( + std::string("op3"), TracerEventType::Operator, 2000, 120000, 10, 11)); + runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch1"), 15000, + 25000, 10, 10, 1, 0)); + runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch2"), 35000, + 45000, 10, 10, 2, 0)); + runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch3"), 10000, + 55000, 10, 11, 3, 0)); + device_events.push_back( + DeviceTraceEvent(std::string("kernel1"), TracerEventType::Kernel, 40000, + 55000, 0, 10, 10, 1, KernelEventInfo())); + device_events.push_back( + DeviceTraceEvent(std::string("kernel2"), TracerEventType::Kernel, 70000, + 95000, 0, 10, 10, 2, KernelEventInfo())); + device_events.push_back( + DeviceTraceEvent(std::string("kernel3"), TracerEventType::Kernel, 60000, + 75000, 0, 10, 11, 3, KernelEventInfo())); + ChromeTracingLogger logger("test_nodetrees_handletrees_case0.json"); + NodeTrees tree(host_events, runtime_events, device_events); + std::map> nodes = + tree.Traverse(true); + EXPECT_EQ(nodes[10].size(), 3u); + EXPECT_EQ(nodes[11].size(), 2u); + std::vector thread1_nodes = nodes[10]; + std::vector thread2_nodes = nodes[11]; + for (auto it = thread1_nodes.begin(); it != thread1_nodes.end(); it++) { + if ((*it)->Name() == "root node") { + EXPECT_EQ((*it)->GetChildren().size(), 1u); + } + if ((*it)->Name() == "op1") { + EXPECT_EQ((*it)->GetChildren().size(), 1u); + EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 1u); + } + } + for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) { + if ((*it)->Name() == "op3") { + EXPECT_EQ((*it)->GetChildren().size(), 0u); + EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 1u); + } + } + std::function host_event_node_handle( + [&](HostTraceEventNode* a) { logger.LogHostTraceEventNode(*a); }); + std::function runtime_event_node_handle([&]( + CudaRuntimeTraceEventNode* a) { logger.LogRuntimeTraceEventNode(*a); }); + std::function device_event_node_handle( + [&](DeviceTraceEventNode* a) { logger.LogDeviceTraceEventNode(*a); }); + tree.HandleTrees(host_event_node_handle, runtime_event_node_handle, + device_event_node_handle); +} diff --git a/paddle/fluid/platform/profiler/trace_event.h b/paddle/fluid/platform/profiler/trace_event.h index 1f146adf4f7..3e4903f6ffb 100644 --- a/paddle/fluid/platform/profiler/trace_event.h +++ b/paddle/fluid/platform/profiler/trace_event.h @@ -84,7 +84,7 @@ struct MemcpyEventInfo { // The kind of the memory copy. // Each kind represents the source and destination targets of a memory copy. // Targets are host, device, and array. Refer to CUpti_ActivityMemcpyKind - // std::string copy_kind; + char copy_kind[kMemKindMaxLen]; // The source memory kind read by the memory copy. // Each kind represents the type of the memory accessed by a memory // operation/copy. Refer to CUpti_ActivityMemoryKind -- GitLab