提交 94c0a64d 编写于 作者: X Xin Pan

Fix a profiler race condition

In multi-thread condition, EnableProfiler can
be called after RecordEvent is constructed. In this
case, RecordEvent constructor will not init anything,
but RecordEvent destructor will do something since EnableProfiler
was called.
This PR fixes it.
上级 ca5ea65a
...@@ -173,8 +173,9 @@ void PopEvent(const std::string& name, const DeviceContext* dev_ctx) { ...@@ -173,8 +173,9 @@ void PopEvent(const std::string& name, const DeviceContext* dev_ctx) {
} }
RecordEvent::RecordEvent(const std::string& name, const DeviceContext* dev_ctx) RecordEvent::RecordEvent(const std::string& name, const DeviceContext* dev_ctx)
: start_ns_(PosixInNsec()) { : is_enabled_(false), start_ns_(PosixInNsec()) {
if (g_state == ProfilerState::kDisabled) return; if (g_state == ProfilerState::kDisabled) return;
is_enabled_ = true;
dev_ctx_ = dev_ctx; dev_ctx_ = dev_ctx;
name_ = name; name_ = name;
PushEvent(name_, dev_ctx_); PushEvent(name_, dev_ctx_);
...@@ -183,7 +184,7 @@ RecordEvent::RecordEvent(const std::string& name, const DeviceContext* dev_ctx) ...@@ -183,7 +184,7 @@ RecordEvent::RecordEvent(const std::string& name, const DeviceContext* dev_ctx)
} }
RecordEvent::~RecordEvent() { RecordEvent::~RecordEvent() {
if (g_state == ProfilerState::kDisabled) return; if (g_state == ProfilerState::kDisabled || !is_enabled_) return;
DeviceTracer* tracer = GetDeviceTracer(); DeviceTracer* tracer = GetDeviceTracer();
if (tracer) { if (tracer) {
tracer->AddCPURecords(CurAnnotation(), start_ns_, PosixInNsec(), tracer->AddCPURecords(CurAnnotation(), start_ns_, PosixInNsec(),
...@@ -193,14 +194,16 @@ RecordEvent::~RecordEvent() { ...@@ -193,14 +194,16 @@ RecordEvent::~RecordEvent() {
PopEvent(name_, dev_ctx_); PopEvent(name_, dev_ctx_);
} }
RecordBlock::RecordBlock(int block_id) : start_ns_(PosixInNsec()) { RecordBlock::RecordBlock(int block_id)
: is_enabled_(false), start_ns_(PosixInNsec()) {
if (g_state == ProfilerState::kDisabled) return; if (g_state == ProfilerState::kDisabled) return;
is_enabled_ = true;
SetCurBlock(block_id); SetCurBlock(block_id);
name_ = string::Sprintf("block_%d", block_id); name_ = string::Sprintf("block_%d", block_id);
} }
RecordBlock::~RecordBlock() { RecordBlock::~RecordBlock() {
if (g_state == ProfilerState::kDisabled) return; if (g_state == ProfilerState::kDisabled || !is_enabled_) return;
DeviceTracer* tracer = GetDeviceTracer(); DeviceTracer* tracer = GetDeviceTracer();
if (tracer) { if (tracer) {
// We try to put all blocks at the same nested depth in the // We try to put all blocks at the same nested depth in the
......
...@@ -74,6 +74,7 @@ struct RecordEvent { ...@@ -74,6 +74,7 @@ struct RecordEvent {
~RecordEvent(); ~RecordEvent();
bool is_enabled_;
uint64_t start_ns_; uint64_t start_ns_;
// The device context is used by Event to get the current cuda stream. // The device context is used by Event to get the current cuda stream.
const DeviceContext* dev_ctx_; const DeviceContext* dev_ctx_;
...@@ -89,6 +90,7 @@ struct RecordBlock { ...@@ -89,6 +90,7 @@ struct RecordBlock {
~RecordBlock(); ~RecordBlock();
private: private:
bool is_enabled_;
std::string name_; std::string name_;
uint64_t start_ns_; uint64_t start_ns_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册