未验证 提交 9e14508c 编写于 作者: Z zhouzj 提交者: GitHub

fix the bug of flatten op in pruning. (#1639)

上级 82da1f14
......@@ -89,8 +89,8 @@ class PruneWorker(object):
transforms(list<dict>): The transforms applied the the current variable/mask.
"""
if var.name() in self.skip_vars:
raise UnsupportOpError("Variable {} was skipped.".format(var.name(
)))
raise UnsupportOpError(
"Variable {} was skipped.".format(var.name()))
if self._visit(var, pruned_axis):
self._prune(var, pruned_axis, transforms)
......@@ -109,8 +109,8 @@ class PruneWorker(object):
def _visit_and_search(self, var, axis, transforms):
self._visit(var, axis)
if var.name() in self.skip_vars:
raise UnsupportOpError("Variable {} was skipped.".format(var.name(
)))
raise UnsupportOpError(
"Variable {} was skipped.".format(var.name()))
pre_ops = var.inputs()
for op in pre_ops:
self._prune_op(op, var, axis, transforms)
......@@ -127,8 +127,8 @@ class PruneWorker(object):
if visited is not None:
self.visited = visited
if op.type() in self.ops_unsupported:
raise UnsupportOpError("Unsupported operator named {}".format(
op.type()))
raise UnsupportOpError(
"Unsupported operator named {}".format(op.type()))
cls = PRUNE_WORKER.get(op.type())
if cls is None:
if op.type() in SKIPPED_OPS:
......@@ -136,8 +136,8 @@ class PruneWorker(object):
if op.type() in OPS_UNCHANGE_SHAPE or not self.skip_stranger:
cls = PRUNE_WORKER.get("default_worker")
else:
raise UnsupportOpError("Unsupported operator named {}".format(
op.type()))
raise UnsupportOpError(
"Unsupported operator named {}".format(op.type()))
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}\ntrans: {}".
format(self.op, op, pruned_axis, var.name(), transforms))
......@@ -665,7 +665,8 @@ class depthwise_conv2d(PruneWorker):
# so it is not neccesary to search succeed operators.
# self._visit_and_search(_filter, 1, transforms)
self._visit(_filter, 1)
self._visit_and_search(_out, channel_axis, transforms + [{
self._visit_and_search(_out, channel_axis,
transforms + [{
"repeat": repeat
}])
elif var == _filter:
......@@ -733,8 +734,9 @@ class mul(PruneWorker):
}])
elif var == y:
if (pruned_axis < y_num_col_dims) and (
1 < len(x_shape) - x_num_col_dims) and max(x_shape[
x_num_col_dims:]) != np.prod(y_shape[:y_num_col_dims]):
1 < len(x_shape) - x_num_col_dims
) and max(x_shape[x_num_col_dims:]) != np.prod(
y_shape[:y_num_col_dims]):
raise UnsupportOpError(
"Unsupport pruning y of mul when pruned_axis < y_num_col_dims and 1 < len(x_shape) - x_num_col_dims."
)
......@@ -763,8 +765,8 @@ class mul(PruneWorker):
tile *= y_shape[i]
for i in range(pruned_axis + 1, y_num_col_dims):
repeat *= y_shape[i]
new_pruned_axis = int(np.argmax(x_shape[
x_num_col_dims:])) + x_num_col_dims
new_pruned_axis = int(
np.argmax(x_shape[x_num_col_dims:])) + x_num_col_dims
self.append_pruned_vars(
x,
# len(x_shape) - 1, trans + [{
......@@ -825,8 +827,8 @@ class matmul(PruneWorker):
mappings = [(1, 1, 1)]
elif x_shape_len >= 3 and y_shape_len >= 3:
mappings = [(x_shape_len - 2, -1, x_shape_len - 2),
(x_shape_len - 1, x_shape_len - 2, -1),
(-1, x_shape_len - 1, x_shape_len - 1)]
(x_shape_len - 1, x_shape_len - 2,
-1), (-1, x_shape_len - 1, x_shape_len - 1)]
if var == x:
for x_i, y_i, out_i in mappings:
if pruned_axis == x_i:
......@@ -953,7 +955,8 @@ class flatten_contiguous_range(PruneWorker):
out_pruned_axis = pruned_axis
if pruned_axis >= start_axis and pruned_axis <= stop_axis:
out_pruned_axis = start_axis
for i in range(pruned_axis + 1, stop_axis + 1):
for i in range(start_axis, stop_axis + 1):
if i != pruned_axis:
stride *= in_var.shape()[i]
elif pruned_axis > stop_axis:
out_pruned_axis = start_axis + pruned_axis - stop_axis
......
......@@ -149,10 +149,9 @@ class MulNet(paddle.nn.Layer):
def forward(self, x):
conv_a = self.conv_a(x)
return paddle.fluid.layers.mul(self.b,
conv_a,
x_num_col_dims=1,
y_num_col_dims=3)
tmp = paddle.flatten(conv_a, start_axis=0, stop_axis=2)
res = paddle.matmul(self.b, tmp)
return res
class TestPruningMul(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册