diff --git a/python/paddle/profiler/profiler.py b/python/paddle/profiler/profiler.py index b2e95d24b0bc6224698daded838dca17adadd310..45cb9651bb19b1a42ebbec294a498efff5b4e65e 100644 --- a/python/paddle/profiler/profiler.py +++ b/python/paddle/profiler/profiler.py @@ -403,7 +403,8 @@ class Profiler: on_trace_ready: Optional[Callable[..., Any]] = None, record_shapes: Optional[bool] = False, profile_memory=False, - timer_only: Optional[bool] = False): + timer_only: Optional[bool] = False, + emit_nvtx: Optional[bool] = False): supported_targets = _get_supported_targets() if targets: self.targets = set(targets) @@ -456,6 +457,7 @@ class Profiler: self.timer_only = timer_only self.record_shapes = record_shapes self.profile_memory = profile_memory + self.emit_nvtx = emit_nvtx def __enter__(self): self.start() @@ -488,6 +490,8 @@ class Profiler: ''' # Timing only without profiling benchmark().begin() + if not self.timer_only or self.emit_nvtx: + utils._is_profiler_used = True if self.timer_only: return if self.record_shapes: @@ -495,7 +499,6 @@ class Profiler: if self.profile_memory: enable_memory_recorder() # CLOSED -> self.current_state - utils._is_profiler_used = True if self.current_state == ProfilerState.READY: self.profiler.prepare() elif self.current_state == ProfilerState.RECORD: