提交 aa7a0cb2 编写于 作者: 文幕地方's avatar 文幕地方

fix bug in amp eval

上级 9e4ae9dc
......@@ -8,7 +8,7 @@ Global:
# evaluation is run every 835 iterations
eval_batch_step: [0, 4000]
cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
pretrained_model: pretrain_models/det_r50_dcn_fce_ctw_v2.0_train/best_accuracy.pdparams
checkpoints:
save_inference_dir:
use_visualdl: False
......
......@@ -13,7 +13,7 @@ train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg
null:null
##
trainer:norm_train
norm_train:tools/train.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
norm_train:tools/train.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
......@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Architecture.Backbone.checkpoints:
norm_export:tools/export_model.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o
norm_export:tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o
quant_export:
fpgm_export:
distill_export:null
......
......@@ -154,12 +154,13 @@ def check_xpu(use_xpu):
except Exception as e:
pass
def to_float32(preds):
if isinstance(preds, dict):
for k in preds:
if isinstance(preds[k], dict) or isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
elif isinstance(preds[k], paddle.Tensor):
preds[k] = preds[k].astype(paddle.float32)
elif isinstance(preds, list):
for k in range(len(preds)):
......@@ -167,12 +168,13 @@ def to_float32(preds):
preds[k] = to_float32(preds[k])
elif isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
elif isinstance(preds[k], paddle.Tensor):
preds[k] = preds[k].astype(paddle.float32)
else:
elif isinstance(preds[k], paddle.Tensor):
preds = preds.astype(paddle.float32)
return preds
def train(config,
train_dataloader,
valid_dataloader,
......@@ -370,7 +372,8 @@ def train(config,
post_process_class,
eval_class,
model_type,
extra_input=extra_input)
extra_input=extra_input,
scaler=scaler)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str)
......@@ -460,7 +463,8 @@ def eval(model,
post_process_class,
eval_class,
model_type=None,
extra_input=False):
extra_input=False,
scaler=None):
model.eval()
with paddle.no_grad():
total_frame = 0.0
......@@ -477,12 +481,24 @@ def eval(model,
break
images = batch[0]
start = time.time()
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
# use amp
if scaler:
with paddle.amp.auto_cast(level='O2'):
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
else:
preds = model(images)
else:
preds = model(images)
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
preds = model(batch)
else:
preds = model(images)
batch_numpy = []
for item in batch:
if isinstance(item, paddle.Tensor):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册