未验证 提交 f45f9ee4 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #1585 from zhangbo9674/dev/resnet50_optimize

Accelerate dynamic graph amp training
...@@ -250,6 +250,8 @@ class Engine(object): ...@@ -250,6 +250,8 @@ class Engine(object):
self.scaler = paddle.amp.GradScaler( self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.scale_loss, init_loss_scaling=self.scale_loss,
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling) use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
if self.config['AMP']['use_pure_fp16'] is True:
self.model = paddle.amp.decorate(models=self.model, level='O2')
self.max_iter = len(self.train_dataloader) - 1 if platform.system( self.max_iter = len(self.train_dataloader) - 1 if platform.system(
) == "Windows" else len(self.train_dataloader) ) == "Windows" else len(self.train_dataloader)
......
...@@ -21,6 +21,7 @@ from ppcls.utils import profiler ...@@ -21,6 +21,7 @@ from ppcls.utils import profiler
def train_epoch(engine, epoch_id, print_batch_step): def train_epoch(engine, epoch_id, print_batch_step):
tic = time.time() tic = time.time()
v_current = [int(i) for i in paddle.__version__.split(".")]
for iter_id, batch in enumerate(engine.train_dataloader): for iter_id, batch in enumerate(engine.train_dataloader):
if iter_id >= engine.max_iter: if iter_id >= engine.max_iter:
break break
...@@ -41,14 +42,15 @@ def train_epoch(engine, epoch_id, print_batch_step): ...@@ -41,14 +42,15 @@ def train_epoch(engine, epoch_id, print_batch_step):
# image input # image input
if engine.amp: if engine.amp:
with paddle.amp.auto_cast(custom_black_list={ amp_level = 'O1'
"flatten_contiguous_range", "greater_than" if engine.config['AMP']['use_pure_fp16'] is True:
}): amp_level = 'O2'
with paddle.amp.auto_cast(custom_black_list={"flatten_contiguous_range", "greater_than"}, level=amp_level):
out = forward(engine, batch) out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, batch[1])
else: else:
out = forward(engine, batch) out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, batch[1])
loss_dict = engine.train_loss_func(out, batch[1])
# step opt and lr # step opt and lr
if engine.amp: if engine.amp:
......
...@@ -17,6 +17,7 @@ from __future__ import division ...@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle import optimizer as optim from paddle import optimizer as optim
import paddle
from ppcls.utils import logger from ppcls.utils import logger
...@@ -36,7 +37,7 @@ class Momentum(object): ...@@ -36,7 +37,7 @@ class Momentum(object):
momentum, momentum,
weight_decay=None, weight_decay=None,
grad_clip=None, grad_clip=None,
multi_precision=False): multi_precision=True):
super().__init__() super().__init__()
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.momentum = momentum self.momentum = momentum
...@@ -55,6 +56,15 @@ class Momentum(object): ...@@ -55,6 +56,15 @@ class Momentum(object):
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
multi_precision=self.multi_precision, multi_precision=self.multi_precision,
parameters=parameters) parameters=parameters)
if hasattr(opt, '_use_multi_tensor'):
opt = optim.Momentum(
learning_rate=self.learning_rate,
momentum=self.momentum,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
multi_precision=self.multi_precision,
parameters=parameters,
use_multi_tensor=True)
return opt return opt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册