diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 5d83f302c8251ec21cfd1be9f2b30d626ebcedb7..37d561eb6cbc27d351ca02d1463e24949a764493 100755 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -27,7 +27,7 @@ from ..models.builder import build_model from ..utils.visual import tensor2img, save_image from ..utils.filesystem import makedirs, save, load from ..utils.timer import TimeAverager - +from ..utils.profiler import add_profiler_step class IterLoader: def __init__(self, dataloader): @@ -147,6 +147,7 @@ class Trainer: self.time_count = {} self.best_metric = {} self.model.set_total_iter(self.total_iters) + self.profiler_options = cfg.profiler_options def distributed_data_parallel(self): paddle.distributed.init_parallel_env() @@ -178,6 +179,8 @@ class Trainer: self.current_epoch = iter_loader.epoch self.inner_iter = self.current_iter % self.iters_per_epoch + add_profiler_step(self.profiler_options) + start_time = step_start_time = time.time() data = next(iter_loader) reader_cost_averager.record(time.time() - step_start_time) diff --git a/ppgan/utils/options.py b/ppgan/utils/options.py index 32572c4032168d0398cc4284a507187d072ea20e..57e40b3c882e14058533d3ad3f3534fd41785ac1 100644 --- a/ppgan/utils/options.py +++ b/ppgan/utils/options.py @@ -60,6 +60,13 @@ def parse_args(): help="path to reference images") parser.add_argument("--model_path", default=None, help="model for loading") + # for profiler + parser.add_argument('-p', + '--profiler_options', + type=str, + default=None, + help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".' + ) args = parser.parse_args() return args diff --git a/ppgan/utils/profiler.py b/ppgan/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..b7359739fa48ec847891f034187b81d25406ff81 --- /dev/null +++ b/ppgan/utils/profiler.py @@ -0,0 +1,111 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import paddle + +# A global variable to record the number of calling times for profiler +# functions. It is used to specify the tracing range of training steps. +_profiler_step_id = 0 + +# A global variable to avoid parsing from string every time. +_profiler_options = None + + +class ProfilerOptions(object): + ''' + Use a string to initialize a ProfilerOptions. + The string should be in the format: "key1=value1;key2=value;key3=value3". + For example: + "profile_path=model.profile" + "batch_range=[50, 60]; profile_path=model.profile" + "batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile" + ProfilerOptions supports following key-value pair: + batch_range - a integer list, e.g. [100, 110]. + state - a string, the optional values are 'CPU', 'GPU' or 'All'. + sorted_key - a string, the optional values are 'calls', 'total', + 'max', 'min' or 'ave. + tracer_option - a string, the optional values are 'Default', 'OpDetail', + 'AllOpDetail'. + profile_path - a string, the path to save the serialized profile data, + which can be used to generate a timeline. + exit_on_finished - a boolean. + ''' + + def __init__(self, options_str): + assert isinstance(options_str, str) + + self._options = { + 'batch_range': [10, 20], + 'state': 'All', + 'sorted_key': 'total', + 'tracer_option': 'Default', + 'profile_path': '/tmp/profile', + 'exit_on_finished': True + } + self._parse_from_string(options_str) + + def _parse_from_string(self, options_str): + for kv in options_str.replace(' ', '').split(';'): + key, value = kv.split('=') + if key == 'batch_range': + value_list = value.replace('[', '').replace(']', '').split(',') + value_list = list(map(int, value_list)) + if len(value_list) >= 2 and value_list[0] >= 0 and value_list[ + 1] > value_list[0]: + self._options[key] = value_list + elif key == 'exit_on_finished': + self._options[key] = value.lower() in ("yes", "true", "t", "1") + elif key in [ + 'state', 'sorted_key', 'tracer_option', 'profile_path' + ]: + self._options[key] = value + + def __getitem__(self, name): + if self._options.get(name, None) is None: + raise ValueError( + "ProfilerOptions does not have an option named %s." % name) + return self._options[name] + + +def add_profiler_step(options_str=None): + ''' + Enable the operator-level timing using PaddlePaddle's profiler. + The profiler uses a independent variable to count the profiler steps. + One call of this function is treated as a profiler step. + + Args: + options_str - a string to initialize the ProfilerOptions. + Default is None, and the profiler is disabled. + ''' + if options_str is None: + return + + global _profiler_step_id + global _profiler_options + + if _profiler_options is None: + _profiler_options = ProfilerOptions(options_str) + + if _profiler_step_id == _profiler_options['batch_range'][0]: + paddle.utils.profiler.start_profiler( + _profiler_options['state'], _profiler_options['tracer_option']) + elif _profiler_step_id == _profiler_options['batch_range'][1]: + paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'], + _profiler_options['profile_path']) + if _profiler_options['exit_on_finished']: + sys.exit(0) + + _profiler_step_id += 1 + diff --git a/ppgan/utils/setup.py b/ppgan/utils/setup.py index e37bde59793e33160edee56368ce9c817223d3de..d6b7d9cb88bd4d52a0ef106cf7ed7172220c3242 100644 --- a/ppgan/utils/setup.py +++ b/ppgan/utils/setup.py @@ -25,6 +25,11 @@ def setup(args, cfg): else: cfg.is_train = True + if args.profiler_options: + cfg.profiler_options = args.profiler_options + else: + cfg.profiler_options = None + cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime()) cfg.output_dir = os.path.join( cfg.output_dir,