未验证 提交 14f1d583 编写于 作者: B Bai Yifan 提交者: GitHub

Merge branch 'develop' into pact_clip

...@@ -166,9 +166,7 @@ class OpWrapper(object): ...@@ -166,9 +166,7 @@ class OpWrapper(object):
""" """
Get all the varibales by the output name. Get all the varibales by the output name.
""" """
return [ return [self._graph.var(var_name) for var_name in self._op.output(name)]
self._graph.var(var_name) for var_name in self._op.output(name)
]
def set_attr(self, key, value): def set_attr(self, key, value):
""" """
...@@ -354,16 +352,6 @@ class GraphWrapper(object): ...@@ -354,16 +352,6 @@ class GraphWrapper(object):
ret += np.product(param.shape()) ret += np.product(param.shape())
return ret return ret
def update_param_shape(self, scope):
"""
Update the shape of parameters in the graph according to tensors in scope.
It is used after loading pruned parameters from file.
"""
for param in self.all_parameters():
tensor_shape = np.array(
scope.find_var(param.name()).get_tensor()).shape
param.set_shape(tensor_shape)
def infer_shape(self): def infer_shape(self):
""" """
Update the groups of convolution layer according to current filters. Update the groups of convolution layer according to current filters.
...@@ -375,6 +363,6 @@ class GraphWrapper(object): ...@@ -375,6 +363,6 @@ class GraphWrapper(object):
def update_groups_of_conv(self): def update_groups_of_conv(self):
for op in self.ops(): for op in self.ops():
if op.type() == 'depthwise_conv2d' or op.type( if 'conv2d' in op.type() and op.attr('groups') >= op.inputs(
) == 'depthwise_conv2d_grad': 'Filter')[0].shape()[0]:
op.set_attr('groups', op.inputs('Filter')[0].shape()[0]) op.set_attr('groups', op.inputs('Filter')[0].shape()[0])
...@@ -58,8 +58,7 @@ def collect_convs(params, graph, visited={}): ...@@ -58,8 +58,7 @@ def collect_convs(params, graph, visited={}):
walker = conv2d_walker( walker = conv2d_walker(
conv_op, pruned_params=pruned_params, visited=visited) conv_op, pruned_params=pruned_params, visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=[0]) walker.prune(param, pruned_axis=0, pruned_idx=[0])
if len(pruned_params) > 0: groups.append(pruned_params)
groups.append(pruned_params)
visited = set() visited = set()
uniq_groups = [] uniq_groups = []
for group in groups: for group in groups:
......
...@@ -42,6 +42,8 @@ class TestPrune(unittest.TestCase): ...@@ -42,6 +42,8 @@ class TestPrune(unittest.TestCase):
conv6 = conv_bn_layer(conv5, 8, 3, "conv6") conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
groups = collect_convs( groups = collect_convs(
["conv1_weights", "conv2_weights", "conv3_weights"], main_program) ["conv1_weights", "conv2_weights", "conv3_weights"], main_program)
while [] in groups:
groups.remove([])
self.assertTrue(len(groups) == 2) self.assertTrue(len(groups) == 2)
self.assertTrue(len(groups[0]) == 18) self.assertTrue(len(groups[0]) == 18)
self.assertTrue(len(groups[1]) == 6) self.assertTrue(len(groups[1]) == 6)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册