From df55cb9b8cb3b3ef701738f82068e51528101748 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Thu, 13 Jan 2022 14:25:40 +0800 Subject: [PATCH] update PicoDet and GFL post_process (#5101) --- configs/picodet/_base_/picodet_esnet.yml | 2 +- ppdet/engine/trainer.py | 7 +- ppdet/modeling/architectures/picodet.py | 8 +- ppdet/modeling/bbox_utils.py | 30 ++++--- ppdet/modeling/heads/gfl_head.py | 110 ++++++++--------------- ppdet/modeling/heads/pico_head.py | 20 ++++- 6 files changed, 79 insertions(+), 98 deletions(-) diff --git a/configs/picodet/_base_/picodet_esnet.yml b/configs/picodet/_base_/picodet_esnet.yml index 150d25e9f..24b213f2a 100644 --- a/configs/picodet/_base_/picodet_esnet.yml +++ b/configs/picodet/_base_/picodet_esnet.yml @@ -1,6 +1,6 @@ architecture: PicoDet pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_0_pretrained.pdparams -export_post_process: False # Whether post-processing is included in the network +export_post_process: False # Whether post-processing is included in the network when export model. PicoDet: backbone: ESNet diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 360a71efc..fcf935302 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -631,9 +631,12 @@ class Trainer(object): im_shape = [image_shape[0], 2] scale_factor = [image_shape[0], 2] - export_post_process = self.cfg.get('export_post_process', False) - if hasattr(self.model, 'deploy') and not export_post_process: + if hasattr(self.model, 'deploy'): self.model.deploy = True + export_post_process = self.cfg.get('export_post_process', False) + if hasattr(self.model, 'export_post_process'): + self.model.export_post_process = export_post_process + image_shape = [None] + image_shape[1:] if hasattr(self.model, 'fuse_norm'): self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize', False) diff --git a/ppdet/modeling/architectures/picodet.py b/ppdet/modeling/architectures/picodet.py index cd807a9fa..7331fb712 100644 --- a/ppdet/modeling/architectures/picodet.py +++ b/ppdet/modeling/architectures/picodet.py @@ -41,7 +41,7 @@ class PicoDet(BaseArch): self.backbone = backbone self.neck = neck self.head = head - self.deploy = False + self.export_post_process = True @classmethod def from_config(cls, cfg, *args, **kwargs): @@ -62,8 +62,8 @@ class PicoDet(BaseArch): def _forward(self): body_feats = self.backbone(self.inputs) fpn_feats = self.neck(body_feats) - head_outs = self.head(fpn_feats, self.deploy) - if self.training or self.deploy: + head_outs = self.head(fpn_feats, self.export_post_process) + if self.training or not self.export_post_process: return head_outs, None else: im_shape = self.inputs['im_shape'] @@ -83,7 +83,7 @@ class PicoDet(BaseArch): return loss def get_pred(self): - if self.deploy: + if not self.export_post_process: return {'picodet': self._forward()[0]} else: bbox_pred, bbox_num = self._forward() diff --git a/ppdet/modeling/bbox_utils.py b/ppdet/modeling/bbox_utils.py index 49a2e281f..cc3d7bb05 100644 --- a/ppdet/modeling/bbox_utils.py +++ b/ppdet/modeling/bbox_utils.py @@ -756,20 +756,22 @@ def bbox_center(boxes): def batch_distance2bbox(points, distance, max_shapes=None): """Decode distance prediction to bounding box for batch. Args: - points (Tensor): [B, ..., 2] - distance (Tensor): [B, ..., 4] - max_shapes (tuple): [B, 2], "h,w" format, Shape of the image. + points (Tensor): [B, ..., 2], "xy" format + distance (Tensor): [B, ..., 4], "ltrb" format + max_shapes (Tensor): [B, 2], "h,w" format, Shape of the image. Returns: - Tensor: Decoded bboxes. + Tensor: Decoded bboxes, "x1y1x2y2" format. """ - x1 = points[..., 0] - distance[..., 0] - y1 = points[..., 1] - distance[..., 1] - x2 = points[..., 0] + distance[..., 2] - y2 = points[..., 1] + distance[..., 3] + lt, rb = paddle.split(distance, 2, -1) + x1y1 = points - lt + x2y2 = points + rb + out_bbox = paddle.concat([x1y1, x2y2], -1) if max_shapes is not None: - for i, max_shape in enumerate(max_shapes): - x1[i] = x1[i].clip(min=0, max=max_shape[1]) - y1[i] = y1[i].clip(min=0, max=max_shape[0]) - x2[i] = x2[i].clip(min=0, max=max_shape[1]) - y2[i] = y2[i].clip(min=0, max=max_shape[0]) - return paddle.stack([x1, y1, x2, y2], -1) + max_shapes = max_shapes.flip(-1).tile([1, 2]) + delta_dim = out_bbox.ndim - max_shapes.ndim + for _ in range(delta_dim): + max_shapes.unsqueeze_(1) + out_bbox = paddle.where(out_bbox < max_shapes, out_bbox, max_shapes) + out_bbox = paddle.where(out_bbox > 0, out_bbox, + paddle.zeros_like(out_bbox)) + return out_bbox diff --git a/ppdet/modeling/heads/gfl_head.py b/ppdet/modeling/heads/gfl_head.py index 17e87a4ef..95980397d 100644 --- a/ppdet/modeling/heads/gfl_head.py +++ b/ppdet/modeling/heads/gfl_head.py @@ -29,7 +29,7 @@ from paddle.nn.initializer import Normal, Constant from ppdet.core.workspace import register from ppdet.modeling.layers import ConvNormLayer -from ppdet.modeling.bbox_utils import distance2bbox, bbox2distance +from ppdet.modeling.bbox_utils import distance2bbox, bbox2distance, batch_distance2bbox from ppdet.data.transform.atss_assigner import bbox_overlaps @@ -241,18 +241,34 @@ class GFLHead(nn.Layer): ), "The size of fpn_feats is not equal to size of fpn_stride" cls_logits_list = [] bboxes_reg_list = [] - for scale_reg, fpn_feat in zip(self.scales_regs, fpn_feats): + for stride, scale_reg, fpn_feat in zip(self.fpn_stride, + self.scales_regs, fpn_feats): conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat) - cls_logits = self.gfl_head_cls(conv_cls_feat) - bbox_reg = scale_reg(self.gfl_head_reg(conv_reg_feat)) + cls_score = self.gfl_head_cls(conv_cls_feat) + bbox_pred = scale_reg(self.gfl_head_reg(conv_reg_feat)) if self.dgqp_module: - quality_score = self.dgqp_module(bbox_reg) - cls_logits = F.sigmoid(cls_logits) * quality_score + quality_score = self.dgqp_module(bbox_pred) + cls_score = F.sigmoid(cls_score) * quality_score if not self.training: - cls_logits = F.sigmoid(cls_logits.transpose([0, 2, 3, 1])) - bbox_reg = bbox_reg.transpose([0, 2, 3, 1]) - cls_logits_list.append(cls_logits) - bboxes_reg_list.append(bbox_reg) + cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1])) + bbox_pred = bbox_pred.transpose([0, 2, 3, 1]) + b, cell_h, cell_w, _ = paddle.shape(cls_score) + y, x = self.get_single_level_center_point( + [cell_h, cell_w], stride, cell_offset=self.cell_offset) + center_points = paddle.stack([x, y], axis=-1) + cls_score = cls_score.reshape([b, -1, self.cls_out_channels]) + bbox_pred = self.distribution_project(bbox_pred) * stride + bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4]) + + # NOTE: If keep_ratio=False and image shape value that + # multiples of 32, distance2bbox not set max_shapes parameter + # to speed up model prediction. If need to set max_shapes, + # please use inputs['im_shape']. + bbox_pred = batch_distance2bbox( + center_points, bbox_pred, max_shapes=None) + + cls_logits_list.append(cls_score) + bboxes_reg_list.append(bbox_pred) return (cls_logits_list, bboxes_reg_list) @@ -410,71 +426,15 @@ class GFLHead(nn.Layer): x = x.flatten() return y, x - def get_bboxes_single(self, - cls_scores, - bbox_preds, - img_shape, - scale_factor, - rescale=True, - cell_offset=0): - assert len(cls_scores) == len(bbox_preds) - mlvl_bboxes = [] - mlvl_scores = [] - for stride, cls_score, bbox_pred in zip(self.fpn_stride, cls_scores, - bbox_preds): - featmap_size = [ - paddle.shape(cls_score)[0], paddle.shape(cls_score)[1] - ] - y, x = self.get_single_level_center_point( - featmap_size, stride, cell_offset=cell_offset) - center_points = paddle.stack([x, y], axis=-1) - scores = cls_score.reshape([-1, self.cls_out_channels]) - bbox_pred = self.distribution_project(bbox_pred) * stride - - if scores.shape[0] > self.nms_pre: - max_scores = scores.max(axis=1) - _, topk_inds = max_scores.topk(self.nms_pre) - center_points = center_points.gather(topk_inds) - bbox_pred = bbox_pred.gather(topk_inds) - scores = scores.gather(topk_inds) - - bboxes = distance2bbox( - center_points, bbox_pred, max_shape=img_shape) - mlvl_bboxes.append(bboxes) - mlvl_scores.append(scores) - mlvl_bboxes = paddle.concat(mlvl_bboxes) - if rescale: - # [h_scale, w_scale] to [w_scale, h_scale, w_scale, h_scale] - im_scale = paddle.concat([scale_factor[::-1], scale_factor[::-1]]) - mlvl_bboxes /= im_scale - mlvl_scores = paddle.concat(mlvl_scores) - mlvl_scores = mlvl_scores.transpose([1, 0]) - return mlvl_bboxes, mlvl_scores - - def decode(self, cls_scores, bbox_preds, im_shape, scale_factor, - cell_offset): - batch_bboxes = [] - batch_scores = [] - for img_id in range(cls_scores[0].shape[0]): - num_levels = len(cls_scores) - cls_score_list = [cls_scores[i][img_id] for i in range(num_levels)] - bbox_pred_list = [bbox_preds[i][img_id] for i in range(num_levels)] - bboxes, scores = self.get_bboxes_single( - cls_score_list, - bbox_pred_list, - im_shape[img_id], - scale_factor[img_id], - cell_offset=cell_offset) - batch_bboxes.append(bboxes) - batch_scores.append(scores) - batch_bboxes = paddle.stack(batch_bboxes, axis=0) - batch_scores = paddle.stack(batch_scores, axis=0) - - return batch_bboxes, batch_scores - def post_process(self, gfl_head_outs, im_shape, scale_factor): cls_scores, bboxes_reg = gfl_head_outs - bboxes, score = self.decode(cls_scores, bboxes_reg, im_shape, - scale_factor, self.cell_offset) - bbox_pred, bbox_num, _ = self.nms(bboxes, score) + bboxes = paddle.concat(bboxes_reg, axis=1) + # rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale] + im_scale = paddle.concat( + [scale_factor[:, ::-1], scale_factor[:, ::-1]], + axis=-1).unsqueeze(1) + bboxes /= im_scale + mlvl_scores = paddle.concat(cls_scores, axis=1) + mlvl_scores = mlvl_scores.transpose([0, 2, 1]) + bbox_pred, bbox_num, _ = self.nms(bboxes, mlvl_scores) return bbox_pred, bbox_num diff --git a/ppdet/modeling/heads/pico_head.py b/ppdet/modeling/heads/pico_head.py index 7cfd24c3c..022acb045 100644 --- a/ppdet/modeling/heads/pico_head.py +++ b/ppdet/modeling/heads/pico_head.py @@ -26,6 +26,7 @@ from paddle.nn.initializer import Normal, Constant from ppdet.core.workspace import register from ppdet.modeling.layers import ConvNormLayer +from ppdet.modeling.bbox_utils import batch_distance2bbox from .simota_head import OTAVFLHead @@ -238,7 +239,7 @@ class PicoHead(OTAVFLHead): bias_attr=ParamAttr(initializer=Constant(value=0)))) self.head_reg_list.append(head_reg) - def forward(self, fpn_feats, deploy=False): + def forward(self, fpn_feats, export_post_process=True): assert len(fpn_feats) == len( self.fpn_stride ), "The size of fpn_feats is not equal to size of fpn_stride" @@ -260,7 +261,7 @@ class PicoHead(OTAVFLHead): quality_score = self.dgqp_module(bbox_pred) cls_score = F.sigmoid(cls_score) * quality_score - if deploy: + if not export_post_process: # Now only supports batch size = 1 in deploy # TODO(ygh): support batch size > 1 cls_score = F.sigmoid(cls_score).reshape( @@ -270,6 +271,21 @@ class PicoHead(OTAVFLHead): elif not self.training: cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1])) bbox_pred = bbox_pred.transpose([0, 2, 3, 1]) + stride = self.fpn_stride[i] + b, cell_h, cell_w, _ = paddle.shape(cls_score) + y, x = self.get_single_level_center_point( + [cell_h, cell_w], stride, cell_offset=self.cell_offset) + center_points = paddle.stack([x, y], axis=-1) + cls_score = cls_score.reshape([b, -1, self.cls_out_channels]) + bbox_pred = self.distribution_project(bbox_pred) * stride + bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4]) + + # NOTE: If keep_ratio=False and image shape value that + # multiples of 32, distance2bbox not set max_shapes parameter + # to speed up model prediction. If need to set max_shapes, + # please use inputs['im_shape']. + bbox_pred = batch_distance2bbox( + center_points, bbox_pred, max_shapes=None) cls_logits_list.append(cls_score) bboxes_reg_list.append(bbox_pred) -- GitLab