未验证 提交 c0ed75a8 编写于 作者: L liutiexing 提交者: GitHub

Update profiler (#42998)

* Update Profiler

* make HostEventRecorder templated
上级 b4a3dab7
...@@ -192,15 +192,15 @@ void RecordEvent::End() { ...@@ -192,15 +192,15 @@ void RecordEvent::End() {
if (LIKELY(FLAGS_enable_host_event_recorder_hook && is_enabled_)) { if (LIKELY(FLAGS_enable_host_event_recorder_hook && is_enabled_)) {
uint64_t end_ns = PosixInNsec(); uint64_t end_ns = PosixInNsec();
if (LIKELY(shallow_copy_name_ != nullptr)) { if (LIKELY(shallow_copy_name_ != nullptr)) {
HostEventRecorder::GetInstance().RecordEvent( HostEventRecorder<CommonEvent>::GetInstance().RecordEvent(
shallow_copy_name_, start_ns_, end_ns, role_, type_); shallow_copy_name_, start_ns_, end_ns, role_, type_);
} else if (name_ != nullptr) { } else if (name_ != nullptr) {
if (attr_ == nullptr) { if (attr_ == nullptr) {
HostEventRecorder::GetInstance().RecordEvent(*name_, start_ns_, end_ns, HostEventRecorder<CommonEvent>::GetInstance().RecordEvent(
role_, type_); *name_, start_ns_, end_ns, role_, type_);
} else { } else {
HostEventRecorder::GetInstance().RecordEvent(*name_, start_ns_, end_ns, HostEventRecorder<CommonEvent>::GetInstance().RecordEvent(
role_, type_, *attr_); *name_, start_ns_, end_ns, role_, type_, *attr_);
delete attr_; delete attr_;
} }
delete name_; delete name_;
...@@ -232,8 +232,8 @@ RecordInstantEvent::RecordInstantEvent(const char *name, TracerEventType type, ...@@ -232,8 +232,8 @@ RecordInstantEvent::RecordInstantEvent(const char *name, TracerEventType type,
return; return;
} }
auto start_end_ns = PosixInNsec(); auto start_end_ns = PosixInNsec();
HostEventRecorder::GetInstance().RecordEvent(name, start_end_ns, start_end_ns, HostEventRecorder<CommonEvent>::GetInstance().RecordEvent(
EventRole::kOrdinary, type); name, start_end_ns, start_end_ns, EventRole::kOrdinary, type);
} }
void MemEvenRecorder::PushMemRecord(const void *ptr, const Place &place, void MemEvenRecorder::PushMemRecord(const void *ptr, const Place &place,
...@@ -327,7 +327,7 @@ void PopMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes, ...@@ -327,7 +327,7 @@ void PopMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes,
void Mark(const std::string &name) { void Mark(const std::string &name) {
if (FLAGS_enable_host_event_recorder_hook) { if (FLAGS_enable_host_event_recorder_hook) {
HostEventRecorder::GetInstance().RecordEvent( HostEventRecorder<CommonEvent>::GetInstance().RecordEvent(
name, 0, 0, EventRole::kOrdinary, TracerEventType::UserDefined); name, 0, 0, EventRole::kOrdinary, TracerEventType::UserDefined);
return; return;
} }
...@@ -522,7 +522,8 @@ void DisableHostEventRecorder() { ...@@ -522,7 +522,8 @@ void DisableHostEventRecorder() {
std::string PrintHostEvents() { std::string PrintHostEvents() {
std::ostringstream oss; std::ostringstream oss;
auto host_evt_sec = HostEventRecorder::GetInstance().GatherEvents(); auto host_evt_sec =
HostEventRecorder<CommonEvent>::GetInstance().GatherEvents();
for (const auto &thr_evt_sec : host_evt_sec.thr_sections) { for (const auto &thr_evt_sec : host_evt_sec.thr_sections) {
oss << thr_evt_sec.thread_id << std::endl; oss << thr_evt_sec.thread_id << std::endl;
for (const auto &evt : thr_evt_sec.events) { for (const auto &evt : thr_evt_sec.events) {
...@@ -534,7 +535,8 @@ std::string PrintHostEvents() { ...@@ -534,7 +535,8 @@ std::string PrintHostEvents() {
return oss.str(); return oss.str();
} }
static void EmulateEventPushAndPop(const HostEventSection &host_sec, static void EmulateEventPushAndPop(
const HostEventSection<CommonEvent> &host_sec,
std::map<uint64_t, ThreadEvents> *out) { std::map<uint64_t, ThreadEvents> *out) {
for (const auto &thr_sec : host_sec.thr_sections) { for (const auto &thr_sec : host_sec.thr_sections) {
uint64_t tid = thr_sec.thread_id; uint64_t tid = thr_sec.thread_id;
...@@ -582,7 +584,8 @@ static void EmulateEventPushAndPop(const HostEventSection &host_sec, ...@@ -582,7 +584,8 @@ static void EmulateEventPushAndPop(const HostEventSection &host_sec,
} }
} }
static void EmulateCPURecordsAdd(const HostEventSection &host_sec) { static void EmulateCPURecordsAdd(
const HostEventSection<CommonEvent> &host_sec) {
DeviceTracer *tracer = GetDeviceTracer(); DeviceTracer *tracer = GetDeviceTracer();
if (tracer == nullptr) { if (tracer == nullptr) {
return; return;
...@@ -610,7 +613,8 @@ static std::map<uint64_t, ThreadEvents> DockHostEventRecorderHostPart() { ...@@ -610,7 +613,8 @@ static std::map<uint64_t, ThreadEvents> DockHostEventRecorderHostPart() {
if (FLAGS_enable_host_event_recorder_hook == false) { if (FLAGS_enable_host_event_recorder_hook == false) {
return thr_events; return thr_events;
} }
auto host_evt_sec = HostEventRecorder::GetInstance().GatherEvents(); auto host_evt_sec =
HostEventRecorder<CommonEvent>::GetInstance().GatherEvents();
EmulateEventPushAndPop(host_evt_sec, &thr_events); EmulateEventPushAndPop(host_evt_sec, &thr_events);
EmulateCPURecordsAdd(host_evt_sec); EmulateCPURecordsAdd(host_evt_sec);
return thr_events; return thr_events;
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include "paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h" #include "paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/os_info.h" #include "paddle/fluid/platform/os_info.h"
#include "paddle/fluid/platform/profiler/common_event.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -182,12 +181,14 @@ char *EventContainer<EventType>::GetStringStorage(size_t sz) { ...@@ -182,12 +181,14 @@ char *EventContainer<EventType>::GetStringStorage(size_t sz) {
return storage; return storage;
} }
template <typename EventType>
struct ThreadEventSection { struct ThreadEventSection {
std::string thread_name; std::string thread_name;
uint64_t thread_id; uint64_t thread_id;
std::vector<CommonEvent> events; std::vector<EventType> events;
}; };
template <typename EventType>
class ThreadEventRecorder { class ThreadEventRecorder {
public: public:
ThreadEventRecorder() { ThreadEventRecorder() {
...@@ -204,8 +205,8 @@ class ThreadEventRecorder { ...@@ -204,8 +205,8 @@ class ThreadEventRecorder {
base_evt_cntr_.Record(std::forward<Args>(args)...); base_evt_cntr_.Record(std::forward<Args>(args)...);
} }
ThreadEventSection GatherEvents() { ThreadEventSection<EventType> GatherEvents() {
ThreadEventSection thr_sec; ThreadEventSection<EventType> thr_sec;
thr_sec.thread_name = thread_name_; thr_sec.thread_name = thread_name_;
thr_sec.thread_id = thread_id_; thr_sec.thread_id = thread_id_;
thr_sec.events = std::move(base_evt_cntr_.Reduce()); thr_sec.events = std::move(base_evt_cntr_.Reduce());
...@@ -215,15 +216,17 @@ class ThreadEventRecorder { ...@@ -215,15 +216,17 @@ class ThreadEventRecorder {
private: private:
uint64_t thread_id_; uint64_t thread_id_;
std::string thread_name_; std::string thread_name_;
EventContainer<CommonEvent> base_evt_cntr_; EventContainer<EventType> base_evt_cntr_;
}; };
template <typename EventType>
struct HostEventSection { struct HostEventSection {
std::string process_name; std::string process_name;
uint64_t process_id; uint64_t process_id;
std::vector<ThreadEventSection> thr_sections; std::vector<ThreadEventSection<EventType>> thr_sections;
}; };
template <typename EventType>
class HostEventRecorder { class HostEventRecorder {
public: public:
// singleton // singleton
...@@ -244,10 +247,10 @@ class HostEventRecorder { ...@@ -244,10 +247,10 @@ class HostEventRecorder {
// thread-unsafe, make sure make sure there is no running tracing. // thread-unsafe, make sure make sure there is no running tracing.
// Poor performance, call it at the ending // Poor performance, call it at the ending
HostEventSection GatherEvents() { HostEventSection<EventType> GatherEvents() {
auto thr_recorders = auto thr_recorders =
ThreadEventRecorderRegistry::GetInstance().GetAllThreadDataByRef(); ThreadEventRecorderRegistry::GetInstance().GetAllThreadDataByRef();
HostEventSection host_sec; HostEventSection<EventType> host_sec;
host_sec.process_id = GetProcessId(); host_sec.process_id = GetProcessId();
host_sec.thr_sections.reserve(thr_recorders.size()); host_sec.thr_sections.reserve(thr_recorders.size());
for (auto &kv : thr_recorders) { for (auto &kv : thr_recorders) {
...@@ -260,12 +263,12 @@ class HostEventRecorder { ...@@ -260,12 +263,12 @@ class HostEventRecorder {
private: private:
using ThreadEventRecorderRegistry = using ThreadEventRecorderRegistry =
framework::ThreadDataRegistry<ThreadEventRecorder>; framework::ThreadDataRegistry<ThreadEventRecorder<EventType>>;
HostEventRecorder() = default; HostEventRecorder() = default;
DISABLE_COPY_AND_ASSIGN(HostEventRecorder); DISABLE_COPY_AND_ASSIGN(HostEventRecorder);
ThreadEventRecorder *GetThreadLocalRecorder() { ThreadEventRecorder<EventType> *GetThreadLocalRecorder() {
return ThreadEventRecorderRegistry::GetInstance() return ThreadEventRecorderRegistry::GetInstance()
.GetMutableCurrentThreadData(); .GetMutableCurrentThreadData();
} }
......
...@@ -30,7 +30,7 @@ namespace platform { ...@@ -30,7 +30,7 @@ namespace platform {
namespace { namespace {
void ProcessHostEvents(const HostEventSection& host_events, void ProcessHostEvents(const HostEventSection<CommonEvent>& host_events,
TraceEventCollector* collector) { TraceEventCollector* collector) {
for (const auto& thr_sec : host_events.thr_sections) { for (const auto& thr_sec : host_events.thr_sections) {
uint64_t tid = thr_sec.thread_id; uint64_t tid = thr_sec.thread_id;
...@@ -62,7 +62,7 @@ void HostTracer::StartTracing() { ...@@ -62,7 +62,7 @@ void HostTracer::StartTracing() {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
state_ == TracerState::READY || state_ == TracerState::STOPED, true, state_ == TracerState::READY || state_ == TracerState::STOPED, true,
platform::errors::PreconditionNotMet("TracerState must be READY")); platform::errors::PreconditionNotMet("TracerState must be READY"));
HostEventRecorder::GetInstance().GatherEvents(); HostEventRecorder<CommonEvent>::GetInstance().GatherEvents();
HostTraceLevel::GetInstance().SetLevel(options_.trace_level); HostTraceLevel::GetInstance().SetLevel(options_.trace_level);
state_ = TracerState::STARTED; state_ = TracerState::STARTED;
} }
...@@ -79,8 +79,8 @@ void HostTracer::CollectTraceData(TraceEventCollector* collector) { ...@@ -79,8 +79,8 @@ void HostTracer::CollectTraceData(TraceEventCollector* collector) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
state_, TracerState::STOPED, state_, TracerState::STOPED,
platform::errors::PreconditionNotMet("TracerState must be STOPED")); platform::errors::PreconditionNotMet("TracerState must be STOPED"));
HostEventSection host_events = HostEventSection<CommonEvent> host_events =
HostEventRecorder::GetInstance().GatherEvents(); HostEventRecorder<CommonEvent>::GetInstance().GatherEvents();
ProcessHostEvents(host_events, collector); ProcessHostEvents(host_events, collector);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册