未验证 提交 8ce83816 编写于 作者: Y Yang Zhang 提交者: GitHub

Remove leftover `Switch` ops (#240)

* Replace `switch` ops with `cond`

* Use func instead of lambda
上级 582d3c25
......@@ -321,26 +321,28 @@ class CascadeMaskRCNN(object):
dtype='float32',
persistable=False,
name=mask_name)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(cond):
fluid.layers.assign(input=bbox_pred, output=mask_pred)
with switch.default():
bbox = fluid.layers.slice(bbox_pred, [1], starts=[2], ends=[6])
im_scale = fluid.layers.slice(
im_info, [1], starts=[2], ends=[3])
im_scale = fluid.layers.sequence_expand(im_scale, bbox)
def noop():
fluid.layers.assign(input=bbox_pred, output=mask_pred)
mask_rois = bbox * im_scale
if self.fpn is None:
mask_feat = self.roi_extractor(last_feat, mask_rois)
mask_feat = self.bbox_head.get_head_feat(mask_feat)
else:
mask_feat = self.roi_extractor(
body_feats, mask_rois, spatial_scale, is_mask=True)
mask_out = self.mask_head.get_prediction(mask_feat, bbox)
fluid.layers.assign(input=mask_out, output=mask_pred)
def process_boxes():
bbox = fluid.layers.slice(bbox_pred, [1], starts=[2], ends=[6])
im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3])
im_scale = fluid.layers.sequence_expand(im_scale, bbox)
mask_rois = bbox * im_scale
if self.fpn is None:
mask_feat = self.roi_extractor(last_feat, mask_rois)
mask_feat = self.bbox_head.get_head_feat(mask_feat)
else:
mask_feat = self.roi_extractor(
body_feats, mask_rois, spatial_scale, is_mask=True)
mask_out = self.mask_head.get_prediction(mask_feat, bbox)
fluid.layers.assign(input=mask_out, output=mask_pred)
fluid.layers.cond(cond, noop, process_boxes)
return mask_pred, bbox_pred
def _input_check(self, require_fields, feed_vars):
......
......@@ -240,27 +240,29 @@ class MaskRCNN(object):
dtype='float32',
persistable=False,
name=mask_name)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(cond):
fluid.layers.assign(input=bbox_pred, output=mask_pred)
with switch.default():
bbox = fluid.layers.slice(bbox_pred, [1], starts=[2], ends=[6])
im_scale = fluid.layers.slice(
im_info, [1], starts=[2], ends=[3])
im_scale = fluid.layers.sequence_expand(im_scale, bbox)
def noop():
fluid.layers.assign(input=bbox_pred, output=mask_pred)
mask_rois = bbox * im_scale
if self.fpn is None:
last_feat = body_feats[list(body_feats.keys())[-1]]
mask_feat = self.roi_extractor(last_feat, mask_rois)
mask_feat = self.bbox_head.get_head_feat(mask_feat)
else:
mask_feat = self.roi_extractor(
body_feats, mask_rois, spatial_scale, is_mask=True)
def process_boxes():
bbox = fluid.layers.slice(bbox_pred, [1], starts=[2], ends=[6])
im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3])
im_scale = fluid.layers.sequence_expand(im_scale, bbox)
mask_rois = bbox * im_scale
if self.fpn is None:
last_feat = body_feats[list(body_feats.keys())[-1]]
mask_feat = self.roi_extractor(last_feat, mask_rois)
mask_feat = self.bbox_head.get_head_feat(mask_feat)
else:
mask_feat = self.roi_extractor(
body_feats, mask_rois, spatial_scale, is_mask=True)
mask_out = self.mask_head.get_prediction(mask_feat, bbox)
fluid.layers.assign(input=mask_out, output=mask_pred)
mask_out = self.mask_head.get_prediction(mask_feat, bbox)
fluid.layers.assign(input=mask_out, output=mask_pred)
fluid.layers.cond(cond, noop, process_boxes)
return mask_pred, bbox_pred
def _input_check(self, require_fields, feed_vars):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册