diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index b086705acf0f8850c1d8cc1e6f6c33bdbe4bc589..cd4e7d9a3776462e4e302ac55e616cb7b9f11e54 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -65,6 +65,8 @@ class Trainer(object): self.mode = mode.lower() self.optimizer = None self.is_loaded_weights = False + self.use_amp = self.cfg.get('amp', False) + self.amp_level = self.cfg.get('amp_level', 'O1') # build data loader capital_mode = self.mode.capitalize() @@ -124,17 +126,6 @@ class Trainer(object): else: self.model.load_meanstd(cfg['TestReader']['sample_transforms']) - self.use_ema = ('use_ema' in cfg and cfg['use_ema']) - if self.use_ema: - ema_decay = self.cfg.get('ema_decay', 0.9998) - cycle_epoch = self.cfg.get('cycle_epoch', -1) - ema_decay_type = self.cfg.get('ema_decay_type', 'threshold') - self.ema = ModelEMA( - self.model, - decay=ema_decay, - ema_decay_type=ema_decay_type, - cycle_epoch=cycle_epoch) - # EvalDataset build with BatchSampler to evaluate in single device # TODO: multi-device evaluate if self.mode == 'eval': @@ -162,6 +153,20 @@ class Trainer(object): self.pruner = create('UnstructuredPruner')(self.model, steps_per_epoch) + if self.use_amp and self.amp_level == 'O2': + self.model = paddle.amp.decorate( + models=self.model, level=self.amp_level) + self.use_ema = ('use_ema' in cfg and cfg['use_ema']) + if self.use_ema: + ema_decay = self.cfg.get('ema_decay', 0.9998) + cycle_epoch = self.cfg.get('cycle_epoch', -1) + ema_decay_type = self.cfg.get('ema_decay_type', 'threshold') + self.ema = ModelEMA( + self.model, + decay=ema_decay, + ema_decay_type=ema_decay_type, + cycle_epoch=cycle_epoch) + self._nranks = dist.get_world_size() self._local_rank = dist.get_rank() @@ -387,13 +392,10 @@ class Trainer(object): model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) # enabel auto mixed precision mode - use_amp = self.cfg.get('amp', False) - amp_level = self.cfg.get('amp_level', 'O1') - if use_amp: + if self.use_amp: scaler = paddle.amp.GradScaler( enable=self.cfg.use_gpu or self.cfg.use_npu, init_loss_scaling=self.cfg.get('init_loss_scaling', 1024)) - model = paddle.amp.decorate(models=model, level=amp_level) # get distributed model if self.cfg.get('fleet', False): model = fleet.distributed_model(model) @@ -438,9 +440,9 @@ class Trainer(object): self._compose_callback.on_step_begin(self.status) data['epoch_id'] = epoch_id - if use_amp: + if self.use_amp: with paddle.amp.auto_cast( - enable=self.cfg.use_gpu, level=amp_level): + enable=self.cfg.use_gpu, level=self.amp_level): # model forward outputs = model(data) loss = outputs['loss'] @@ -532,7 +534,12 @@ class Trainer(object): self.status['step_id'] = step_id self._compose_callback.on_step_begin(self.status) # forward - outs = self.model(data) + if self.use_amp: + with paddle.amp.auto_cast( + enable=self.cfg.use_gpu, level=self.amp_level): + outs = self.model(data) + else: + outs = self.model(data) # update metrics for metric in self._metrics: diff --git a/ppdet/modeling/assigners/utils.py b/ppdet/modeling/assigners/utils.py index 6a89593a316da8cc9221fee872f34a8452200751..0bc399315797b4be04954858fac5cccbbd73ee33 100644 --- a/ppdet/modeling/assigners/utils.py +++ b/ppdet/modeling/assigners/utils.py @@ -176,7 +176,8 @@ def compute_max_iou_gt(ious): def generate_anchors_for_grid_cell(feats, fpn_strides, grid_cell_size=5.0, - grid_cell_offset=0.5): + grid_cell_offset=0.5, + dtype='float32'): r""" Like ATSS, generate anchors based on grid size. Args: @@ -206,16 +207,15 @@ def generate_anchors_for_grid_cell(feats, shift_x - cell_half_size, shift_y - cell_half_size, shift_x + cell_half_size, shift_y + cell_half_size ], - axis=-1).astype(feat.dtype) - anchor_point = paddle.stack( - [shift_x, shift_y], axis=-1).astype(feat.dtype) + axis=-1).astype(dtype) + anchor_point = paddle.stack([shift_x, shift_y], axis=-1).astype(dtype) anchors.append(anchor.reshape([-1, 4])) anchor_points.append(anchor_point.reshape([-1, 2])) num_anchors_list.append(len(anchors[-1])) stride_tensor.append( paddle.full( - [num_anchors_list[-1], 1], stride, dtype=feat.dtype)) + [num_anchors_list[-1], 1], stride, dtype=dtype)) anchors = paddle.concat(anchors) anchors.stop_gradient = True anchor_points = paddle.concat(anchor_points) diff --git a/ppdet/modeling/heads/ppyoloe_head.py b/ppdet/modeling/heads/ppyoloe_head.py index 15eff3d4d5d2275de6332ceb328cf6ba6445c9ad..4e9c303dc64252b26a9ff3153cf65f34b53196a4 100644 --- a/ppdet/modeling/heads/ppyoloe_head.py +++ b/ppdet/modeling/heads/ppyoloe_head.py @@ -160,7 +160,7 @@ class PPYOLOEHead(nn.Layer): num_anchors_list, stride_tensor ], targets) - def _generate_anchors(self, feats=None): + def _generate_anchors(self, feats=None, dtype='float32'): # just use in eval time anchor_points = [] stride_tensor = [] @@ -175,11 +175,9 @@ class PPYOLOEHead(nn.Layer): shift_y, shift_x = paddle.meshgrid(shift_y, shift_x) anchor_point = paddle.cast( paddle.stack( - [shift_x, shift_y], axis=-1), dtype='float32') + [shift_x, shift_y], axis=-1), dtype=dtype) anchor_points.append(anchor_point.reshape([-1, 2])) - stride_tensor.append( - paddle.full( - [h * w, 1], stride, dtype='float32')) + stride_tensor.append(paddle.full([h * w, 1], stride, dtype=dtype)) anchor_points = paddle.concat(anchor_points) stride_tensor = paddle.concat(stride_tensor) return anchor_points, stride_tensor diff --git a/tools/eval.py b/tools/eval.py index 3128261b068aba9f5b8cf2ba882d11e5a9646143..f2e4fd0490ab44e40a2b0479ef60bb1f9cbbb76b 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -77,6 +77,12 @@ def parse_args(): default=False, help='Whether to save the evaluation results only') + parser.add_argument( + "--amp", + action='store_true', + default=False, + help="Enable auto mixed precision eval.") + args = parser.parse_args() return args