未验证 提交 ff2f1373 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enable Profiler to add nvtx tag. (#45162)

上级 d85cf4d4
......@@ -403,7 +403,8 @@ class Profiler:
on_trace_ready: Optional[Callable[..., Any]] = None,
record_shapes: Optional[bool] = False,
profile_memory=False,
timer_only: Optional[bool] = False):
timer_only: Optional[bool] = False,
emit_nvtx: Optional[bool] = False):
supported_targets = _get_supported_targets()
if targets:
self.targets = set(targets)
......@@ -456,6 +457,7 @@ class Profiler:
self.timer_only = timer_only
self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.emit_nvtx = emit_nvtx
def __enter__(self):
self.start()
......@@ -488,6 +490,8 @@ class Profiler:
'''
# Timing only without profiling
benchmark().begin()
if not self.timer_only or self.emit_nvtx:
utils._is_profiler_used = True
if self.timer_only:
return
if self.record_shapes:
......@@ -495,7 +499,6 @@ class Profiler:
if self.profile_memory:
enable_memory_recorder()
# CLOSED -> self.current_state
utils._is_profiler_used = True
if self.current_state == ProfilerState.READY:
self.profiler.prepare()
elif self.current_state == ProfilerState.RECORD:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册