diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index f86125521d19342f63a9fcb3bdcaed02cc4c6463..aa65f290c0a5f4f13b3103fb4404815e2ae74a88 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -104,8 +104,9 @@ def load_model(config, model, optimizer=None, model_type='det'): continue pre_value = params[key] if pre_value.dtype == paddle.float16: - pre_value = pre_value.astype(paddle.float32) is_float16 = True + if pre_value.dtype != value.dtype: + pre_value = pre_value.astype(value.dtype) if list(value.shape) == list(pre_value.shape): new_state_dict[key] = pre_value else: @@ -162,8 +163,9 @@ def load_pretrained_params(model, path): logger.warning("The pretrained params {} not in model".format(k1)) else: if params[k1].dtype == paddle.float16: - params[k1] = params[k1].astype(paddle.float32) is_float16 = True + if params[k1].dtype != state_dict[k1].dtype: + params[k1] = params[k1].astype(state_dict[k1].dtype) if list(state_dict[k1].shape) == list(params[k1].shape): new_state_dict[k1] = params[k1] else: diff --git a/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt index a96b87dede1e1b4c7b3ed59c4bd9c0470402e7e2..6d20b2df7420371ce964cf8fd5cb29726c000d1d 100644 --- a/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt +++ b/test_tipc/configs/ch_PP-OCRv2_rec/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt index 59fc1bd4160ec77edb0b781c8ffa9845c6a3d5c7..fee08b08ede0f61ae4f57fd42dba303301798a3e 100644 --- a/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt +++ b/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_image_shape="3,48,320" --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_infer_python.txt index 40f397948936beba0a3a4bdce9aa4a9953ec9d0f..dc490cdc60c2c012549e6fd00c13ec18676ede20 100644 --- a/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_infer_python.txt +++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_infer_python.txt b/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_infer_python.txt index b9a1ae4984c30a08d75b73b884ceb97658eb11c7..85741f98c3fd645a64d8820a046030f1bb7e03c7 100644 --- a/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_infer_python.txt +++ b/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/en_table_structure/table_mv3.yml b/test_tipc/configs/en_table_structure/table_mv3.yml index 6ff31fc262b4380b4cc5258a7b2e098ada39dba0..edcbe2c3b00e8d8a56ad8dd9f208e283b511b86e 100755 --- a/test_tipc/configs/en_table_structure/table_mv3.yml +++ b/test_tipc/configs/en_table_structure/table_mv3.yml @@ -4,7 +4,7 @@ Global: log_smooth_window: 20 print_batch_step: 5 save_model_dir: ./output/table_mv3/ - save_epoch_step: 3 + save_epoch_step: 400 # evaluation is run every 400 iterations after the 0th iteration eval_batch_step: [0, 40000] cal_metric_during_train: True @@ -17,7 +17,8 @@ Global: # for data or label process character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_type: en - max_text_length: 800 + max_text_length: &max_text_length 500 + box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy' infer_mode: False Optimizer: @@ -37,12 +38,14 @@ Architecture: Backbone: name: MobileNetV3 scale: 1.0 - model_name: large + model_name: small + disable_se: true Head: name: TableAttentionHead hidden_size: 256 loc_type: 2 - max_text_length: 800 + max_text_length: *max_text_length + loc_reg_num: &loc_reg_num 4 Loss: name: TableAttentionLoss @@ -70,6 +73,8 @@ Train: learn_empty_box: False merge_no_span_structure: False replace_empty_cell_token: False + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length - TableBoxEncode: - ResizeTableImage: max_len: 488 @@ -102,6 +107,8 @@ Eval: learn_empty_box: False merge_no_span_structure: False replace_empty_cell_token: False + loc_reg_num: *loc_reg_num + max_text_length: *max_text_length - TableBoxEncode: - ResizeTableImage: max_len: 488 diff --git a/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt b/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt index fed8ba26753bb770e062f751a9ba1e8e35fc6843..4a8fda0fea76da41a0a13b61f35d96a4d230d488 100644 --- a/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt +++ b/test_tipc/configs/rec_mtb_nrtr/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/EN_symbo --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/train_infer_python.txt b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/train_infer_python.txt index db89b4c78d72d1853096d6b44b73a7ca61792dfe..22c29c9b233ac908741accd7eb85fb3832fb0c0f 100644 --- a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/train_infer_python.txt +++ b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/rec_mv3_none_none_ctc_v2_0/train_infer_python.txt b/test_tipc/configs/rec_mv3_none_none_ctc_v2_0/train_infer_python.txt index 003e91ff3d95e62d4353d7c4545e780ecd2f9708..d91c55e8852eee2cc7913235308f6d1f31e1f2e9 100644 --- a/test_tipc/configs/rec_mv3_none_none_ctc_v2_0/train_infer_python.txt +++ b/test_tipc/configs/rec_mv3_none_none_ctc_v2_0/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/train_infer_python.txt b/test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/train_infer_python.txt index c7b416c83323863a905929a2effcb1d3ad856422..77dc79cdae8bf4843ad17282885b46a33e64ce53 100644 --- a/test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/train_infer_python.txt +++ b/test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/train_infer_python.txt b/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/train_infer_python.txt index 0c6e2d1da7f163521e8859bd8c96436b2a6bac64..f38c8d8d67bae84232749e60952a5c73871f9a88 100644 --- a/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/train_infer_python.txt +++ b/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/rec_r31_robustscanner/train_infer_python.txt b/test_tipc/configs/rec_r31_robustscanner/train_infer_python.txt index 07498c9e81ada9652343b8d8fff0f102d4684380..336e6c73fdcef9a06c540a3a28a706d3aff716c7 100644 --- a/test_tipc/configs/rec_r31_robustscanner/train_infer_python.txt +++ b/test_tipc/configs/rec_r31_robustscanner/train_infer_python.txt @@ -39,11 +39,11 @@ infer_export:tools/export_model.py -c test_tipc/configs/rec_r31_robustscanner/re infer_quant:False inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict90.txt --rec_image_shape="3,48,48,160" --use_space_char=False --rec_algorithm="RobustScanner" --use_gpu:True|False ---enable_mkldnn:True|False ---cpu_threads:1|6 ---rec_batch_num:1|6 ---use_tensorrt:False|False ---precision:fp32|int8 +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 --rec_model_dir: --image_dir:./inference/rec_inference --save_log_path:./test/output/ diff --git a/test_tipc/configs/rec_r31_sar/train_infer_python.txt b/test_tipc/configs/rec_r31_sar/train_infer_python.txt index 03ec54abb65ac41d3b5ad4f6e2fdcf7abb34c344..4acc6223e3b65211d62f2f128150e1c76f286674 100644 --- a/test_tipc/configs/rec_r31_sar/train_infer_python.txt +++ b/test_tipc/configs/rec_r31_sar/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict90.t --use_gpu:True --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt index 115dfd661abc64db9e14c629f79099be7b6ff0e0..ac378b36046d532a887056183de9c7788f628b76 100644 --- a/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt +++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/spi --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/train_infer_python.txt index 07a6190b0ef09da5cd20b9dd8ea922544c578710..b53efbd6ba5db36813733f6682bde1cfd614c6ee 100644 --- a/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/train_infer_python.txt +++ b/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/train_infer_python.txt index 145793aa472d8330daf9321f44692a03e7ef6354..7d953968b8a9d3f62f7c6fb48ed65bd9743d5ba3 100644 --- a/test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/train_infer_python.txt +++ b/test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/train_infer_python.txt index 759518a4a11a17e076401bb8dd193617c9f10530..0910ff840e350333a26de9b959229b6f8d39c19e 100644 --- a/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/train_infer_python.txt +++ b/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/train_infer_python.txt index ecc898341ce14dfed0de4290b798dd70078ae2da..33144e622e5fbb399e6dd274196812e2d44dc0fd 100644 --- a/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/train_infer_python.txt +++ b/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/rec_r45_abinet/train_infer_python.txt b/test_tipc/configs/rec_r45_abinet/train_infer_python.txt index ecab1bcbbde11fc6d14357b6715033704c2c3316..04fc188649c77c62b43307cb2fff2249f28bddae 100644 --- a/test_tipc/configs/rec_r45_abinet/train_infer_python.txt +++ b/test_tipc/configs/rec_r45_abinet/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/rec_r45_visionlan/train_infer_python.txt b/test_tipc/configs/rec_r45_visionlan/train_infer_python.txt index c08ae7beb6c867bf36283e60dc1e70cfd9ee06a7..79618edafa794a683e085fb1b8050358342e1f77 100644 --- a/test_tipc/configs/rec_r45_visionlan/train_infer_python.txt +++ b/test_tipc/configs/rec_r45_visionlan/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt b/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt index b5a5286010a5830dc23031b3e0885247fb6ae53f..c1cfd1fcd930c6992982feeb3c118dbc5a56f226 100644 --- a/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt +++ b/test_tipc/configs/rec_r50_fpn_vd_none_srn/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/rec_svtrnet/train_infer_python.txt b/test_tipc/configs/rec_svtrnet/train_infer_python.txt index a7e4a24063b2e248f2ab92d5efd257a2837c0a34..5508c0411cfdc7102ccec7a00c59c2a5e1a54998 100644 --- a/test_tipc/configs/rec_svtrnet/train_infer_python.txt +++ b/test_tipc/configs/rec_svtrnet/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/configs/rec_vitstr_none_ce/train_infer_python.txt b/test_tipc/configs/rec_vitstr_none_ce/train_infer_python.txt index 04c5742ea2ddaf01e782d8b39c21bcbcfa0a7ce7..187c11544998626af556e3eeef5f958fbe42fea0 100644 --- a/test_tipc/configs/rec_vitstr_none_ce/train_infer_python.txt +++ b/test_tipc/configs/rec_vitstr_none_ce/train_infer_python.txt @@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/EN_symbo --use_gpu:True|False --enable_mkldnn:False --cpu_threads:6 ---rec_batch_num:1|6 +--rec_batch_num:1 --use_tensorrt:False --precision:fp32 --rec_model_dir: diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index 6aea98a734e0fce8df00293b5362851144a7b119..4ea1803738d3c80f07dbdf1916e68809210c5b52 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -221,7 +221,6 @@ if [ ${MODE} = "lite_train_lite_infer" ];then fi if [ ${model_name} == "layoutxlm_ser" ] || [ ${model_name} == "vi_layoutxlm_ser" ]; then pip install -r ppstructure/kie/requirements.txt - pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate cd ./train_data/ && tar xf XFUND.tar cd ../ diff --git a/tools/eval.py b/tools/eval.py index 38d72d178db45a4787ddc09c865afba9222f385a..3d1d3813d33e251ec83a9729383fe772bc4cc225 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -23,6 +23,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, __dir__) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) +import paddle from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process @@ -86,6 +87,30 @@ def main(): else: model_type = None + # build metric + eval_class = build_metric(config['Metric']) + # amp + use_amp = config["Global"].get("use_amp", False) + amp_level = config["Global"].get("amp_level", 'O2') + amp_custom_black_list = config['Global'].get('amp_custom_black_list',[]) + if use_amp: + AMP_RELATED_FLAGS_SETTING = { + 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, + 'FLAGS_max_inplace_grad_add': 8, + } + paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) + scale_loss = config["Global"].get("scale_loss", 1.0) + use_dynamic_loss_scaling = config["Global"].get( + "use_dynamic_loss_scaling", False) + scaler = paddle.amp.GradScaler( + init_loss_scaling=scale_loss, + use_dynamic_loss_scaling=use_dynamic_loss_scaling) + if amp_level == "O2": + model = paddle.amp.decorate( + models=model, level=amp_level, master_weight=True) + else: + scaler = None + best_model_dict = load_model( config, model, model_type=config['Architecture']["model_type"]) if len(best_model_dict): @@ -93,11 +118,9 @@ def main(): for k, v in best_model_dict.items(): logger.info('{}:{}'.format(k, v)) - # build metric - eval_class = build_metric(config['Metric']) # start eval metric = program.eval(model, valid_dataloader, post_process_class, - eval_class, model_type, extra_input) + eval_class, model_type, extra_input, scaler, amp_level, amp_custom_black_list) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) diff --git a/tools/program.py b/tools/program.py index 7af1fe7354106f06b4384abb56de7675e4dbe053..16d3d4035af933cda01b422ea56e9e2895ec2b88 100755 --- a/tools/program.py +++ b/tools/program.py @@ -191,7 +191,8 @@ def train(config, logger, log_writer=None, scaler=None, - amp_level='O2'): + amp_level='O2', + amp_custom_black_list=[]): cal_metric_during_train = config['Global'].get('cal_metric_during_train', False) calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1) @@ -278,10 +279,7 @@ def train(config, model_average = True # use amp if scaler: - custom_black_list = config['Global'].get( - 'amp_custom_black_list', []) - with paddle.amp.auto_cast( - level=amp_level, custom_black_list=custom_black_list): + with paddle.amp.auto_cast(level=amp_level, custom_black_list=amp_custom_black_list): if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) elif model_type in ["kie"]: @@ -386,7 +384,9 @@ def train(config, eval_class, model_type, extra_input=extra_input, - scaler=scaler) + scaler=scaler, + amp_level=amp_level, + amp_custom_black_list=amp_custom_black_list) cur_metric_str = 'cur metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metric.items()])) logger.info(cur_metric_str) @@ -477,7 +477,9 @@ def eval(model, eval_class, model_type=None, extra_input=False, - scaler=None): + scaler=None, + amp_level='O2', + amp_custom_black_list = []): model.eval() with paddle.no_grad(): total_frame = 0.0 @@ -498,7 +500,7 @@ def eval(model, # use amp if scaler: - with paddle.amp.auto_cast(level='O2'): + with paddle.amp.auto_cast(level=amp_level, custom_black_list=amp_custom_black_list): if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) elif model_type in ["kie"]: diff --git a/tools/train.py b/tools/train.py index 5f310938f3ae3488281b47ccdb436697595b5578..d0f200189e34265b3c080ac9e25eb80d29c705b7 100755 --- a/tools/train.py +++ b/tools/train.py @@ -138,9 +138,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) - # load pretrain model - pre_best_model_dict = load_model(config, model, optimizer, - config['Architecture']["model_type"]) + logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: logger.info('valid dataloader has {} iters'.format( @@ -148,6 +146,7 @@ def main(config, device, logger, vdl_writer): use_amp = config["Global"].get("use_amp", False) amp_level = config["Global"].get("amp_level", 'O2') + amp_custom_black_list = config['Global'].get('amp_custom_black_list',[]) if use_amp: AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, @@ -166,12 +165,16 @@ def main(config, device, logger, vdl_writer): else: scaler = None + # load pretrain model + pre_best_model_dict = load_model(config, model, optimizer, + config['Architecture']["model_type"]) + if config['Global']['distributed']: model = paddle.DataParallel(model) # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, - eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level) + eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level, amp_custom_black_list) def test_reader(config, device, logger):