diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index d3c4d7d9c51070e03304169418647ca7b9d68cdf..1829bf019a8ca804d337f25817fcdf3b34b03325 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -32,7 +32,6 @@ import paddle import paddle.nn as nn import paddle.distributed as dist from paddle.distributed import fleet -from paddle import amp from paddle.static import InputSpec from ppdet.optimizer import ModelEMA @@ -380,13 +379,21 @@ class Trainer(object): self.cfg['EvalDataset'] = self.cfg.EvalDataset = create( "EvalDataset")() + model = self.model sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and self.cfg.use_gpu and self._nranks > 1) if sync_bn: - self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm( - self.model) + model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) - model = self.model + # enabel auto mixed precision mode + use_amp = self.cfg.get('amp', False) + amp_level = self.cfg.get('amp_level', 'O1') + if use_amp: + scaler = paddle.amp.GradScaler( + enable=self.cfg.use_gpu or self.cfg.use_npu, + init_loss_scaling=self.cfg.get('init_loss_scaling', 1024)) + model = paddle.amp.decorate(models=model, level=amp_level) + # get distributed model if self.cfg.get('fleet', False): model = fleet.distributed_model(model) self.optimizer = fleet.distributed_optimizer(self.optimizer) @@ -394,13 +401,7 @@ class Trainer(object): find_unused_parameters = self.cfg[ 'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False model = paddle.DataParallel( - self.model, find_unused_parameters=find_unused_parameters) - - # enabel auto mixed precision mode - if self.cfg.get('amp', False): - scaler = amp.GradScaler( - enable=self.cfg.use_gpu or self.cfg.use_npu, - init_loss_scaling=1024) + model, find_unused_parameters=find_unused_parameters) self.status.update({ 'epoch_id': self.start_epoch, @@ -436,12 +437,12 @@ class Trainer(object): self._compose_callback.on_step_begin(self.status) data['epoch_id'] = epoch_id - if self.cfg.get('amp', False): - with amp.auto_cast(enable=self.cfg.use_gpu): + if use_amp: + with paddle.amp.auto_cast( + enable=self.cfg.use_gpu, level=amp_level): # model forward outputs = model(data) loss = outputs['loss'] - # model backward scaled_loss = scaler.scale(loss) scaled_loss.backward()