未验证 提交 3a48c562 编写于 作者: W wangguanzhong 提交者: GitHub

fix cascade series models (#1588)

上级 a30b8c3f
......@@ -135,6 +135,7 @@ class CascadeMaskRCNN(object):
proposals = None
bbox_pred = None
max_overlap = None
for i in range(3):
if i > 0:
refined_bbox = self._decode_box(
......@@ -146,10 +147,14 @@ class CascadeMaskRCNN(object):
if mode == 'train':
outs = self.bbox_assigner(
input_rois=refined_bbox, feed_vars=feed_vars, curr_stage=i)
input_rois=refined_bbox,
feed_vars=feed_vars,
curr_stage=i,
max_overlap=max_overlap)
proposals = outs[0]
rcnn_target_list.append(outs)
max_overlap = outs[-1]
rcnn_target_list.append(outs[:-1])
else:
proposals = refined_bbox
proposal_list.append(proposals)
......
......@@ -128,6 +128,7 @@ class CascadeRCNN(object):
proposals = None
bbox_pred = None
max_overlap = None
for i in range(3):
if i > 0:
refined_bbox = self._decode_box(
......@@ -139,10 +140,14 @@ class CascadeRCNN(object):
if mode == 'train':
outs = self.bbox_assigner(
input_rois=refined_bbox, feed_vars=feed_vars, curr_stage=i)
input_rois=refined_bbox,
feed_vars=feed_vars,
curr_stage=i,
max_overlap=max_overlap)
proposals = outs[0]
rcnn_target_list.append(outs)
max_overlap = outs[-1]
rcnn_target_list.append(outs[:-1])
else:
proposals = refined_bbox
proposal_list.append(proposals)
......
......@@ -117,6 +117,7 @@ class CascadeRCNNClsAware(object):
self.cascade_decoded_box = []
self.cascade_cls_prob = []
max_overlap = None
for stage in range(3):
if stage > 0:
......@@ -126,9 +127,13 @@ class CascadeRCNNClsAware(object):
if mode == "train":
self.cascade_var_v[stage].stop_gradient = True
outs = self.bbox_assigner(
input_rois=pool_rois, feed_vars=feed_vars, curr_stage=stage)
input_rois=pool_rois,
feed_vars=feed_vars,
curr_stage=stage,
max_overlap=max_overlap)
pool_rois = outs[0]
rcnn_target_list.append(outs)
max_overlap = outs[-1]
rcnn_target_list.append(outs[:-1])
# extract roi features
roi_feat = self.roi_extractor(body_feats, pool_rois, spatial_scale)
......
......@@ -158,13 +158,18 @@ class HybridTaskCascade(object):
bbox_pred = None
outs = None
refined_bbox = rpn_rois
max_overlap = None
for i in range(self.num_stage):
# BBox Branch
if mode == 'train':
outs = self.bbox_assigner(
input_rois=refined_bbox, feed_vars=feed_vars, curr_stage=i)
input_rois=refined_bbox,
feed_vars=feed_vars,
curr_stage=i,
max_overlap=max_overlap)
proposals = outs[0]
rcnn_target_list.append(outs)
max_overlap = outs[-1]
rcnn_target_list.append(outs[:-1])
else:
proposals = refined_bbox
proposal_list.append(proposals)
......
......@@ -53,7 +53,7 @@ class CascadeBBoxAssigner(object):
self.use_random = shuffle_before_sample
self.class_aware = class_aware
def __call__(self, input_rois, feed_vars, curr_stage):
def __call__(self, input_rois, feed_vars, curr_stage, max_overlap=None):
curr_bbox_reg_w = [
1. / self.bbox_reg_weights[curr_stage],
......@@ -76,5 +76,7 @@ class CascadeBBoxAssigner(object):
class_nums=self.class_nums if self.class_aware else 2,
is_cls_agnostic=not self.class_aware,
is_cascade_rcnn=True
if curr_stage > 0 and not self.class_aware else False)
if curr_stage > 0 and not self.class_aware else False,
max_overlap=max_overlap,
return_max_overlap=True)
return outs
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册