diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index af222444ee6caa8f2dfe89207dee525b85fd6020..65e8d0b8bcffee8180633b5c417c4b74b5197a18 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -455,8 +455,11 @@ class DETRPostProcess(object): use_focal_loss=False, with_mask=False, mask_threshold=0.5, - use_avg_mask_score=False): + use_avg_mask_score=False, + bbox_decode_type='origin'): super(DETRPostProcess, self).__init__() + assert bbox_decode_type in ['origin', 'pad'] + self.num_classes = num_classes self.num_top_queries = num_top_queries self.dual_queries = dual_queries @@ -465,6 +468,7 @@ class DETRPostProcess(object): self.with_mask = with_mask self.mask_threshold = mask_threshold self.use_avg_mask_score = use_avg_mask_score + self.bbox_decode_type = bbox_decode_type def _mask_postprocess(self, mask_pred, score_pred, index): mask_score = F.sigmoid(paddle.gather_nd(mask_pred, index)) @@ -478,7 +482,7 @@ class DETRPostProcess(object): def __call__(self, head_out, im_shape, scale_factor, pad_shape): """ - Decode the bbox. + Decode the bbox and mask. Args: head_out (tuple): bbox_pred, cls_logit and masks of bbox_head output. @@ -502,9 +506,15 @@ class DETRPostProcess(object): # calculate the original shape of the image origin_shape = paddle.floor(im_shape / scale_factor + 0.5) img_h, img_w = paddle.split(origin_shape, 2, axis=-1) - # calculate the shape of the image with padding - out_shape = pad_shape / im_shape * origin_shape - out_shape = out_shape.flip(1).tile([1, 2]).unsqueeze(1) + if self.bbox_decode_type == 'pad': + # calculate the shape of the image with padding + out_shape = pad_shape / im_shape * origin_shape + out_shape = out_shape.flip(1).tile([1, 2]).unsqueeze(1) + elif self.bbox_decode_type == 'origin': + out_shape = origin_shape.flip(1).tile([1, 2]).unsqueeze(1) + else: + raise Exception( + f'Wrong `bbox_decode_type`: {self.bbox_decode_type}.') bbox_pred *= out_shape scores = F.sigmoid(logits) if self.use_focal_loss else F.softmax(