From aa7a0cb2a89cb436f1c5d64d8da0d7f592a689c2 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Mon, 8 Aug 2022 11:31:12 +0000 Subject: [PATCH] fix bug in amp eval --- .../det_r50_vd_dcn_fce_ctw.yml | 2 +- .../layoutxlm_ser/train_infer_python.txt | 4 +-- tools/program.py | 36 +++++++++++++------ 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml b/test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml index 3a513b8f..29f6f32a 100644 --- a/test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml +++ b/test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml @@ -8,7 +8,7 @@ Global: # evaluation is run every 835 iterations eval_batch_step: [0, 4000] cal_metric_during_train: False - pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained + pretrained_model: pretrain_models/det_r50_dcn_fce_ctw_v2.0_train/best_accuracy.pdparams checkpoints: save_inference_dir: use_visualdl: False diff --git a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt index 887c3285..53415b3e 100644 --- a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt +++ b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg null:null ## trainer:norm_train -norm_train:tools/train.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false +norm_train:tools/train.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false pact_train:null fpgm_train:null distill_train:null @@ -27,7 +27,7 @@ null:null ===========================infer_params=========================== Global.save_inference_dir:./output/ Architecture.Backbone.checkpoints: -norm_export:tools/export_model.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o +norm_export:tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o quant_export: fpgm_export: distill_export:null diff --git a/tools/program.py b/tools/program.py index 0fa0e609..c4a9f916 100755 --- a/tools/program.py +++ b/tools/program.py @@ -154,12 +154,13 @@ def check_xpu(use_xpu): except Exception as e: pass + def to_float32(preds): if isinstance(preds, dict): for k in preds: if isinstance(preds[k], dict) or isinstance(preds[k], list): preds[k] = to_float32(preds[k]) - else: + elif isinstance(preds[k], paddle.Tensor): preds[k] = preds[k].astype(paddle.float32) elif isinstance(preds, list): for k in range(len(preds)): @@ -167,12 +168,13 @@ def to_float32(preds): preds[k] = to_float32(preds[k]) elif isinstance(preds[k], list): preds[k] = to_float32(preds[k]) - else: + elif isinstance(preds[k], paddle.Tensor): preds[k] = preds[k].astype(paddle.float32) - else: + elif isinstance(preds[k], paddle.Tensor): preds = preds.astype(paddle.float32) return preds + def train(config, train_dataloader, valid_dataloader, @@ -370,7 +372,8 @@ def train(config, post_process_class, eval_class, model_type, - extra_input=extra_input) + extra_input=extra_input, + scaler=scaler) cur_metric_str = 'cur metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metric.items()])) logger.info(cur_metric_str) @@ -460,7 +463,8 @@ def eval(model, post_process_class, eval_class, model_type=None, - extra_input=False): + extra_input=False, + scaler=None): model.eval() with paddle.no_grad(): total_frame = 0.0 @@ -477,12 +481,24 @@ def eval(model, break images = batch[0] start = time.time() - if model_type == 'table' or extra_input: - preds = model(images, data=batch[1:]) - elif model_type in ["kie", 'vqa']: - preds = model(batch) + + # use amp + if scaler: + with paddle.amp.auto_cast(level='O2'): + if model_type == 'table' or extra_input: + preds = model(images, data=batch[1:]) + elif model_type in ["kie", 'vqa']: + preds = model(batch) + else: + preds = model(images) else: - preds = model(images) + if model_type == 'table' or extra_input: + preds = model(images, data=batch[1:]) + elif model_type in ["kie", 'vqa']: + preds = model(batch) + else: + preds = model(images) + batch_numpy = [] for item in batch: if isinstance(item, paddle.Tensor): -- GitLab