提交 aa7a0cb2 编写于 作者: 文幕地方's avatar 文幕地方

fix bug in amp eval

上级 9e4ae9dc
...@@ -8,7 +8,7 @@ Global: ...@@ -8,7 +8,7 @@ Global:
# evaluation is run every 835 iterations # evaluation is run every 835 iterations
eval_batch_step: [0, 4000] eval_batch_step: [0, 4000]
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained pretrained_model: pretrain_models/det_r50_dcn_fce_ctw_v2.0_train/best_accuracy.pdparams
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
......
...@@ -13,7 +13,7 @@ train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg ...@@ -13,7 +13,7 @@ train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false norm_train:tools/train.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -27,7 +27,7 @@ null:null ...@@ -27,7 +27,7 @@ null:null
===========================infer_params=========================== ===========================infer_params===========================
Global.save_inference_dir:./output/ Global.save_inference_dir:./output/
Architecture.Backbone.checkpoints: Architecture.Backbone.checkpoints:
norm_export:tools/export_model.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o norm_export:tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o
quant_export: quant_export:
fpgm_export: fpgm_export:
distill_export:null distill_export:null
......
...@@ -154,12 +154,13 @@ def check_xpu(use_xpu): ...@@ -154,12 +154,13 @@ def check_xpu(use_xpu):
except Exception as e: except Exception as e:
pass pass
def to_float32(preds): def to_float32(preds):
if isinstance(preds, dict): if isinstance(preds, dict):
for k in preds: for k in preds:
if isinstance(preds[k], dict) or isinstance(preds[k], list): if isinstance(preds[k], dict) or isinstance(preds[k], list):
preds[k] = to_float32(preds[k]) preds[k] = to_float32(preds[k])
else: elif isinstance(preds[k], paddle.Tensor):
preds[k] = preds[k].astype(paddle.float32) preds[k] = preds[k].astype(paddle.float32)
elif isinstance(preds, list): elif isinstance(preds, list):
for k in range(len(preds)): for k in range(len(preds)):
...@@ -167,12 +168,13 @@ def to_float32(preds): ...@@ -167,12 +168,13 @@ def to_float32(preds):
preds[k] = to_float32(preds[k]) preds[k] = to_float32(preds[k])
elif isinstance(preds[k], list): elif isinstance(preds[k], list):
preds[k] = to_float32(preds[k]) preds[k] = to_float32(preds[k])
else: elif isinstance(preds[k], paddle.Tensor):
preds[k] = preds[k].astype(paddle.float32) preds[k] = preds[k].astype(paddle.float32)
else: elif isinstance(preds[k], paddle.Tensor):
preds = preds.astype(paddle.float32) preds = preds.astype(paddle.float32)
return preds return preds
def train(config, def train(config,
train_dataloader, train_dataloader,
valid_dataloader, valid_dataloader,
...@@ -370,7 +372,8 @@ def train(config, ...@@ -370,7 +372,8 @@ def train(config,
post_process_class, post_process_class,
eval_class, eval_class,
model_type, model_type,
extra_input=extra_input) extra_input=extra_input,
scaler=scaler)
cur_metric_str = 'cur metric, {}'.format(', '.join( cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()])) ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str) logger.info(cur_metric_str)
...@@ -460,7 +463,8 @@ def eval(model, ...@@ -460,7 +463,8 @@ def eval(model,
post_process_class, post_process_class,
eval_class, eval_class,
model_type=None, model_type=None,
extra_input=False): extra_input=False,
scaler=None):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
...@@ -477,12 +481,24 @@ def eval(model, ...@@ -477,12 +481,24 @@ def eval(model,
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) # use amp
elif model_type in ["kie", 'vqa']: if scaler:
preds = model(batch) with paddle.amp.auto_cast(level='O2'):
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
else:
preds = model(images)
else: else:
preds = model(images) if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
else:
preds = model(images)
batch_numpy = [] batch_numpy = []
for item in batch: for item in batch:
if isinstance(item, paddle.Tensor): if isinstance(item, paddle.Tensor):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册