未验证 提交 f5692c3f 编写于 作者: A andyjpaddle 提交者: GitHub

Merge pull request #7143 from WenmuZhou/tttt

fix bug in amp eval
...@@ -8,7 +8,7 @@ Global: ...@@ -8,7 +8,7 @@ Global:
# evaluation is run every 835 iterations # evaluation is run every 835 iterations
eval_batch_step: [0, 4000] eval_batch_step: [0, 4000]
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained pretrained_model: pretrain_models/det_r50_dcn_fce_ctw_v2.0_train/best_accuracy.pdparams
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
......
...@@ -13,7 +13,7 @@ train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg ...@@ -13,7 +13,7 @@ train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false norm_train:tools/train.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -27,7 +27,7 @@ null:null ...@@ -27,7 +27,7 @@ null:null
===========================infer_params=========================== ===========================infer_params===========================
Global.save_inference_dir:./output/ Global.save_inference_dir:./output/
Architecture.Backbone.checkpoints: Architecture.Backbone.checkpoints:
norm_export:tools/export_model.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o norm_export:tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o
quant_export: quant_export:
fpgm_export: fpgm_export:
distill_export:null distill_export:null
......
...@@ -372,7 +372,8 @@ def train(config, ...@@ -372,7 +372,8 @@ def train(config,
post_process_class, post_process_class,
eval_class, eval_class,
model_type, model_type,
extra_input=extra_input) extra_input=extra_input,
scaler=scaler)
cur_metric_str = 'cur metric, {}'.format(', '.join( cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()])) ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str) logger.info(cur_metric_str)
...@@ -462,7 +463,8 @@ def eval(model, ...@@ -462,7 +463,8 @@ def eval(model,
post_process_class, post_process_class,
eval_class, eval_class,
model_type=None, model_type=None,
extra_input=False): extra_input=False,
scaler=None):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
...@@ -479,12 +481,24 @@ def eval(model, ...@@ -479,12 +481,24 @@ def eval(model,
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) # use amp
elif model_type in ["kie", 'vqa']: if scaler:
preds = model(batch) with paddle.amp.auto_cast(level='O2'):
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
else:
preds = model(images)
else: else:
preds = model(images) if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
else:
preds = model(images)
batch_numpy = [] batch_numpy = []
for item in batch: for item in batch:
if isinstance(item, paddle.Tensor): if isinstance(item, paddle.Tensor):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册