/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/profiler/output_logger.h" #include "paddle/fluid/platform/profiler/trace_event.h" namespace paddle { namespace platform { class DeviceTraceEventNode { public: // constructor explicit DeviceTraceEventNode(const DeviceTraceEvent& device_event) : device_event_(device_event) {} // destructor ~DeviceTraceEventNode() {} // getter std::string name() const { return device_event_.name; } TracerEventType type() const { return device_event_.type; } uint64_t start_ns() const { return device_event_.start_ns; } uint64_t end_ns() const { return device_event_.end_ns; } uint64_t device_id() const { return device_event_.device_id; } uint64_t context_id() const { return device_event_.context_id; } uint64_t stream_id() const { return device_event_.stream_id; } uint64_t duration() const { return device_event_.end_ns - device_event_.start_ns; } uint32_t correlation_id() const { return device_event_.correlation_id; } KernelEventInfo kernel_info() const { PADDLE_ENFORCE_EQ( device_event_.type, TracerEventType::Kernel, platform::errors::Unavailable( "Can not kernel_info, " "TracerEventType in node must be TracerEventType::Kernel.")); return device_event_.kernel_info; } MemcpyEventInfo memcpy_info() const { PADDLE_ENFORCE_EQ( device_event_.type, TracerEventType::Memcpy, platform::errors::Unavailable( "Can not get memcpy_info, " "TracerEventType in node must be TracerEventType::Memcpy.")); return device_event_.memcpy_info; } MemsetEventInfo memset_info() const { PADDLE_ENFORCE_EQ( device_event_.type, TracerEventType::Memset, platform::errors::Unavailable( "Can not get memset_info, " "TracerEventType in node must be TracerEventType::Memset.")); return device_event_.memset_info; } // member function void LogMe(BaseLogger* logger) { logger->LogDeviceTraceEventNode(*this); } private: // data DeviceTraceEvent device_event_; }; class CudaRuntimeTraceEventNode { public: // constructor explicit CudaRuntimeTraceEventNode(const RuntimeTraceEvent& runtime_event) : runtime_event_(runtime_event) {} // destructor ~CudaRuntimeTraceEventNode(); // getter std::string name() const { return runtime_event_.name; } TracerEventType type() const { return runtime_event_.type; } uint64_t start_ns() const { return runtime_event_.start_ns; } uint64_t end_ns() const { return runtime_event_.end_ns; } uint64_t process_id() const { return runtime_event_.process_id; } uint64_t thread_id() const { return runtime_event_.thread_id; } uint64_t duration() const { return runtime_event_.end_ns - runtime_event_.start_ns; } uint32_t correlation_id() const { return runtime_event_.correlation_id; } uint32_t callback_id() const { return runtime_event_.callback_id; } // member function void AddDeviceTraceEventNode(DeviceTraceEventNode* node) { device_node_ptrs_.push_back(node); } void LogMe(BaseLogger* logger) { logger->LogRuntimeTraceEventNode(*this); } std::vector& GetDeviceTraceEventNodes() { return device_node_ptrs_; } private: // data RuntimeTraceEvent runtime_event_; // device events called by this std::vector device_node_ptrs_; }; class HostTraceEventNode { public: // constructor explicit HostTraceEventNode(const HostTraceEvent& host_event) : host_event_(host_event) {} // destructor ~HostTraceEventNode(); // getter std::string name() const { return host_event_.name; } TracerEventType type() const { return host_event_.type; } uint64_t start_ns() const { return host_event_.start_ns; } uint64_t end_ns() const { return host_event_.end_ns; } uint64_t process_id() const { return host_event_.process_id; } uint64_t thread_id() const { return host_event_.thread_id; } uint64_t duration() const { return host_event_.end_ns - host_event_.start_ns; } // member function void AddChild(HostTraceEventNode* node) { children_.push_back(node); } void AddCudaRuntimeNode(CudaRuntimeTraceEventNode* node) { runtime_node_ptrs_.push_back(node); } std::vector& GetChildren() { return children_; } std::vector& GetRuntimeTraceEventNodes() { return runtime_node_ptrs_; } void LogMe(BaseLogger* logger) { logger->LogHostTraceEventNode(*this); } private: // data HostTraceEvent host_event_; // cuda runtime events called by this std::vector runtime_node_ptrs_; // host events called by this std::vector children_; }; class NodeTrees { public: // constructor NodeTrees(const std::list& host_events, const std::list& runtime_events, const std::list& device_events) { std::vector host_event_nodes; std::vector runtime_event_nodes; std::vector device_event_nodes; // encapsulate event into nodes for (auto it = host_events.begin(); it != host_events.end(); ++it) { host_event_nodes.push_back(new HostTraceEventNode(*it)); } for (auto it = runtime_events.begin(); it != runtime_events.end(); ++it) { runtime_event_nodes.push_back(new CudaRuntimeTraceEventNode(*it)); } for (auto it = device_events.begin(); it != device_events.end(); ++it) { device_event_nodes.push_back(new DeviceTraceEventNode(*it)); } // build tree BuildTrees(host_event_nodes, runtime_event_nodes, device_event_nodes); } explicit NodeTrees( const std::map& thread_event_trees_map) : thread_event_trees_map_(thread_event_trees_map) {} // destructor ~NodeTrees(); void LogMe(BaseLogger* logger); void HandleTrees(std::function, std::function, std::function); std::map GetNodeTrees() { return thread_event_trees_map_; } std::map> Traverse(bool bfs) const; private: std::map thread_event_trees_map_; void BuildTrees(const std::vector&, std::vector&, const std::vector&); HostTraceEventNode* BuildTreeRelationship( std::vector host_event_nodes, std::vector runtime_event_nodes); }; } // namespace platform } // namespace paddle