From 83efeeae4fa8d825c2848a47f8fc602a7126cee4 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Wed, 30 Mar 2022 10:08:59 +0800 Subject: [PATCH] Add timer tool to Profiler (#40386) --- .../fluid/dataloader/dataloader_iter.py | 7 + .../fluid/tests/unittests/test_newprofiler.py | 72 +++ python/paddle/profiler/profiler.py | 133 +++++- python/paddle/profiler/timer.py | 418 ++++++++++++++++++ 4 files changed, 628 insertions(+), 2 deletions(-) create mode 100644 python/paddle/profiler/timer.py diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index edd22bb94d6..0dc733440fa 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -41,6 +41,7 @@ from .worker import ParentWatchDog, get_worker_info, _worker_loop, \ _DatasetKind, _IterableDatasetStopIteration, _WorkerException, \ _ResumeIteration from .flat import _flatten_batch, _restore_batch +from paddle.profiler.timer import benchmark __all__ = ['get_worker_info'] @@ -256,6 +257,8 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): event_type=profiler.TracerEventType.Dataloader) trace_event.begin() try: + benchmark().check_if_need_record(self) + benchmark().before_reader() if in_dygraph_mode(): data = core.eager.read_next_tensor_list( self._reader.read_next_list()[0]) @@ -283,6 +286,7 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase): data = data[0] else: data = self._reader.read_next() + benchmark().after_reader() return data except StopIteration: @@ -708,6 +712,8 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): event_type=profiler.TracerEventType.Dataloader) trace_event.begin() try: + benchmark().check_if_need_record(self) + benchmark().before_reader() # _batches_outstanding here record the total batch data number # in 'from after _try_put_indices to beforeoutput data', this # value should be _outstanding_capacity if data is not drained, @@ -750,6 +756,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): else: data = self._reader.read_next() self._on_output_batch() + benchmark().after_reader() return data except StopIteration: if not self._persistent_workers: diff --git a/python/paddle/fluid/tests/unittests/test_newprofiler.py b/python/paddle/fluid/tests/unittests/test_newprofiler.py index 12fb0fa61b0..c93e5dce86d 100755 --- a/python/paddle/fluid/tests/unittests/test_newprofiler.py +++ b/python/paddle/fluid/tests/unittests/test_newprofiler.py @@ -19,6 +19,9 @@ import numpy as np import paddle import paddle.profiler as profiler +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.io import Dataset, DataLoader class TestProfiler(unittest.TestCase): @@ -125,5 +128,74 @@ class TestProfiler(unittest.TestCase): result = profiler.utils.load_profiler_result('./test_profiler_pb.pb') +class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([100]).astype('float32') + label = np.random.randint(0, 10 - 1, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + +class SimpleNet(nn.Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.fc = nn.Linear(100, 10) + + def forward(self, image, label=None): + return self.fc(image) + + +class TestTimerOnly(unittest.TestCase): + def test_with_dataloader(self): + def train(step_num_samples=None): + dataset = RandomDataset(20 * 4) + simple_net = SimpleNet() + opt = paddle.optimizer.SGD(learning_rate=1e-3, + parameters=simple_net.parameters()) + loader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + drop_last=True, + num_workers=2) + step_info = '' + p = profiler.Profiler(timer_only=True) + p.start() + for i, (image, label) in enumerate(loader()): + out = simple_net(image) + loss = F.cross_entropy(out, label) + avg_loss = paddle.mean(loss) + avg_loss.backward() + opt.minimize(avg_loss) + simple_net.clear_gradients() + p.step(num_samples=step_num_samples) + if i % 10 == 0: + step_info = p.step_info() + print("Iter {}: {}".format(i, step_info)) + p.stop() + return step_info + + step_info = train(step_num_samples=None) + self.assertTrue('steps/s' in step_info) + step_info = train(step_num_samples=4) + self.assertTrue('samples/s' in step_info) + + def test_without_dataloader(self): + x = paddle.to_tensor(np.random.randn(10, 10)) + y = paddle.to_tensor(np.random.randn(10, 10)) + p = profiler.Profiler(timer_only=True) + p.start() + step_info = '' + for i in range(20): + out = x + y + p.step() + p.stop() + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/profiler/profiler.py b/python/paddle/profiler/profiler.py index d978e594399..c1c4f4ff8c1 100644 --- a/python/paddle/profiler/profiler.py +++ b/python/paddle/profiler/profiler.py @@ -25,6 +25,7 @@ from paddle.fluid.core import (_Profiler, _ProfilerResult, ProfilerOptions, from .utils import RecordEvent, wrap_optimizers from .profiler_statistic import StatisticData, _build_table, SortedKeys +from .timer import benchmark class ProfilerState(Enum): @@ -269,6 +270,8 @@ class Profiler: which means profiling range [start_batch, end_batch). on_trace_ready (Callable, optional): Callable object, serves as callback function, and takes the Profiler object as parameter, which provides a way for users to do post-processing. This callable object will be called when ``scheduler`` returns ``ProfilerState.RECORD_AND_RETURN``. The default value is :ref:`export_chrome_tracing ` (./profiler_log/). + timer_only (bool, optional): If it is True, the cost of Dataloader and every step of the model will be count without profiling. Otherwise, the model will + be timed and profiled. Default: False. Examples: 1. profiling range [2, 5). @@ -316,6 +319,68 @@ class Profiler: p.stop() p.summary() + 4. Use profiler to get throughput and cost of the model + + .. code-block:: python + :name: code-example-timer1 + + import paddle + import paddle.profiler as profiler + + class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = paddle.rand(shape=[100], dtype='float32') + label = paddle.randint(0, 10, shape=[1], dtype='int64') + return image, label + + def __len__(self): + return self.num_samples + + class SimpleNet(paddle.nn.Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.fc = paddle.nn.Linear(100, 10) + + def forward(self, image, label=None): + return self.fc(image) + + dataset = RandomDataset(20 * 4) + simple_net = SimpleNet() + opt = paddle.optimizer.SGD(learning_rate=1e-3, + parameters=simple_net.parameters()) + BATCH_SIZE = 4 + loader = paddle.io.DataLoader( + dataset, + batch_size=BATCH_SIZE) + p = profiler.Profiler(timer_only=True) + p.start() + for i, (image, label) in enumerate(loader()): + out = simple_net(image) + loss = paddle.nn.functional.cross_entropy(out, label) + avg_loss = paddle.mean(loss) + avg_loss.backward() + opt.minimize(avg_loss) + simple_net.clear_gradients() + p.step(num_samples=BATCH_SIZE) + if i % 10 == 0: + step_info = p.step_info(unit='images') + print("Iter {}: {}".format(i, step_info)) + # The average statistics for 10 steps between the last and this call will be + # printed when the "step_info" is called at 10 iteration intervals. + # The values you get may be different from the following. + # Iter 0: reader_cost: 0.51946 s batch_cost: 0.66077 s ips: 6.054 images/s + # Iter 10: reader_cost: 0.00014 s batch_cost: 0.00441 s ips: 907.009 images/s + p.stop() + # The performance summary will be automatically printed when the "stop" is called. + # Reader Ratio: 2.658% + # Time Unit: s, IPS Unit: images/s + # | | avg | max | min | + # | reader_cost | 0.00011 | 0.00013 | 0.00007 | + # | batch_cost | 0.00405 | 0.00434 | 0.00326 | + # | ips | 1086.42904 | 1227.30604 | 959.92796 | """ def __init__( @@ -323,7 +388,8 @@ class Profiler: *, targets: Optional[Iterable[ProfilerTarget]]=None, scheduler: Union[Callable[[int], ProfilerState], tuple, None]=None, - on_trace_ready: Optional[Callable[..., Any]]=None): + on_trace_ready: Optional[Callable[..., Any]]=None, + timer_only: Optional[bool]=False): supported_targets = _get_supported_targets() if targets: self.targets = set(targets) @@ -371,6 +437,7 @@ class Profiler: self.current_state = self.scheduler(self.step_num) self.record_event = None self.profiler_result = None + self.timer_only = timer_only def __enter__(self): self.start() @@ -399,7 +466,12 @@ class Profiler: #train() prof.step() prof.stop() + ''' + # Timing only without profiling + benchmark().begin() + if self.timer_only: + return # CLOSED -> self.current_state if self.current_state == ProfilerState.READY: self.profiler.prepare() @@ -435,6 +507,9 @@ class Profiler: prof.step() prof.stop() ''' + benchmark().end() + if self.timer_only: + return # self.current_state -> CLOSED # In this situation, RECORD state is regarded as RECORD_AND_RETURN if self.record_event: @@ -451,11 +526,15 @@ class Profiler: if self.on_trace_ready: self.on_trace_ready(self) - def step(self): + def step(self, num_samples: Optional[int]=None): r""" Signals the profiler that the next profiling step has started. Get the new ProfilerState and trigger corresponding action. + Args: + num_samples (int|None, optional): Specifies the batch size of every step of the model + that is used to compute throughput when timer_only is True. Default: None. + Examples: .. code-block:: python :name: code-example6 @@ -473,6 +552,9 @@ class Profiler: prof.step() prof.stop() """ + benchmark().step(num_samples) + if self.timer_only: + return if self.record_event: self.record_event.end() self.record_event = None @@ -485,6 +567,53 @@ class Profiler: event_type=TracerEventType.ProfileStep) self.record_event.begin() + def step_info(self, unit=None): + r""" + Get statistics for current step. If the function is called at certain iteration + intervals, the result is the average of all steps between the previous call and + this call. Statistics are as follows: + + 1. reader_cost: the cost of loading data measured in seconds. + + 2. batch_cost: the cost of step measured in seconds. + + 3. ips(Instance Per Second): the throughput of the model measured in `samples/s` + or others depends on the `unit`. When `num_samples` of `step()` is None, it is + measured in `steps/s`. + + Args: + unit (string, optional): The unit of input data is only used When `num_samples` + of `step()` is specified as a number. For example, when it is `images`, the unit + of throughput is `images/s`. Default: None, the unit of throughput is `samples/s`. + + Returns: + string: A string representing the statistic. + + Examples: + .. code-block:: python + :name: code-example-timer2 + + import paddle.profiler as profiler + prof = profiler.Profiler(timer_only=True) + prof.start() + for iter in range(20): + #train() + prof.step() + if iter % 10 == 0: + print("Iter {}: {}".format(iter, prof.step_info())) + # The example does not call the DataLoader, so there is no "reader_cost". + # Iter 0: batch_cost: 0.00001 s ips: 86216.623 steps/s + # Iter 10: batch_cost: 0.00001 s ips: 103645.034 steps/s + prof.stop() + # Time Unit: s, IPS Unit: steps/s + # | | avg | max | min | + # | batch_cost | 0.00000 | 0.00002 | 0.00000 | + # | ips | 267846.19437 | 712030.38727 | 45134.16662 | + """ + if unit is None: + unit = 'samples' + return benchmark().step_info(unit) + def _trigger_action(self): if self.previous_state == ProfilerState.CLOSED: if self.current_state == ProfilerState.READY: # CLOSED -> READY diff --git a/python/paddle/profiler/timer.py b/python/paddle/profiler/timer.py new file mode 100644 index 00000000000..1fb06ddc55e --- /dev/null +++ b/python/paddle/profiler/timer.py @@ -0,0 +1,418 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import timeit +import logging +from collections import OrderedDict + + +class Stack(object): + """ + The stack in a Last-In/First-Out (LIFO) manner. New element is added at + the end and an element is removed from that end. + """ + + def __init__(self): + self.items = [] + + def push(self, item): + self.items.append(item) + + def pop(self): + return self.items.pop() + + def is_empty(self): + return len(self.items) == 0 + + def peek(self): + if not self.is_empty(): + return self.items[len(self.items) - 1] + else: + return None + + +class Event(object): + """ + A Event is used to record the cost of every step and the cost of + the total steps except skipped steps. + """ + + def __init__(self): + self.reader_cost_averager = TimeAverager() + self.batch_cost_averager = TimeAverager() + self.total_samples = 0 + self.total_iters = 0 + self.skip_iter = 10 + self.reader_records = dict(max=0, min=float('inf'), total=0) + self.batch_records = dict(max=0, min=float('inf'), total=0) + self.speed_records = dict(max=0, min=float('inf')) + self.reader = None + self.need_record = True + # The speed mode depends on the setting of num_samples, there + # are 2 modes: steps/s(num_samples=None) or samples/s. + self.speed_mode = 'samples/s' + # The speed unit depends on the unit of samples that is + # specified in step_info and only works in this speed_mode="samples/s". + self.speed_unit = 'samples/s' + + def reset(self): + self.reader_cost_averager.reset() + self.batch_cost_averager.reset() + + def record_reader(self, usetime): + self.reader_cost_averager.record(usetime) + if self.total_iters >= self.skip_iter: + self._update_records(usetime, self.reader_records) + + def record_batch(self, usetime, num_samples=None): + if num_samples is None: + self.speed_mode = "steps/s" + self.speed_unit = "steps/s" + self.batch_cost_averager.record(usetime, num_samples) + self.total_iters += 1 + + if self.total_iters >= self.skip_iter: + self._update_records(usetime, self.batch_records) + if self.speed_mode == "samples/s": + current_speed = float(num_samples) / usetime + self.total_samples += num_samples + else: + current_speed = 1.0 / usetime # steps/s + self._update_records(current_speed, self.speed_records) + + def _update_records(self, current_record, records): + if current_record > records['max']: + records['max'] = current_record + elif current_record < records['min']: + records['min'] = current_record + if 'total' in records.keys(): + records['total'] += current_record + + def reader_average(self): + return self.reader_cost_averager.get_average() + + def batch_average(self): + return self.batch_cost_averager.get_average() + + def speed_average(self): + if self.speed_mode == "samples/s": + return self.batch_cost_averager.get_ips_average() + else: + return self.batch_cost_averager.get_step_average() + + def get_summary(self): + if self.total_iters <= self.skip_iter: + return {} + + reader_avg = 0 + batch_avg = 0 + speed_avg = 0 + + self.total_iters -= self.skip_iter + reader_avg = self.reader_records['total'] / float(self.total_iters) + batch_avg = self.batch_records['total'] / float(self.total_iters) + if self.speed_mode == "samples/s": + speed_avg = float(self.total_samples) / self.batch_records['total'] + else: + speed_avg = float(self.total_iters) / self.batch_records['total'] + + reader_summary = dict( + max=self.reader_records['max'], + min=self.reader_records['min'], + avg=reader_avg) + batch_summary = dict( + max=self.batch_records['max'], + min=self.batch_records['min'], + avg=batch_avg) + ips_summary = dict( + max=self.speed_records['max'], + min=self.speed_records['min'], + avg=speed_avg) + reader_ratio = (reader_avg / batch_avg) * 100 + summary = dict( + reader_summary=reader_summary, + batch_summary=batch_summary, + ips_summary=ips_summary, + reader_ratio=reader_ratio) + + return summary + + +class Hook(object): + """ + As the base class. All types of hooks should inherit from it. + """ + + def begin(self, benchmark): + pass + + def end(self, benchmark): + pass + + def before_reader(self, benchmark): + pass + + def after_reader(self, benchmark): + pass + + def after_step(self, benchmark): + pass + + +class TimerHook(Hook): + """ + A hook for recording real-time performance and the summary + performance of total steps. + """ + + def __init__(self): + self.start_time = timeit.default_timer() + self.start_reader = timeit.default_timer() + + def begin(self, benchmark): + """ + Create the event for timing and initialize the start time of a step. + This function will be called in `Profiler.start()`. + """ + + benchmark.events.push(Event()) + benchmark.current_event = benchmark.events.peek() + self.start_time = timeit.default_timer() + + def before_reader(self, benchmark): + """ + Initialize the start time of the dataloader. This function will be + called at the begining of `next` method in `_DataLoaderIterMultiProcess` or + `_DataLoaderIterSingleProcess`. + + """ + + self.start_reader = timeit.default_timer() + + def after_reader(self, benchmark): + """ + Record the cost of dataloader for the current step. Since the skipped steps + are 10, it will update the maximum, minimum and the total time from the step + 11 to the current step. This function will be called at the end of `next` + method in `_DataLoaderIterMultiProcess` or `_DataLoaderIterSingleProcess`. + + """ + + reader_cost = timeit.default_timer() - self.start_reader + if (benchmark.current_event is None) or ( + not benchmark.current_event.need_record) or (reader_cost == 0): + return + benchmark.current_event.record_reader(reader_cost) + + def after_step(self, benchmark): + """ + Record the cost for the current step. It will contain the cost of the loading + data if there is a dataloader. Similar to `after_reader`, it will also update + the maximum, minimum and the total time from the step 11 to the current step + as well as the the maximum and minimum speed of the model. This function will + be called in in `Profiler.step()`. + + """ + + if (benchmark.current_event is None) or ( + not benchmark.current_event.need_record): + return + batch_cost = timeit.default_timer() - self.start_time + benchmark.current_event.record_batch(batch_cost, benchmark.num_samples) + self.start_time = timeit.default_timer() + + def end(self, benchmark): + """ + Print the performance summary of the model and pop the current event + from the events stack. Since there may be nested timing events, such + as evaluation in the training process, the current event needs to be + update to the event at the top of the stack. + + """ + + if benchmark.events.is_empty(): + return + self._print_summary(benchmark) + benchmark.events.pop() + benchmark.current_event = benchmark.events.peek() + self.start_time = timeit.default_timer() + + def _print_summary(self, benchmark): + summary = benchmark.current_event.get_summary() + if not summary: + return + print('Perf Summary'.center(100, '=')) + if summary['reader_ratio'] != 0: + print('Reader Ratio: ' + '%.3f' % (summary['reader_ratio']) + '%') + print('Time Unit: s, IPS Unit: %s' % + (benchmark.current_event.speed_unit)) + print('|', ''.center(15), '|', 'avg'.center(15), '|', 'max'.center(15), + '|', 'min'.center(15), '|') + # if DataLoader is not called, reader_summary is unnecessary. + if summary['reader_summary']['avg'] != 0: + self._print_stats('reader_cost', summary['reader_summary']) + self._print_stats('batch_cost', summary['batch_summary']) + self._print_stats('ips', summary['ips_summary']) + + def _print_stats(self, item, message_dict): + avg_str = '%.5f' % (message_dict['avg']) + max_str = '%.5f' % (message_dict['max']) + min_str = '%.5f' % (message_dict['min']) + print('|', + item.center(15), '|', + avg_str.center(15), '|', + max_str.center(15), '|', min_str.center(15), '|') + + +class TimeAverager(object): + """ + Record the cost of every step and count the average. + """ + + def __init__(self): + self.reset() + + def reset(self): + self._total_iters = 0 + self._total_time = 0 + self._total_samples = 0 + + def record(self, usetime, num_samples=None): + self._total_iters += 1 + self._total_time += usetime + if num_samples: + self._total_samples += num_samples + + def get_average(self): + """ + Get the average cost of loading data or a step. + """ + + if self._total_iters == 0: + return 0 + return self._total_time / float(self._total_iters) + + def get_ips_average(self): + """ + Get the average throughput when speed mode is "samples/s". + """ + + if not self._total_samples or self._total_iters == 0: + return 0 + return float(self._total_samples) / self._total_time + + def get_step_average(self): + """ + Get the average speed when speed mode is "step/s". + """ + + if self._total_iters == 0: + return 0 + return float(self._total_iters) / self._total_time + + +class Benchmark(object): + """ + A tool for the statistics of model performance. The `before_reader` + and `after_reader` are called in the DataLoader to count the cost + of loading the data. The `begin`, `step` and `end` are called to + count the cost of a step or total steps. + """ + + def __init__(self): + self.num_samples = None + self.hooks = OrderedDict(timer_hook=TimerHook()) + self.current_event = None + self.events = Stack() + + def step(self, num_samples=None): + """ + Record the statistic for the current step. It will be called in + `Profiler.step()`. + """ + + self.num_samples = num_samples + self.after_step() + + def step_info(self, unit): + """ + It returns the statistic of the current step as a string. It contains + "reader_cost", "batch_cost" and "ips". + """ + + message = '' + reader_average = self.current_event.reader_average() + batch_average = self.current_event.batch_average() + if reader_average: + message += ' reader_cost: %.5f s' % (reader_average) + if batch_average: + if self.current_event.speed_mode == 'steps/s': + self.current_event.speed_unit = 'steps/s' + else: + self.current_event.speed_unit = unit + '/s' + message += ' %s: %.5f s' % ('batch_cost', batch_average) + speed_average = self.current_event.speed_average() + if speed_average: + message += ' ips: %.3f %s' % (speed_average, + self.current_event.speed_unit) + self.current_event.reset() + return message + + def begin(self): + for hook in self.hooks.values(): + hook.begin(self) + + def before_reader(self): + for hook in self.hooks.values(): + hook.before_reader(self) + + def after_reader(self): + for hook in self.hooks.values(): + hook.after_reader(self) + + def after_step(self): + for hook in self.hooks.values(): + hook.after_step(self) + + def end(self): + for hook in self.hooks.values(): + hook.end(self) + + def check_if_need_record(self, reader): + if self.current_event is None: + return + if self.current_event.need_record: + # set reader for the current event at the first iter + if self.current_event.reader is None: + self.current_event.reader = reader + elif self.current_event.reader.__dict__[ + '_dataset'] != reader.__dict__['_dataset']: + # enter a new task but not calling beign() to record it. + # we pause the timer until the end of new task, so that + # the cost of new task is not added to the current event. + # eg. start evaluation in the traing task + self.current_event.need_record = False + else: + # when the new task exits, continue timing for the current event. + if self.current_event.reader.__dict__[ + '_dataset'] == reader.__dict__['_dataset']: + self.current_event.need_record = True + self.hooks['timer_hook'].start_time = timeit.default_timer() + + +_benchmark_ = Benchmark() + + +def benchmark(): + return _benchmark_ -- GitLab