提交 b54ee044 编写于 作者: Z zhangbo9674

Accelerate dynamic graph amp training

上级 7732a69f
......@@ -250,6 +250,8 @@ class Engine(object):
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.scale_loss,
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
if self.config['AMP']['use_pure_fp16'] is True:
self.model = paddle.amp.decorate(models=self.model, level='O2')
self.max_iter = len(self.train_dataloader) - 1 if platform.system(
) == "Windows" else len(self.train_dataloader)
......
......@@ -41,14 +41,15 @@ def train_epoch(engine, epoch_id, print_batch_step):
# image input
if engine.amp:
with paddle.amp.auto_cast(custom_black_list={
"flatten_contiguous_range", "greater_than"
}):
amp_level = 'O1'
if engine.config['AMP']['use_pure_fp16'] is True:
amp_level = 'O2'
with paddle.amp.auto_cast(custom_black_list={"flatten_contiguous_range", "greater_than"}, level=amp_level):
out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, batch[1])
else:
out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, batch[1])
loss_dict = engine.train_loss_func(out, batch[1])
# step opt and lr
if engine.amp:
......@@ -58,7 +59,7 @@ def train_epoch(engine, epoch_id, print_batch_step):
else:
loss_dict["loss"].backward()
engine.optimizer.step()
engine.optimizer.clear_grad()
engine.optimizer.clear_grad(set_to_zero=True)
engine.lr_sch.step()
# below code just for logging
......
......@@ -36,13 +36,15 @@ class Momentum(object):
momentum,
weight_decay=None,
grad_clip=None,
multi_precision=False):
multi_precision=True,
use_multi_tensor=True):
super().__init__()
self.learning_rate = learning_rate
self.momentum = momentum
self.weight_decay = weight_decay
self.grad_clip = grad_clip
self.multi_precision = multi_precision
self.use_multi_tensor = use_multi_tensor
def __call__(self, model_list):
# model_list is None in static graph
......@@ -54,6 +56,7 @@ class Momentum(object):
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
multi_precision=self.multi_precision,
use_multi_tensor=self.use_multi_tensor,
parameters=parameters)
return opt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册