diff --git a/paddleslim/prune/prune_walker.py b/paddleslim/prune/prune_walker.py index 7d2cafea408f384cf7ae627e254760fba0317c54..a20cef9b01924bd866675b164013bcdd81ec333e 100644 --- a/paddleslim/prune/prune_walker.py +++ b/paddleslim/prune/prune_walker.py @@ -23,7 +23,7 @@ _logger = get_logger(__name__, level=logging.INFO) PRUNE_WORKER = Registry('prune_worker') -SKIP_OPS = ["conditional_block","shape","reshape2"] +SKIP_OPS = ["conditional_block", "shape", "reshape2"] class PruneWorker(object): @@ -643,6 +643,7 @@ class roi_align(PruneWorker): for op in next_ops: self._prune_op(op, out_var, pruned_axis, pruned_idx) + @PRUNE_WORKER.register class lod_reset(PruneWorker): def __init__(self, op, pruned_params, visited={}): @@ -661,10 +662,12 @@ class lod_reset(PruneWorker): for op in next_ops: self._prune_op(op, out_var, pruned_axis, pruned_idx) + @PRUNE_WORKER.register class gather(PruneWorker): def __init__(self, op, pruned_params, visited): super(gather, self).__init__(op, pruned_params, visited) + def _prune(self, var, pruned_axis, pruned_idx): if var in self.op.outputs("Out"): in_var = self.op.inputs("X")[0] @@ -676,4 +679,4 @@ class gather(PruneWorker): self._visit(out_var, pruned_axis) next_ops = out_var.outputs() for op in next_ops: - self._prune_op(op, out_var, pruned_axis, pruned_idx) \ No newline at end of file + self._prune_op(op, out_var, pruned_axis, pruned_idx) diff --git a/tests/test_add_ops.py b/tests/test_add_ops.py index cb671df0ada28272b6b6104e82541fee3dd00d97..d5314ec0881709cacceb4d95cf6fdca65e08721b 100644 --- a/tests/test_add_ops.py +++ b/tests/test_add_ops.py @@ -42,21 +42,22 @@ class TestPrune(unittest.TestCase): conv3 = conv_bn_layer(sum1, 8, 3, "conv3") conv4 = conv_bn_layer(conv3, 8, 3, "conv4") sum2 = conv4 + sum1 - #test roi_align - rois = fluid.data( - name='rois', shape=[None, 4], dtype='float32') - align_out = fluid.layers.roi_align(input=sum2, - rois=rois, - pooled_height=7, - pooled_width=7, - spatial_scale=0.5, - sampling_ratio=-1) + # test roi_align + rois = fluid.data(name='rois', shape=[None, 4], dtype='float32') + align_out = fluid.layers.roi_align( + input=sum2, + rois=rois, + pooled_height=7, + pooled_width=7, + spatial_scale=0.5, + sampling_ratio=-1) conv5 = conv_bn_layer(align_out, 8, 3, "conv5") - #test gather - index = fluid.layers.data(name='index', shape=[-1, 1], dtype='int32') + # test gather + index = fluid.layers.data( + name='index', shape=[-1, 1], dtype='int32') gather_out = fluid.layers.gather(sum2, index) conv6 = conv_bn_layer(gather_out, 8, 3, "conv6") - #test lod_reset + # test lod_reset y = fluid.layers.data(name='y', shape=[6], lod_level=2) lodset_out = fluid.layers.lod_reset(x=sum2, y=y) conv7 = conv_bn_layer(lodset_out, 8, 3, "conv7") @@ -79,4 +80,4 @@ class TestPrune(unittest.TestCase): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()