未验证 提交 8d0e52a1 编写于 作者: S shangliang Xu 提交者: GitHub

fix bbox decode in detrpostprocess (#7916)

上级 d9553e34
......@@ -455,8 +455,11 @@ class DETRPostProcess(object):
use_focal_loss=False,
with_mask=False,
mask_threshold=0.5,
use_avg_mask_score=False):
use_avg_mask_score=False,
bbox_decode_type='origin'):
super(DETRPostProcess, self).__init__()
assert bbox_decode_type in ['origin', 'pad']
self.num_classes = num_classes
self.num_top_queries = num_top_queries
self.dual_queries = dual_queries
......@@ -465,6 +468,7 @@ class DETRPostProcess(object):
self.with_mask = with_mask
self.mask_threshold = mask_threshold
self.use_avg_mask_score = use_avg_mask_score
self.bbox_decode_type = bbox_decode_type
def _mask_postprocess(self, mask_pred, score_pred, index):
mask_score = F.sigmoid(paddle.gather_nd(mask_pred, index))
......@@ -478,7 +482,7 @@ class DETRPostProcess(object):
def __call__(self, head_out, im_shape, scale_factor, pad_shape):
"""
Decode the bbox.
Decode the bbox and mask.
Args:
head_out (tuple): bbox_pred, cls_logit and masks of bbox_head output.
......@@ -502,9 +506,15 @@ class DETRPostProcess(object):
# calculate the original shape of the image
origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
img_h, img_w = paddle.split(origin_shape, 2, axis=-1)
# calculate the shape of the image with padding
out_shape = pad_shape / im_shape * origin_shape
out_shape = out_shape.flip(1).tile([1, 2]).unsqueeze(1)
if self.bbox_decode_type == 'pad':
# calculate the shape of the image with padding
out_shape = pad_shape / im_shape * origin_shape
out_shape = out_shape.flip(1).tile([1, 2]).unsqueeze(1)
elif self.bbox_decode_type == 'origin':
out_shape = origin_shape.flip(1).tile([1, 2]).unsqueeze(1)
else:
raise Exception(
f'Wrong `bbox_decode_type`: {self.bbox_decode_type}.')
bbox_pred *= out_shape
scores = F.sigmoid(logits) if self.use_focal_loss else F.softmax(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册