“e983cc90fcee4e5b73bce9d4853b85aac4661e3a”上不存在“...paddle/v2/fluid/tests/unittests/test_sequence_expand.py”
未验证 提交 7e3e14ae 编写于 作者: Z zhouzj 提交者: GitHub

Fix the bug of pruning dw_conv. (#1311)

上级 419ffd3b
...@@ -522,7 +522,6 @@ class depthwise_conv2d(PruneWorker): ...@@ -522,7 +522,6 @@ class depthwise_conv2d(PruneWorker):
channel_axis = 1 channel_axis = 1
if data_format == "NHWC": if data_format == "NHWC":
channel_axis = 3 channel_axis = 3
if var == _in_var: if var == _in_var:
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format( assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format(
pruned_axis) pruned_axis)
...@@ -533,7 +532,6 @@ class depthwise_conv2d(PruneWorker): ...@@ -533,7 +532,6 @@ class depthwise_conv2d(PruneWorker):
"repeat": repeat "repeat": repeat
}]) }])
# kernel_number * groups will be pruned by reducing groups # kernel_number * groups will be pruned by reducing groups
self.append_pruned_vars(_filter, 1, transforms)
self._visit_and_search(_filter, 0, transforms + [{ self._visit_and_search(_filter, 0, transforms + [{
"repeat": repeat "repeat": repeat
}]) }])
...@@ -546,14 +544,13 @@ class depthwise_conv2d(PruneWorker): ...@@ -546,14 +544,13 @@ class depthwise_conv2d(PruneWorker):
}]) }])
elif var == _filter: elif var == _filter:
assert pruned_axis == 0, "The filter of depthwise conv2d can only be pruned at axis 0." assert pruned_axis == 0, "The filter of depthwise conv2d can only be pruned at axis 0."
self.append_pruned_vars(_filter, 1, transforms) self.append_pruned_vars(_filter, 0, transforms)
self._visit_and_search(_in_var, channel_axis, transforms) self._visit_and_search(_in_var, channel_axis, transforms)
self._visit_and_search(_out, channel_axis, transforms) self._visit_and_search(_out, channel_axis, transforms)
elif var == _out: elif var == _out:
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format( assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}".format(
pruned_axis) pruned_axis)
self.append_pruned_vars(_filter, 0, transforms) self.append_pruned_vars(_filter, 0, transforms)
self.append_pruned_vars(_filter, 1, transforms)
self._visit_and_search(_filter, 0, transforms) self._visit_and_search(_filter, 0, transforms)
# It will not pruning number of kernels in depthwise conv2d, # It will not pruning number of kernels in depthwise conv2d,
# so it is not neccesary to search succeed operators. # so it is not neccesary to search succeed operators.
......
...@@ -117,7 +117,7 @@ class Pruner(): ...@@ -117,7 +117,7 @@ class Pruner():
_groups = 1 _groups = 1
if not lazy: if not lazy:
# update groups of conv2d # update groups of conv2d
if pruned_axis == 1: if pruned_axis == 1 or pruned_axis == 0:
for op in param.outputs(): for op in param.outputs():
if op.type() in [ if op.type() in [
"conv2d", "conv2d_grad", "depthwise_conv2d", "conv2d", "conv2d_grad", "depthwise_conv2d",
...@@ -132,7 +132,7 @@ class Pruner(): ...@@ -132,7 +132,7 @@ class Pruner():
f"change groups of {op.type()}({param.name()}) from {op.attr('groups')} to {new_groups};" f"change groups of {op.type()}({param.name()}) from {op.attr('groups')} to {new_groups};"
) )
op.set_attr("groups", new_groups) op.set_attr("groups", new_groups)
if _groups == 1:
origin_shape = copy.deepcopy(param.shape()) origin_shape = copy.deepcopy(param.shape())
if param_shape_backup is not None: if param_shape_backup is not None:
param_shape_backup[param.name()] = origin_shape param_shape_backup[param.name()] = origin_shape
......
...@@ -473,9 +473,9 @@ class TestDepthwiseConv2d(TestPruneWorker): ...@@ -473,9 +473,9 @@ class TestDepthwiseConv2d(TestPruneWorker):
def set_cases(self): def set_cases(self):
weight_var = self.graph.var('conv1.w_0') weight_var = self.graph.var('conv1.w_0')
self.cases.append((self.in_var, 1, {'conv1.w_0': [0, 1]})) self.cases.append((self.in_var, 1, {'conv1.w_0': [0]}))
self.cases.append((self.out_var, 1, {'conv1.w_0': [0, 1]})) self.cases.append((self.out_var, 1, {'conv1.w_0': [0]}))
self.cases.append((weight_var, 0, {'conv1.w_0': [1]})) self.cases.append((weight_var, 0, {'conv1.w_0': [0]}))
def test_prune(self): def test_prune(self):
self.check_in_out() self.check_in_out()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册