未验证 提交 b735a396 编写于 作者: W whs 提交者: GitHub

Fix pruning to support reshape2 for bias (#1700)

上级 2fd40962
......@@ -208,8 +208,7 @@ class SuperConv2D(paddle.nn.Conv2D):
filters = self.weight
else:
filters = self.weight[:out_nc, :in_nc, start:end, start:end]
if self.transform_kernel != False and kernel_size < self._kernel_size[
0]:
if self.transform_kernel != False and kernel_size < self._kernel_size[0]:
### if transform kernel, then use matrix to transform
start_filter = self.weight[:out_nc, :in_nc, :, :]
for i in range(len(self.ks_set) - 1, 0, -1):
......@@ -223,10 +222,11 @@ class SuperConv2D(paddle.nn.Conv2D):
_input_filter,
shape=[(_input_filter.shape[0] * _input_filter.shape[1]),
-1])
_input_filter = paddle.matmul(
_input_filter,
self.__getattr__('%dto%d_matrix' %
(src_ks, target_ks)), False, False)
_input_filter = paddle.matmul(_input_filter,
self.__getattr__(
'%dto%d_matrix' %
(src_ks, target_ks)), False,
False)
_input_filter = paddle.reshape(
_input_filter,
shape=[
......@@ -279,11 +279,11 @@ class SuperConv2D(paddle.nn.Conv2D):
out_nc = int(channel)
else:
out_nc = self._out_channels
ks = int(self._kernel_size[0]) if kernel_size == None else int(
kernel_size)
ks = int(
self._kernel_size[0]) if kernel_size == None else int(kernel_size)
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc,
out_nc)
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(
in_nc, out_nc)
weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks)
......@@ -513,8 +513,7 @@ class SuperConv2DTranspose(paddle.nn.Conv2DTranspose):
def get_active_filter(self, in_nc, out_nc, kernel_size):
start, end = compute_start_end(self._kernel_size[0], kernel_size)
filters = self.weight[:in_nc, :out_nc, start:end, start:end]
if self.transform_kernel != False and kernel_size < self._kernel_size[
0]:
if self.transform_kernel != False and kernel_size < self._kernel_size[0]:
start_filter = self.weight[:in_nc, :out_nc, :, :]
for i in range(len(self.ks_set) - 1, 0, -1):
src_ks = self.ks_set[i]
......@@ -527,10 +526,11 @@ class SuperConv2DTranspose(paddle.nn.Conv2DTranspose):
_input_filter,
shape=[(_input_filter.shape[0] * _input_filter.shape[1]),
-1])
_input_filter = paddle.matmul(
_input_filter,
self.__getattr__('%dto%d_matrix' %
(src_ks, target_ks)), False, False)
_input_filter = paddle.matmul(_input_filter,
self.__getattr__(
'%dto%d_matrix' %
(src_ks, target_ks)), False,
False)
_input_filter = paddle.reshape(
_input_filter,
shape=[
......@@ -590,11 +590,11 @@ class SuperConv2DTranspose(paddle.nn.Conv2DTranspose):
else:
out_nc = self._out_channels
ks = int(self._kernel_size[0]) if kernel_size == None else int(
kernel_size)
ks = int(
self._kernel_size[0]) if kernel_size == None else int(kernel_size)
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc,
out_nc)
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(
in_nc, out_nc)
weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks)
......@@ -731,8 +731,8 @@ class SuperSeparableConv2D(paddle.nn.Layer):
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self.conv[0]._out_channels
if self.expand_ratio != None:
self.base_output_dim = int(self.conv[0]._out_channels /
max(self.expand_ratio))
self.base_output_dim = int(
self.conv[0]._out_channels / max(self.expand_ratio))
def forward(self, input, expand_ratio=None, channel=None):
"""
......@@ -863,8 +863,8 @@ class SuperLinear(paddle.nn.Linear):
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self._out_features
if self.expand_ratio != None:
self.base_output_dim = int(self._out_features /
max(self.expand_ratio))
self.base_output_dim = int(
self._out_features / max(self.expand_ratio))
def forward(self, input, expand_ratio=None, channel=None):
"""
......@@ -941,9 +941,9 @@ class SuperBatchNorm2D(paddle.nn.BatchNorm2D):
data_format='NCHW',
use_global_stats=None,
name=None):
super(SuperBatchNorm2D, self).__init__(
num_features, momentum, epsilon, weight_attr, bias_attr,
data_format, use_global_stats, name)
super(SuperBatchNorm2D,
self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, use_global_stats, name)
self.cur_config = None
def forward(self, input):
......@@ -1047,8 +1047,7 @@ class SuperBatchNorm2D(paddle.nn.BatchNorm2D):
"Variance": [variance]
}
helper = paddle.fluid.dygraph.layer_object_helper.LayerObjectHelper(
'batch_norm')
helper = paddle.fluid.layer_helper.LayerHelper('batch_norm')
param_dtype = input.dtype if input.dtype != 'float16' else 'float32'
saved_mean = helper.create_variable_for_type_inference(
......@@ -1150,8 +1149,7 @@ class SuperSyncBatchNorm(paddle.nn.SyncBatchNorm):
"Variance": [self._variance]
}
helper = paddle.fluid.dygraph.layer_object_helper.LayerObjectHelper(
'sync_batch_norm')
helper = paddle.fluid.layer_helper.LayerHelper('sync_batch_norm')
saved_mean = helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
......@@ -1211,8 +1209,8 @@ class SuperInstanceNorm2D(paddle.nn.InstanceNorm2D):
bias_attr=None,
data_format='NCHW',
name=None):
super(SuperInstanceNorm2D, self).__init__(num_features, epsilon,
momentum, weight_attr,
super(SuperInstanceNorm2D,
self).__init__(num_features, epsilon, momentum, weight_attr,
bias_attr, data_format, name)
self.cur_config = None
......@@ -1319,8 +1317,7 @@ class SuperLayerNorm(paddle.nn.LayerNorm):
"begin_norm_axis": begin_norm_axis
}
helper = paddle.fluid.dygraph.layer_object_helper.LayerObjectHelper(
'layer_norm')
helper = paddle.fluid.layer_helper.LayerHelper('layer_norm')
dtype = input.dtype
mean_out = helper.create_variable_for_type_inference(
......@@ -1399,17 +1396,17 @@ class SuperEmbedding(paddle.nn.Embedding):
sparse=False,
weight_attr=None,
name=None):
super(SuperEmbedding, self).__init__(num_embeddings, embedding_dim,
padding_idx, sparse, weight_attr,
name)
super(SuperEmbedding,
self).__init__(num_embeddings, embedding_dim, padding_idx, sparse,
weight_attr, name)
self.candidate_config = candidate_config
self.cur_config = None
self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self._embedding_dim
if self.expand_ratio != None:
self.base_output_dim = int(self._embedding_dim /
max(self.expand_ratio))
self.base_output_dim = int(
self._embedding_dim / max(self.expand_ratio))
def forward(self, input, expand_ratio=None, channel=None):
"""
......
......@@ -233,7 +233,7 @@ class reshape2(PruneWorker):
assert self._valid_reshape2(
shape), "we don't support the shape {} in pruning".format(shape)
# assert self._valid_pruned_axis(shape, pruned_axis), "we don't support pruned axis is {} when shape is changing from {} to {}".format(pruned_axis, in_shape, out_shape)
self.append_pruned_vars(xshape_var, pruned_axis + 1, transforms)
# self.append_pruned_vars(xshape_var, pruned_axis + 1, transforms)
if var in self.op.inputs("X"):
if (len(out_shape) > len(in_shape)):
#self.op.set_attr('shape',
......@@ -254,6 +254,10 @@ class reshape2(PruneWorker):
#self.op.set_attr('shape',
# [0, 0, int(shape[2] * 0.875), shape[3]])
transform = {"repeat": out_shape[pruned_axis + 1]}
elif len(in_shape) == 1 and len(
out_shape) == 4 and out_shape[pruned_axis] == in_shape[0]:
transform = {}
self.append_pruned_vars(in_var, 0, transforms)
else:
transform = {}
self._visit_and_search(in_var, pruned_axis,
......
......@@ -50,6 +50,14 @@ class Pruner():
self.pruned_weights = False
def _update_reshape_op(self, param: VarWrapper, op: OpWrapper, new_shape):
if op.type() == 'reshape2':
_param_shape = param.shape()
_shape_attr = op.attr('shape')
if len(_param_shape) == 1 and _param_shape[0] == _shape_attr[1]:
_shape_attr[1] = new_shape[0]
op.set_attr("shape", _shape_attr)
def prune(self,
program,
scope,
......@@ -111,8 +119,8 @@ class Pruner():
merge_pruned_params[param][pruned_axis].append(pruned_idx)
for param_name in merge_pruned_params:
for pruned_axis in merge_pruned_params[param_name]:
pruned_idx = np.concatenate(merge_pruned_params[param_name][
pruned_axis])
pruned_idx = np.concatenate(
merge_pruned_params[param_name][pruned_axis])
param = graph.var(param_name)
_groups = 1
if not lazy:
......@@ -138,6 +146,7 @@ class Pruner():
param_shape_backup[param.name()] = origin_shape
new_shape = list(param.shape())
new_shape[pruned_axis] -= len(pruned_idx)
self._update_reshape_op(param, op, new_shape)
param.set_shape(new_shape)
if not only_graph and (_groups == 1 or pruned_axis != 1):
......@@ -159,8 +168,8 @@ class Pruner():
except IndexError as e:
_logger.error(
"Pruning {} with shape {} on axis {}, but get [{}]; ".
format(param.name(),
param_t.shape(), pruned_axis, e))
format(param.name(), param_t.shape(), pruned_axis,
e))
graph.infer_shape()
self.pruned_weights = (not only_graph)
......
......@@ -25,6 +25,7 @@ class TestWalker(unittest.TestCase):
x = np.random.uniform(-1, 1, x_shape).astype('float32')
pruner = L1NormFilterPruner(net, [paddle.to_tensor(x)])
pruner.prune_vars({"conv2d_0.w_0": 0.2}, 0)
net(paddle.to_tensor(x))
self.assertTrue(net.linear.weight.shape == [5400, 5])
......
......@@ -32,8 +32,8 @@ class TestEagerDygraph2Program(unittest.TestCase):
def prepare_inputs(self):
self.inputs = [3, 28, 28]
self.ops = [
'assign_value', 'reshape2', 'conv2d', 'elementwise_add', 'pool2d',
'reshape2', 'matmul_v2', 'elementwise_add'
'assign_value', 'reshape2', 'conv2d', 'reshape2', 'elementwise_add',
'pool2d', 'reshape2', 'matmul_v2', 'elementwise_add'
]
def prepare_layer(self):
......@@ -51,8 +51,8 @@ class TestEagerDygraph2Program2(TestEagerDygraph2Program):
def prepare_inputs(self):
self.inputs = [[3, 28, 28]]
self.ops = [
'assign_value', 'reshape2', 'conv2d', 'elementwise_add', 'pool2d',
'reshape2', 'matmul_v2', 'elementwise_add'
'assign_value', 'reshape2', 'conv2d', 'reshape2', 'elementwise_add',
'pool2d', 'reshape2', 'matmul_v2', 'elementwise_add'
]
......@@ -60,8 +60,8 @@ class TestEagerDygraph2Program3(TestEagerDygraph2Program):
def prepare_inputs(self):
self.inputs = paddle.randn([3, 28, 28])
self.ops = [
'reshape2', 'conv2d', 'elementwise_add', 'pool2d', 'reshape2',
'matmul_v2', 'elementwise_add'
'reshape2', 'conv2d', 'reshape2', 'elementwise_add', 'pool2d',
'reshape2', 'matmul_v2', 'elementwise_add'
]
......@@ -69,8 +69,8 @@ class TestEagerDygraph2Program4(TestEagerDygraph2Program):
def prepare_inputs(self):
self.inputs = [paddle.randn([3, 28, 28])]
self.ops = [
'reshape2', 'conv2d', 'elementwise_add', 'pool2d', 'reshape2',
'matmul_v2', 'elementwise_add'
'reshape2', 'conv2d', 'reshape2', 'elementwise_add', 'pool2d',
'reshape2', 'matmul_v2', 'elementwise_add'
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册