未验证 提交 d496c8e8 编写于 作者: Z Zhang Ting 提交者: GitHub

fix performance (#9772)

上级 43abe2fa
...@@ -21,6 +21,7 @@ Global: ...@@ -21,6 +21,7 @@ Global:
save_res_path: ./output/ser/xfund_zh/res save_res_path: ./output/ser/xfund_zh/res
kie_rec_model_dir: kie_rec_model_dir:
kie_det_model_dir: kie_det_model_dir:
amp_custom_white_list: ['scale', 'concat', 'elementwise_add']
Architecture: Architecture:
model_type: kie model_type: kie
......
...@@ -22,6 +22,7 @@ Global: ...@@ -22,6 +22,7 @@ Global:
use_sync_bn: True use_sync_bn: True
save_res_path: 'output/infer' save_res_path: 'output/infer'
d2s_train_image_shape: [3, -1, -1] d2s_train_image_shape: [3, -1, -1]
amp_custom_white_list: ['concat', 'elementwise_sub', 'set_value']
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -22,6 +22,7 @@ Global: ...@@ -22,6 +22,7 @@ Global:
use_sync_bn: True use_sync_bn: True
save_res_path: 'output/infer' save_res_path: 'output/infer'
d2s_train_image_shape: [3, -1, -1] d2s_train_image_shape: [3, -1, -1]
amp_custom_white_list: ['concat', 'elementwise_sub', 'set_value']
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -188,7 +188,8 @@ def train(config, ...@@ -188,7 +188,8 @@ def train(config,
log_writer=None, log_writer=None,
scaler=None, scaler=None,
amp_level='O2', amp_level='O2',
amp_custom_black_list=[]): amp_custom_black_list=[],
amp_custom_white_list=[]):
cal_metric_during_train = config['Global'].get('cal_metric_during_train', cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False) False)
calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1) calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
...@@ -277,7 +278,8 @@ def train(config, ...@@ -277,7 +278,8 @@ def train(config,
if scaler: if scaler:
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
level=amp_level, level=amp_level,
custom_black_list=amp_custom_black_list): custom_black_list=amp_custom_black_list,
custom_white_list=amp_custom_white_list):
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie"]: elif model_type in ["kie"]:
......
...@@ -161,6 +161,7 @@ def main(config, device, logger, vdl_writer): ...@@ -161,6 +161,7 @@ def main(config, device, logger, vdl_writer):
use_amp = config["Global"].get("use_amp", False) use_amp = config["Global"].get("use_amp", False)
amp_level = config["Global"].get("amp_level", 'O2') amp_level = config["Global"].get("amp_level", 'O2')
amp_custom_black_list = config['Global'].get('amp_custom_black_list', []) amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
amp_custom_white_list = config['Global'].get('amp_custom_white_list', [])
if use_amp: if use_amp:
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, } AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
if paddle.is_compiled_with_cuda(): if paddle.is_compiled_with_cuda():
...@@ -194,7 +195,7 @@ def main(config, device, logger, vdl_writer): ...@@ -194,7 +195,7 @@ def main(config, device, logger, vdl_writer):
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer, scaler, eval_class, pre_best_model_dict, logger, vdl_writer, scaler,
amp_level, amp_custom_black_list) amp_level, amp_custom_black_list, amp_custom_white_list)
def test_reader(config, device, logger): def test_reader(config, device, logger):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册