未验证 提交 fee73135 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Use cuda events to improve timing for multi-stream execution (#1881)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 66aae13d
......@@ -20,42 +20,55 @@ except ImportError:
pass
class CudaEventTimer(object):
def __init__(self, start_event: torch.cuda.Event, end_event: torch.cuda.Event):
self.start_event = start_event
self.end_event = end_event
def get_elapsed_msec(self):
torch.cuda.current_stream().wait_event(self.end_event)
self.end_event.synchronize()
return self.start_event.elapsed_time(self.end_event)
class SynchronizedWallClockTimer:
"""Group of timers. Borrowed from Nvidia Megatron code"""
class Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
self.records = []
self.event_timers = []
self.start_event = None
self.elapsed_records = None
def start(self):
"""Start the timer."""
assert not self.started_, "timer has already been started"
torch.cuda.synchronize()
self.start_time = time.time()
assert not self.started_, f"{self.name} timer has already been started"
self.start_event = torch.cuda.Event(enable_timing=True)
self.start_event.record()
self.started_ = True
def stop(self, reset=False, record=False):
"""Stop the timer."""
assert self.started_, "timer is not started"
torch.cuda.synchronize()
if reset:
self.elapsed_ = time.time() - self.start_time
else:
self.elapsed_ += time.time() - self.start_time
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
self.event_timers.append(CudaEventTimer(self.start_event, end_event))
self.start_event = None
self.started_ = False
if record:
self.records.append(self.elapsed_)
def _get_elapsed_msec(self):
self.elapsed_records = [et.get_elapsed_msec() for et in self.event_timers]
self.event_timers.clear()
return sum(self.elapsed_records)
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
self.acc_ = 0.0
self.cnt_ = 0
self.start_event = None
self.elapsed_records = None
self.event_timers.clear()
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
......@@ -64,7 +77,7 @@ class SynchronizedWallClockTimer:
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
elapsed_ = self._get_elapsed_msec()
# Reset the elapsed time
if reset:
self.reset()
......@@ -74,7 +87,7 @@ class SynchronizedWallClockTimer:
return elapsed_
def mean(self):
return trim_mean(self.records, 0.1)
return trim_mean(self.elapsed_records, 0.1)
def __init__(self):
self.timers = {}
......@@ -102,8 +115,7 @@ class SynchronizedWallClockTimer:
string = f"rank={torch.distributed.get_rank()} time (ms)"
for name in names:
if name in self.timers:
elapsed_time = (self.timers[name].elapsed(reset=reset) * 1000.0 /
normalizer)
elapsed_time = (self.timers[name].elapsed(reset=reset) / normalizer)
string += " | {}: {:.2f}".format(name, elapsed_time)
log_dist(string, ranks=ranks or [0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册