提交 205592a3 编写于 作者: Z zhangbo9674 提交者: Tingquan Gao

fix amp with distribute bug

上级 15cffcc0
......@@ -211,6 +211,14 @@ class Engine(object):
self.optimizer, self.lr_sch = build_optimizer(
self.config["Optimizer"], self.config["Global"]["epochs"],
len(self.train_dataloader), [self.model])
# for amp training
if self.amp:
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.scale_loss,
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
if self.config['AMP']['use_pure_fp16'] is True:
self.model = paddle.amp.decorate(models=self.model, level='O2')
# for distributed
self.config["Global"][
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册