未验证 提交 75e3def5 编写于 作者: C cnn 提交者: GitHub

[dev] rcnn support bs>1 (#3174)

* rcnn bs>1

* delete redundant comments
上级 fa6c5a11
......@@ -342,7 +342,11 @@ class RCNNBox(object):
origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
scale_list = []
origin_shape_list = []
for idx, roi_per_im in enumerate(roi):
batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
# bbox_pred.shape: [N, C*4]
for idx in range(batch_size):
roi_per_im = roi[idx]
rois_num_per_im = rois_num[idx]
expand_im_shape = paddle.expand(im_shape[idx, :],
[rois_num_per_im, 2])
......
......@@ -35,7 +35,7 @@ __all__ = [
@register
class BBoxPostProcess(object):
class BBoxPostProcess(nn.Layer):
__shared__ = ['num_classes']
__inject__ = ['decode', 'nms']
......@@ -44,8 +44,14 @@ class BBoxPostProcess(object):
self.num_classes = num_classes
self.decode = decode
self.nms = nms
self.fake_bboxes = paddle.to_tensor(
np.array(
[[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
dtype='float32'))
self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
def __call__(self, head_out, rois, im_shape, scale_factor):
def forward(self, head_out, rois, im_shape, scale_factor):
"""
Decode the bbox and do NMS if needed.
......@@ -90,10 +96,8 @@ class BBoxPostProcess(object):
"""
if bboxes.shape[0] == 0:
bboxes = paddle.to_tensor(
np.array(
[[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
bboxes = self.fake_bboxes
bbox_num = self.fake_bbox_num
origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
......
......@@ -133,21 +133,14 @@ class RPNHead(nn.Layer):
anchors = self.anchor_generator(rpn_feats)
# TODO: Fix batch_size > 1 when testing.
if self.training:
batch_size = inputs['im_shape'].shape[0]
else:
batch_size = 1
rois, rois_num = self._gen_proposal(scores, deltas, anchors, inputs,
batch_size)
rois, rois_num = self._gen_proposal(scores, deltas, anchors, inputs)
if self.training:
loss = self.get_loss(scores, deltas, anchors, inputs)
return rois, rois_num, loss
else:
return rois, rois_num, None
def _gen_proposal(self, scores, bbox_deltas, anchors, inputs, batch_size):
def _gen_proposal(self, scores, bbox_deltas, anchors, inputs):
"""
scores (list[Tensor]): Multi-level scores prediction
bbox_deltas (list[Tensor]): Multi-level deltas prediction
......@@ -161,6 +154,7 @@ class RPNHead(nn.Layer):
# Get 'topk' of them as final output
bs_rois_collect = []
bs_rois_num_collect = []
batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
# Generate proposals for each level and each batch.
# Discard batch-computing to avoid sorting bbox cross different batches.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册