提交 1b93e92c 编写于 作者: 文幕地方's avatar 文幕地方

fix amp bug

上级 b742e399
...@@ -162,18 +162,18 @@ def to_float32(preds): ...@@ -162,18 +162,18 @@ def to_float32(preds):
for k in preds: for k in preds:
if isinstance(preds[k], dict) or isinstance(preds[k], list): if isinstance(preds[k], dict) or isinstance(preds[k], list):
preds[k] = to_float32(preds[k]) preds[k] = to_float32(preds[k])
else: elif isinstance(preds[k], paddle.Tensor):
preds[k] = paddle.to_tensor(preds[k], dtype='float32') preds[k] = preds[k].astype(paddle.float32)
elif isinstance(preds, list): elif isinstance(preds, list):
for k in range(len(preds)): for k in range(len(preds)):
if isinstance(preds[k], dict): if isinstance(preds[k], dict):
preds[k] = to_float32(preds[k]) preds[k] = to_float32(preds[k])
elif isinstance(preds[k], list): elif isinstance(preds[k], list):
preds[k] = to_float32(preds[k]) preds[k] = to_float32(preds[k])
else: elif isinstance(preds[k], paddle.Tensor):
preds[k] = paddle.to_tensor(preds[k], dtype='float32') preds[k] = preds[k].astype(paddle.float32)
else: elif isinstance(preds, paddle.Tensor):
preds = paddle.to_tensor(preds, dtype='float32') preds = preds.astype(paddle.float32)
return preds return preds
...@@ -190,7 +190,8 @@ def train(config, ...@@ -190,7 +190,8 @@ def train(config,
pre_best_model_dict, pre_best_model_dict,
logger, logger,
log_writer=None, log_writer=None,
scaler=None): scaler=None,
amp_level='O2'):
cal_metric_during_train = config['Global'].get('cal_metric_during_train', cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False) False)
calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1) calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
...@@ -276,7 +277,7 @@ def train(config, ...@@ -276,7 +277,7 @@ def train(config,
model_average = True model_average = True
# use amp # use amp
if scaler: if scaler:
with paddle.amp.auto_cast(level='O2'): with paddle.amp.auto_cast(level=amp_level):
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']: elif model_type in ["kie", 'vqa']:
...@@ -514,6 +515,7 @@ def eval(model, ...@@ -514,6 +515,7 @@ def eval(model,
sum_images, i), fm_lr) sum_images, i), fm_lr)
else: else:
preds = model(images) preds = model(images)
preds = to_float32(preds)
else: else:
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
......
...@@ -147,6 +147,7 @@ def main(config, device, logger, vdl_writer): ...@@ -147,6 +147,7 @@ def main(config, device, logger, vdl_writer):
len(valid_dataloader))) len(valid_dataloader)))
use_amp = config["Global"].get("use_amp", False) use_amp = config["Global"].get("use_amp", False)
amp_level = config["Global"].get("amp_level", 'O2')
if use_amp: if use_amp:
AMP_RELATED_FLAGS_SETTING = { AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1, 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
...@@ -159,8 +160,9 @@ def main(config, device, logger, vdl_writer): ...@@ -159,8 +160,9 @@ def main(config, device, logger, vdl_writer):
scaler = paddle.amp.GradScaler( scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss, init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling) use_dynamic_loss_scaling=use_dynamic_loss_scaling)
if amp_level == "O2":
model, optimizer = paddle.amp.decorate( model, optimizer = paddle.amp.decorate(
models=model, optimizers=optimizer, level='O2', master_weight=True) models=model, optimizers=optimizer, level=amp_level, master_weight=True)
else: else:
scaler = None scaler = None
...@@ -169,7 +171,7 @@ def main(config, device, logger, vdl_writer): ...@@ -169,7 +171,7 @@ def main(config, device, logger, vdl_writer):
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer, scaler) eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level)
def test_reader(config, device, logger): def test_reader(config, device, logger):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册