提交 1b93e92c 编写于 作者: 文幕地方's avatar 文幕地方

fix amp bug

上级 b742e399
......@@ -162,18 +162,18 @@ def to_float32(preds):
for k in preds:
if isinstance(preds[k], dict) or isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
preds[k] = paddle.to_tensor(preds[k], dtype='float32')
elif isinstance(preds[k], paddle.Tensor):
preds[k] = preds[k].astype(paddle.float32)
elif isinstance(preds, list):
for k in range(len(preds)):
if isinstance(preds[k], dict):
preds[k] = to_float32(preds[k])
elif isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
preds[k] = paddle.to_tensor(preds[k], dtype='float32')
else:
preds = paddle.to_tensor(preds, dtype='float32')
elif isinstance(preds[k], paddle.Tensor):
preds[k] = preds[k].astype(paddle.float32)
elif isinstance(preds, paddle.Tensor):
preds = preds.astype(paddle.float32)
return preds
......@@ -190,7 +190,8 @@ def train(config,
pre_best_model_dict,
logger,
log_writer=None,
scaler=None):
scaler=None,
amp_level='O2'):
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False)
calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
......@@ -276,7 +277,7 @@ def train(config,
model_average = True
# use amp
if scaler:
with paddle.amp.auto_cast(level='O2'):
with paddle.amp.auto_cast(level=amp_level):
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
......@@ -514,6 +515,7 @@ def eval(model,
sum_images, i), fm_lr)
else:
preds = model(images)
preds = to_float32(preds)
else:
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
......
......@@ -147,6 +147,7 @@ def main(config, device, logger, vdl_writer):
len(valid_dataloader)))
use_amp = config["Global"].get("use_amp", False)
amp_level = config["Global"].get("amp_level", 'O2')
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
......@@ -159,8 +160,9 @@ def main(config, device, logger, vdl_writer):
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
model, optimizer = paddle.amp.decorate(
models=model, optimizers=optimizer, level='O2', master_weight=True)
if amp_level == "O2":
model, optimizer = paddle.amp.decorate(
models=model, optimizers=optimizer, level=amp_level, master_weight=True)
else:
scaler = None
......@@ -169,7 +171,7 @@ def main(config, device, logger, vdl_writer):
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer, scaler)
eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level)
def test_reader(config, device, logger):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册