From 9d73950ec9ab7fb14c2ca2f8128f0b0944b5ed7e Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 14 Dec 2017 11:02:57 +0800 Subject: [PATCH] Add profiling tools for fluid. --- paddle/platform/CMakeLists.txt | 3 + paddle/platform/device_context.h | 12 ++ paddle/platform/profiler.cc | 74 ++++++++++++ paddle/platform/profiler.h | 197 +++++++++++++++++++++++++++++++ paddle/platform/profiler_test.cc | 98 +++++++++++++++ 5 files changed, 384 insertions(+) create mode 100644 paddle/platform/profiler.cc create mode 100644 paddle/platform/profiler.h create mode 100644 paddle/platform/profiler_test.cc diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 88df28a966..9fb6cd0de5 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -30,3 +30,6 @@ nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_ nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context) nv_test(nccl_test SRCS nccl_test.cu DEPS dynload_cuda gpu_info device_context) + +cc_library(profiler SRCS profiler.cc DEPS device_context) +cc_test(profiler_test SRCS profiler_test.cc DEPS profiler) diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index ef5f19214d..2b10cc5df8 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -103,6 +103,18 @@ class CUDADeviceContext : public DeviceContext { cublasHandle_t cublas_handle_; }; +class DeviceGuard { + public: + explicit DeviceGuard(int device) { + original_device_ = platform::GetCurrentDeviceId(); + platform::SetDeviceId(device); + } + ~DeviceGuard() { platform::SetDeviceId(original_device_); } + + private: + int original_device_; +}; + #endif } // namespace platform diff --git a/paddle/platform/profiler.cc b/paddle/platform/profiler.cc new file mode 100644 index 0000000000..40b34b732c --- /dev/null +++ b/paddle/platform/profiler.cc @@ -0,0 +1,74 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/profiler.h" + +namespace paddle { +namespace platform { + +ProfilerState kState = ProfilerState::kDisabled; +uint32_t kNextThreadId = 0; +std::mutex kAllEventListsMutex; +std::list> kAllEventLists; +thread_local std::shared_ptr kEventList; +thread_local int32_t kThreadId; + +void EnableProfiler(ProfilerState state) { + PADDLE_ENFORCE(state != ProfilerState::kDisabled, + "Can't enbale profling, since the input state is ", + "ProfilerState::kDisabled"); + PADDLE_ENFORCE(kState == ProfilerState::kDisabled, + "The profiling state should be disabled when calling ", + "EnableProfiler."); + kState = state; +#ifdef PADDLE_WITH_CUDA + auto ForEachDevice = [](std::function op) { + int count = GetCUDADeviceCount(); + for (int i = 0; i < count; i++) { + DeviceGuard dev_guard(i); + op(i); + } + }; + if (kState == ProfilerState::kCUDA) { + // Generate some dummy evenets first to reduce the startup overhead. + for (int i = 0; i < 5; i++) { + ForEachDevice([](int d) { + DeviceContext* dev_ctx = new CUDADeviceContext(GPUPlace(d)); + Mark("_cuda_startup_", dev_ctx); + dev_ctx->Wait(); + }); + } + } +#endif + // Mark the profiling start. + Mark("_start_profiler_"); +} + +std::vector> DisableProfiler() { + PADDLE_ENFORCE(kState != ProfilerState::kDisabled, + "Can't disable profiling, since it's not starting."); + // Mark the profiling stop. + Mark("_stop_profiler_"); + kState = ProfilerState::kDisabled; + std::vector> result; + std::lock_guard guard(kAllEventListsMutex); + for (auto it = kAllEventLists.begin(); it != kAllEventLists.end(); ++it) { + auto& list = *it; + result.emplace_back(list->Reduce()); + } + return result; +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/profiler.h b/paddle/platform/profiler.h new file mode 100644 index 0000000000..2242635024 --- /dev/null +++ b/paddle/platform/profiler.h @@ -0,0 +1,197 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace platform { + +enum EventKind { kMark, kPushRange, kPopRange }; + +inline uint64_t GetTimeInNsec() { + // using std::chrono; + using clock = std::conditional::type; + return std::chrono::duration_cast( + clock::now().time_since_epoch()) + .count(); +} + +class Event { + public: + // the DeviceContext is used to get the cuda stream. + Event(EventKind kind, std::string name, uint32_t thread_id, + const platform::DeviceContext* dev_ctx = nullptr) + : kind_(kind), name_(std::move(name)), thread_id_(thread_id) { + has_cuda_ = false; +#ifdef PADDLE_WITH_CUDA + auto* cuda_dev_ctx = + static_cast(dev_ctx); + if (cuda_dev_ctx) { + PADDLE_ENFORCE(cudaGetDevice(&device_)); + PADDLE_ENFORCE(cudaEventCreate(&event_)); + auto stream = cuda_dev_ctx->stream(); + PADDLE_ENFORCE(cudaEventRecord(event_, stream)); + has_cuda_ = true; + } +#endif + cpu_ns_ = GetTimeInNsec(); + } + + std::string kind() const { + switch (kind_) { + case EventKind::kMark: + return "mark"; + case EventKind::kPushRange: + return "push"; + case EventKind::kPopRange: + return "pop"; + } + PADDLE_THROW("Unknown EventKind."); + } + + std::string name() const { return name_; } + + bool has_cuda() const { return has_cuda_; } + +#ifdef PADDLE_WITH_CUDA + cudaEvent_t event() const { return event_; } + + int device() const { return device_; } +#endif + + double CpuElapsedUs(const Event& e) const { + return (e.cpu_ns_ - cpu_ns_) / (1000.0); + } + + double CudaElapsedUs(const Event& e) const { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE(e.has_cuda() && has_cuda()); + PADDLE_ENFORCE(e.device() == device()); + PADDLE_ENFORCE(cudaEventSynchronize(event_)); + PADDLE_ENFORCE(cudaEventSynchronize(e.event())); + float ms; + PADDLE_ENFORCE(cudaEventElapsedTime(&ms, event_, e.event())); + return ms * 1000.0; +#else + PADDLE_THROW("CUDA is not enabled"); +#endif + } + + private: + EventKind kind_; + std::string name_; + uint32_t thread_id_; + int64_t cpu_ns_; + bool has_cuda_; +#ifdef PADDLE_WITH_CUDA + cudaEvent_t event_ = nullptr; + int device_ = -1; +#endif +}; + +struct EventList { + constexpr static std::size_t kMB = 1024 * 1024; + constexpr static std::size_t kEventBlockSize = 16 * kMB; + constexpr static std::size_t kEventSize = sizeof(Event); + constexpr static std::size_t kEventAlign = alignof(Event); + constexpr static std::size_t kNumBlock = + kEventBlockSize / + ((kEventSize + kEventAlign - 1) / kEventAlign * kEventAlign); + + template + void Record(Args&&... args) { + if (event_blocks.empty() || event_blocks.front().size() == kNumBlock) { + event_blocks.emplace_front(); + event_blocks.front().reserve(kNumBlock); + } + event_blocks.front().emplace_back(std::forward(args)...); + } + + std::vector Reduce() { + std::vector result; + for (auto& block : event_blocks) { + result.insert(result.begin(), std::make_move_iterator(block.begin()), + std::make_move_iterator(block.end())); + } + event_blocks.clear(); + return result; + } + + std::forward_list> event_blocks; +}; + +enum ProfilerState { + kDisabled, + kCPU, + kCUDA, +}; + +// The profiler state, the initial value is ProfilerState::kDisabled +extern ProfilerState kState; +// The global mutex +extern std::mutex kAllEventListsMutex; +// The total event lists of all threads +extern std::list> kAllEventLists; +// The thread local event list only can be accessed by the specific thread +extern thread_local std::shared_ptr kEventList; +// The thread index of each thread +extern thread_local int32_t kThreadId; +// The kNextThreadId is a global counter for threads, by the kThreadId and +// kNextThreadId, we can know how many threads have created EventList. +extern uint32_t kNextThreadId; + +inline EventList& GetEventList() { + if (!kEventList) { + std::lock_guard guard(kAllEventListsMutex); + kEventList = std::make_shared(); + kThreadId = kNextThreadId++; + kAllEventLists.emplace_front(kEventList); + } + return *kEventList; +} + +inline void Mark(const std::string name, + const platform::DeviceContext* dev_ctx = nullptr) { + GetEventList().Record(EventKind::kMark, std::move(name), kThreadId, dev_ctx); +} + +struct RecordEvent { + explicit RecordEvent(const std::string name, + platform::DeviceContext* dev_ctx = nullptr) { + if (kState == ProfilerState::kDisabled) return; + dev_ctx_ = dev_ctx; + GetEventList().Record(EventKind::kPushRange, std::move(name), kThreadId, + dev_ctx_); + } + + ~RecordEvent() { + if (kState == ProfilerState::kDisabled) return; + GetEventList().Record(EventKind::kPopRange, std::string(), kThreadId, + dev_ctx_); + } + platform::DeviceContext* dev_ctx_; +}; + +void EnableProfiler(ProfilerState state); +std::vector> DisableProfiler(); + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/profiler_test.cc b/paddle/platform/profiler_test.cc new file mode 100644 index 0000000000..ed64ff40c9 --- /dev/null +++ b/paddle/platform/profiler_test.cc @@ -0,0 +1,98 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/profiler.h" +#include "gtest/gtest.h" + +TEST(Event, CpuElapsedTime) { + using paddle::platform::Event; + using paddle::platform::EventKind; + + Event start_event(EventKind::kPushRange, "test", 0); + EXPECT_TRUE(start_event.has_cuda() == false); + int counter = 0; + while (counter != 1000) { + counter++; + } + Event stop_event(EventKind::kPopRange, "test", 0); + EXPECT_GT(start_event.CpuElapsedUs(stop_event), 0); +} + +#ifdef PADDLE_WITH_CUDA +TEST(Event, CudaElapsedTime) { + using paddle::platform::DeviceContext; + using paddle::platform::CUDADeviceContext; + using paddle::platform::GPUPlace; + using paddle::platform::Event; + using paddle::platform::EventKind; + + DeviceContext* dev_ctx = new CUDADeviceContext(GPUPlace(0)); + Event start_event(EventKind::kPushRange, "test", 0, dev_ctx); + EXPECT_TRUE(start_event.has_cuda() == true); + int counter = 0; + while (counter != 1000) { + counter++; + } + Event stop_event(EventKind::kPopRange, "test", 0, dev_ctx); + EXPECT_GT(start_event.CudaElapsedUs(stop_event), 0); +} +#endif + +TEST(RecordEvent, RecordEvent) { + using paddle::platform::DeviceContext; + using paddle::platform::CUDADeviceContext; + using paddle::platform::GPUPlace; + using paddle::platform::Event; + using paddle::platform::EventKind; + using paddle::platform::RecordEvent; + using paddle::platform::ProfilerState; + + ProfilerState state = ProfilerState::kCPU; + DeviceContext* dev_ctx = nullptr; +#ifdef PADDLE_WITH_CUDA + state = ProfilerState::kCUDA; + dev_ctx = + new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace(0)); +#endif + EnableProfiler(state); + + for (int i = 1; i < 5; ++i) { + std::string name = "op_" + std::to_string(i); + RecordEvent record_event(name, dev_ctx); + int counter = 1; + while (counter != i * 1000) counter++; + } + std::vector> events = paddle::platform::DisableProfiler(); + int cuda_startup_count = 0; + int start_profiler_count = 0; + int stop_profiler_count = 0; + for (size_t i = 0; i < events.size(); ++i) { + for (size_t j = 0; j < events[i].size(); ++j) { + if (events[i][j].name() == "_cuda_startup_") ++cuda_startup_count; + if (events[i][j].name() == "_start_profiler_") ++start_profiler_count; + if (events[i][j].name() == "_stop_profiler_") ++stop_profiler_count; + if (events[i][j].name() == "push") { + EXPECT_EQ(events[i][j + 1].name(), "pop"); +#ifdef PADDLE_WITH_CUDA + EXPECT_GT(events[i][j].CudaElapsedUs(events[i][j + 1]), 0); +#else + EXPECT_GT(events[i][j].CpuElapsedUs(events[i][j + 1]), 0); +#endif + } + } + } + EXPECT_EQ(cuda_startup_count % 5, 0); + EXPECT_EQ(start_profiler_count, 1); + EXPECT_EQ(stop_profiler_count, 1); +} -- GitLab