diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index b83007ac3bc6ed8713ca65fddabccfd292a2732f..8d9260811a8c9274dcaade9b090bab727d1952ca 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -74,7 +74,8 @@ cc_library(backward SRCS backward.cc DEPS net_op) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context fill_constant_op) cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) -cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog lod_rank_table) +cc_library(executor SRCS executor.cc DEPS op_registry device_context scope +framework_proto backward glog lod_rank_table profiler) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index bd58c0a7f8161f6b45f2b500f3685e4028d97e96..c28ffefdd0872238299cdbb0653ee17cdad61699 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/op_registry.h" #include "paddle/platform/place.h" +#include "paddle/platform/profiler.h" DECLARE_bool(do_memory_benchmark); DEFINE_bool(check_nan_inf, false, @@ -117,6 +118,10 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, for (auto& op_desc : block.AllOps()) { auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); VLOG(4) << op->DebugStringEx(local_scope); + + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::RecordEvent record_event(op->Type(), pool.Get(place_)); + op->Run(*local_scope, place_); VLOG(3) << op->DebugStringEx(local_scope); if (FLAGS_do_memory_benchmark) { diff --git a/paddle/platform/profiler.cc b/paddle/platform/profiler.cc index 7e2e2d968ef877f6aa8b87ab8f044e89574dffa9..2a8afc940393baaaa939471f50f2d5c63edd6a84 100644 --- a/paddle/platform/profiler.cc +++ b/paddle/platform/profiler.cc @@ -47,16 +47,16 @@ inline uint64_t GetTimeInNsec() { } Event::Event(EventKind kind, std::string name, uint32_t thread_id, - DeviceContext* dev_ctx) + const DeviceContext* dev_ctx) : kind_(kind), name_(name), thread_id_(thread_id), has_cuda_(false) { #ifdef PADDLE_WITH_CUDA - auto* cuda_dev_ctx = static_cast(dev_ctx); - if (cuda_dev_ctx) { + has_cuda_ = dev_ctx ? platform::is_gpu_place(dev_ctx->GetPlace()) : false; + if (has_cuda_) { + auto* cuda_dev_ctx = static_cast(dev_ctx); PADDLE_ENFORCE(cudaGetDevice(&device_)); PADDLE_ENFORCE(cudaEventCreate(&event_)); auto stream = cuda_dev_ctx->stream(); PADDLE_ENFORCE(cudaEventRecord(event_, stream)); - has_cuda_ = true; } #endif cpu_ns_ = GetTimeInNsec(); @@ -114,19 +114,20 @@ inline EventList& GetEventList() { return *g_event_list; } -void Mark(const std::string& name, DeviceContext* dev_ctx) { +void Mark(const std::string& name, const DeviceContext* dev_ctx) { GetEventList().Record(EventKind::kMark, name, g_thread_id, dev_ctx); } -void PushEvent(const std::string& name, DeviceContext* dev_ctx) { +void PushEvent(const std::string& name, const DeviceContext* dev_ctx) { GetEventList().Record(EventKind::kPushRange, name, g_thread_id, dev_ctx); } -void PopEvent(const std::string& name, DeviceContext* dev_ctx) { +void PopEvent(const std::string& name, const DeviceContext* dev_ctx) { GetEventList().Record(EventKind::kPopRange, name, g_thread_id, dev_ctx); } -RecordEvent::RecordEvent(const std::string& name, DeviceContext* dev_ctx) { +RecordEvent::RecordEvent(const std::string& name, + const DeviceContext* dev_ctx) { if (g_state == ProfilerState::kDisabled) return; dev_ctx_ = dev_ctx; name_ = name; @@ -155,6 +156,7 @@ void EnableProfiler(ProfilerState state) { DeviceContext* dev_ctx = new CUDADeviceContext(CUDAPlace(d)); Mark("_cuda_startup_", dev_ctx); dev_ctx->Wait(); + delete dev_ctx; }); } } @@ -163,14 +165,17 @@ void EnableProfiler(ProfilerState state) { Mark("_start_profiler_", nullptr); } -std::vector> DisableProfiler() { - PADDLE_ENFORCE(g_state != ProfilerState::kDisabled, - "Can't disable profiling, since it's not starting."); - // Mark the profiling stop. - Mark("_stop_profiler_", nullptr); - g_state = ProfilerState::kDisabled; - std::vector> result; +void ResetProfiler() { std::lock_guard guard(g_all_event_lists_mutex); + for (auto it = g_all_event_lists.begin(); it != g_all_event_lists.end(); + ++it) { + (*it)->Clear(); + } +} + +std::vector> GetAllEvents() { + std::lock_guard guard(g_all_event_lists_mutex); + std::vector> result; for (auto it = g_all_event_lists.begin(); it != g_all_event_lists.end(); ++it) { result.emplace_back((*it)->Reduce()); @@ -178,6 +183,18 @@ std::vector> DisableProfiler() { return result; } +void DisableProfiler(EventSortingKey sorted_key) { + PADDLE_ENFORCE(g_state != ProfilerState::kDisabled, + "Can't disable profiling, since it's not starting."); + // Mark the profiling stop. + Mark("_stop_profiler_", nullptr); + g_state = ProfilerState::kDisabled; + + std::vector> all_events = GetAllEvents(); + ParseEvents(all_events, sorted_key); + ResetProfiler(); +} + void ParseEvents(std::vector>& events, EventSortingKey sorted_by) { if (g_profiler_place == "") return; @@ -291,12 +308,12 @@ void ParseEvents(std::vector>& events, } // Print report - PrintProfilingReport(events_table, sorted_domain, max_name_width + 4, 12); + PrintProfiler(events_table, sorted_domain, max_name_width + 4, 12); } -void PrintProfilingReport(std::vector>& events_table, - std::string& sorted_domain, const size_t name_width, - const size_t data_width) { +void PrintProfiler(std::vector>& events_table, + std::string& sorted_domain, const size_t name_width, + const size_t data_width) { // Output header information std::cout << "\n------------------------->" << " Profiling Report " diff --git a/paddle/platform/profiler.h b/paddle/platform/profiler.h index 6df48ef8806e865f473b4317ac0283863c3c6f64..8de1e6ad296d1e15c1659ccf431f1d5013eb608c 100644 --- a/paddle/platform/profiler.h +++ b/paddle/platform/profiler.h @@ -29,7 +29,7 @@ class Event { // The DeviceContext is used to get the cuda stream. // If CPU profiling mode, can pass nullptr. Event(EventKind kind, std::string name, uint32_t thread_id, - DeviceContext* dev_ctx); + const DeviceContext* dev_ctx); std::string kind() const; std::string name() const { return name_; } @@ -84,6 +84,8 @@ struct EventList { return result; } + void Clear() { event_blocks.clear(); } + std::forward_list> event_blocks; }; @@ -93,29 +95,26 @@ enum ProfilerState { kCUDA, // GPU profiling state }; -void Mark(const std::string& name, DeviceContext* dev_ctx); +void Mark(const std::string& name, const DeviceContext* dev_ctx); -void PushEvent(const std::string& name, DeviceContext* dev_ctx); +void PushEvent(const std::string& name, const DeviceContext* dev_ctx); -void PopEvent(const std::string& name, DeviceContext* dev_ctx); +void PopEvent(const std::string& name, const DeviceContext* dev_ctx); struct RecordEvent { - explicit RecordEvent(const std::string& name, DeviceContext* dev_ctx); + explicit RecordEvent(const std::string& name, const DeviceContext* dev_ctx); ~RecordEvent(); // The device context is used by Event to get the current cuda stream. - DeviceContext* dev_ctx_; + const DeviceContext* dev_ctx_; // Event name std::string name_; }; -// Enable the profiling function. -void EnableProfiler(ProfilerState state); - // Return the event list of all threads. Asummed the returned value calls // event_lists, event_lists[i][j] represents the j-th Event of i-th thread. -std::vector> DisableProfiler(); +std::vector> GetAllEvents(); // The information of each event given in the profiling report struct EventItem { @@ -130,13 +129,22 @@ struct EventItem { // Candidate keys to sort the profiling report enum EventSortingKey { kDefault, kCalls, kTotal, kMin, kMax, kAve }; +// Enable the profiling function. +void EnableProfiler(ProfilerState state); + +// Clear the g_all_event_lists, which is total event lists of all threads. +void ResetProfiler(); + +void DisableProfiler(EventSortingKey sorted_key); + // Parse the event list and output the profiling report void ParseEvents(std::vector>&, EventSortingKey sorted_by = EventSortingKey::kDefault); // Print results -void PrintProfilingReport(std::vector>& events_table, - std::string& sorted_domain, const size_t name_width, - const size_t data_width); +void PrintProfiler(std::vector>& events_table, + std::string& sorted_domain, const size_t name_width, + const size_t data_width); + } // namespace platform } // namespace paddle diff --git a/paddle/platform/profiler_test.cc b/paddle/platform/profiler_test.cc index 13dea713c71e147ed5dd8d090e92d86c96256c09..81f10c91342f76910cc780b0ebd0c0df04e9d7bf 100644 --- a/paddle/platform/profiler_test.cc +++ b/paddle/platform/profiler_test.cc @@ -103,18 +103,14 @@ TEST(RecordEvent, RecordEvent) { // Bad Usage: PushEvent("event_without_pop", dev_ctx); PopEvent("event_without_push", dev_ctx); - std::vector> events = paddle::platform::DisableProfiler(); - // Will remove parsing-related code from test later - ParseEvents(events, EventSortingKey::kTotal); + std::vector> events = paddle::platform::GetAllEvents(); int cuda_startup_count = 0; int start_profiler_count = 0; - int stop_profiler_count = 0; for (size_t i = 0; i < events.size(); ++i) { for (size_t j = 0; j < events[i].size(); ++j) { if (events[i][j].name() == "_cuda_startup_") ++cuda_startup_count; if (events[i][j].name() == "_start_profiler_") ++start_profiler_count; - if (events[i][j].name() == "_stop_profiler_") ++stop_profiler_count; if (events[i][j].name() == "push") { EXPECT_EQ(events[i][j + 1].name(), "pop"); #ifdef PADDLE_WITH_CUDA @@ -127,5 +123,7 @@ TEST(RecordEvent, RecordEvent) { } EXPECT_EQ(cuda_startup_count % 5, 0); EXPECT_EQ(start_profiler_count, 1); - EXPECT_EQ(stop_profiler_count, 1); + + // Will remove parsing-related code from test later + DisableProfiler(EventSortingKey::kTotal); } diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 7b374307071d2da91a677361b404448f1a3816b0..e78673e0baa03496faab13d069b3bd456660bad6 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,7 +1,7 @@ if(WITH_PYTHON) cc_library(paddle_pybind SHARED SRCS pybind.cc exception.cc protobuf.cc const_value.cc - DEPS pybind python backward proto_desc paddle_memory executor prune init + DEPS pybind python backward proto_desc paddle_memory executor prune init profiler ${GLOB_OP_LIB}) if(NOT APPLE AND NOT ANDROID) target_link_libraries(paddle_pybind rt) diff --git a/paddle/pybind/protobuf.h b/paddle/pybind/protobuf.h index 089183accc08c3c486a7ae78ccfe060853ec54f5..9e747e9ea60fd95c74937daa283bc7a9eb9368c0 100644 --- a/paddle/pybind/protobuf.h +++ b/paddle/pybind/protobuf.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include "paddle/platform/variant.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index c5d70bc9f91bc92b28a546cc79b08a9fda150050..82f5b1922c6e97ee73a187e838350a965f1fd269 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -30,6 +30,7 @@ limitations under the License. */ #include "paddle/operators/net_op.h" #include "paddle/platform/enforce.h" #include "paddle/platform/place.h" +#include "paddle/platform/profiler.h" #include "paddle/pybind/const_value.h" #include "paddle/pybind/exception.h" #include "paddle/pybind/pybind.h" @@ -476,6 +477,24 @@ All parameter, weight, gradient are variables in Paddle. m.def("nvprof_stop", platform::CudaProfilerStop); #endif + py::enum_(m, "ProfilerState", py::arithmetic()) + .value("kDisabled", platform::ProfilerState::kDisabled) + .value("kCPU", platform::ProfilerState::kCPU) + .value("kCUDA", platform::ProfilerState::kCUDA) + .export_values(); + + py::enum_(m, "EventSortingKey", py::arithmetic()) + .value("kDefault", platform::EventSortingKey::kDefault) + .value("kCalls", platform::EventSortingKey::kCalls) + .value("kTotal", platform::EventSortingKey::kTotal) + .value("kMin", platform::EventSortingKey::kMin) + .value("kMax", platform::EventSortingKey::kMax) + .value("kAve", platform::EventSortingKey::kAve) + .export_values(); + + m.def("enable_profiler", platform::EnableProfiler); + m.def("disable_profiler", platform::DisableProfiler); + m.def("reset_profiler", platform::ResetProfiler); return m.ptr(); } } // namespace pybind diff --git a/python/paddle/v2/fluid/profiler.py b/python/paddle/v2/fluid/profiler.py index 29e0d54a3ac9622e5505c8e5de38616d9c636e67..51c1c8aa705513825b46fb936c6c99090c50fb7d 100644 --- a/python/paddle/v2/fluid/profiler.py +++ b/python/paddle/v2/fluid/profiler.py @@ -63,3 +63,58 @@ def cuda_profiler(output_file, output_mode=None, config=None): # Disables profiler collection. core.nvprof_stop() os.remove(config_file) + + +def reset_profiler(): + """The profiler clear interface. + reset_profiler will clear the previous time record. + """ + core.reset_profiler() + + +@contextmanager +def profiler(state, sorted_key=None): + """The profiler interface. + Different from cuda_profiler, this profiler can be used to profile both CPU + and GPU program. By defalut, it records the CPU and GPU operator kernels, + if you want to profile other program, you can refer the profiling tutorial + to add more records. + + Args: + state (string) : The profiling state, which should be 'CPU' or 'GPU', + telling the profiler to use CPU timer or GPU timer for profiling. + Although users may have already specified the execution place + (CPUPlace/CUDAPlace) in the begining, for flexibility the profiler + would not inherit this place. + sorted_key (string) : If None, the profiling results will be printed + in the order of first end time of events. Otherwise, the profiling + results will be sorted by the this flag. This flag should be one + of 'calls', 'total', 'max', 'min' or 'ave'. + The `calls` means sorting by the number of calls. + The `total` means sorting by the total execution time. + The `max` means sorting by the maximum execution time. + The `min` means sorting by the minimum execution time. + The `ave` means sorting by the average execution time. + """ + + if state not in ['CPU', 'GPU']: + raise ValueError("The state must be 'CPU' or 'GPU'.") + prof_state = core.ProfilerState.kCUDA if state == "GPU" else core.ProfilerState.kCPU + core.enable_profiler(prof_state) + yield + + if sorted_key not in ['calls', 'total', 'max', 'min', 'ave']: + raise ValueError("The state must be in 'calls', 'total', " + "'max', 'min', 'ave'") + sorted_key = 'default' if sorted_key is None else sorted_key + key_map = { + 'default': core.EventSortingKey.kDefault, + 'calls': core.EventSortingKey.kCalls, + 'total': core.EventSortingKey.kTotal, + 'max': core.EventSortingKey.kMax, + 'min': core.EventSortingKey.kMin, + 'ave': core.EventSortingKey.kAve, + } + # TODO(qingqing) : redirect C++ ostream to Python stream. + # with core.ostream_redirect(stdout=True, stderr=True): + core.disable_profiler(key_map[sorted_key]) diff --git a/python/paddle/v2/fluid/tests/test_profiler.py b/python/paddle/v2/fluid/tests/test_profiler.py index abf8881b6786416f56f93e498761a4791b35d7c3..34700df37d22cf71bad2d86efa4718a3767c2d4f 100644 --- a/python/paddle/v2/fluid/tests/test_profiler.py +++ b/python/paddle/v2/fluid/tests/test_profiler.py @@ -13,11 +13,12 @@ # limitations under the License. import unittest +import os import numpy as np import paddle.v2.fluid as fluid import paddle.v2.fluid.profiler as profiler import paddle.v2.fluid.layers as layers -import os +import paddle.v2.fluid.core as core class TestProfiler(unittest.TestCase): @@ -40,6 +41,50 @@ class TestProfiler(unittest.TestCase): exe.run(fluid.default_main_program(), feed={'data': input}) os.remove(output_file) + def net_profiler(self, state): + if state == 'GPU' and not core.is_compile_gpu(): + return + startup_program = fluid.Program() + main_program = fluid.Program() + + with fluid.program_guard(main_program, startup_program): + image = fluid.layers.data(name='x', shape=[784], dtype='float32') + hidden1 = fluid.layers.fc(input=image, size=128, act='relu') + hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu') + predict = fluid.layers.fc(input=hidden2, size=10, act='softmax') + label = fluid.layers.data(name='y', shape=[1], dtype='int64') + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + accuracy = fluid.evaluator.Accuracy(input=predict, label=label) + + optimizer = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9) + opts = optimizer.minimize(avg_cost, startup_program=startup_program) + + place = fluid.CPUPlace() if state == 'CPU' else fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(startup_program) + + accuracy.reset(exe) + with profiler.profiler(state, 'total') as prof: + for iter in range(10): + if iter == 2: + profiler.reset_profiler() + x = np.random.random((32, 784)).astype("float32") + y = np.random.randint(0, 10, (32, 1)).astype("int64") + + outs = exe.run(main_program, + feed={'x': x, + 'y': y}, + fetch_list=[avg_cost] + accuracy.metrics) + acc = np.array(outs[1]) + pass_acc = accuracy.eval(exe) + + def test_cpu_profiler(self): + self.net_profiler('CPU') + + def test_cuda_profiler(self): + self.net_profiler('GPU') + if __name__ == '__main__': unittest.main()