未验证 提交 2dc058ad 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] update amp, add amp_level (#6054)

上级 84faecbc
......@@ -32,7 +32,6 @@ import paddle
import paddle.nn as nn
import paddle.distributed as dist
from paddle.distributed import fleet
from paddle import amp
from paddle.static import InputSpec
from ppdet.optimizer import ModelEMA
......@@ -380,13 +379,21 @@ class Trainer(object):
self.cfg['EvalDataset'] = self.cfg.EvalDataset = create(
"EvalDataset")()
model = self.model
sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
self.cfg.use_gpu and self._nranks > 1)
if sync_bn:
self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
self.model)
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = self.model
# enabel auto mixed precision mode
use_amp = self.cfg.get('amp', False)
amp_level = self.cfg.get('amp_level', 'O1')
if use_amp:
scaler = paddle.amp.GradScaler(
enable=self.cfg.use_gpu or self.cfg.use_npu,
init_loss_scaling=self.cfg.get('init_loss_scaling', 1024))
model = paddle.amp.decorate(models=model, level=amp_level)
# get distributed model
if self.cfg.get('fleet', False):
model = fleet.distributed_model(model)
self.optimizer = fleet.distributed_optimizer(self.optimizer)
......@@ -394,13 +401,7 @@ class Trainer(object):
find_unused_parameters = self.cfg[
'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
model = paddle.DataParallel(
self.model, find_unused_parameters=find_unused_parameters)
# enabel auto mixed precision mode
if self.cfg.get('amp', False):
scaler = amp.GradScaler(
enable=self.cfg.use_gpu or self.cfg.use_npu,
init_loss_scaling=1024)
model, find_unused_parameters=find_unused_parameters)
self.status.update({
'epoch_id': self.start_epoch,
......@@ -436,12 +437,12 @@ class Trainer(object):
self._compose_callback.on_step_begin(self.status)
data['epoch_id'] = epoch_id
if self.cfg.get('amp', False):
with amp.auto_cast(enable=self.cfg.use_gpu):
if use_amp:
with paddle.amp.auto_cast(
enable=self.cfg.use_gpu, level=amp_level):
# model forward
outputs = model(data)
loss = outputs['loss']
# model backward
scaled_loss = scaler.scale(loss)
scaled_loss.backward()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册