From 09799566193ae4e797795a69993e4e80cb86179e Mon Sep 17 00:00:00 2001 From: chengduo Date: Tue, 12 Mar 2019 23:00:22 -0500 Subject: [PATCH] Add memory profiler (#16137) test=develop --- paddle/fluid/memory/allocation/CMakeLists.txt | 2 +- .../memory/allocation/legacy_allocator.cc | 12 +- paddle/fluid/platform/device_tracer.cc | 59 +++- paddle/fluid/platform/device_tracer.h | 21 ++ paddle/fluid/platform/event.h | 33 +++ paddle/fluid/platform/profiler.cc | 257 +++++++++++++----- paddle/fluid/platform/profiler.h | 77 +++++- paddle/fluid/platform/profiler.proto | 17 ++ tools/timeline.py | 104 +++++++ 9 files changed, 505 insertions(+), 77 deletions(-) diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index 4b7b9064dcd..7c44e18f8f3 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -3,7 +3,7 @@ cc_library(cpu_allocator SRCS cpu_allocator.cc DEPS allocator) cc_library(best_fit_allocator SRCS best_fit_allocator.cc DEPS allocator) cc_library(locked_allocator SRCS locked_allocator.cc DEPS allocator) cc_library(buffered_allocator SRCS buffered_allocator.cc DEPS allocator) -cc_library(legacy_allocator SRCS legacy_allocator.cc DEPS allocator buddy_allocator) +cc_library(legacy_allocator SRCS legacy_allocator.cc DEPS allocator buddy_allocator profiler) cc_test(buffered_allocator_test SRCS buffered_allocator_test.cc DEPS best_fit_allocator locked_allocator buffered_allocator cpu_allocator) if (WITH_GPU) diff --git a/paddle/fluid/memory/allocation/legacy_allocator.cc b/paddle/fluid/memory/allocation/legacy_allocator.cc index a97d54a1917..c233bf4edf5 100644 --- a/paddle/fluid/memory/allocation/legacy_allocator.cc +++ b/paddle/fluid/memory/allocation/legacy_allocator.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/memory/allocation/legacy_allocator.h" - #include #include #include @@ -24,9 +22,11 @@ #endif #include "glog/logging.h" +#include "paddle/fluid/memory/allocation/legacy_allocator.h" #include "paddle/fluid/memory/detail/buddy_allocator.h" #include "paddle/fluid/memory/detail/system_allocator.h" #include "paddle/fluid/platform/gpu_info.h" +#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/split.h" @@ -329,18 +329,22 @@ size_t Usage::operator()(const platform::CUDAPinnedPlace &cuda_pinned) const { } // namespace legacy namespace allocation { - LegacyMemMonitor GPUMemMonitor; Allocation *LegacyAllocator::AllocateImpl(size_t size, Allocator::Attr attr) { void *ptr = boost::apply_visitor(legacy::AllocVisitor(size), place_); - return new Allocation(ptr, size, place_); + auto *tmp_alloc = new Allocation(ptr, size, place_); + platform::MemEvenRecorder::Instance().PushMemRecord( + static_cast(tmp_alloc), place_, size); + return tmp_alloc; } void LegacyAllocator::Free(Allocation *allocation) { boost::apply_visitor( legacy::FreeVisitor(allocation->ptr(), allocation->size()), allocation->place()); + platform::MemEvenRecorder::Instance().PopMemRecord( + static_cast(allocation), place_); delete allocation; } diff --git a/paddle/fluid/platform/device_tracer.cc b/paddle/fluid/platform/device_tracer.cc index b084f1a649b..8458b17f82a 100644 --- a/paddle/fluid/platform/device_tracer.cc +++ b/paddle/fluid/platform/device_tracer.cc @@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/platform/device_tracer.h" #include #include @@ -30,6 +29,8 @@ limitations under the License. */ #include "glog/logging.h" #include "google/protobuf/text_format.h" #include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/platform/device_tracer.h" +#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/string/printf.h" namespace paddle { @@ -317,6 +318,24 @@ class DeviceTracerImpl : public DeviceTracer { stream_id, correlation_id, bytes}); } + void AddMemInfoRecord(uint64_t start_ns, uint64_t end_ns, size_t bytes, + const Place &place, const std::string &alloc_in, + const std::string &free_in, int64_t thread_id) { + if (0 == start_ns || 0 == end_ns) { + VLOG(3) << alloc_in << ", " << free_in << " Cannot be traced."; + return; + } + thread_local std::forward_list *local_mem_info_record = + nullptr; + if (local_mem_info_record == nullptr) { + std::lock_guard l(trace_mu_); + mem_info_record_.emplace_front(); + local_mem_info_record = &mem_info_record_.front(); + } + local_mem_info_record->emplace_front(MemInfoRecord{ + start_ns, end_ns, bytes, place, thread_id, alloc_in, free_in}); + } + void AddActiveKindRecords(const std::string &anno, uint64_t start_ns, uint64_t end_ns, int64_t device_id, int64_t thread_id, uint32_t correlation_id) { @@ -409,6 +428,7 @@ class DeviceTracerImpl : public DeviceTracer { correlations_.clear(); for (auto &tmp : correlations_pairs) tmp.clear(); for (auto &tmp : cpu_records_) tmp.clear(); + for (auto &tmp : mem_info_record_) tmp.clear(); for (auto &tmp : active_kind_records_) tmp.clear(); } @@ -440,9 +460,12 @@ class DeviceTracerImpl : public DeviceTracer { proto::Profile profile_pb; profile_pb.set_start_ns(start_ns_); profile_pb.set_end_ns(end_ns_); - if (correlations_.empty()) - for (auto &tmp : correlations_pairs) + if (correlations_.empty()) { + for (auto &tmp : correlations_pairs) { for (auto &pair : tmp) correlations_[pair.first] = pair.second; + } + } + for (const KernelRecord &r : kernel_records_) { auto *event = profile_pb.add_events(); event->set_type(proto::Event::GPUKernel); @@ -462,6 +485,7 @@ class DeviceTracerImpl : public DeviceTracer { event->set_device_id(r.device_id); } VLOG(1) << "KernelRecord event miss: " << miss << " find: " << find; + for (auto &tmp : cpu_records_) { for (const CPURecord &r : tmp) { auto *event = profile_pb.add_events(); @@ -473,6 +497,7 @@ class DeviceTracerImpl : public DeviceTracer { event->set_device_id(r.device_id); } } + for (auto &tmp : active_kind_records_) { for (const ActiveKindRecord &r : tmp) { auto *event = profile_pb.add_events(); @@ -510,6 +535,31 @@ class DeviceTracerImpl : public DeviceTracer { event->mutable_memcopy()->set_bytes(r.bytes); } VLOG(1) << "MemRecord event miss: " << miss << " find: " << find; + + for (auto &tmp : mem_info_record_) { + for (const auto &r : tmp) { + auto *event = profile_pb.add_mem_events(); + event->set_device_id(0); + if (platform::is_cpu_place(r.place)) { + event->set_place(proto::MemEvent::CPUPlace); + } else if (platform::is_gpu_place(r.place)) { + event->set_place(proto::MemEvent::CUDAPlace); + event->set_device_id( + boost::get(r.place).GetDeviceId()); + } else if (platform::is_cuda_pinned_place(r.place)) { + event->set_place(proto::MemEvent::CUDAPinnedPlace); + } else { + PADDLE_THROW("The current place is not supported."); + } + event->set_alloc_in(r.alloc_in); + event->set_free_in(r.free_in); + event->set_start_ns(r.start_ns); + event->set_end_ns(r.end_ns); + event->set_bytes(r.bytes); + event->set_thread_id(r.thread_id); + } + } + std::ofstream profile_f; profile_f.open(profile_path, std::ios::out | std::ios::trunc | std::ios::binary); @@ -553,6 +603,7 @@ class DeviceTracerImpl : public DeviceTracer { std::forward_list kernel_records_; std::forward_list mem_records_; std::forward_list> cpu_records_; + std::forward_list> mem_info_record_; std::forward_list> active_kind_records_; std::forward_list>> correlations_pairs; @@ -575,7 +626,7 @@ Event *CurAnnotation() { return annotation_stack.back(); } std::string CurAnnotationName() { - if (annotation_stack.empty()) return ""; + if (annotation_stack.empty()) return "Unknown"; return annotation_stack.back()->name(); } diff --git a/paddle/fluid/platform/device_tracer.h b/paddle/fluid/platform/device_tracer.h index a8f1d89383d..85168a046fb 100644 --- a/paddle/fluid/platform/device_tracer.h +++ b/paddle/fluid/platform/device_tracer.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/cupti.h" #include "paddle/fluid/platform/event.h" +#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/profiler.pb.h" @@ -47,6 +48,7 @@ class DeviceTracer { int64_t stream_id; uint32_t correlation_id; }; + struct CPURecord { std::string name; uint64_t start_ns; @@ -54,6 +56,7 @@ class DeviceTracer { int64_t device_id; int64_t thread_id; }; + struct MemRecord { std::string name; uint64_t start_ns; @@ -63,6 +66,17 @@ class DeviceTracer { uint32_t correlation_id; uint64_t bytes; }; + + struct MemInfoRecord { + uint64_t start_ns; + uint64_t end_ns; + size_t bytes; + Place place; + int64_t thread_id; + std::string alloc_in; + std::string free_in; + }; + struct ActiveKindRecord { std::string name; uint64_t start_ns; @@ -71,6 +85,7 @@ class DeviceTracer { int64_t thread_id; uint32_t correlation_id; }; + virtual ~DeviceTracer() {} // Needs to be called once before use. virtual void Enable() = 0; @@ -97,6 +112,12 @@ class DeviceTracer { int64_t thread_id, uint32_t correlation_id) = 0; + virtual void AddMemInfoRecord(uint64_t start_ns, uint64_t end_ns, + size_t bytes, const Place& place, + const std::string& alloc_in, + const std::string& free_in, + int64_t thread_id) = 0; + // Add a cuda kernel stats. `correlation_id` will be mapped to annotation // added before for human readability. virtual void AddKernelRecords(std::string name, uint64_t start, uint64_t end, diff --git a/paddle/fluid/platform/event.h b/paddle/fluid/platform/event.h index 2dcf966754c..e9bdb82a50f 100644 --- a/paddle/fluid/platform/event.h +++ b/paddle/fluid/platform/event.h @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include #ifdef PADDLE_WITH_CUDA #include #endif +#include "paddle/fluid/platform/place.h" namespace paddle { namespace platform { @@ -64,5 +66,36 @@ class Event { #endif #endif }; + +class MemEvent { + public: + MemEvent(EventType type, uint64_t start_ns, uint64_t end_ns, size_t bytes, + Place place, int64_t thread_id, const std::string& annotation) + : type_(type), + start_ns_(start_ns), + end_ns_(end_ns), + bytes_(bytes), + place_(place), + thread_id_(thread_id), + annotation_(annotation) {} + + const EventType& type() const { return type_; } + uint64_t start_ns() const { return start_ns_; } + uint64_t end_ns() const { return end_ns_; } + size_t bytes() const { return bytes_; } + Place place() const { return place_; } + int64_t thread_id() const { return thread_id_; } + const std::string& annotation() const { return annotation_; } + + private: + EventType type_; + uint64_t start_ns_ = 0; + uint64_t end_ns_ = 0; + size_t bytes_; + Place place_; + int64_t thread_id_; + std::string annotation_; +}; + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index 9a285a6b533..6d055a44210 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/profiler.h" - #include #include #include @@ -21,6 +20,8 @@ limitations under the License. */ #include // NOLINT #include #include +#include + #ifdef PADDLE_WITH_CUDA #include #endif // PADDLE_WITH_CUDA @@ -36,8 +37,6 @@ DEFINE_bool(enable_rpc_profiler, false, "Enable rpc profiler or not."); namespace paddle { namespace platform { -struct EventList; - static int64_t profiler_lister_id = 0; static bool should_send_profile_state = false; std::mutex profiler_mu; @@ -53,43 +52,15 @@ static uint32_t g_next_thread_id = 0; // The global mutex static std::mutex g_all_event_lists_mutex; // The total event lists of all threads -static std::list> g_all_event_lists; +static std::list>> g_all_event_lists; // The thread local event list only can be accessed by the specific thread -static thread_local std::shared_ptr g_event_list; - -struct EventList { - constexpr static size_t kMB = 1024 * 1024; - constexpr static size_t kEventBlockSize = 16 * kMB; - constexpr static size_t kEventSize = sizeof(Event); - constexpr static size_t kEventAlign = alignof(Event); - constexpr static size_t kNumBlock = - kEventBlockSize / - ((kEventSize + kEventAlign - 1) / kEventAlign * kEventAlign); - - template - Event* Record(Args&&... args) { - if (event_blocks.empty() || event_blocks.front().size() == kNumBlock) { - event_blocks.emplace_front(); - event_blocks.front().reserve(kNumBlock); - } - event_blocks.front().emplace_back(std::forward(args)...); - return &event_blocks.front().back(); - } - - std::vector Reduce() { - std::vector result; - for (auto& block : event_blocks) { - result.insert(result.begin(), std::make_move_iterator(block.begin()), - std::make_move_iterator(block.end())); - } - event_blocks.clear(); - return result; - } +static thread_local std::shared_ptr> g_event_list; - void Clear() { event_blocks.clear(); } - - std::forward_list> event_blocks; -}; +static std::list>> g_all_mem_event_lists; +static thread_local std::shared_ptr> g_mem_event_list; +static std::mutex g_all_mem_event_lists_mutex; +static thread_local int32_t g_mem_thread_id; +static uint32_t g_mem_next_thread_id = 0; inline uint64_t GetTimeInNsec() { using clock = std::conditional &GetMemEventList() { + if (!g_mem_event_list) { + g_mem_event_list = std::make_shared>(); + std::lock_guard guard(g_all_mem_event_lists_mutex); + g_mem_thread_id = g_mem_next_thread_id++; + g_all_mem_event_lists.emplace_front(g_mem_event_list); + } + return *g_mem_event_list; +} + +void PushMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes, + const Place &place, const std::string &annotation) { + GetMemEventList().Record(EventType::kPushRange, start_ns, end_ns, bytes, + place, g_mem_thread_id, annotation); +} + +void PopMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes, + const Place &place, const std::string &annotation) { + GetMemEventList().Record(EventType::kPopRange, start_ns, end_ns, bytes, place, + g_mem_thread_id, annotation); +} + +inline EventList &GetEventList() { if (!g_event_list) { std::lock_guard guard(g_all_event_lists_mutex); - g_event_list = std::make_shared(); + g_event_list = std::make_shared>(); g_thread_id = g_next_thread_id++; g_all_event_lists.emplace_front(g_event_list); RecoreCurThreadId(g_thread_id); @@ -131,26 +124,26 @@ inline EventList& GetEventList() { return *g_event_list; } -void Mark(const std::string& name) { +void Mark(const std::string &name) { GetEventList().Record(EventType::kMark, name, g_thread_id); } -Event* PushEvent(const std::string& name) { +Event *PushEvent(const std::string &name) { return GetEventList().Record(EventType::kPushRange, name, g_thread_id); } -void PopEvent(const std::string& name) { +void PopEvent(const std::string &name) { GetEventList().Record(EventType::kPopRange, name, g_thread_id); } -RecordEvent::RecordEvent(const std::string& name) +RecordEvent::RecordEvent(const std::string &name) : is_enabled_(false), start_ns_(PosixInNsec()) { if (g_state == ProfilerState::kDisabled) return; // lock is not needed, the code below is thread-safe is_enabled_ = true; name_ = name; - Event* e = PushEvent(name_); + Event *e = PushEvent(name_); // Maybe need the same push/pop behavior. SetCurAnnotation(e); } @@ -158,7 +151,7 @@ RecordEvent::RecordEvent(const std::string& name) RecordEvent::~RecordEvent() { if (g_state == ProfilerState::kDisabled || !is_enabled_) return; // lock is not needed, the code below is thread-safe - DeviceTracer* tracer = GetDeviceTracer(); + DeviceTracer *tracer = GetDeviceTracer(); if (tracer) { tracer->AddCPURecords(CurAnnotationName(), start_ns_, PosixInNsec(), BlockDepth(), g_thread_id); @@ -167,7 +160,56 @@ RecordEvent::~RecordEvent() { PopEvent(name_); } -RecordRPCEvent::RecordRPCEvent(const std::string& name) { +MemEvenRecorder MemEvenRecorder::recorder; + +void MemEvenRecorder::PushMemRecord(const void *ptr, const Place &place, + size_t size) { + if (g_state == ProfilerState::kDisabled) return; + std::lock_guard guard(mtx_); + auto &events = address_memevent_[place]; + PADDLE_ENFORCE(events.count(ptr) == 0, ""); + events.emplace(ptr, std::unique_ptr( + new MemEvenRecorder::RecordMemEvent(place, size))); +} + +void MemEvenRecorder::PopMemRecord(const void *ptr, const Place &place) { + if (g_state == ProfilerState::kDisabled) return; + std::lock_guard guard(mtx_); + auto &events = address_memevent_[place]; + auto iter = events.find(ptr); + // The ptr maybe not in address_memevent + if (iter != events.end()) { + events.erase(iter); + } +} + +void MemEvenRecorder::Flush() { + std::lock_guard guard(mtx_); + address_memevent_.clear(); +} + +MemEvenRecorder::RecordMemEvent::RecordMemEvent(const Place &place, + size_t bytes) + : place_(place), + bytes_(bytes), + start_ns_(PosixInNsec()), + alloc_in_(CurAnnotationName()) { + PushMemEvent(start_ns_, end_ns_, bytes_, place_, alloc_in_); +} + +MemEvenRecorder::RecordMemEvent::~RecordMemEvent() { + DeviceTracer *tracer = GetDeviceTracer(); + end_ns_ = PosixInNsec(); + + auto annotation_free = CurAnnotationName(); + if (tracer) { + tracer->AddMemInfoRecord(start_ns_, end_ns_, bytes_, place_, alloc_in_, + annotation_free, g_mem_thread_id); + } + PopMemEvent(start_ns_, end_ns_, bytes_, place_, annotation_free); +} + +RecordRPCEvent::RecordRPCEvent(const std::string &name) { if (FLAGS_enable_rpc_profiler) { event_.reset(new platform::RecordEvent(name)); } @@ -185,7 +227,7 @@ RecordBlock::RecordBlock(int block_id) RecordBlock::~RecordBlock() { // lock is not needed, the code below is thread-safe if (g_state == ProfilerState::kDisabled || !is_enabled_) return; - DeviceTracer* tracer = GetDeviceTracer(); + DeviceTracer *tracer = GetDeviceTracer(); if (tracer) { // We try to put all blocks at the same nested depth in the // same timeline lane. and distinguish the using thread_id. @@ -232,11 +274,16 @@ void EnableProfiler(ProfilerState state) { void ResetProfiler() { SynchronizeAllDevice(); GetDeviceTracer()->Reset(); + MemEvenRecorder::Instance().Flush(); std::lock_guard guard(g_all_event_lists_mutex); for (auto it = g_all_event_lists.begin(); it != g_all_event_lists.end(); ++it) { (*it)->Clear(); } + for (auto it = g_all_mem_event_lists.begin(); + it != g_all_mem_event_lists.end(); ++it) { + (*it)->Clear(); + } } std::vector> GetAllEvents() { @@ -249,6 +296,15 @@ std::vector> GetAllEvents() { return result; } +std::vector> GetMemEvents() { + std::lock_guard guard(g_all_mem_event_lists_mutex); + std::vector> result; + for (auto &it : g_all_mem_event_lists) { + result.emplace_back((*it).Reduce()); + } + return result; +} + // The information of each event given in the profiling report struct EventItem { std::string name; @@ -263,8 +319,8 @@ struct EventItem { }; // Print results -void PrintProfiler(const std::vector>& events_table, - const std::string& sorted_domain, const size_t name_width, +void PrintProfiler(const std::vector> &events_table, + const std::string &sorted_domain, const size_t name_width, const size_t data_width, bool merge_thread) { // Output header information std::cout << "\n------------------------->" @@ -302,7 +358,7 @@ void PrintProfiler(const std::vector>& events_table, << std::setw(data_width) << "Ratio." << std::endl; for (size_t i = 0; i < events_table.size(); ++i) { for (size_t j = 0; j < events_table[i].size(); ++j) { - const EventItem& event_item = events_table[i][j]; + const EventItem &event_item = events_table[i][j]; std::cout << std::setw(name_width) << event_item.name << std::setw(data_width) << event_item.calls << std::setw(data_width) << event_item.total_time; @@ -326,54 +382,54 @@ void PrintProfiler(const std::vector>& events_table, } // Parse the event list and output the profiling report -void ParseEvents(const std::vector>& events, +void ParseEvents(const std::vector> &events, bool merge_thread, EventSortingKey sorted_by = EventSortingKey::kDefault) { if (g_state == ProfilerState::kDisabled) return; if (merge_thread && events.size() < 2) return; std::string sorted_domain; - std::function sorted_func; + std::function sorted_func; switch (sorted_by) { case EventSortingKey::kCalls: sorted_domain = "number of calls"; - sorted_func = [](const EventItem& a, const EventItem& b) { + sorted_func = [](const EventItem &a, const EventItem &b) { return a.calls > b.calls; }; break; case EventSortingKey::kTotal: sorted_domain = "total time"; - sorted_func = [](const EventItem& a, const EventItem& b) { + sorted_func = [](const EventItem &a, const EventItem &b) { return a.total_time > b.total_time; }; break; case EventSortingKey::kMin: sorted_domain = "minimum time"; - sorted_func = [](const EventItem& a, const EventItem& b) { + sorted_func = [](const EventItem &a, const EventItem &b) { return a.min_time > b.min_time; }; break; case EventSortingKey::kMax: sorted_domain = "maximum time"; - sorted_func = [](const EventItem& a, const EventItem& b) { + sorted_func = [](const EventItem &a, const EventItem &b) { return a.max_time > b.max_time; }; break; case EventSortingKey::kAve: sorted_domain = "average time"; - sorted_func = [](const EventItem& a, const EventItem& b) { + sorted_func = [](const EventItem &a, const EventItem &b) { return a.ave_time > b.ave_time; }; break; case EventSortingKey::kGPUTime: sorted_domain = "average time"; - sorted_func = [](const EventItem& a, const EventItem& b) { + sorted_func = [](const EventItem &a, const EventItem &b) { return a.gpu_time > b.gpu_time; }; break; case EventSortingKey::kCPUTime: sorted_domain = "average time"; - sorted_func = [](const EventItem& a, const EventItem& b) { + sorted_func = [](const EventItem &a, const EventItem &b) { return a.cpu_time > b.cpu_time; }; break; @@ -381,7 +437,7 @@ void ParseEvents(const std::vector>& events, sorted_domain = "event first end time"; } - const std::vector>* analyze_events; + const std::vector> *analyze_events; std::vector> merged_events_list; if (merge_thread) { std::vector merged_events; @@ -469,7 +525,7 @@ void ParseEvents(const std::vector>& events, } } // average time - for (auto& item : event_items) { + for (auto &item : event_items) { item.ave_time = item.total_time / item.calls; item.ratio = item.total_time / total; } @@ -493,15 +549,77 @@ void ParseEvents(const std::vector>& events, merge_thread); } +struct MemoryProfierReport { + size_t alloc_times{0}; + size_t alloc_size{0}; + size_t free_times{0}; + size_t free_size{0}; +}; + +// Print results +void PrintMemProfiler( + const std::map> + &annotation_report, + const size_t name_width, const size_t data_width) { + // Output header information + std::cout << "\n------------------------->" + << " Memory Profiling Report " + << "<-------------------------\n\n"; + + // Output events table + std::cout.setf(std::ios::left); + std::cout << std::setw(name_width) << "Event" << std::setw(data_width) + << "Alloc Calls" << std::setw(data_width) << "Size(MB)" + << std::setw(data_width) << "Free Calls" << std::setw(data_width) + << "Size(MB)" << std::endl; + + for (auto &tmp : annotation_report) { + for (auto &e : tmp.second) { + auto event_name = string::Sprintf("%s:%s", tmp.first, e.first); + std::cout << std::setw(name_width) << event_name; + std::cout << std::setw(data_width) << e.second.alloc_times; + std::cout << std::setw(data_width) + << e.second.alloc_size / (1024.0 * 1024.0); + std::cout << std::setw(data_width) << e.second.free_times; + std::cout << std::setw(data_width) + << e.second.free_size / (1024.0 * 1024.0) << std::endl; + } + } + std::cout << std::endl; +} + +// parse memory events +void ParseMemEvents(const std::vector> &events) { + if (g_state == ProfilerState::kDisabled) return; + // place, annotation, alloc times, alloc size + std::map> + annotation_report; + + for (auto &tmp : events) { + for (auto &e : tmp) { + if (e.type() == EventType::kPushRange) { + annotation_report[e.place()][e.annotation()].alloc_times += 1; + annotation_report[e.place()][e.annotation()].alloc_size += e.bytes(); + } else if (e.type() == EventType::kPopRange) { + annotation_report[e.place()][e.annotation()].free_times += 1; + annotation_report[e.place()][e.annotation()].free_size += e.bytes(); + } + } + } + PrintMemProfiler(annotation_report, 55, 18); +} + void DisableProfiler(EventSortingKey sorted_key, - const std::string& profile_path) { + const std::string &profile_path) { SynchronizeAllDevice(); + MemEvenRecorder::Instance().Flush(); + std::lock_guard l(profiler_mu); if (g_state == ProfilerState::kDisabled) return; // Mark the profiling stop. Mark("_stop_profiler_"); - DeviceTracer* tracer = GetDeviceTracer(); + DeviceTracer *tracer = GetDeviceTracer(); if (tracer->IsEnabled()) { tracer->Disable(); tracer->GenProfile(profile_path); @@ -511,6 +629,11 @@ void DisableProfiler(EventSortingKey sorted_key, std::vector> all_events = GetAllEvents(); ParseEvents(all_events, true, sorted_key); ParseEvents(all_events, false, sorted_key); + if (VLOG_IS_ON(5)) { + std::vector> all_mem_events = GetMemEvents(); + ParseMemEvents(all_mem_events); + } + ResetProfiler(); g_state = ProfilerState::kDisabled; should_send_profile_state = true; diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index aec0ae34292..8d11855b70d 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -15,10 +15,17 @@ limitations under the License. */ #pragma once #include #include +#include +#include +#include // NOLINT #include +#include +#include +#include #include #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/event.h" +#include "paddle/fluid/platform/place.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/gpu_info.h" #endif @@ -34,8 +41,41 @@ enum ProfilerState { void Mark(const std::string& name); -Event* PushEvent(const std::string& name); +void PushMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes, + const Place& place); +void PopMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes, + const Place& place); + +struct MemEvenRecorder { + public: + void PushMemRecord(const void* ptr, const Place& place, size_t size); + void PopMemRecord(const void* ptr, const Place& place); + void Flush(); + static MemEvenRecorder& Instance() { return recorder; } + private: + struct RecordMemEvent { + RecordMemEvent(const Place& place, size_t bytes); + ~RecordMemEvent(); + + Place place_; + size_t bytes_; + uint64_t start_ns_; + uint64_t end_ns_; + std::string alloc_in_; + std::string free_in_; + }; + + static MemEvenRecorder recorder; + std::map>> + address_memevent_; + std::mutex mtx_; + MemEvenRecorder() {} + DISABLE_COPY_AND_ASSIGN(MemEvenRecorder); +}; + +Event* PushEvent(const std::string& name); void PopEvent(const std::string& name); struct RecordEvent { @@ -87,6 +127,41 @@ enum EventSortingKey { kGPUTime }; +template +struct EventList { + constexpr static size_t kMB = 1024 * 1024; + constexpr static size_t kEventBlockSize = 16 * kMB; + constexpr static size_t kEventSize = sizeof(T); + constexpr static size_t kEventAlign = alignof(T); + constexpr static size_t kNumBlock = + kEventBlockSize / + ((kEventSize + kEventAlign - 1) / kEventAlign * kEventAlign); + + template + T* Record(Args&&... args) { + if (event_blocks.empty() || event_blocks.front().size() == kNumBlock) { + event_blocks.emplace_front(); + event_blocks.front().reserve(kNumBlock); + } + event_blocks.front().emplace_back(std::forward(args)...); + return &event_blocks.front().back(); + } + + std::vector Reduce() { + std::vector result; + for (auto& block : event_blocks) { + result.insert(result.begin(), std::make_move_iterator(block.begin()), + std::make_move_iterator(block.end())); + } + event_blocks.clear(); + return result; + } + + void Clear() { event_blocks.clear(); } + + std::forward_list> event_blocks; +}; + // Enable the profiling function. void EnableProfiler(ProfilerState state); diff --git a/paddle/fluid/platform/profiler.proto b/paddle/fluid/platform/profiler.proto index e761d7b266e..cfa3c6906f8 100644 --- a/paddle/fluid/platform/profiler.proto +++ b/paddle/fluid/platform/profiler.proto @@ -34,8 +34,25 @@ message Event { optional string detail_info = 9; } +message MemEvent { + enum Place { + CUDAPlace = 0; + CPUPlace = 1; + CUDAPinnedPlace = 2; + } + optional uint64 start_ns = 1; + optional uint64 end_ns = 2; + optional uint64 bytes = 3; + optional Place place = 4; + optional uint64 thread_id = 5; + optional uint32 device_id = 6; + optional string alloc_in = 7; + optional string free_in = 8; +} + message Profile { repeated Event events = 1; optional uint64 start_ns = 2; optional uint64 end_ns = 3; + repeated MemEvent mem_events = 4; } \ No newline at end of file diff --git a/tools/timeline.py b/tools/timeline.py index 78796664177..694ab1d50fd 100644 --- a/tools/timeline.py +++ b/tools/timeline.py @@ -95,6 +95,22 @@ class _ChromeTraceFormatter(object): event['args'] = args self._events.append(event) + def emit_counter(self, category, name, pid, timestamp, counter, value): + """Emits a record for a single counter. + + Args: + category: The event category as string + name: The event name as string + pid: Identifier of the process generating this event as integer + timestamp: The timestamps of this event as long integer + counter: Name of the counter as string + value: Value of the counter as integer + tid: Thread id of the allocation as integer + """ + event = self._create_event('C', category, name, pid, 0, timestamp) + event['args'] = {counter: value} + self._events.append(event) + def format_to_string(self, pretty=False): """Formats the chrome trace to a string. @@ -117,6 +133,7 @@ class Timeline(object): self._profile_dict = profile_dict self._pid = 0 self._devices = dict() + self._mem_devices = dict() self._chrome_trace = _ChromeTraceFormatter() def _allocate_pid(self): @@ -143,6 +160,45 @@ class Timeline(object): self._devices[(k, event.device_id, "GPUKernel")] = pid self._chrome_trace.emit_pid("%s:gpu:%d" % (k, event.device_id), pid) + for mevent in profile_pb.mem_events: + if mevent.place == profiler_pb2.MemEvent.CUDAPlace: + if (k, mevent.device_id, "GPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, mevent.device_id, "GPU")] = pid + self._chrome_trace.emit_pid( + "memory usage on %s:gpu:%d" % (k, mevent.device_id), + pid) + elif mevent.place == profiler_pb2.MemEvent.CPUPlace: + if (k, mevent.device_id, "CPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, mevent.device_id, "CPU")] = pid + self._chrome_trace.emit_pid( + "memory usage on %s:cpu:%d" % (k, mevent.device_id), + pid) + elif mevent.place == profiler_pb2.MemEvent.CUDAPinnedPlace: + if (k, mevent.device_id, "CUDAPinnedPlace" + ) not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, mevent.device_id, + "CUDAPinnedPlace")] = pid + self._chrome_trace.emit_pid( + "memory usage on %s:cudapinnedplace:%d" % + (k, mevent.device_id), pid) + if (k, 0, "CPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, 0, "CPU")] = pid + self._chrome_trace.emit_pid("memory usage on %s:cpu:%d" % + (k, 0), pid) + if (k, 0, "GPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, 0, "GPU")] = pid + self._chrome_trace.emit_pid("memory usage on %s:gpu:%d" % + (k, 0), pid) + if (k, 0, "CUDAPinnedPlace") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, 0, "CUDAPinnedPlace")] = pid + self._chrome_trace.emit_pid( + "memory usage on %s:cudapinnedplace:%d" % (k, 0), pid) def _allocate_events(self): for k, profile_pb in six.iteritems(self._profile_dict): @@ -163,9 +219,57 @@ class Timeline(object): event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid, event.sub_device_id, 'Op', event.name, args) + def _allocate_memory_event(self): + place_to_str = { + profiler_pb2.MemEvent.CPUPlace: "CPU", + profiler_pb2.MemEvent.CUDAPlace: "GPU", + profiler_pb2.MemEvent.CUDAPinnedPlace: "CUDAPinnedPlace" + } + for k, profile_pb in six.iteritems(self._profile_dict): + mem_list = [] + end_profiler = 0 + for mevent in profile_pb.mem_events: + crt_info = dict() + crt_info['time'] = mevent.start_ns + crt_info['size'] = mevent.bytes + if mevent.place in place_to_str: + place = place_to_str[mevent.place] + else: + place = "UnDefine" + crt_info['place'] = place + pid = self._mem_devices[(k, mevent.device_id, place)] + crt_info['pid'] = pid + crt_info['thread_id'] = mevent.thread_id + crt_info['device_id'] = mevent.device_id + mem_list.append(crt_info) + crt_info = dict() + crt_info['place'] = place + crt_info['pid'] = pid + crt_info['thread_id'] = mevent.thread_id + crt_info['device_id'] = mevent.device_id + crt_info['time'] = mevent.end_ns + crt_info['size'] = -mevent.bytes + mem_list.append(crt_info) + end_profiler = max(end_profiler, crt_info['time']) + mem_list.sort(key=lambda tmp: (tmp.get('time', 0))) + i = 0 + total_size = 0 + while i < len(mem_list): + total_size += mem_list[i]['size'] + while i < len(mem_list) - 1 and mem_list[i]['time'] == mem_list[ + i + 1]['time']: + total_size += mem_list[i + 1]['size'] + i += 1 + + self._chrome_trace.emit_counter( + "Memory", "Memory", mem_list[i]['pid'], mem_list[i]['time'], + 0, total_size) + i += 1 + def generate_chrome_trace(self): self._allocate_pids() self._allocate_events() + self._allocate_memory_event() return self._chrome_trace.format_to_string() -- GitLab