未验证 提交 afb3b7a1 编写于 作者: W wangguanzhong 提交者: GitHub

Remove conditional block in RCNN export onnx (#5371)

* support rcnn onnx

* clean code

* update cascade rcnn

* add todo for rpn proposals
上级 57158544
...@@ -121,7 +121,7 @@ def _dump_infer_config(config, path, image_shape, model): ...@@ -121,7 +121,7 @@ def _dump_infer_config(config, path, image_shape, model):
setup_orderdict() setup_orderdict()
use_dynamic_shape = True if image_shape[2] == -1 else False use_dynamic_shape = True if image_shape[2] == -1 else False
infer_cfg = OrderedDict({ infer_cfg = OrderedDict({
'mode': 'fluid', 'mode': 'paddle',
'draw_threshold': 0.5, 'draw_threshold': 0.5,
'metric': config['metric'], 'metric': config['metric'],
'use_dynamic_shape': use_dynamic_shape 'use_dynamic_shape': use_dynamic_shape
......
...@@ -117,8 +117,8 @@ class CascadeRCNN(BaseArch): ...@@ -117,8 +117,8 @@ class CascadeRCNN(BaseArch):
return bbox_pred, bbox_num, None return bbox_pred, bbox_num, None
mask_out = self.mask_head(body_feats, bbox, bbox_num, self.inputs) mask_out = self.mask_head(body_feats, bbox, bbox_num, self.inputs)
origin_shape = self.bbox_post_process.get_origin_shape() origin_shape = self.bbox_post_process.get_origin_shape()
mask_pred = self.mask_post_process(mask_out[:, 0, :, :], bbox_pred, mask_pred = self.mask_post_process(mask_out, bbox_pred, bbox_num,
bbox_num, origin_shape) origin_shape)
return bbox_pred, bbox_num, mask_pred return bbox_pred, bbox_num, mask_pred
def get_loss(self, ): def get_loss(self, ):
......
...@@ -115,8 +115,8 @@ class MaskRCNN(BaseArch): ...@@ -115,8 +115,8 @@ class MaskRCNN(BaseArch):
bbox_pred = self.bbox_post_process.get_pred(bbox, bbox_num, bbox_pred = self.bbox_post_process.get_pred(bbox, bbox_num,
im_shape, scale_factor) im_shape, scale_factor)
origin_shape = self.bbox_post_process.get_origin_shape() origin_shape = self.bbox_post_process.get_origin_shape()
mask_pred = self.mask_post_process(mask_out[:, 0, :, :], bbox_pred, mask_pred = self.mask_post_process(mask_out, bbox_pred, bbox_num,
bbox_num, origin_shape) origin_shape)
return bbox_pred, bbox_num, mask_pred return bbox_pred, bbox_num, mask_pred
def get_loss(self, ): def get_loss(self, ):
......
...@@ -103,7 +103,7 @@ class MaskFeat(nn.Layer): ...@@ -103,7 +103,7 @@ class MaskFeat(nn.Layer):
@register @register
class MaskHead(nn.Layer): class MaskHead(nn.Layer):
__shared__ = ['num_classes'] __shared__ = ['num_classes', 'export_onnx']
__inject__ = ['mask_assigner'] __inject__ = ['mask_assigner']
""" """
RCNN mask head RCNN mask head
...@@ -123,9 +123,11 @@ class MaskHead(nn.Layer): ...@@ -123,9 +123,11 @@ class MaskHead(nn.Layer):
roi_extractor=RoIAlign().__dict__, roi_extractor=RoIAlign().__dict__,
mask_assigner='MaskAssigner', mask_assigner='MaskAssigner',
num_classes=80, num_classes=80,
share_bbox_feat=False): share_bbox_feat=False,
export_onnx=False):
super(MaskHead, self).__init__() super(MaskHead, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.export_onnx = export_onnx
self.roi_extractor = roi_extractor self.roi_extractor = roi_extractor
if isinstance(roi_extractor, dict): if isinstance(roi_extractor, dict):
...@@ -206,7 +208,7 @@ class MaskHead(nn.Layer): ...@@ -206,7 +208,7 @@ class MaskHead(nn.Layer):
rois_num (Tensor): The number of prediction for each batch rois_num (Tensor): The number of prediction for each batch
scale_factor (Tensor): The scale factor from origin size to input size scale_factor (Tensor): The scale factor from origin size to input size
""" """
if rois.shape[0] == 0: if not self.export_onnx and rois.shape[0] == 0:
mask_out = paddle.full([1, 1, 1, 1], -1) mask_out = paddle.full([1, 1, 1, 1], -1)
else: else:
bbox = [rois[:, 2:]] bbox = [rois[:, 2:]]
...@@ -218,19 +220,13 @@ class MaskHead(nn.Layer): ...@@ -218,19 +220,13 @@ class MaskHead(nn.Layer):
mask_feat = self.head(rois_feat) mask_feat = self.head(rois_feat)
mask_logit = self.mask_fcn_logits(mask_feat) mask_logit = self.mask_fcn_logits(mask_feat)
mask_num_class = mask_logit.shape[1] if self.num_classes == 1:
if mask_num_class == 1:
mask_out = F.sigmoid(mask_logit) mask_out = F.sigmoid(mask_logit)
else: else:
num_masks = mask_logit.shape[0] num_masks = paddle.shape(mask_logit)[0]
mask_out = [] index = paddle.arange(num_masks).cast('int32')
# TODO: need to optimize gather mask_out = mask_logit[index, labels]
for i in range(mask_logit.shape[0]): mask_out = F.sigmoid(mask_out)
pred_masks = paddle.unsqueeze(
mask_logit[i, :, :, :], axis=0)
mask = paddle.gather(pred_masks, labels[i], axis=1)
mask_out.append(mask)
mask_out = F.sigmoid(paddle.concat(mask_out))
return mask_out return mask_out
def forward(self, def forward(self,
......
...@@ -363,18 +363,20 @@ class AnchorGeneratorSSD(object): ...@@ -363,18 +363,20 @@ class AnchorGeneratorSSD(object):
@register @register
@serializable @serializable
class RCNNBox(object): class RCNNBox(object):
__shared__ = ['num_classes'] __shared__ = ['num_classes', 'export_onnx']
def __init__(self, def __init__(self,
prior_box_var=[10., 10., 5., 5.], prior_box_var=[10., 10., 5., 5.],
code_type="decode_center_size", code_type="decode_center_size",
box_normalized=False, box_normalized=False,
num_classes=80): num_classes=80,
export_onnx=False):
super(RCNNBox, self).__init__() super(RCNNBox, self).__init__()
self.prior_box_var = prior_box_var self.prior_box_var = prior_box_var
self.code_type = code_type self.code_type = code_type
self.box_normalized = box_normalized self.box_normalized = box_normalized
self.num_classes = num_classes self.num_classes = num_classes
self.export_onnx = export_onnx
def __call__(self, bbox_head_out, rois, im_shape, scale_factor): def __call__(self, bbox_head_out, rois, im_shape, scale_factor):
bbox_pred = bbox_head_out[0] bbox_pred = bbox_head_out[0]
...@@ -382,39 +384,38 @@ class RCNNBox(object): ...@@ -382,39 +384,38 @@ class RCNNBox(object):
roi = rois[0] roi = rois[0]
rois_num = rois[1] rois_num = rois[1]
origin_shape = paddle.floor(im_shape / scale_factor + 0.5) if self.export_onnx:
scale_list = [] onnx_rois_num_per_im = rois_num[0]
origin_shape_list = [] origin_shape = paddle.expand(im_shape[0, :],
[onnx_rois_num_per_im, 2])
batch_size = 1
if isinstance(roi, list):
batch_size = len(roi)
else: else:
batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1]) origin_shape_list = []
# bbox_pred.shape: [N, C*4] if isinstance(roi, list):
for idx in range(batch_size): batch_size = len(roi)
roi_per_im = roi[idx] else:
rois_num_per_im = rois_num[idx] batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
expand_im_shape = paddle.expand(im_shape[idx, :],
[rois_num_per_im, 2]) # bbox_pred.shape: [N, C*4]
origin_shape_list.append(expand_im_shape) for idx in range(batch_size):
rois_num_per_im = rois_num[idx]
expand_im_shape = paddle.expand(im_shape[idx, :],
[rois_num_per_im, 2])
origin_shape_list.append(expand_im_shape)
origin_shape = paddle.concat(origin_shape_list) origin_shape = paddle.concat(origin_shape_list)
# bbox_pred.shape: [N, C*4] # bbox_pred.shape: [N, C*4]
# C=num_classes in faster/mask rcnn(bbox_head), C=1 in cascade rcnn(cascade_head) # C=num_classes in faster/mask rcnn(bbox_head), C=1 in cascade rcnn(cascade_head)
bbox = paddle.concat(roi) bbox = paddle.concat(roi)
if bbox.shape[0] == 0: bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var)
bbox = paddle.zeros([0, bbox_pred.shape[1]], dtype='float32')
else:
bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var)
scores = cls_prob[:, :-1] scores = cls_prob[:, :-1]
# bbox.shape: [N, C, 4] # bbox.shape: [N, C, 4]
# bbox.shape[1] must be equal to scores.shape[1] # bbox.shape[1] must be equal to scores.shape[1]
bbox_num_class = bbox.shape[1] total_num = bbox.shape[0]
if bbox_num_class == 1: bbox_dim = bbox.shape[-1]
bbox = paddle.tile(bbox, [1, self.num_classes, 1]) bbox = paddle.expand(bbox, [total_num, self.num_classes, bbox_dim])
origin_h = paddle.unsqueeze(origin_shape[:, 0], axis=1) origin_h = paddle.unsqueeze(origin_shape[:, 0], axis=1)
origin_w = paddle.unsqueeze(origin_shape[:, 1], axis=1) origin_w = paddle.unsqueeze(origin_shape[:, 1], axis=1)
...@@ -1422,7 +1423,7 @@ class ConvMixer(nn.Layer): ...@@ -1422,7 +1423,7 @@ class ConvMixer(nn.Layer):
Seq, ActBn = nn.Sequential, lambda x: Seq(x, nn.GELU(), nn.BatchNorm2D(dim)) Seq, ActBn = nn.Sequential, lambda x: Seq(x, nn.GELU(), nn.BatchNorm2D(dim))
Residual = type('Residual', (Seq, ), Residual = type('Residual', (Seq, ),
{'forward': lambda self, x: self[0](x) + x}) {'forward': lambda self, x: self[0](x) + x})
return Seq(*[ return Seq(* [
Seq(Residual( Seq(Residual(
ActBn( ActBn(
nn.Conv2D( nn.Conv2D(
......
...@@ -34,14 +34,16 @@ __all__ = [ ...@@ -34,14 +34,16 @@ __all__ = [
@register @register
class BBoxPostProcess(nn.Layer): class BBoxPostProcess(nn.Layer):
__shared__ = ['num_classes'] __shared__ = ['num_classes', 'export_onnx']
__inject__ = ['decode', 'nms'] __inject__ = ['decode', 'nms']
def __init__(self, num_classes=80, decode=None, nms=None): def __init__(self, num_classes=80, decode=None, nms=None,
export_onnx=False):
super(BBoxPostProcess, self).__init__() super(BBoxPostProcess, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.decode = decode self.decode = decode
self.nms = nms self.nms = nms
self.export_onnx = export_onnx
def forward(self, head_out, rois, im_shape, scale_factor): def forward(self, head_out, rois, im_shape, scale_factor):
""" """
...@@ -52,6 +54,7 @@ class BBoxPostProcess(nn.Layer): ...@@ -52,6 +54,7 @@ class BBoxPostProcess(nn.Layer):
rois (tuple): roi and rois_num of rpn_head output. rois (tuple): roi and rois_num of rpn_head output.
im_shape (Tensor): The shape of the input image. im_shape (Tensor): The shape of the input image.
scale_factor (Tensor): The scale factor of the input image. scale_factor (Tensor): The scale factor of the input image.
export_onnx (bool): whether export model to onnx
Returns: Returns:
bbox_pred (Tensor): The output prediction with shape [N, 6], including bbox_pred (Tensor): The output prediction with shape [N, 6], including
labels, scores and bboxes. The size of bboxes are corresponding labels, scores and bboxes. The size of bboxes are corresponding
...@@ -62,9 +65,20 @@ class BBoxPostProcess(nn.Layer): ...@@ -62,9 +65,20 @@ class BBoxPostProcess(nn.Layer):
if self.nms is not None: if self.nms is not None:
bboxes, score = self.decode(head_out, rois, im_shape, scale_factor) bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
bbox_pred, bbox_num, _ = self.nms(bboxes, score, self.num_classes) bbox_pred, bbox_num, _ = self.nms(bboxes, score, self.num_classes)
else: else:
bbox_pred, bbox_num = self.decode(head_out, rois, im_shape, bbox_pred, bbox_num = self.decode(head_out, rois, im_shape,
scale_factor) scale_factor)
if self.export_onnx:
# add fake box after postprocess when exporting onnx
fake_bboxes = paddle.to_tensor(
np.array(
[[0., 0.0, 0.0, 0.0, 1.0, 1.0]], dtype='float32'))
bbox_pred = paddle.concat([bbox_pred, fake_bboxes])
bbox_num = bbox_num + 1
return bbox_pred, bbox_num return bbox_pred, bbox_num
def get_pred(self, bboxes, bbox_num, im_shape, scale_factor): def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
...@@ -86,45 +100,55 @@ class BBoxPostProcess(nn.Layer): ...@@ -86,45 +100,55 @@ class BBoxPostProcess(nn.Layer):
pred_result (Tensor): The final prediction results with shape [N, 6] pred_result (Tensor): The final prediction results with shape [N, 6]
including labels, scores and bboxes. including labels, scores and bboxes.
""" """
if not self.export_onnx:
bboxes_list = [] bboxes_list = []
bbox_num_list = [] bbox_num_list = []
id_start = 0 id_start = 0
fake_bboxes = paddle.to_tensor( fake_bboxes = paddle.to_tensor(
np.array( np.array(
[[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32')) [[0., 0.0, 0.0, 0.0, 1.0, 1.0]], dtype='float32'))
fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32')) fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
# add fake bbox when output is empty for each batch # add fake bbox when output is empty for each batch
for i in range(bbox_num.shape[0]): for i in range(bbox_num.shape[0]):
if bbox_num[i] == 0: if bbox_num[i] == 0:
bboxes_i = fake_bboxes bboxes_i = fake_bboxes
bbox_num_i = fake_bbox_num bbox_num_i = fake_bbox_num
else: else:
bboxes_i = bboxes[id_start:id_start + bbox_num[i], :] bboxes_i = bboxes[id_start:id_start + bbox_num[i], :]
bbox_num_i = bbox_num[i] bbox_num_i = bbox_num[i]
id_start += bbox_num[i] id_start += bbox_num[i]
bboxes_list.append(bboxes_i) bboxes_list.append(bboxes_i)
bbox_num_list.append(bbox_num_i) bbox_num_list.append(bbox_num_i)
bboxes = paddle.concat(bboxes_list) bboxes = paddle.concat(bboxes_list)
bbox_num = paddle.concat(bbox_num_list) bbox_num = paddle.concat(bbox_num_list)
origin_shape = paddle.floor(im_shape / scale_factor + 0.5) origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
origin_shape_list = [] if not self.export_onnx:
scale_factor_list = [] origin_shape_list = []
# scale_factor: scale_y, scale_x scale_factor_list = []
for i in range(bbox_num.shape[0]): # scale_factor: scale_y, scale_x
expand_shape = paddle.expand(origin_shape[i:i + 1, :], for i in range(bbox_num.shape[0]):
[bbox_num[i], 2]) expand_shape = paddle.expand(origin_shape[i:i + 1, :],
scale_y, scale_x = scale_factor[i][0], scale_factor[i][1] [bbox_num[i], 2])
scale = paddle.concat([scale_x, scale_y, scale_x, scale_y]) scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
expand_scale = paddle.expand(scale, [bbox_num[i], 4]) scale = paddle.concat([scale_x, scale_y, scale_x, scale_y])
origin_shape_list.append(expand_shape) expand_scale = paddle.expand(scale, [bbox_num[i], 4])
scale_factor_list.append(expand_scale) origin_shape_list.append(expand_shape)
scale_factor_list.append(expand_scale)
self.origin_shape_list = paddle.concat(origin_shape_list)
scale_factor_list = paddle.concat(scale_factor_list)
self.origin_shape_list = paddle.concat(origin_shape_list) else:
scale_factor_list = paddle.concat(scale_factor_list) # simplify the computation for bs=1 when exporting onnx
scale_y, scale_x = scale_factor[0][0], scale_factor[0][1]
scale = paddle.concat(
[scale_x, scale_y, scale_x, scale_y]).unsqueeze(0)
self.origin_shape_list = paddle.expand(origin_shape,
[bbox_num[0], 2])
scale_factor_list = paddle.expand(scale, [bbox_num[0], 4])
# bboxes: [N, 6], label, score, bbox # bboxes: [N, 6], label, score, bbox
pred_label = bboxes[:, 0:1] pred_label = bboxes[:, 0:1]
...@@ -170,19 +194,20 @@ class MaskPostProcess(object): ...@@ -170,19 +194,20 @@ class MaskPostProcess(object):
""" """
Paste the mask prediction to the original image. Paste the mask prediction to the original image.
""" """
x0_int, y0_int = 0, 0
x1_int, y1_int = im_w, im_h
x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1) x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
masks = paddle.unsqueeze(masks, [0, 1]) N = masks.shape[0]
img_y = paddle.arange(0, im_h, dtype='float32') + 0.5 img_y = paddle.arange(y0_int, y1_int) + 0.5
img_x = paddle.arange(0, im_w, dtype='float32') + 0.5 img_x = paddle.arange(x0_int, x1_int) + 0.5
img_y = (img_y - y0) / (y1 - y0) * 2 - 1 img_y = (img_y - y0) / (y1 - y0) * 2 - 1
img_x = (img_x - x0) / (x1 - x0) * 2 - 1 img_x = (img_x - x0) / (x1 - x0) * 2 - 1
img_x = paddle.unsqueeze(img_x, [1]) # img_x, img_y have shapes (N, w), (N, h)
img_y = paddle.unsqueeze(img_y, [2])
N = boxes.shape[0]
gx = paddle.expand(img_x, [N, img_y.shape[1], img_x.shape[2]]) gx = img_x[:, None, :].expand(
gy = paddle.expand(img_y, [N, img_y.shape[1], img_x.shape[2]]) [N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]])
gy = img_y[:, :, None].expand(
[N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]])
grid = paddle.stack([gx, gy], axis=3) grid = paddle.stack([gx, gy], axis=3)
img_masks = F.grid_sample(masks, grid, align_corners=False) img_masks = F.grid_sample(masks, grid, align_corners=False)
return img_masks[:, 0] return img_masks[:, 0]
...@@ -208,19 +233,13 @@ class MaskPostProcess(object): ...@@ -208,19 +233,13 @@ class MaskPostProcess(object):
# TODO: support bs > 1 and mask output dtype is bool # TODO: support bs > 1 and mask output dtype is bool
pred_result = paddle.zeros( pred_result = paddle.zeros(
[num_mask, origin_shape[0][0], origin_shape[0][1]], dtype='int32') [num_mask, origin_shape[0][0], origin_shape[0][1]], dtype='int32')
if bbox_num == 1 and bboxes[0][0] == -1:
return pred_result im_h, im_w = origin_shape[0][0], origin_shape[0][1]
pred_mask = self.paste_mask(mask_out[:, None, :, :], bboxes[:, 2:],
# TODO: optimize chunk paste im_h, im_w)
pred_result = [] pred_mask = pred_mask >= self.binary_thresh
for i in range(bboxes.shape[0]): pred_result = paddle.cast(pred_mask, 'int32')
im_h, im_w = origin_shape[i][0], origin_shape[i][1]
pred_mask = self.paste_mask(mask_out[i], bboxes[i:i + 1, 2:], im_h,
im_w)
pred_mask = pred_mask >= self.binary_thresh
pred_mask = paddle.cast(pred_mask, 'int32')
pred_result.append(pred_mask)
pred_result = paddle.concat(pred_result)
return pred_result return pred_result
......
...@@ -66,18 +66,21 @@ class RPNHead(nn.Layer): ...@@ -66,18 +66,21 @@ class RPNHead(nn.Layer):
in_channel (int): channel of input feature maps which can be in_channel (int): channel of input feature maps which can be
derived by from_config derived by from_config
""" """
__shared__ = ['export_onnx']
def __init__(self, def __init__(self,
anchor_generator=AnchorGenerator().__dict__, anchor_generator=AnchorGenerator().__dict__,
rpn_target_assign=RPNTargetAssign().__dict__, rpn_target_assign=RPNTargetAssign().__dict__,
train_proposal=ProposalGenerator(12000, 2000).__dict__, train_proposal=ProposalGenerator(12000, 2000).__dict__,
test_proposal=ProposalGenerator().__dict__, test_proposal=ProposalGenerator().__dict__,
in_channel=1024): in_channel=1024,
export_onnx=False):
super(RPNHead, self).__init__() super(RPNHead, self).__init__()
self.anchor_generator = anchor_generator self.anchor_generator = anchor_generator
self.rpn_target_assign = rpn_target_assign self.rpn_target_assign = rpn_target_assign
self.train_proposal = train_proposal self.train_proposal = train_proposal
self.test_proposal = test_proposal self.test_proposal = test_proposal
self.export_onnx = export_onnx
if isinstance(anchor_generator, dict): if isinstance(anchor_generator, dict):
self.anchor_generator = AnchorGenerator(**anchor_generator) self.anchor_generator = AnchorGenerator(**anchor_generator)
if isinstance(rpn_target_assign, dict): if isinstance(rpn_target_assign, dict):
...@@ -149,49 +152,90 @@ class RPNHead(nn.Layer): ...@@ -149,49 +152,90 @@ class RPNHead(nn.Layer):
# Collect multi-level proposals for each batch # Collect multi-level proposals for each batch
# Get 'topk' of them as final output # Get 'topk' of them as final output
bs_rois_collect = []
bs_rois_num_collect = []
batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
# Generate proposals for each level and each batch. if self.export_onnx:
# Discard batch-computing to avoid sorting bbox cross different batches. # bs = 1 when exporting onnx
for i in range(batch_size): onnx_rpn_rois_list = []
rpn_rois_list = [] onnx_rpn_prob_list = []
rpn_prob_list = [] onnx_rpn_rois_num_list = []
rpn_rois_num_list = []
for rpn_score, rpn_delta, anchor in zip(scores, bbox_deltas, for rpn_score, rpn_delta, anchor in zip(scores, bbox_deltas,
anchors): anchors):
rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n = prop_gen( onnx_rpn_rois, onnx_rpn_rois_prob, onnx_rpn_rois_num, onnx_post_nms_top_n = prop_gen(
scores=rpn_score[i:i + 1], scores=rpn_score[0:1],
bbox_deltas=rpn_delta[i:i + 1], bbox_deltas=rpn_delta[0:1],
anchors=anchor, anchors=anchor,
im_shape=im_shape[i:i + 1]) im_shape=im_shape[0:1])
if rpn_rois.shape[0] > 0: onnx_rpn_rois_list.append(onnx_rpn_rois)
onnx_rpn_prob_list.append(onnx_rpn_rois_prob)
onnx_rpn_rois_num_list.append(onnx_rpn_rois_num)
onnx_rpn_rois = paddle.concat(onnx_rpn_rois_list)
onnx_rpn_prob = paddle.concat(onnx_rpn_prob_list).flatten()
onnx_top_n = paddle.to_tensor(onnx_post_nms_top_n).cast('int32')
onnx_num_rois = paddle.shape(onnx_rpn_prob)[0].cast('int32')
k = paddle.minimum(onnx_top_n, onnx_num_rois)
onnx_topk_prob, onnx_topk_inds = paddle.topk(onnx_rpn_prob, k)
onnx_topk_rois = paddle.gather(onnx_rpn_rois, onnx_topk_inds)
# TODO(wangguanzhong): Now bs_rois_collect in export_onnx is moved outside conditional branch
# due to problems in dy2static of paddle. Will fix it when updating paddle framework.
# bs_rois_collect = [onnx_topk_rois]
# bs_rois_num_collect = paddle.shape(onnx_topk_rois)[0]
else:
bs_rois_collect = []
bs_rois_num_collect = []
batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
# Generate proposals for each level and each batch.
# Discard batch-computing to avoid sorting bbox cross different batches.
for i in range(batch_size):
rpn_rois_list = []
rpn_prob_list = []
rpn_rois_num_list = []
for rpn_score, rpn_delta, anchor in zip(scores, bbox_deltas,
anchors):
rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n = prop_gen(
scores=rpn_score[i:i + 1],
bbox_deltas=rpn_delta[i:i + 1],
anchors=anchor,
im_shape=im_shape[i:i + 1])
rpn_rois_list.append(rpn_rois) rpn_rois_list.append(rpn_rois)
rpn_prob_list.append(rpn_rois_prob) rpn_prob_list.append(rpn_rois_prob)
rpn_rois_num_list.append(rpn_rois_num) rpn_rois_num_list.append(rpn_rois_num)
if len(scores) > 1: if len(scores) > 1:
rpn_rois = paddle.concat(rpn_rois_list) rpn_rois = paddle.concat(rpn_rois_list)
rpn_prob = paddle.concat(rpn_prob_list).flatten() rpn_prob = paddle.concat(rpn_prob_list).flatten()
if rpn_prob.shape[0] > post_nms_top_n: num_rois = paddle.shape(rpn_prob)[0].cast('int32')
topk_prob, topk_inds = paddle.topk(rpn_prob, post_nms_top_n) if num_rois > post_nms_top_n:
topk_rois = paddle.gather(rpn_rois, topk_inds) topk_prob, topk_inds = paddle.topk(rpn_prob,
post_nms_top_n)
topk_rois = paddle.gather(rpn_rois, topk_inds)
else:
topk_rois = rpn_rois
topk_prob = rpn_prob
else: else:
topk_rois = rpn_rois topk_rois = rpn_rois_list[0]
topk_prob = rpn_prob topk_prob = rpn_prob_list[0].flatten()
else:
topk_rois = rpn_rois_list[0]
topk_prob = rpn_prob_list[0].flatten()
bs_rois_collect.append(topk_rois) bs_rois_collect.append(topk_rois)
bs_rois_num_collect.append(paddle.shape(topk_rois)[0]) bs_rois_num_collect.append(paddle.shape(topk_rois)[0])
bs_rois_num_collect = paddle.concat(bs_rois_num_collect) bs_rois_num_collect = paddle.concat(bs_rois_num_collect)
if self.export_onnx:
output_rois = [onnx_topk_rois]
output_rois_num = paddle.shape(onnx_topk_rois)[0]
else:
output_rois = bs_rois_collect
output_rois_num = bs_rois_num_collect
return bs_rois_collect, bs_rois_num_collect return output_rois, output_rois_num
def get_loss(self, pred_scores, pred_deltas, anchors, inputs): def get_loss(self, pred_scores, pred_deltas, anchors, inputs):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册