提交 d2a70243 编写于 作者: D dangqingqing

Refine profiler and expose to Python.

上级 df9c13a8
......@@ -26,7 +26,7 @@ ExternalProject_Add(
extern_pybind
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/pybind/pybind11.git"
GIT_TAG "v2.1.1"
GIT_TAG "v2.2.1"
PREFIX ${PYBIND_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
......
......@@ -68,7 +68,8 @@ cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context fill_constant_op)
cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog lod_rank_table)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table profiler)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/place.h"
#include "paddle/platform/profiler.h"
DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be "
......@@ -116,6 +117,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(3) << op->DebugStringEx(local_scope);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto dev_ctx = const_cast<platform::DeviceContext*>(pool.Get(place_));
platform::RecordEvent record_event(op->Type(), dev_ctx);
op->Run(*local_scope, place_);
if (FLAGS_check_nan_inf) {
for (auto& vname : op->OutputVars(true)) {
......
......@@ -163,14 +163,17 @@ void EnableProfiler(ProfilerState state) {
Mark("_start_profiler_", nullptr);
}
std::vector<std::vector<Event>> DisableProfiler() {
PADDLE_ENFORCE(g_state != ProfilerState::kDisabled,
"Can't disable profiling, since it's not starting.");
// Mark the profiling stop.
Mark("_stop_profiler_", nullptr);
g_state = ProfilerState::kDisabled;
std::vector<std::vector<Event>> result;
void ResetProfiler() {
std::lock_guard<std::mutex> guard(g_all_event_lists_mutex);
for (auto it = g_all_event_lists.begin(); it != g_all_event_lists.end();
++it) {
(*it)->Clear();
}
}
std::vector<std::vector<Event>> GetAllEvents() {
std::lock_guard<std::mutex> guard(g_all_event_lists_mutex);
std::vector<std::vector<Event>> result;
for (auto it = g_all_event_lists.begin(); it != g_all_event_lists.end();
++it) {
result.emplace_back((*it)->Reduce());
......@@ -178,6 +181,18 @@ std::vector<std::vector<Event>> DisableProfiler() {
return result;
}
void DisableProfiler(EventSortingKey sorted_key) {
PADDLE_ENFORCE(g_state != ProfilerState::kDisabled,
"Can't disable profiling, since it's not starting.");
// Mark the profiling stop.
Mark("_stop_profiler_", nullptr);
g_state = ProfilerState::kDisabled;
std::vector<std::vector<Event>> all_events = GetAllEvents();
ParseEvents(all_events, sorted_key);
ResetProfiler();
}
void ParseEvents(std::vector<std::vector<Event>>& events,
EventSortingKey sorted_by) {
if (g_profiler_place == "") return;
......@@ -291,12 +306,12 @@ void ParseEvents(std::vector<std::vector<Event>>& events,
}
// Print report
PrintProfilingReport(events_table, sorted_domain, max_name_width + 4, 12);
PrintProfiler(events_table, sorted_domain, max_name_width + 4, 12);
}
void PrintProfilingReport(std::vector<std::vector<EventItem>>& events_table,
std::string& sorted_domain, const size_t name_width,
const size_t data_width) {
void PrintProfiler(std::vector<std::vector<EventItem>>& events_table,
std::string& sorted_domain, const size_t name_width,
const size_t data_width) {
// Output header information
std::cout << "\n------------------------->"
<< " Profiling Report "
......
......@@ -84,6 +84,8 @@ struct EventList {
return result;
}
void Clear() { event_blocks.clear(); }
std::forward_list<std::vector<Event>> event_blocks;
};
......@@ -110,12 +112,9 @@ struct RecordEvent {
std::string name_;
};
// Enable the profiling function.
void EnableProfiler(ProfilerState state);
// Return the event list of all threads. Asummed the returned value calls
// event_lists, event_lists[i][j] represents the j-th Event of i-th thread.
std::vector<std::vector<Event>> DisableProfiler();
std::vector<std::vector<Event>> GetAllEvents();
// The information of each event given in the profiling report
struct EventItem {
......@@ -130,13 +129,22 @@ struct EventItem {
// Candidate keys to sort the profiling report
enum EventSortingKey { kDefault, kCalls, kTotal, kMin, kMax, kAve };
// Enable the profiling function.
void EnableProfiler(ProfilerState state);
// Clear the g_all_event_lists, which is total event lists of all threads.
void ResetProfiler();
void DisableProfiler(EventSortingKey sorted_key);
// Parse the event list and output the profiling report
void ParseEvents(std::vector<std::vector<Event>>&,
EventSortingKey sorted_by = EventSortingKey::kDefault);
// Print results
void PrintProfilingReport(std::vector<std::vector<EventItem>>& events_table,
std::string& sorted_domain, const size_t name_width,
const size_t data_width);
void PrintProfiler(std::vector<std::vector<EventItem>>& events_table,
std::string& sorted_domain, const size_t name_width,
const size_t data_width);
} // namespace platform
} // namespace paddle
......@@ -103,18 +103,14 @@ TEST(RecordEvent, RecordEvent) {
// Bad Usage:
PushEvent("event_without_pop", dev_ctx);
PopEvent("event_without_push", dev_ctx);
std::vector<std::vector<Event>> events = paddle::platform::DisableProfiler();
// Will remove parsing-related code from test later
ParseEvents(events, EventSortingKey::kTotal);
std::vector<std::vector<Event>> events = paddle::platform::GetAllEvents();
int cuda_startup_count = 0;
int start_profiler_count = 0;
int stop_profiler_count = 0;
for (size_t i = 0; i < events.size(); ++i) {
for (size_t j = 0; j < events[i].size(); ++j) {
if (events[i][j].name() == "_cuda_startup_") ++cuda_startup_count;
if (events[i][j].name() == "_start_profiler_") ++start_profiler_count;
if (events[i][j].name() == "_stop_profiler_") ++stop_profiler_count;
if (events[i][j].name() == "push") {
EXPECT_EQ(events[i][j + 1].name(), "pop");
#ifdef PADDLE_WITH_CUDA
......@@ -127,5 +123,7 @@ TEST(RecordEvent, RecordEvent) {
}
EXPECT_EQ(cuda_startup_count % 5, 0);
EXPECT_EQ(start_profiler_count, 1);
EXPECT_EQ(stop_profiler_count, 1);
// Will remove parsing-related code from test later
DisableProfiler(EventSortingKey::kTotal);
}
if(WITH_PYTHON)
cc_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc const_value.cc
DEPS pybind python backward proto_desc paddle_memory executor prune init
DEPS pybind python backward proto_desc paddle_memory executor prune init profiler
${GLOB_OP_LIB})
if(NOT APPLE AND NOT ANDROID)
target_link_libraries(paddle_pybind rt)
......
......@@ -21,74 +21,24 @@ limitations under the License. */
#include "paddle/framework/program_desc.h"
#include "paddle/framework/var_desc.h"
// Cast boost::variant for PyBind.
// Copy from
// https://github.com/pybind/pybind11/issues/576#issuecomment-269563199
using boost::variant;
namespace pybind11 {
namespace detail {
// Can be replaced by a generic lambda in C++14
struct variant_caster_visitor : public boost::static_visitor<handle> {
return_value_policy policy;
handle parent;
variant_caster_visitor(return_value_policy policy, handle parent)
: policy(policy), parent(parent) {}
template <class T>
handle operator()(T const &src) const {
return make_caster<T>::cast(src, policy, parent);
}
};
template <class Variant>
struct variant_caster;
template <template <class...> class V, class... Ts>
struct variant_caster<V<Ts...>> {
using Type = V<Ts...>;
template <typename T>
typename std::enable_if<
!std::is_same<T, boost::detail::variant::void_>::value, bool>::type
try_load(handle src, bool convert) {
auto caster = make_caster<T>();
if (!load_success_ && caster.load(src, convert)) {
load_success_ = true;
value = cast_op<T>(caster);
return true;
}
return false;
}
template <typename T>
typename std::enable_if<std::is_same<T, boost::detail::variant::void_>::value,
bool>::type
try_load(handle src, bool convert) {
return false;
}
bool load(handle src, bool convert) {
auto unused = {false, try_load<Ts>(src, convert)...};
(void)(unused);
return load_success_;
}
static handle cast(Type const &src, return_value_policy policy,
handle parent) {
variant_caster_visitor visitor(policy, parent);
return boost::apply_visitor(visitor, src);
}
PYBIND11_TYPE_CASTER(Type, _("Variant"));
bool load_success_{false};
};
// Add specialization for concrete variant type
template <class... Args>
struct type_caster<boost::variant<Args...>>
: variant_caster<boost::variant<Args...>> {};
template <>
struct visit_helper<boost::variant> {
template <typename... Args>
static auto call(Args &&... args) -> decltype(boost::apply_visitor(args...)) {
return boost::apply_visitor(args...);
}
};
} // namespace detail
} // namespace pybind11
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <Python.h>
#include <fstream>
#include <vector>
#include "paddle/platform/variant.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pybind/protobuf.h"
#include "pybind11/iostream.h"
#include <mutex> // for call_once
#include <unordered_map>
......@@ -30,6 +31,7 @@ limitations under the License. */
#include "paddle/operators/net_op.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "paddle/platform/profiler.h"
#include "paddle/pybind/const_value.h"
#include "paddle/pybind/exception.h"
#include "paddle/pybind/pybind.h"
......@@ -60,8 +62,8 @@ bool IsCompileGPU() {
#endif
}
PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle");
PYBIND11_MODULE(core, m) {
m.doc() = "C++ core of PaddlePaddle";
// using framework in this function. Since it is inside a function, it will
// not cause namespace pollution.
......@@ -481,7 +483,26 @@ All parameter, weight, gradient are variables in Paddle.
m.def("nvprof_stop", platform::CudaProfilerStop);
#endif
return m.ptr();
py::enum_<platform::ProfilerState>(m, "ProfilerState", py::arithmetic())
.value("kDisabled", platform::ProfilerState::kDisabled)
.value("kCPU", platform::ProfilerState::kCPU)
.value("kCUDA", platform::ProfilerState::kCUDA)
.export_values();
py::enum_<platform::EventSortingKey>(m, "EventSortingKey", py::arithmetic())
.value("kDefault", platform::EventSortingKey::kDefault)
.value("kCalls", platform::EventSortingKey::kCalls)
.value("kTotal", platform::EventSortingKey::kTotal)
.value("kMin", platform::EventSortingKey::kMin)
.value("kMax", platform::EventSortingKey::kMax)
.value("kAve", platform::EventSortingKey::kAve)
.export_values();
m.def("enable_profiler", platform::EnableProfiler);
m.def("disable_profiler", platform::DisableProfiler);
m.def("reset_profiler", platform::ResetProfiler);
py::add_ostream_redirect(m, "ostream_redirect");
}
} // namespace pybind
} // namespace paddle
......@@ -49,3 +49,48 @@ def cuda_profiler(output_file, output_mode=None, config=None):
# Disables profiler collection.
core.nvprof_stop()
os.remove(config_file)
def reset_profiler():
core.reset_profiler()
@contextmanager
def profiler(state, sorted_key=None):
"""The profiler interface.
Different from cuda_profiler, this fuction can be used to profile both CPU
and GPU program.
Args:
state (string) : The profiler state, It should be 'CPU' or 'GPU'.
sorted_key (string) : If None, the profiler results will be printed
without sorting. Otherwise, the profiler results will be sorted
by the this flag. This flag should be one of 'calls', 'total',
'max', 'min' or 'ave'.
The `calls` means sorting by the calling counter.
The `total` means sorting by the total execution time.
The `max` means sorting by the maximum execution time.
The `min` means sorting by the minimum execution time.
The `ave` means sorting by the average execution time.
"""
if state not in ['CPU', 'GPU']:
raise ValueError("The state must be 'CPU' or 'GPU'.")
prof_state = core.ProfilerState.kCUDA if state == "GPU" else core.ProfilerState.kCPU
core.enable_profiler(prof_state)
yield
if sorted_key not in ['calls', 'total', 'max', 'min', 'ave']:
raise ValueError("The state must be in 'calls', 'total', "
"'max', 'min', 'ave'")
sorted_key = 'default' if sorted_key is None else sorted_key
key_map = {
'default': core.EventSortingKey.kDefault,
'calls': core.EventSortingKey.kCalls,
'total': core.EventSortingKey.kTotal,
'max': core.EventSortingKey.kMax,
'min': core.EventSortingKey.kMin,
'ave': core.EventSortingKey.kAve,
}
with core.ostream_redirect(stdout=True, stderr=True):
core.disable_profiler(key_map[sorted_key])
import unittest
import os
import numpy as np
import paddle.v2.fluid as fluid
import paddle.v2.fluid.profiler as profiler
import paddle.v2.fluid.layers as layers
import os
import paddle.v2.fluid.core as core
class TestProfiler(unittest.TestCase):
......@@ -26,6 +27,40 @@ class TestProfiler(unittest.TestCase):
exe.run(fluid.default_main_program(), feed={'data': input})
os.remove(output_file)
def test_profiler(self):
image = fluid.layers.data(name='x', shape=[784], dtype='float32')
hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
predict = fluid.layers.fc(input=hidden2, size=10, act='softmax')
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
optimizer = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9)
opts = optimizer.minimize(avg_cost)
accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
states = ['CPU', 'GPU'] if core.is_compile_gpu() else ['CPU']
for state in states:
place = fluid.CPUPlace() if state == 'CPU' else fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
accuracy.reset(exe)
with profiler.profiler(state, 'total') as prof:
for iter in range(10):
if iter == 2:
profiler.reset_profiler()
x = np.random.random((32, 784)).astype("float32")
y = np.random.randint(0, 10, (32, 1)).astype("int64")
outs = exe.run(fluid.default_main_program(),
feed={'x': x,
'y': y},
fetch_list=[avg_cost] + accuracy.metrics)
acc = np.array(outs[1])
pass_acc = accuracy.eval(exe)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册