From c3924a959b504891b28639f4abebcb1ac892d42c Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Mon, 22 Aug 2022 11:32:37 +0000 Subject: [PATCH] add amp eval --- tools/eval.py | 29 ++++++++++++++++++++++++++--- tools/program.py | 16 ++++++++++------ tools/train.py | 11 +++++++---- 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/tools/eval.py b/tools/eval.py index 38d72d17..3d1d3813 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -23,6 +23,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, __dir__) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) +import paddle from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process @@ -86,6 +87,30 @@ def main(): else: model_type = None + # build metric + eval_class = build_metric(config['Metric']) + # amp + use_amp = config["Global"].get("use_amp", False) + amp_level = config["Global"].get("amp_level", 'O2') + amp_custom_black_list = config['Global'].get('amp_custom_black_list',[]) + if use_amp: + AMP_RELATED_FLAGS_SETTING = { + 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, + 'FLAGS_max_inplace_grad_add': 8, + } + paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) + scale_loss = config["Global"].get("scale_loss", 1.0) + use_dynamic_loss_scaling = config["Global"].get( + "use_dynamic_loss_scaling", False) + scaler = paddle.amp.GradScaler( + init_loss_scaling=scale_loss, + use_dynamic_loss_scaling=use_dynamic_loss_scaling) + if amp_level == "O2": + model = paddle.amp.decorate( + models=model, level=amp_level, master_weight=True) + else: + scaler = None + best_model_dict = load_model( config, model, model_type=config['Architecture']["model_type"]) if len(best_model_dict): @@ -93,11 +118,9 @@ def main(): for k, v in best_model_dict.items(): logger.info('{}:{}'.format(k, v)) - # build metric - eval_class = build_metric(config['Metric']) # start eval metric = program.eval(model, valid_dataloader, post_process_class, - eval_class, model_type, extra_input) + eval_class, model_type, extra_input, scaler, amp_level, amp_custom_black_list) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) diff --git a/tools/program.py b/tools/program.py index 50f8455e..94629109 100755 --- a/tools/program.py +++ b/tools/program.py @@ -191,7 +191,8 @@ def train(config, logger, log_writer=None, scaler=None, - amp_level='O2'): + amp_level='O2', + amp_custom_black_list=[]): cal_metric_during_train = config['Global'].get('cal_metric_during_train', False) calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1) @@ -277,8 +278,7 @@ def train(config, model_average = True # use amp if scaler: - custom_black_list = config['Global'].get('amp_custom_black_list',[]) - with paddle.amp.auto_cast(level=amp_level, custom_black_list=custom_black_list): + with paddle.amp.auto_cast(level=amp_level, custom_black_list=amp_custom_black_list): if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) elif model_type in ["kie", 'vqa']: @@ -383,7 +383,9 @@ def train(config, eval_class, model_type, extra_input=extra_input, - scaler=scaler) + scaler=scaler, + amp_level=amp_level, + amp_custom_black_list=amp_custom_black_list) cur_metric_str = 'cur metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metric.items()])) logger.info(cur_metric_str) @@ -474,7 +476,9 @@ def eval(model, eval_class, model_type=None, extra_input=False, - scaler=None): + scaler=None, + amp_level='O2', + amp_custom_black_list = []): model.eval() with paddle.no_grad(): total_frame = 0.0 @@ -495,7 +499,7 @@ def eval(model, # use amp if scaler: - with paddle.amp.auto_cast(level='O2'): + with paddle.amp.auto_cast(level=amp_level, custom_black_list=amp_custom_black_list): if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) elif model_type in ["kie", 'vqa']: diff --git a/tools/train.py b/tools/train.py index 5f310938..d0f20018 100755 --- a/tools/train.py +++ b/tools/train.py @@ -138,9 +138,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) - # load pretrain model - pre_best_model_dict = load_model(config, model, optimizer, - config['Architecture']["model_type"]) + logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: logger.info('valid dataloader has {} iters'.format( @@ -148,6 +146,7 @@ def main(config, device, logger, vdl_writer): use_amp = config["Global"].get("use_amp", False) amp_level = config["Global"].get("amp_level", 'O2') + amp_custom_black_list = config['Global'].get('amp_custom_black_list',[]) if use_amp: AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, @@ -166,12 +165,16 @@ def main(config, device, logger, vdl_writer): else: scaler = None + # load pretrain model + pre_best_model_dict = load_model(config, model, optimizer, + config['Architecture']["model_type"]) + if config['Global']['distributed']: model = paddle.DataParallel(model) # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, - eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level) + eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level, amp_custom_black_list) def test_reader(config, device, logger): -- GitLab