未验证 提交 3a48c562 编写于 作者: W wangguanzhong 提交者: GitHub

fix cascade series models (#1588)

上级 a30b8c3f
...@@ -135,6 +135,7 @@ class CascadeMaskRCNN(object): ...@@ -135,6 +135,7 @@ class CascadeMaskRCNN(object):
proposals = None proposals = None
bbox_pred = None bbox_pred = None
max_overlap = None
for i in range(3): for i in range(3):
if i > 0: if i > 0:
refined_bbox = self._decode_box( refined_bbox = self._decode_box(
...@@ -146,10 +147,14 @@ class CascadeMaskRCNN(object): ...@@ -146,10 +147,14 @@ class CascadeMaskRCNN(object):
if mode == 'train': if mode == 'train':
outs = self.bbox_assigner( outs = self.bbox_assigner(
input_rois=refined_bbox, feed_vars=feed_vars, curr_stage=i) input_rois=refined_bbox,
feed_vars=feed_vars,
curr_stage=i,
max_overlap=max_overlap)
proposals = outs[0] proposals = outs[0]
rcnn_target_list.append(outs) max_overlap = outs[-1]
rcnn_target_list.append(outs[:-1])
else: else:
proposals = refined_bbox proposals = refined_bbox
proposal_list.append(proposals) proposal_list.append(proposals)
......
...@@ -128,6 +128,7 @@ class CascadeRCNN(object): ...@@ -128,6 +128,7 @@ class CascadeRCNN(object):
proposals = None proposals = None
bbox_pred = None bbox_pred = None
max_overlap = None
for i in range(3): for i in range(3):
if i > 0: if i > 0:
refined_bbox = self._decode_box( refined_bbox = self._decode_box(
...@@ -139,10 +140,14 @@ class CascadeRCNN(object): ...@@ -139,10 +140,14 @@ class CascadeRCNN(object):
if mode == 'train': if mode == 'train':
outs = self.bbox_assigner( outs = self.bbox_assigner(
input_rois=refined_bbox, feed_vars=feed_vars, curr_stage=i) input_rois=refined_bbox,
feed_vars=feed_vars,
curr_stage=i,
max_overlap=max_overlap)
proposals = outs[0] proposals = outs[0]
rcnn_target_list.append(outs) max_overlap = outs[-1]
rcnn_target_list.append(outs[:-1])
else: else:
proposals = refined_bbox proposals = refined_bbox
proposal_list.append(proposals) proposal_list.append(proposals)
......
...@@ -117,6 +117,7 @@ class CascadeRCNNClsAware(object): ...@@ -117,6 +117,7 @@ class CascadeRCNNClsAware(object):
self.cascade_decoded_box = [] self.cascade_decoded_box = []
self.cascade_cls_prob = [] self.cascade_cls_prob = []
max_overlap = None
for stage in range(3): for stage in range(3):
if stage > 0: if stage > 0:
...@@ -126,9 +127,13 @@ class CascadeRCNNClsAware(object): ...@@ -126,9 +127,13 @@ class CascadeRCNNClsAware(object):
if mode == "train": if mode == "train":
self.cascade_var_v[stage].stop_gradient = True self.cascade_var_v[stage].stop_gradient = True
outs = self.bbox_assigner( outs = self.bbox_assigner(
input_rois=pool_rois, feed_vars=feed_vars, curr_stage=stage) input_rois=pool_rois,
feed_vars=feed_vars,
curr_stage=stage,
max_overlap=max_overlap)
pool_rois = outs[0] pool_rois = outs[0]
rcnn_target_list.append(outs) max_overlap = outs[-1]
rcnn_target_list.append(outs[:-1])
# extract roi features # extract roi features
roi_feat = self.roi_extractor(body_feats, pool_rois, spatial_scale) roi_feat = self.roi_extractor(body_feats, pool_rois, spatial_scale)
......
...@@ -158,13 +158,18 @@ class HybridTaskCascade(object): ...@@ -158,13 +158,18 @@ class HybridTaskCascade(object):
bbox_pred = None bbox_pred = None
outs = None outs = None
refined_bbox = rpn_rois refined_bbox = rpn_rois
max_overlap = None
for i in range(self.num_stage): for i in range(self.num_stage):
# BBox Branch # BBox Branch
if mode == 'train': if mode == 'train':
outs = self.bbox_assigner( outs = self.bbox_assigner(
input_rois=refined_bbox, feed_vars=feed_vars, curr_stage=i) input_rois=refined_bbox,
feed_vars=feed_vars,
curr_stage=i,
max_overlap=max_overlap)
proposals = outs[0] proposals = outs[0]
rcnn_target_list.append(outs) max_overlap = outs[-1]
rcnn_target_list.append(outs[:-1])
else: else:
proposals = refined_bbox proposals = refined_bbox
proposal_list.append(proposals) proposal_list.append(proposals)
......
...@@ -53,7 +53,7 @@ class CascadeBBoxAssigner(object): ...@@ -53,7 +53,7 @@ class CascadeBBoxAssigner(object):
self.use_random = shuffle_before_sample self.use_random = shuffle_before_sample
self.class_aware = class_aware self.class_aware = class_aware
def __call__(self, input_rois, feed_vars, curr_stage): def __call__(self, input_rois, feed_vars, curr_stage, max_overlap=None):
curr_bbox_reg_w = [ curr_bbox_reg_w = [
1. / self.bbox_reg_weights[curr_stage], 1. / self.bbox_reg_weights[curr_stage],
...@@ -76,5 +76,7 @@ class CascadeBBoxAssigner(object): ...@@ -76,5 +76,7 @@ class CascadeBBoxAssigner(object):
class_nums=self.class_nums if self.class_aware else 2, class_nums=self.class_nums if self.class_aware else 2,
is_cls_agnostic=not self.class_aware, is_cls_agnostic=not self.class_aware,
is_cascade_rcnn=True is_cascade_rcnn=True
if curr_stage > 0 and not self.class_aware else False) if curr_stage > 0 and not self.class_aware else False,
max_overlap=max_overlap,
return_max_overlap=True)
return outs return outs
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册