diff --git a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml index df429314cd0ec058aa6779a0ff55656f1b211bbf..acf438950a43af3356c7ab0aadf956fdf226814e 100644 --- a/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml +++ b/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml @@ -14,6 +14,9 @@ Global: use_visualdl: False infer_img: doc/imgs_en/img_10.jpg save_res_path: ./output/det_db/predicts_db.txt + use_amp: False + amp_level: O2 + amp_custom_black_list: ['exp'] Architecture: name: DistillationModel diff --git a/tools/program.py b/tools/program.py index 8de15ee0121511a9e3ea665e9d6e8ab0733b19c3..5a4d3ea4d2ec6832e6735d15096d46fbb62f86dd 100755 --- a/tools/program.py +++ b/tools/program.py @@ -278,7 +278,8 @@ def train(config, model_average = True # use amp if scaler: - with paddle.amp.auto_cast(level=amp_level): + custom_black_list = config['Global'].get('amp_custom_black_list',[]) + with paddle.amp.auto_cast(level=amp_level, custom_black_list=custom_black_list): if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) elif model_type in ["kie", 'vqa']: