提交 c69584a6 编写于 作者: C chenluyan

Fix the code style problem. test=develop

上级 99b47985
...@@ -23,7 +23,7 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -23,7 +23,7 @@ _logger = get_logger(__name__, level=logging.INFO)
PRUNE_WORKER = Registry('prune_worker') PRUNE_WORKER = Registry('prune_worker')
SKIP_OPS = ["conditional_block","shape","reshape2"] SKIP_OPS = ["conditional_block", "shape", "reshape2"]
class PruneWorker(object): class PruneWorker(object):
...@@ -643,6 +643,7 @@ class roi_align(PruneWorker): ...@@ -643,6 +643,7 @@ class roi_align(PruneWorker):
for op in next_ops: for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx) self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class lod_reset(PruneWorker): class lod_reset(PruneWorker):
def __init__(self, op, pruned_params, visited={}): def __init__(self, op, pruned_params, visited={}):
...@@ -661,10 +662,12 @@ class lod_reset(PruneWorker): ...@@ -661,10 +662,12 @@ class lod_reset(PruneWorker):
for op in next_ops: for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx) self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class gather(PruneWorker): class gather(PruneWorker):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited):
super(gather, self).__init__(op, pruned_params, visited) super(gather, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs("Out"): if var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0] in_var = self.op.inputs("X")[0]
...@@ -676,4 +679,4 @@ class gather(PruneWorker): ...@@ -676,4 +679,4 @@ class gather(PruneWorker):
self._visit(out_var, pruned_axis) self._visit(out_var, pruned_axis)
next_ops = out_var.outputs() next_ops = out_var.outputs()
for op in next_ops: for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx) self._prune_op(op, out_var, pruned_axis, pruned_idx)
\ No newline at end of file
...@@ -42,21 +42,22 @@ class TestPrune(unittest.TestCase): ...@@ -42,21 +42,22 @@ class TestPrune(unittest.TestCase):
conv3 = conv_bn_layer(sum1, 8, 3, "conv3") conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4") conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1 sum2 = conv4 + sum1
#test roi_align # test roi_align
rois = fluid.data( rois = fluid.data(name='rois', shape=[None, 4], dtype='float32')
name='rois', shape=[None, 4], dtype='float32') align_out = fluid.layers.roi_align(
align_out = fluid.layers.roi_align(input=sum2, input=sum2,
rois=rois, rois=rois,
pooled_height=7, pooled_height=7,
pooled_width=7, pooled_width=7,
spatial_scale=0.5, spatial_scale=0.5,
sampling_ratio=-1) sampling_ratio=-1)
conv5 = conv_bn_layer(align_out, 8, 3, "conv5") conv5 = conv_bn_layer(align_out, 8, 3, "conv5")
#test gather # test gather
index = fluid.layers.data(name='index', shape=[-1, 1], dtype='int32') index = fluid.layers.data(
name='index', shape=[-1, 1], dtype='int32')
gather_out = fluid.layers.gather(sum2, index) gather_out = fluid.layers.gather(sum2, index)
conv6 = conv_bn_layer(gather_out, 8, 3, "conv6") conv6 = conv_bn_layer(gather_out, 8, 3, "conv6")
#test lod_reset # test lod_reset
y = fluid.layers.data(name='y', shape=[6], lod_level=2) y = fluid.layers.data(name='y', shape=[6], lod_level=2)
lodset_out = fluid.layers.lod_reset(x=sum2, y=y) lodset_out = fluid.layers.lod_reset(x=sum2, y=y)
conv7 = conv_bn_layer(lodset_out, 8, 3, "conv7") conv7 = conv_bn_layer(lodset_out, 8, 3, "conv7")
...@@ -79,4 +80,4 @@ class TestPrune(unittest.TestCase): ...@@ -79,4 +80,4 @@ class TestPrune(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册