提交 9d73950e 编写于 作者: D dangqingqing

Add profiling tools for fluid.

上级 95924686
...@@ -30,3 +30,6 @@ nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_ ...@@ -30,3 +30,6 @@ nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context) nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context)
nv_test(nccl_test SRCS nccl_test.cu DEPS dynload_cuda gpu_info device_context) nv_test(nccl_test SRCS nccl_test.cu DEPS dynload_cuda gpu_info device_context)
cc_library(profiler SRCS profiler.cc DEPS device_context)
cc_test(profiler_test SRCS profiler_test.cc DEPS profiler)
...@@ -103,6 +103,18 @@ class CUDADeviceContext : public DeviceContext { ...@@ -103,6 +103,18 @@ class CUDADeviceContext : public DeviceContext {
cublasHandle_t cublas_handle_; cublasHandle_t cublas_handle_;
}; };
class DeviceGuard {
public:
explicit DeviceGuard(int device) {
original_device_ = platform::GetCurrentDeviceId();
platform::SetDeviceId(device);
}
~DeviceGuard() { platform::SetDeviceId(original_device_); }
private:
int original_device_;
};
#endif #endif
} // namespace platform } // namespace platform
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/profiler.h"
namespace paddle {
namespace platform {
ProfilerState kState = ProfilerState::kDisabled;
uint32_t kNextThreadId = 0;
std::mutex kAllEventListsMutex;
std::list<std::shared_ptr<EventList>> kAllEventLists;
thread_local std::shared_ptr<EventList> kEventList;
thread_local int32_t kThreadId;
void EnableProfiler(ProfilerState state) {
PADDLE_ENFORCE(state != ProfilerState::kDisabled,
"Can't enbale profling, since the input state is ",
"ProfilerState::kDisabled");
PADDLE_ENFORCE(kState == ProfilerState::kDisabled,
"The profiling state should be disabled when calling ",
"EnableProfiler.");
kState = state;
#ifdef PADDLE_WITH_CUDA
auto ForEachDevice = [](std::function<void(int)> op) {
int count = GetCUDADeviceCount();
for (int i = 0; i < count; i++) {
DeviceGuard dev_guard(i);
op(i);
}
};
if (kState == ProfilerState::kCUDA) {
// Generate some dummy evenets first to reduce the startup overhead.
for (int i = 0; i < 5; i++) {
ForEachDevice([](int d) {
DeviceContext* dev_ctx = new CUDADeviceContext(GPUPlace(d));
Mark("_cuda_startup_", dev_ctx);
dev_ctx->Wait();
});
}
}
#endif
// Mark the profiling start.
Mark("_start_profiler_");
}
std::vector<std::vector<Event>> DisableProfiler() {
PADDLE_ENFORCE(kState != ProfilerState::kDisabled,
"Can't disable profiling, since it's not starting.");
// Mark the profiling stop.
Mark("_stop_profiler_");
kState = ProfilerState::kDisabled;
std::vector<std::vector<Event>> result;
std::lock_guard<std::mutex> guard(kAllEventListsMutex);
for (auto it = kAllEventLists.begin(); it != kAllEventLists.end(); ++it) {
auto& list = *it;
result.emplace_back(list->Reduce());
}
return result;
}
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <forward_list>
#include <list>
#include <mutex>
#include <vector>
#include "paddle/platform/device_context.h"
namespace paddle {
namespace platform {
enum EventKind { kMark, kPushRange, kPopRange };
inline uint64_t GetTimeInNsec() {
// using std::chrono;
using clock = std::conditional<std::chrono::high_resolution_clock::is_steady,
std::chrono::high_resolution_clock,
std::chrono::steady_clock>::type;
return std::chrono::duration_cast<std::chrono::nanoseconds>(
clock::now().time_since_epoch())
.count();
}
class Event {
public:
// the DeviceContext is used to get the cuda stream.
Event(EventKind kind, std::string name, uint32_t thread_id,
const platform::DeviceContext* dev_ctx = nullptr)
: kind_(kind), name_(std::move(name)), thread_id_(thread_id) {
has_cuda_ = false;
#ifdef PADDLE_WITH_CUDA
auto* cuda_dev_ctx =
static_cast<const platform::CUDADeviceContext*>(dev_ctx);
if (cuda_dev_ctx) {
PADDLE_ENFORCE(cudaGetDevice(&device_));
PADDLE_ENFORCE(cudaEventCreate(&event_));
auto stream = cuda_dev_ctx->stream();
PADDLE_ENFORCE(cudaEventRecord(event_, stream));
has_cuda_ = true;
}
#endif
cpu_ns_ = GetTimeInNsec();
}
std::string kind() const {
switch (kind_) {
case EventKind::kMark:
return "mark";
case EventKind::kPushRange:
return "push";
case EventKind::kPopRange:
return "pop";
}
PADDLE_THROW("Unknown EventKind.");
}
std::string name() const { return name_; }
bool has_cuda() const { return has_cuda_; }
#ifdef PADDLE_WITH_CUDA
cudaEvent_t event() const { return event_; }
int device() const { return device_; }
#endif
double CpuElapsedUs(const Event& e) const {
return (e.cpu_ns_ - cpu_ns_) / (1000.0);
}
double CudaElapsedUs(const Event& e) const {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(e.has_cuda() && has_cuda());
PADDLE_ENFORCE(e.device() == device());
PADDLE_ENFORCE(cudaEventSynchronize(event_));
PADDLE_ENFORCE(cudaEventSynchronize(e.event()));
float ms;
PADDLE_ENFORCE(cudaEventElapsedTime(&ms, event_, e.event()));
return ms * 1000.0;
#else
PADDLE_THROW("CUDA is not enabled");
#endif
}
private:
EventKind kind_;
std::string name_;
uint32_t thread_id_;
int64_t cpu_ns_;
bool has_cuda_;
#ifdef PADDLE_WITH_CUDA
cudaEvent_t event_ = nullptr;
int device_ = -1;
#endif
};
struct EventList {
constexpr static std::size_t kMB = 1024 * 1024;
constexpr static std::size_t kEventBlockSize = 16 * kMB;
constexpr static std::size_t kEventSize = sizeof(Event);
constexpr static std::size_t kEventAlign = alignof(Event);
constexpr static std::size_t kNumBlock =
kEventBlockSize /
((kEventSize + kEventAlign - 1) / kEventAlign * kEventAlign);
template <typename... Args>
void Record(Args&&... args) {
if (event_blocks.empty() || event_blocks.front().size() == kNumBlock) {
event_blocks.emplace_front();
event_blocks.front().reserve(kNumBlock);
}
event_blocks.front().emplace_back(std::forward<Args>(args)...);
}
std::vector<Event> Reduce() {
std::vector<Event> result;
for (auto& block : event_blocks) {
result.insert(result.begin(), std::make_move_iterator(block.begin()),
std::make_move_iterator(block.end()));
}
event_blocks.clear();
return result;
}
std::forward_list<std::vector<Event>> event_blocks;
};
enum ProfilerState {
kDisabled,
kCPU,
kCUDA,
};
// The profiler state, the initial value is ProfilerState::kDisabled
extern ProfilerState kState;
// The global mutex
extern std::mutex kAllEventListsMutex;
// The total event lists of all threads
extern std::list<std::shared_ptr<EventList>> kAllEventLists;
// The thread local event list only can be accessed by the specific thread
extern thread_local std::shared_ptr<EventList> kEventList;
// The thread index of each thread
extern thread_local int32_t kThreadId;
// The kNextThreadId is a global counter for threads, by the kThreadId and
// kNextThreadId, we can know how many threads have created EventList.
extern uint32_t kNextThreadId;
inline EventList& GetEventList() {
if (!kEventList) {
std::lock_guard<std::mutex> guard(kAllEventListsMutex);
kEventList = std::make_shared<EventList>();
kThreadId = kNextThreadId++;
kAllEventLists.emplace_front(kEventList);
}
return *kEventList;
}
inline void Mark(const std::string name,
const platform::DeviceContext* dev_ctx = nullptr) {
GetEventList().Record(EventKind::kMark, std::move(name), kThreadId, dev_ctx);
}
struct RecordEvent {
explicit RecordEvent(const std::string name,
platform::DeviceContext* dev_ctx = nullptr) {
if (kState == ProfilerState::kDisabled) return;
dev_ctx_ = dev_ctx;
GetEventList().Record(EventKind::kPushRange, std::move(name), kThreadId,
dev_ctx_);
}
~RecordEvent() {
if (kState == ProfilerState::kDisabled) return;
GetEventList().Record(EventKind::kPopRange, std::string(), kThreadId,
dev_ctx_);
}
platform::DeviceContext* dev_ctx_;
};
void EnableProfiler(ProfilerState state);
std::vector<std::vector<Event>> DisableProfiler();
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/profiler.h"
#include "gtest/gtest.h"
TEST(Event, CpuElapsedTime) {
using paddle::platform::Event;
using paddle::platform::EventKind;
Event start_event(EventKind::kPushRange, "test", 0);
EXPECT_TRUE(start_event.has_cuda() == false);
int counter = 0;
while (counter != 1000) {
counter++;
}
Event stop_event(EventKind::kPopRange, "test", 0);
EXPECT_GT(start_event.CpuElapsedUs(stop_event), 0);
}
#ifdef PADDLE_WITH_CUDA
TEST(Event, CudaElapsedTime) {
using paddle::platform::DeviceContext;
using paddle::platform::CUDADeviceContext;
using paddle::platform::GPUPlace;
using paddle::platform::Event;
using paddle::platform::EventKind;
DeviceContext* dev_ctx = new CUDADeviceContext(GPUPlace(0));
Event start_event(EventKind::kPushRange, "test", 0, dev_ctx);
EXPECT_TRUE(start_event.has_cuda() == true);
int counter = 0;
while (counter != 1000) {
counter++;
}
Event stop_event(EventKind::kPopRange, "test", 0, dev_ctx);
EXPECT_GT(start_event.CudaElapsedUs(stop_event), 0);
}
#endif
TEST(RecordEvent, RecordEvent) {
using paddle::platform::DeviceContext;
using paddle::platform::CUDADeviceContext;
using paddle::platform::GPUPlace;
using paddle::platform::Event;
using paddle::platform::EventKind;
using paddle::platform::RecordEvent;
using paddle::platform::ProfilerState;
ProfilerState state = ProfilerState::kCPU;
DeviceContext* dev_ctx = nullptr;
#ifdef PADDLE_WITH_CUDA
state = ProfilerState::kCUDA;
dev_ctx =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace(0));
#endif
EnableProfiler(state);
for (int i = 1; i < 5; ++i) {
std::string name = "op_" + std::to_string(i);
RecordEvent record_event(name, dev_ctx);
int counter = 1;
while (counter != i * 1000) counter++;
}
std::vector<std::vector<Event>> events = paddle::platform::DisableProfiler();
int cuda_startup_count = 0;
int start_profiler_count = 0;
int stop_profiler_count = 0;
for (size_t i = 0; i < events.size(); ++i) {
for (size_t j = 0; j < events[i].size(); ++j) {
if (events[i][j].name() == "_cuda_startup_") ++cuda_startup_count;
if (events[i][j].name() == "_start_profiler_") ++start_profiler_count;
if (events[i][j].name() == "_stop_profiler_") ++stop_profiler_count;
if (events[i][j].name() == "push") {
EXPECT_EQ(events[i][j + 1].name(), "pop");
#ifdef PADDLE_WITH_CUDA
EXPECT_GT(events[i][j].CudaElapsedUs(events[i][j + 1]), 0);
#else
EXPECT_GT(events[i][j].CpuElapsedUs(events[i][j + 1]), 0);
#endif
}
}
}
EXPECT_EQ(cuda_startup_count % 5, 0);
EXPECT_EQ(start_profiler_count, 1);
EXPECT_EQ(stop_profiler_count, 1);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册