未验证 提交 c0ed75a8 编写于 作者: L liutiexing 提交者: GitHub

Update profiler (#42998)

* Update Profiler

* make HostEventRecorder templated
上级 b4a3dab7
......@@ -192,15 +192,15 @@ void RecordEvent::End() {
if (LIKELY(FLAGS_enable_host_event_recorder_hook && is_enabled_)) {
uint64_t end_ns = PosixInNsec();
if (LIKELY(shallow_copy_name_ != nullptr)) {
HostEventRecorder::GetInstance().RecordEvent(
HostEventRecorder<CommonEvent>::GetInstance().RecordEvent(
shallow_copy_name_, start_ns_, end_ns, role_, type_);
} else if (name_ != nullptr) {
if (attr_ == nullptr) {
HostEventRecorder::GetInstance().RecordEvent(*name_, start_ns_, end_ns,
role_, type_);
HostEventRecorder<CommonEvent>::GetInstance().RecordEvent(
*name_, start_ns_, end_ns, role_, type_);
} else {
HostEventRecorder::GetInstance().RecordEvent(*name_, start_ns_, end_ns,
role_, type_, *attr_);
HostEventRecorder<CommonEvent>::GetInstance().RecordEvent(
*name_, start_ns_, end_ns, role_, type_, *attr_);
delete attr_;
}
delete name_;
......@@ -232,8 +232,8 @@ RecordInstantEvent::RecordInstantEvent(const char *name, TracerEventType type,
return;
}
auto start_end_ns = PosixInNsec();
HostEventRecorder::GetInstance().RecordEvent(name, start_end_ns, start_end_ns,
EventRole::kOrdinary, type);
HostEventRecorder<CommonEvent>::GetInstance().RecordEvent(
name, start_end_ns, start_end_ns, EventRole::kOrdinary, type);
}
void MemEvenRecorder::PushMemRecord(const void *ptr, const Place &place,
......@@ -327,7 +327,7 @@ void PopMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes,
void Mark(const std::string &name) {
if (FLAGS_enable_host_event_recorder_hook) {
HostEventRecorder::GetInstance().RecordEvent(
HostEventRecorder<CommonEvent>::GetInstance().RecordEvent(
name, 0, 0, EventRole::kOrdinary, TracerEventType::UserDefined);
return;
}
......@@ -522,7 +522,8 @@ void DisableHostEventRecorder() {
std::string PrintHostEvents() {
std::ostringstream oss;
auto host_evt_sec = HostEventRecorder::GetInstance().GatherEvents();
auto host_evt_sec =
HostEventRecorder<CommonEvent>::GetInstance().GatherEvents();
for (const auto &thr_evt_sec : host_evt_sec.thr_sections) {
oss << thr_evt_sec.thread_id << std::endl;
for (const auto &evt : thr_evt_sec.events) {
......@@ -534,7 +535,8 @@ std::string PrintHostEvents() {
return oss.str();
}
static void EmulateEventPushAndPop(const HostEventSection &host_sec,
static void EmulateEventPushAndPop(
const HostEventSection<CommonEvent> &host_sec,
std::map<uint64_t, ThreadEvents> *out) {
for (const auto &thr_sec : host_sec.thr_sections) {
uint64_t tid = thr_sec.thread_id;
......@@ -582,7 +584,8 @@ static void EmulateEventPushAndPop(const HostEventSection &host_sec,
}
}
static void EmulateCPURecordsAdd(const HostEventSection &host_sec) {
static void EmulateCPURecordsAdd(
const HostEventSection<CommonEvent> &host_sec) {
DeviceTracer *tracer = GetDeviceTracer();
if (tracer == nullptr) {
return;
......@@ -610,7 +613,8 @@ static std::map<uint64_t, ThreadEvents> DockHostEventRecorderHostPart() {
if (FLAGS_enable_host_event_recorder_hook == false) {
return thr_events;
}
auto host_evt_sec = HostEventRecorder::GetInstance().GatherEvents();
auto host_evt_sec =
HostEventRecorder<CommonEvent>::GetInstance().GatherEvents();
EmulateEventPushAndPop(host_evt_sec, &thr_events);
EmulateCPURecordsAdd(host_evt_sec);
return thr_events;
......
......@@ -21,7 +21,6 @@
#include "paddle/fluid/framework/new_executor/workqueue/thread_data_registry.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/os_info.h"
#include "paddle/fluid/platform/profiler/common_event.h"
namespace paddle {
namespace platform {
......@@ -182,12 +181,14 @@ char *EventContainer<EventType>::GetStringStorage(size_t sz) {
return storage;
}
template <typename EventType>
struct ThreadEventSection {
std::string thread_name;
uint64_t thread_id;
std::vector<CommonEvent> events;
std::vector<EventType> events;
};
template <typename EventType>
class ThreadEventRecorder {
public:
ThreadEventRecorder() {
......@@ -204,8 +205,8 @@ class ThreadEventRecorder {
base_evt_cntr_.Record(std::forward<Args>(args)...);
}
ThreadEventSection GatherEvents() {
ThreadEventSection thr_sec;
ThreadEventSection<EventType> GatherEvents() {
ThreadEventSection<EventType> thr_sec;
thr_sec.thread_name = thread_name_;
thr_sec.thread_id = thread_id_;
thr_sec.events = std::move(base_evt_cntr_.Reduce());
......@@ -215,15 +216,17 @@ class ThreadEventRecorder {
private:
uint64_t thread_id_;
std::string thread_name_;
EventContainer<CommonEvent> base_evt_cntr_;
EventContainer<EventType> base_evt_cntr_;
};
template <typename EventType>
struct HostEventSection {
std::string process_name;
uint64_t process_id;
std::vector<ThreadEventSection> thr_sections;
std::vector<ThreadEventSection<EventType>> thr_sections;
};
template <typename EventType>
class HostEventRecorder {
public:
// singleton
......@@ -244,10 +247,10 @@ class HostEventRecorder {
// thread-unsafe, make sure make sure there is no running tracing.
// Poor performance, call it at the ending
HostEventSection GatherEvents() {
HostEventSection<EventType> GatherEvents() {
auto thr_recorders =
ThreadEventRecorderRegistry::GetInstance().GetAllThreadDataByRef();
HostEventSection host_sec;
HostEventSection<EventType> host_sec;
host_sec.process_id = GetProcessId();
host_sec.thr_sections.reserve(thr_recorders.size());
for (auto &kv : thr_recorders) {
......@@ -260,12 +263,12 @@ class HostEventRecorder {
private:
using ThreadEventRecorderRegistry =
framework::ThreadDataRegistry<ThreadEventRecorder>;
framework::ThreadDataRegistry<ThreadEventRecorder<EventType>>;
HostEventRecorder() = default;
DISABLE_COPY_AND_ASSIGN(HostEventRecorder);
ThreadEventRecorder *GetThreadLocalRecorder() {
ThreadEventRecorder<EventType> *GetThreadLocalRecorder() {
return ThreadEventRecorderRegistry::GetInstance()
.GetMutableCurrentThreadData();
}
......
......@@ -30,7 +30,7 @@ namespace platform {
namespace {
void ProcessHostEvents(const HostEventSection& host_events,
void ProcessHostEvents(const HostEventSection<CommonEvent>& host_events,
TraceEventCollector* collector) {
for (const auto& thr_sec : host_events.thr_sections) {
uint64_t tid = thr_sec.thread_id;
......@@ -62,7 +62,7 @@ void HostTracer::StartTracing() {
PADDLE_ENFORCE_EQ(
state_ == TracerState::READY || state_ == TracerState::STOPED, true,
platform::errors::PreconditionNotMet("TracerState must be READY"));
HostEventRecorder::GetInstance().GatherEvents();
HostEventRecorder<CommonEvent>::GetInstance().GatherEvents();
HostTraceLevel::GetInstance().SetLevel(options_.trace_level);
state_ = TracerState::STARTED;
}
......@@ -79,8 +79,8 @@ void HostTracer::CollectTraceData(TraceEventCollector* collector) {
PADDLE_ENFORCE_EQ(
state_, TracerState::STOPED,
platform::errors::PreconditionNotMet("TracerState must be STOPED"));
HostEventSection host_events =
HostEventRecorder::GetInstance().GatherEvents();
HostEventSection<CommonEvent> host_events =
HostEventRecorder<CommonEvent>::GetInstance().GatherEvents();
ProcessHostEvents(host_events, collector);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册