未验证 提交 9e14508c 编写于 作者: Z zhouzj 提交者: GitHub

fix the bug of flatten op in pruning. (#1639)

上级 82da1f14
...@@ -89,8 +89,8 @@ class PruneWorker(object): ...@@ -89,8 +89,8 @@ class PruneWorker(object):
transforms(list<dict>): The transforms applied the the current variable/mask. transforms(list<dict>): The transforms applied the the current variable/mask.
""" """
if var.name() in self.skip_vars: if var.name() in self.skip_vars:
raise UnsupportOpError("Variable {} was skipped.".format(var.name( raise UnsupportOpError(
))) "Variable {} was skipped.".format(var.name()))
if self._visit(var, pruned_axis): if self._visit(var, pruned_axis):
self._prune(var, pruned_axis, transforms) self._prune(var, pruned_axis, transforms)
...@@ -109,8 +109,8 @@ class PruneWorker(object): ...@@ -109,8 +109,8 @@ class PruneWorker(object):
def _visit_and_search(self, var, axis, transforms): def _visit_and_search(self, var, axis, transforms):
self._visit(var, axis) self._visit(var, axis)
if var.name() in self.skip_vars: if var.name() in self.skip_vars:
raise UnsupportOpError("Variable {} was skipped.".format(var.name( raise UnsupportOpError(
))) "Variable {} was skipped.".format(var.name()))
pre_ops = var.inputs() pre_ops = var.inputs()
for op in pre_ops: for op in pre_ops:
self._prune_op(op, var, axis, transforms) self._prune_op(op, var, axis, transforms)
...@@ -127,8 +127,8 @@ class PruneWorker(object): ...@@ -127,8 +127,8 @@ class PruneWorker(object):
if visited is not None: if visited is not None:
self.visited = visited self.visited = visited
if op.type() in self.ops_unsupported: if op.type() in self.ops_unsupported:
raise UnsupportOpError("Unsupported operator named {}".format( raise UnsupportOpError(
op.type())) "Unsupported operator named {}".format(op.type()))
cls = PRUNE_WORKER.get(op.type()) cls = PRUNE_WORKER.get(op.type())
if cls is None: if cls is None:
if op.type() in SKIPPED_OPS: if op.type() in SKIPPED_OPS:
...@@ -136,8 +136,8 @@ class PruneWorker(object): ...@@ -136,8 +136,8 @@ class PruneWorker(object):
if op.type() in OPS_UNCHANGE_SHAPE or not self.skip_stranger: if op.type() in OPS_UNCHANGE_SHAPE or not self.skip_stranger:
cls = PRUNE_WORKER.get("default_worker") cls = PRUNE_WORKER.get("default_worker")
else: else:
raise UnsupportOpError("Unsupported operator named {}".format( raise UnsupportOpError(
op.type())) "Unsupported operator named {}".format(op.type()))
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}\ntrans: {}". _logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}\ntrans: {}".
format(self.op, op, pruned_axis, var.name(), transforms)) format(self.op, op, pruned_axis, var.name(), transforms))
...@@ -665,7 +665,8 @@ class depthwise_conv2d(PruneWorker): ...@@ -665,7 +665,8 @@ class depthwise_conv2d(PruneWorker):
# so it is not neccesary to search succeed operators. # so it is not neccesary to search succeed operators.
# self._visit_and_search(_filter, 1, transforms) # self._visit_and_search(_filter, 1, transforms)
self._visit(_filter, 1) self._visit(_filter, 1)
self._visit_and_search(_out, channel_axis, transforms + [{ self._visit_and_search(_out, channel_axis,
transforms + [{
"repeat": repeat "repeat": repeat
}]) }])
elif var == _filter: elif var == _filter:
...@@ -733,8 +734,9 @@ class mul(PruneWorker): ...@@ -733,8 +734,9 @@ class mul(PruneWorker):
}]) }])
elif var == y: elif var == y:
if (pruned_axis < y_num_col_dims) and ( if (pruned_axis < y_num_col_dims) and (
1 < len(x_shape) - x_num_col_dims) and max(x_shape[ 1 < len(x_shape) - x_num_col_dims
x_num_col_dims:]) != np.prod(y_shape[:y_num_col_dims]): ) and max(x_shape[x_num_col_dims:]) != np.prod(
y_shape[:y_num_col_dims]):
raise UnsupportOpError( raise UnsupportOpError(
"Unsupport pruning y of mul when pruned_axis < y_num_col_dims and 1 < len(x_shape) - x_num_col_dims." "Unsupport pruning y of mul when pruned_axis < y_num_col_dims and 1 < len(x_shape) - x_num_col_dims."
) )
...@@ -763,8 +765,8 @@ class mul(PruneWorker): ...@@ -763,8 +765,8 @@ class mul(PruneWorker):
tile *= y_shape[i] tile *= y_shape[i]
for i in range(pruned_axis + 1, y_num_col_dims): for i in range(pruned_axis + 1, y_num_col_dims):
repeat *= y_shape[i] repeat *= y_shape[i]
new_pruned_axis = int(np.argmax(x_shape[ new_pruned_axis = int(
x_num_col_dims:])) + x_num_col_dims np.argmax(x_shape[x_num_col_dims:])) + x_num_col_dims
self.append_pruned_vars( self.append_pruned_vars(
x, x,
# len(x_shape) - 1, trans + [{ # len(x_shape) - 1, trans + [{
...@@ -825,8 +827,8 @@ class matmul(PruneWorker): ...@@ -825,8 +827,8 @@ class matmul(PruneWorker):
mappings = [(1, 1, 1)] mappings = [(1, 1, 1)]
elif x_shape_len >= 3 and y_shape_len >= 3: elif x_shape_len >= 3 and y_shape_len >= 3:
mappings = [(x_shape_len - 2, -1, x_shape_len - 2), mappings = [(x_shape_len - 2, -1, x_shape_len - 2),
(x_shape_len - 1, x_shape_len - 2, -1), (x_shape_len - 1, x_shape_len - 2,
(-1, x_shape_len - 1, x_shape_len - 1)] -1), (-1, x_shape_len - 1, x_shape_len - 1)]
if var == x: if var == x:
for x_i, y_i, out_i in mappings: for x_i, y_i, out_i in mappings:
if pruned_axis == x_i: if pruned_axis == x_i:
...@@ -953,7 +955,8 @@ class flatten_contiguous_range(PruneWorker): ...@@ -953,7 +955,8 @@ class flatten_contiguous_range(PruneWorker):
out_pruned_axis = pruned_axis out_pruned_axis = pruned_axis
if pruned_axis >= start_axis and pruned_axis <= stop_axis: if pruned_axis >= start_axis and pruned_axis <= stop_axis:
out_pruned_axis = start_axis out_pruned_axis = start_axis
for i in range(pruned_axis + 1, stop_axis + 1): for i in range(start_axis, stop_axis + 1):
if i != pruned_axis:
stride *= in_var.shape()[i] stride *= in_var.shape()[i]
elif pruned_axis > stop_axis: elif pruned_axis > stop_axis:
out_pruned_axis = start_axis + pruned_axis - stop_axis out_pruned_axis = start_axis + pruned_axis - stop_axis
......
...@@ -149,10 +149,9 @@ class MulNet(paddle.nn.Layer): ...@@ -149,10 +149,9 @@ class MulNet(paddle.nn.Layer):
def forward(self, x): def forward(self, x):
conv_a = self.conv_a(x) conv_a = self.conv_a(x)
return paddle.fluid.layers.mul(self.b, tmp = paddle.flatten(conv_a, start_axis=0, stop_axis=2)
conv_a, res = paddle.matmul(self.b, tmp)
x_num_col_dims=1, return res
y_num_col_dims=3)
class TestPruningMul(unittest.TestCase): class TestPruningMul(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册