未验证 提交 7c8ba912 编写于 作者: W whs 提交者: GitHub

Skip pruning output channels of conv2d_transpose. (#411)

上级 18411ec3
......@@ -14,7 +14,7 @@
# limitations under the License.
from ..core import GraphWrapper
from .prune_walker import conv2d as conv2d_walker
from .prune_walker import PRUNE_WORKER
__all__ = ["collect_convs"]
......@@ -55,8 +55,9 @@ def collect_convs(params, graph, visited={}):
pruned_params = []
param = graph.var(param)
conv_op = param.outputs()[0]
walker = conv2d_walker(
conv_op, pruned_params=pruned_params, visited=visited)
cls = PRUNE_WORKER.get(conv_op.type())
walker = cls(conv_op, pruned_params=pruned_params, visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=[0])
groups.append(pruned_params)
visited = set()
......
......@@ -84,9 +84,7 @@ class PruneWorker(object):
cls = PRUNE_WORKER.get("default_walker")
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format(
self.op, op, pruned_axis, var.name()))
walker = cls(op,
pruned_params=self.pruned_params,
visited=self.visited)
walker = cls(op, pruned_params=self.pruned_params, visited=self.visited)
walker.prune(var, pruned_axis, pruned_idx)
......@@ -175,29 +173,8 @@ class conv2d_transpose(PruneWorker):
self._prune_op(op, filter_var, 0, pruned_idx)
elif var in self.op.inputs("Filter"):
assert pruned_axis in [0, 1]
self.pruned_params.append((var, pruned_axis, pruned_idx))
for op in var.outputs():
self._prune_op(op, var, pruned_axis, pruned_idx)
if pruned_axis == 1:
if len(self.op.inputs("Bias")) > 0:
self.pruned_params.append(
(self.op.inputs("Bias"), channel_axis, pruned_idx))
output_var = self.op.outputs("Output")[0]
self._visit(output_var, channel_axis)
next_ops = output_var.outputs()
for op in next_ops:
self._prune_op(op, output_var, channel_axis, pruned_idx)
elif pruned_axis == 0:
input_var = self.op.inputs("Input")[0]
self._visit(input_var, channel_axis)
pre_ops = input_var.inputs()
for op in pre_ops:
self._prune_op(op, input_var, channel_axis, pruned_idx)
_logger.warn("Skip pruning output channels of conv2d_transpose!")
return
elif var in self.op.outputs("Output"):
assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format(
pruned_axis, var.name())
......
......@@ -41,6 +41,9 @@ class TestPrune(unittest.TestCase):
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
conv7 = fluid.layers.conv2d_transpose(
input=conv6, num_filters=16, filter_size=2, stride=2)
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
......@@ -53,8 +56,8 @@ class TestPrune(unittest.TestCase):
main_program, _, _ = pruner.prune(
main_program,
scope,
params=["conv4_weights"],
ratios=[0.5],
params=["conv4_weights", "conv2d_transpose_0.w_0"],
ratios=[0.5, 0.6],
place=place,
lazy=False,
only_graph=False,
......@@ -67,11 +70,12 @@ class TestPrune(unittest.TestCase):
"conv3_weights": (8, 4, 3, 3),
"conv4_weights": (4, 8, 3, 3),
"conv5_weights": (8, 4, 3, 3),
"conv6_weights": (8, 8, 3, 3)
"conv6_weights": (8, 8, 3, 3),
"conv2d_transpose_0.w_0": (8, 16, 2, 2),
}
for param in main_program.global_block().all_parameters():
if "weights" in param.name:
if param.name in shapes:
print("param: {}; param shape: {}".format(param.name,
param.shape))
self.assertTrue(param.shape == shapes[param.name])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册