diff --git a/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml b/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml index d77951785132cb21b29819317acd27a18c234175..50b04ba0dd139060d50aa70421221dfc2c66067f 100644 --- a/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml +++ b/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml @@ -21,6 +21,7 @@ Global: save_res_path: ./output/ser/xfund_zh/res kie_rec_model_dir: kie_det_model_dir: + amp_custom_white_list: ['scale', 'concat', 'elementwise_add'] Architecture: model_type: kie diff --git a/configs/table/SLANet.yml b/configs/table/SLANet.yml index 4a8c35d8b681b793ebcf116c3c568e73fe388aad..3f2d6b28eabf35da227a981e7783ccd59fa62333 100644 --- a/configs/table/SLANet.yml +++ b/configs/table/SLANet.yml @@ -22,6 +22,7 @@ Global: use_sync_bn: True save_res_path: 'output/infer' d2s_train_image_shape: [3, -1, -1] + amp_custom_white_list: ['concat', 'elementwise_sub', 'set_value'] Optimizer: name: Adam diff --git a/test_tipc/configs/slanet/SLANet.yml b/test_tipc/configs/slanet/SLANet.yml index 813363fb180e1eaf8214a19133916fcdeede6648..76e8cc5bb7e2abe6658544f3c29f38e2c2f69c15 100644 --- a/test_tipc/configs/slanet/SLANet.yml +++ b/test_tipc/configs/slanet/SLANet.yml @@ -22,6 +22,7 @@ Global: use_sync_bn: True save_res_path: 'output/infer' d2s_train_image_shape: [3, -1, -1] + amp_custom_white_list: ['concat', 'elementwise_sub', 'set_value'] Optimizer: name: Adam diff --git a/tools/program.py b/tools/program.py index b11d1a09739ed590d0aea696bcb81f8d671a6730..2761e1273392c953fdbe6b1e0cc8dca92533b6df 100755 --- a/tools/program.py +++ b/tools/program.py @@ -188,7 +188,8 @@ def train(config, log_writer=None, scaler=None, amp_level='O2', - amp_custom_black_list=[]): + amp_custom_black_list=[], + amp_custom_white_list=[]): cal_metric_during_train = config['Global'].get('cal_metric_during_train', False) calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1) @@ -277,7 +278,8 @@ def train(config, if scaler: with paddle.amp.auto_cast( level=amp_level, - custom_black_list=amp_custom_black_list): + custom_black_list=amp_custom_black_list, + custom_white_list=amp_custom_white_list): if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) elif model_type in ["kie"]: diff --git a/tools/train.py b/tools/train.py index 3d5bf8447c786eeaf84c2877358dd2f48910e9ea..a162fa6d932657c4bd5c928ede6921ce7a162b16 100755 --- a/tools/train.py +++ b/tools/train.py @@ -161,6 +161,7 @@ def main(config, device, logger, vdl_writer): use_amp = config["Global"].get("use_amp", False) amp_level = config["Global"].get("amp_level", 'O2') amp_custom_black_list = config['Global'].get('amp_custom_black_list', []) + amp_custom_white_list = config['Global'].get('amp_custom_white_list', []) if use_amp: AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, } if paddle.is_compiled_with_cuda(): @@ -194,7 +195,7 @@ def main(config, device, logger, vdl_writer): program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, eval_class, pre_best_model_dict, logger, vdl_writer, scaler, - amp_level, amp_custom_black_list) + amp_level, amp_custom_black_list, amp_custom_white_list) def test_reader(config, device, logger):