未验证 提交 8ce83816 编写于 作者: Y Yang Zhang 提交者: GitHub

Remove leftover `Switch` ops (#240)

* Replace `switch` ops with `cond`

* Use func instead of lambda
上级 582d3c25
......@@ -321,14 +321,14 @@ class CascadeMaskRCNN(object):
dtype='float32',
persistable=False,
name=mask_name)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(cond):
def noop():
fluid.layers.assign(input=bbox_pred, output=mask_pred)
with switch.default():
def process_boxes():
bbox = fluid.layers.slice(bbox_pred, [1], starts=[2], ends=[6])
im_scale = fluid.layers.slice(
im_info, [1], starts=[2], ends=[3])
im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3])
im_scale = fluid.layers.sequence_expand(im_scale, bbox)
mask_rois = bbox * im_scale
......@@ -341,6 +341,8 @@ class CascadeMaskRCNN(object):
mask_out = self.mask_head.get_prediction(mask_feat, bbox)
fluid.layers.assign(input=mask_out, output=mask_pred)
fluid.layers.cond(cond, noop, process_boxes)
return mask_pred, bbox_pred
def _input_check(self, require_fields, feed_vars):
......
......@@ -240,14 +240,14 @@ class MaskRCNN(object):
dtype='float32',
persistable=False,
name=mask_name)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(cond):
def noop():
fluid.layers.assign(input=bbox_pred, output=mask_pred)
with switch.default():
def process_boxes():
bbox = fluid.layers.slice(bbox_pred, [1], starts=[2], ends=[6])
im_scale = fluid.layers.slice(
im_info, [1], starts=[2], ends=[3])
im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3])
im_scale = fluid.layers.sequence_expand(im_scale, bbox)
mask_rois = bbox * im_scale
......@@ -261,6 +261,8 @@ class MaskRCNN(object):
mask_out = self.mask_head.get_prediction(mask_feat, bbox)
fluid.layers.assign(input=mask_out, output=mask_pred)
fluid.layers.cond(cond, noop, process_boxes)
return mask_pred, bbox_pred
def _input_check(self, require_fields, feed_vars):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册