提交 205592a3 编写于 作者: Z zhangbo9674 提交者: Tingquan Gao

fix amp with distribute bug

上级 15cffcc0
...@@ -212,6 +212,14 @@ class Engine(object): ...@@ -212,6 +212,14 @@ class Engine(object):
self.config["Optimizer"], self.config["Global"]["epochs"], self.config["Optimizer"], self.config["Global"]["epochs"],
len(self.train_dataloader), [self.model]) len(self.train_dataloader), [self.model])
# for amp training
if self.amp:
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.scale_loss,
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
if self.config['AMP']['use_pure_fp16'] is True:
self.model = paddle.amp.decorate(models=self.model, level='O2')
# for distributed # for distributed
self.config["Global"][ self.config["Global"][
"distributed"] = paddle.distributed.get_world_size() != 1 "distributed"] = paddle.distributed.get_world_size() != 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册