提交 e7b16011 编写于 作者: 文幕地方's avatar 文幕地方

fix fp16 bug

上级 28daed6f
......@@ -119,9 +119,6 @@ def main(config, device, logger, vdl_writer):
config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture'])
if config['Global']['distributed']:
model = paddle.DataParallel(model)
model = apply_to_static(model, config, logger)
# build loss
......@@ -157,10 +154,13 @@ def main(config, device, logger, vdl_writer):
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
model, optimizer = paddle.amp.decorate(models=model, optimizers=optimizer, level='O2', master_weight=True)
model, optimizer = paddle.amp.decorate(
models=model, optimizers=optimizer, level='O2', master_weight=True)
else:
scaler = None
if config['Global']['distributed']:
model = paddle.DataParallel(model)
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册