提交 c69584a6 编写于 作者: C chenluyan

Fix the code style problem. test=develop

上级 99b47985
......@@ -23,7 +23,7 @@ _logger = get_logger(__name__, level=logging.INFO)
PRUNE_WORKER = Registry('prune_worker')
SKIP_OPS = ["conditional_block","shape","reshape2"]
SKIP_OPS = ["conditional_block", "shape", "reshape2"]
class PruneWorker(object):
......@@ -643,6 +643,7 @@ class roi_align(PruneWorker):
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class lod_reset(PruneWorker):
def __init__(self, op, pruned_params, visited={}):
......@@ -661,10 +662,12 @@ class lod_reset(PruneWorker):
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class gather(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(gather, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0]
......@@ -676,4 +679,4 @@ class gather(PruneWorker):
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
\ No newline at end of file
self._prune_op(op, out_var, pruned_axis, pruned_idx)
......@@ -42,21 +42,22 @@ class TestPrune(unittest.TestCase):
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
#test roi_align
rois = fluid.data(
name='rois', shape=[None, 4], dtype='float32')
align_out = fluid.layers.roi_align(input=sum2,
rois=rois,
pooled_height=7,
pooled_width=7,
spatial_scale=0.5,
sampling_ratio=-1)
# test roi_align
rois = fluid.data(name='rois', shape=[None, 4], dtype='float32')
align_out = fluid.layers.roi_align(
input=sum2,
rois=rois,
pooled_height=7,
pooled_width=7,
spatial_scale=0.5,
sampling_ratio=-1)
conv5 = conv_bn_layer(align_out, 8, 3, "conv5")
#test gather
index = fluid.layers.data(name='index', shape=[-1, 1], dtype='int32')
# test gather
index = fluid.layers.data(
name='index', shape=[-1, 1], dtype='int32')
gather_out = fluid.layers.gather(sum2, index)
conv6 = conv_bn_layer(gather_out, 8, 3, "conv6")
#test lod_reset
# test lod_reset
y = fluid.layers.data(name='y', shape=[6], lod_level=2)
lodset_out = fluid.layers.lod_reset(x=sum2, y=y)
conv7 = conv_bn_layer(lodset_out, 8, 3, "conv7")
......@@ -79,4 +80,4 @@ class TestPrune(unittest.TestCase):
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册