未验证 提交 2dc058ad 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] update amp, add amp_level (#6054)

上级 84faecbc
...@@ -32,7 +32,6 @@ import paddle ...@@ -32,7 +32,6 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle import amp
from paddle.static import InputSpec from paddle.static import InputSpec
from ppdet.optimizer import ModelEMA from ppdet.optimizer import ModelEMA
...@@ -380,13 +379,21 @@ class Trainer(object): ...@@ -380,13 +379,21 @@ class Trainer(object):
self.cfg['EvalDataset'] = self.cfg.EvalDataset = create( self.cfg['EvalDataset'] = self.cfg.EvalDataset = create(
"EvalDataset")() "EvalDataset")()
model = self.model
sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
self.cfg.use_gpu and self._nranks > 1) self.cfg.use_gpu and self._nranks > 1)
if sync_bn: if sync_bn:
self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm( model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
self.model)
model = self.model # enabel auto mixed precision mode
use_amp = self.cfg.get('amp', False)
amp_level = self.cfg.get('amp_level', 'O1')
if use_amp:
scaler = paddle.amp.GradScaler(
enable=self.cfg.use_gpu or self.cfg.use_npu,
init_loss_scaling=self.cfg.get('init_loss_scaling', 1024))
model = paddle.amp.decorate(models=model, level=amp_level)
# get distributed model
if self.cfg.get('fleet', False): if self.cfg.get('fleet', False):
model = fleet.distributed_model(model) model = fleet.distributed_model(model)
self.optimizer = fleet.distributed_optimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer)
...@@ -394,13 +401,7 @@ class Trainer(object): ...@@ -394,13 +401,7 @@ class Trainer(object):
find_unused_parameters = self.cfg[ find_unused_parameters = self.cfg[
'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False 'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
model = paddle.DataParallel( model = paddle.DataParallel(
self.model, find_unused_parameters=find_unused_parameters) model, find_unused_parameters=find_unused_parameters)
# enabel auto mixed precision mode
if self.cfg.get('amp', False):
scaler = amp.GradScaler(
enable=self.cfg.use_gpu or self.cfg.use_npu,
init_loss_scaling=1024)
self.status.update({ self.status.update({
'epoch_id': self.start_epoch, 'epoch_id': self.start_epoch,
...@@ -436,12 +437,12 @@ class Trainer(object): ...@@ -436,12 +437,12 @@ class Trainer(object):
self._compose_callback.on_step_begin(self.status) self._compose_callback.on_step_begin(self.status)
data['epoch_id'] = epoch_id data['epoch_id'] = epoch_id
if self.cfg.get('amp', False): if use_amp:
with amp.auto_cast(enable=self.cfg.use_gpu): with paddle.amp.auto_cast(
enable=self.cfg.use_gpu, level=amp_level):
# model forward # model forward
outputs = model(data) outputs = model(data)
loss = outputs['loss'] loss = outputs['loss']
# model backward # model backward
scaled_loss = scaler.scale(loss) scaled_loss = scaler.scale(loss)
scaled_loss.backward() scaled_loss.backward()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册