diff --git a/paddleslim/core/graph_wrapper.py b/paddleslim/core/graph_wrapper.py index 04a85d8f15f6fda3c5f871c5845152cab46322bc..970744ffd6fd060beef43504548bf7b3633356ee 100644 --- a/paddleslim/core/graph_wrapper.py +++ b/paddleslim/core/graph_wrapper.py @@ -166,9 +166,7 @@ class OpWrapper(object): """ Get all the varibales by the output name. """ - return [ - self._graph.var(var_name) for var_name in self._op.output(name) - ] + return [self._graph.var(var_name) for var_name in self._op.output(name)] def set_attr(self, key, value): """ @@ -354,16 +352,6 @@ class GraphWrapper(object): ret += np.product(param.shape()) return ret - def update_param_shape(self, scope): - """ - Update the shape of parameters in the graph according to tensors in scope. - It is used after loading pruned parameters from file. - """ - for param in self.all_parameters(): - tensor_shape = np.array( - scope.find_var(param.name()).get_tensor()).shape - param.set_shape(tensor_shape) - def infer_shape(self): """ Update the groups of convolution layer according to current filters. @@ -375,6 +363,6 @@ class GraphWrapper(object): def update_groups_of_conv(self): for op in self.ops(): - if op.type() == 'depthwise_conv2d' or op.type( - ) == 'depthwise_conv2d_grad': + if 'conv2d' in op.type() and op.attr('groups') >= op.inputs( + 'Filter')[0].shape()[0]: op.set_attr('groups', op.inputs('Filter')[0].shape()[0]) diff --git a/paddleslim/prune/group_param.py b/paddleslim/prune/group_param.py index f861a7e5319322d513c93992ade8b762e87818ab..c237b84685b90567738c55e8d1611357481aeb2c 100644 --- a/paddleslim/prune/group_param.py +++ b/paddleslim/prune/group_param.py @@ -58,8 +58,7 @@ def collect_convs(params, graph, visited={}): walker = conv2d_walker( conv_op, pruned_params=pruned_params, visited=visited) walker.prune(param, pruned_axis=0, pruned_idx=[0]) - if len(pruned_params) > 0: - groups.append(pruned_params) + groups.append(pruned_params) visited = set() uniq_groups = [] for group in groups: diff --git a/tests/test_group_param.py b/tests/test_group_param.py index cd699bfd68bf2d28e670a03f9200944be9b1a562..f0fb73611b8143ae7e570f213382cc7931130ff2 100644 --- a/tests/test_group_param.py +++ b/tests/test_group_param.py @@ -42,6 +42,8 @@ class TestPrune(unittest.TestCase): conv6 = conv_bn_layer(conv5, 8, 3, "conv6") groups = collect_convs( ["conv1_weights", "conv2_weights", "conv3_weights"], main_program) + while [] in groups: + groups.remove([]) self.assertTrue(len(groups) == 2) self.assertTrue(len(groups[0]) == 18) self.assertTrue(len(groups[1]) == 6)