未验证 提交 75e3def5 编写于 作者: C cnn 提交者: GitHub

[dev] rcnn support bs>1 (#3174)

* rcnn bs>1

* delete redundant comments
上级 fa6c5a11
...@@ -342,7 +342,11 @@ class RCNNBox(object): ...@@ -342,7 +342,11 @@ class RCNNBox(object):
origin_shape = paddle.floor(im_shape / scale_factor + 0.5) origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
scale_list = [] scale_list = []
origin_shape_list = [] origin_shape_list = []
for idx, roi_per_im in enumerate(roi):
batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
# bbox_pred.shape: [N, C*4]
for idx in range(batch_size):
roi_per_im = roi[idx]
rois_num_per_im = rois_num[idx] rois_num_per_im = rois_num[idx]
expand_im_shape = paddle.expand(im_shape[idx, :], expand_im_shape = paddle.expand(im_shape[idx, :],
[rois_num_per_im, 2]) [rois_num_per_im, 2])
......
...@@ -35,7 +35,7 @@ __all__ = [ ...@@ -35,7 +35,7 @@ __all__ = [
@register @register
class BBoxPostProcess(object): class BBoxPostProcess(nn.Layer):
__shared__ = ['num_classes'] __shared__ = ['num_classes']
__inject__ = ['decode', 'nms'] __inject__ = ['decode', 'nms']
...@@ -44,8 +44,14 @@ class BBoxPostProcess(object): ...@@ -44,8 +44,14 @@ class BBoxPostProcess(object):
self.num_classes = num_classes self.num_classes = num_classes
self.decode = decode self.decode = decode
self.nms = nms self.nms = nms
self.fake_bboxes = paddle.to_tensor(
np.array(
[[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
dtype='float32'))
self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
def __call__(self, head_out, rois, im_shape, scale_factor): def forward(self, head_out, rois, im_shape, scale_factor):
""" """
Decode the bbox and do NMS if needed. Decode the bbox and do NMS if needed.
...@@ -90,10 +96,8 @@ class BBoxPostProcess(object): ...@@ -90,10 +96,8 @@ class BBoxPostProcess(object):
""" """
if bboxes.shape[0] == 0: if bboxes.shape[0] == 0:
bboxes = paddle.to_tensor( bboxes = self.fake_bboxes
np.array( bbox_num = self.fake_bbox_num
[[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
origin_shape = paddle.floor(im_shape / scale_factor + 0.5) origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
......
...@@ -133,21 +133,14 @@ class RPNHead(nn.Layer): ...@@ -133,21 +133,14 @@ class RPNHead(nn.Layer):
anchors = self.anchor_generator(rpn_feats) anchors = self.anchor_generator(rpn_feats)
# TODO: Fix batch_size > 1 when testing. rois, rois_num = self._gen_proposal(scores, deltas, anchors, inputs)
if self.training:
batch_size = inputs['im_shape'].shape[0]
else:
batch_size = 1
rois, rois_num = self._gen_proposal(scores, deltas, anchors, inputs,
batch_size)
if self.training: if self.training:
loss = self.get_loss(scores, deltas, anchors, inputs) loss = self.get_loss(scores, deltas, anchors, inputs)
return rois, rois_num, loss return rois, rois_num, loss
else: else:
return rois, rois_num, None return rois, rois_num, None
def _gen_proposal(self, scores, bbox_deltas, anchors, inputs, batch_size): def _gen_proposal(self, scores, bbox_deltas, anchors, inputs):
""" """
scores (list[Tensor]): Multi-level scores prediction scores (list[Tensor]): Multi-level scores prediction
bbox_deltas (list[Tensor]): Multi-level deltas prediction bbox_deltas (list[Tensor]): Multi-level deltas prediction
...@@ -161,6 +154,7 @@ class RPNHead(nn.Layer): ...@@ -161,6 +154,7 @@ class RPNHead(nn.Layer):
# Get 'topk' of them as final output # Get 'topk' of them as final output
bs_rois_collect = [] bs_rois_collect = []
bs_rois_num_collect = [] bs_rois_num_collect = []
batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
# Generate proposals for each level and each batch. # Generate proposals for each level and each batch.
# Discard batch-computing to avoid sorting bbox cross different batches. # Discard batch-computing to avoid sorting bbox cross different batches.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册