未验证 提交 36b38fc3 编写于 作者: Y yukavio 提交者: GitHub

Fix bug when prune depthwise convolution layer. (#399)

* fix bug when prune the depthwise convolution layer

* fix bug when prune the depthwise convolution layer

* fix bug when prune depthwise convolution layer

* fix pruner when prune depthwise convolution layer

* remove print from unit test
上级 d00373ae
......@@ -166,9 +166,7 @@ class OpWrapper(object):
"""
Get all the varibales by the output name.
"""
return [
self._graph.var(var_name) for var_name in self._op.output(name)
]
return [self._graph.var(var_name) for var_name in self._op.output(name)]
def set_attr(self, key, value):
"""
......@@ -354,16 +352,6 @@ class GraphWrapper(object):
ret += np.product(param.shape())
return ret
def update_param_shape(self, scope):
"""
Update the shape of parameters in the graph according to tensors in scope.
It is used after loading pruned parameters from file.
"""
for param in self.all_parameters():
tensor_shape = np.array(
scope.find_var(param.name()).get_tensor()).shape
param.set_shape(tensor_shape)
def infer_shape(self):
"""
Update the groups of convolution layer according to current filters.
......@@ -375,6 +363,6 @@ class GraphWrapper(object):
def update_groups_of_conv(self):
for op in self.ops():
if op.type() == 'depthwise_conv2d' or op.type(
) == 'depthwise_conv2d_grad':
if 'conv2d' in op.type() and op.attr('groups') >= op.inputs(
'Filter')[0].shape()[0]:
op.set_attr('groups', op.inputs('Filter')[0].shape()[0])
......@@ -58,8 +58,7 @@ def collect_convs(params, graph, visited={}):
walker = conv2d_walker(
conv_op, pruned_params=pruned_params, visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=[0])
if len(pruned_params) > 0:
groups.append(pruned_params)
groups.append(pruned_params)
visited = set()
uniq_groups = []
for group in groups:
......
......@@ -42,6 +42,8 @@ class TestPrune(unittest.TestCase):
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
groups = collect_convs(
["conv1_weights", "conv2_weights", "conv3_weights"], main_program)
while [] in groups:
groups.remove([])
self.assertTrue(len(groups) == 2)
self.assertTrue(len(groups[0]) == 18)
self.assertTrue(len(groups[1]) == 6)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册