未验证 提交 1dbff736 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #7298 from WenmuZhou/tipc

Simpify train infer python
...@@ -104,8 +104,9 @@ def load_model(config, model, optimizer=None, model_type='det'): ...@@ -104,8 +104,9 @@ def load_model(config, model, optimizer=None, model_type='det'):
continue continue
pre_value = params[key] pre_value = params[key]
if pre_value.dtype == paddle.float16: if pre_value.dtype == paddle.float16:
pre_value = pre_value.astype(paddle.float32)
is_float16 = True is_float16 = True
if pre_value.dtype != value.dtype:
pre_value = pre_value.astype(value.dtype)
if list(value.shape) == list(pre_value.shape): if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value new_state_dict[key] = pre_value
else: else:
...@@ -162,8 +163,9 @@ def load_pretrained_params(model, path): ...@@ -162,8 +163,9 @@ def load_pretrained_params(model, path):
logger.warning("The pretrained params {} not in model".format(k1)) logger.warning("The pretrained params {} not in model".format(k1))
else: else:
if params[k1].dtype == paddle.float16: if params[k1].dtype == paddle.float16:
params[k1] = params[k1].astype(paddle.float32)
is_float16 = True is_float16 = True
if params[k1].dtype != state_dict[k1].dtype:
params[k1] = params[k1].astype(state_dict[k1].dtype)
if list(state_dict[k1].shape) == list(params[k1].shape): if list(state_dict[k1].shape) == list(params[k1].shape):
new_state_dict[k1] = params[k1] new_state_dict[k1] = params[k1]
else: else:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_image_shape="3,48,320" ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_image_shape="3,48,320"
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 5 print_batch_step: 5
save_model_dir: ./output/table_mv3/ save_model_dir: ./output/table_mv3/
save_epoch_step: 3 save_epoch_step: 400
# evaluation is run every 400 iterations after the 0th iteration # evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 40000] eval_batch_step: [0, 40000]
cal_metric_during_train: True cal_metric_during_train: True
...@@ -17,7 +17,8 @@ Global: ...@@ -17,7 +17,8 @@ Global:
# for data or label process # for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en character_type: en
max_text_length: 800 max_text_length: &max_text_length 500
box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy'
infer_mode: False infer_mode: False
Optimizer: Optimizer:
...@@ -37,12 +38,14 @@ Architecture: ...@@ -37,12 +38,14 @@ Architecture:
Backbone: Backbone:
name: MobileNetV3 name: MobileNetV3
scale: 1.0 scale: 1.0
model_name: large model_name: small
disable_se: true
Head: Head:
name: TableAttentionHead name: TableAttentionHead
hidden_size: 256 hidden_size: 256
loc_type: 2 loc_type: 2
max_text_length: 800 max_text_length: *max_text_length
loc_reg_num: &loc_reg_num 4
Loss: Loss:
name: TableAttentionLoss name: TableAttentionLoss
...@@ -70,6 +73,8 @@ Train: ...@@ -70,6 +73,8 @@ Train:
learn_empty_box: False learn_empty_box: False
merge_no_span_structure: False merge_no_span_structure: False
replace_empty_cell_token: False replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- TableBoxEncode: - TableBoxEncode:
- ResizeTableImage: - ResizeTableImage:
max_len: 488 max_len: 488
...@@ -102,6 +107,8 @@ Eval: ...@@ -102,6 +107,8 @@ Eval:
learn_empty_box: False learn_empty_box: False
merge_no_span_structure: False merge_no_span_structure: False
replace_empty_cell_token: False replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
- TableBoxEncode: - TableBoxEncode:
- ResizeTableImage: - ResizeTableImage:
max_len: 488 max_len: 488
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/EN_symbo ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/EN_symbo
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -39,11 +39,11 @@ infer_export:tools/export_model.py -c test_tipc/configs/rec_r31_robustscanner/re ...@@ -39,11 +39,11 @@ infer_export:tools/export_model.py -c test_tipc/configs/rec_r31_robustscanner/re
infer_quant:False infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict90.txt --rec_image_shape="3,48,48,160" --use_space_char=False --rec_algorithm="RobustScanner" inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict90.txt --rec_image_shape="3,48,48,160" --use_space_char=False --rec_algorithm="RobustScanner"
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:True|False --enable_mkldnn:False
--cpu_threads:1|6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False|False --use_tensorrt:False
--precision:fp32|int8 --precision:fp32
--rec_model_dir: --rec_model_dir:
--image_dir:./inference/rec_inference --image_dir:./inference/rec_inference
--save_log_path:./test/output/ --save_log_path:./test/output/
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict90.t ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict90.t
--use_gpu:True --use_gpu:True
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/spi ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/spi
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/EN_symbo ...@@ -41,7 +41,7 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/EN_symbo
--use_gpu:True|False --use_gpu:True|False
--enable_mkldnn:False --enable_mkldnn:False
--cpu_threads:6 --cpu_threads:6
--rec_batch_num:1|6 --rec_batch_num:1
--use_tensorrt:False --use_tensorrt:False
--precision:fp32 --precision:fp32
--rec_model_dir: --rec_model_dir:
......
...@@ -221,7 +221,6 @@ if [ ${MODE} = "lite_train_lite_infer" ];then ...@@ -221,7 +221,6 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
fi fi
if [ ${model_name} == "layoutxlm_ser" ] || [ ${model_name} == "vi_layoutxlm_ser" ]; then if [ ${model_name} == "layoutxlm_ser" ] || [ ${model_name} == "vi_layoutxlm_ser" ]; then
pip install -r ppstructure/kie/requirements.txt pip install -r ppstructure/kie/requirements.txt
pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
cd ./train_data/ && tar xf XFUND.tar cd ./train_data/ && tar xf XFUND.tar
cd ../ cd ../
......
...@@ -23,6 +23,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) ...@@ -23,6 +23,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, __dir__) sys.path.insert(0, __dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
import paddle
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
...@@ -86,6 +87,30 @@ def main(): ...@@ -86,6 +87,30 @@ def main():
else: else:
model_type = None model_type = None
# build metric
eval_class = build_metric(config['Metric'])
# amp
use_amp = config["Global"].get("use_amp", False)
amp_level = config["Global"].get("amp_level", 'O2')
amp_custom_black_list = config['Global'].get('amp_custom_black_list',[])
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
'FLAGS_max_inplace_grad_add': 8,
}
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["Global"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["Global"].get(
"use_dynamic_loss_scaling", False)
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
if amp_level == "O2":
model = paddle.amp.decorate(
models=model, level=amp_level, master_weight=True)
else:
scaler = None
best_model_dict = load_model( best_model_dict = load_model(
config, model, model_type=config['Architecture']["model_type"]) config, model, model_type=config['Architecture']["model_type"])
if len(best_model_dict): if len(best_model_dict):
...@@ -93,11 +118,9 @@ def main(): ...@@ -93,11 +118,9 @@ def main():
for k, v in best_model_dict.items(): for k, v in best_model_dict.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
# build metric
eval_class = build_metric(config['Metric'])
# start eval # start eval
metric = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, extra_input) eval_class, model_type, extra_input, scaler, amp_level, amp_custom_black_list)
logger.info('metric eval ***************') logger.info('metric eval ***************')
for k, v in metric.items(): for k, v in metric.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
......
...@@ -191,7 +191,8 @@ def train(config, ...@@ -191,7 +191,8 @@ def train(config,
logger, logger,
log_writer=None, log_writer=None,
scaler=None, scaler=None,
amp_level='O2'): amp_level='O2',
amp_custom_black_list=[]):
cal_metric_during_train = config['Global'].get('cal_metric_during_train', cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False) False)
calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1) calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
...@@ -278,10 +279,7 @@ def train(config, ...@@ -278,10 +279,7 @@ def train(config,
model_average = True model_average = True
# use amp # use amp
if scaler: if scaler:
custom_black_list = config['Global'].get( with paddle.amp.auto_cast(level=amp_level, custom_black_list=amp_custom_black_list):
'amp_custom_black_list', [])
with paddle.amp.auto_cast(
level=amp_level, custom_black_list=custom_black_list):
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie"]: elif model_type in ["kie"]:
...@@ -386,7 +384,9 @@ def train(config, ...@@ -386,7 +384,9 @@ def train(config,
eval_class, eval_class,
model_type, model_type,
extra_input=extra_input, extra_input=extra_input,
scaler=scaler) scaler=scaler,
amp_level=amp_level,
amp_custom_black_list=amp_custom_black_list)
cur_metric_str = 'cur metric, {}'.format(', '.join( cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()])) ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str) logger.info(cur_metric_str)
...@@ -477,7 +477,9 @@ def eval(model, ...@@ -477,7 +477,9 @@ def eval(model,
eval_class, eval_class,
model_type=None, model_type=None,
extra_input=False, extra_input=False,
scaler=None): scaler=None,
amp_level='O2',
amp_custom_black_list = []):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
...@@ -498,7 +500,7 @@ def eval(model, ...@@ -498,7 +500,7 @@ def eval(model,
# use amp # use amp
if scaler: if scaler:
with paddle.amp.auto_cast(level='O2'): with paddle.amp.auto_cast(level=amp_level, custom_black_list=amp_custom_black_list):
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie"]: elif model_type in ["kie"]:
......
...@@ -138,9 +138,7 @@ def main(config, device, logger, vdl_writer): ...@@ -138,9 +138,7 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = load_model(config, model, optimizer,
config['Architecture']["model_type"])
logger.info('train dataloader has {} iters'.format(len(train_dataloader))) logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None: if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format( logger.info('valid dataloader has {} iters'.format(
...@@ -148,6 +146,7 @@ def main(config, device, logger, vdl_writer): ...@@ -148,6 +146,7 @@ def main(config, device, logger, vdl_writer):
use_amp = config["Global"].get("use_amp", False) use_amp = config["Global"].get("use_amp", False)
amp_level = config["Global"].get("amp_level", 'O2') amp_level = config["Global"].get("amp_level", 'O2')
amp_custom_black_list = config['Global'].get('amp_custom_black_list',[])
if use_amp: if use_amp:
AMP_RELATED_FLAGS_SETTING = { AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1, 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
...@@ -166,12 +165,16 @@ def main(config, device, logger, vdl_writer): ...@@ -166,12 +165,16 @@ def main(config, device, logger, vdl_writer):
else: else:
scaler = None scaler = None
# load pretrain model
pre_best_model_dict = load_model(config, model, optimizer,
config['Architecture']["model_type"])
if config['Global']['distributed']: if config['Global']['distributed']:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level) eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level, amp_custom_black_list)
def test_reader(config, device, logger): def test_reader(config, device, logger):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册