未验证 提交 b735a396 编写于 作者: W whs 提交者: GitHub

Fix pruning to support reshape2 for bias (#1700)

上级 2fd40962
...@@ -208,8 +208,7 @@ class SuperConv2D(paddle.nn.Conv2D): ...@@ -208,8 +208,7 @@ class SuperConv2D(paddle.nn.Conv2D):
filters = self.weight filters = self.weight
else: else:
filters = self.weight[:out_nc, :in_nc, start:end, start:end] filters = self.weight[:out_nc, :in_nc, start:end, start:end]
if self.transform_kernel != False and kernel_size < self._kernel_size[ if self.transform_kernel != False and kernel_size < self._kernel_size[0]:
0]:
### if transform kernel, then use matrix to transform ### if transform kernel, then use matrix to transform
start_filter = self.weight[:out_nc, :in_nc, :, :] start_filter = self.weight[:out_nc, :in_nc, :, :]
for i in range(len(self.ks_set) - 1, 0, -1): for i in range(len(self.ks_set) - 1, 0, -1):
...@@ -223,10 +222,11 @@ class SuperConv2D(paddle.nn.Conv2D): ...@@ -223,10 +222,11 @@ class SuperConv2D(paddle.nn.Conv2D):
_input_filter, _input_filter,
shape=[(_input_filter.shape[0] * _input_filter.shape[1]), shape=[(_input_filter.shape[0] * _input_filter.shape[1]),
-1]) -1])
_input_filter = paddle.matmul( _input_filter = paddle.matmul(_input_filter,
_input_filter, self.__getattr__(
self.__getattr__('%dto%d_matrix' % '%dto%d_matrix' %
(src_ks, target_ks)), False, False) (src_ks, target_ks)), False,
False)
_input_filter = paddle.reshape( _input_filter = paddle.reshape(
_input_filter, _input_filter,
shape=[ shape=[
...@@ -279,11 +279,11 @@ class SuperConv2D(paddle.nn.Conv2D): ...@@ -279,11 +279,11 @@ class SuperConv2D(paddle.nn.Conv2D):
out_nc = int(channel) out_nc = int(channel)
else: else:
out_nc = self._out_channels out_nc = self._out_channels
ks = int(self._kernel_size[0]) if kernel_size == None else int( ks = int(
kernel_size) self._kernel_size[0]) if kernel_size == None else int(kernel_size)
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc, groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(
out_nc) in_nc, out_nc)
weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks) weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks)
...@@ -293,7 +293,7 @@ class SuperConv2D(paddle.nn.Conv2D): ...@@ -293,7 +293,7 @@ class SuperConv2D(paddle.nn.Conv2D):
padding = self._padding padding = self._padding
if self.bias is not None: if self.bias is not None:
### if conv is depthwise conv, expand_ratio=0, but conv' expand ### if conv is depthwise conv, expand_ratio=0, but conv' expand
### ratio before depthwise conv is not equal to 1.0, the shape of the weight ### ratio before depthwise conv is not equal to 1.0, the shape of the weight
### about this depthwise conv is changed, but out_nc is not change, ### about this depthwise conv is changed, but out_nc is not change,
### so need to change bias shape according to the weight_out_nc. ### so need to change bias shape according to the weight_out_nc.
...@@ -513,8 +513,7 @@ class SuperConv2DTranspose(paddle.nn.Conv2DTranspose): ...@@ -513,8 +513,7 @@ class SuperConv2DTranspose(paddle.nn.Conv2DTranspose):
def get_active_filter(self, in_nc, out_nc, kernel_size): def get_active_filter(self, in_nc, out_nc, kernel_size):
start, end = compute_start_end(self._kernel_size[0], kernel_size) start, end = compute_start_end(self._kernel_size[0], kernel_size)
filters = self.weight[:in_nc, :out_nc, start:end, start:end] filters = self.weight[:in_nc, :out_nc, start:end, start:end]
if self.transform_kernel != False and kernel_size < self._kernel_size[ if self.transform_kernel != False and kernel_size < self._kernel_size[0]:
0]:
start_filter = self.weight[:in_nc, :out_nc, :, :] start_filter = self.weight[:in_nc, :out_nc, :, :]
for i in range(len(self.ks_set) - 1, 0, -1): for i in range(len(self.ks_set) - 1, 0, -1):
src_ks = self.ks_set[i] src_ks = self.ks_set[i]
...@@ -527,10 +526,11 @@ class SuperConv2DTranspose(paddle.nn.Conv2DTranspose): ...@@ -527,10 +526,11 @@ class SuperConv2DTranspose(paddle.nn.Conv2DTranspose):
_input_filter, _input_filter,
shape=[(_input_filter.shape[0] * _input_filter.shape[1]), shape=[(_input_filter.shape[0] * _input_filter.shape[1]),
-1]) -1])
_input_filter = paddle.matmul( _input_filter = paddle.matmul(_input_filter,
_input_filter, self.__getattr__(
self.__getattr__('%dto%d_matrix' % '%dto%d_matrix' %
(src_ks, target_ks)), False, False) (src_ks, target_ks)), False,
False)
_input_filter = paddle.reshape( _input_filter = paddle.reshape(
_input_filter, _input_filter,
shape=[ shape=[
...@@ -590,11 +590,11 @@ class SuperConv2DTranspose(paddle.nn.Conv2DTranspose): ...@@ -590,11 +590,11 @@ class SuperConv2DTranspose(paddle.nn.Conv2DTranspose):
else: else:
out_nc = self._out_channels out_nc = self._out_channels
ks = int(self._kernel_size[0]) if kernel_size == None else int( ks = int(
kernel_size) self._kernel_size[0]) if kernel_size == None else int(kernel_size)
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc, groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(
out_nc) in_nc, out_nc)
weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks) weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks)
...@@ -731,8 +731,8 @@ class SuperSeparableConv2D(paddle.nn.Layer): ...@@ -731,8 +731,8 @@ class SuperSeparableConv2D(paddle.nn.Layer):
'expand_ratio'] if 'expand_ratio' in candidate_config else None 'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self.conv[0]._out_channels self.base_output_dim = self.conv[0]._out_channels
if self.expand_ratio != None: if self.expand_ratio != None:
self.base_output_dim = int(self.conv[0]._out_channels / self.base_output_dim = int(
max(self.expand_ratio)) self.conv[0]._out_channels / max(self.expand_ratio))
def forward(self, input, expand_ratio=None, channel=None): def forward(self, input, expand_ratio=None, channel=None):
""" """
...@@ -863,8 +863,8 @@ class SuperLinear(paddle.nn.Linear): ...@@ -863,8 +863,8 @@ class SuperLinear(paddle.nn.Linear):
'expand_ratio'] if 'expand_ratio' in candidate_config else None 'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self._out_features self.base_output_dim = self._out_features
if self.expand_ratio != None: if self.expand_ratio != None:
self.base_output_dim = int(self._out_features / self.base_output_dim = int(
max(self.expand_ratio)) self._out_features / max(self.expand_ratio))
def forward(self, input, expand_ratio=None, channel=None): def forward(self, input, expand_ratio=None, channel=None):
""" """
...@@ -941,9 +941,9 @@ class SuperBatchNorm2D(paddle.nn.BatchNorm2D): ...@@ -941,9 +941,9 @@ class SuperBatchNorm2D(paddle.nn.BatchNorm2D):
data_format='NCHW', data_format='NCHW',
use_global_stats=None, use_global_stats=None,
name=None): name=None):
super(SuperBatchNorm2D, self).__init__( super(SuperBatchNorm2D,
num_features, momentum, epsilon, weight_attr, bias_attr, self).__init__(num_features, momentum, epsilon, weight_attr,
data_format, use_global_stats, name) bias_attr, data_format, use_global_stats, name)
self.cur_config = None self.cur_config = None
def forward(self, input): def forward(self, input):
...@@ -1047,8 +1047,7 @@ class SuperBatchNorm2D(paddle.nn.BatchNorm2D): ...@@ -1047,8 +1047,7 @@ class SuperBatchNorm2D(paddle.nn.BatchNorm2D):
"Variance": [variance] "Variance": [variance]
} }
helper = paddle.fluid.dygraph.layer_object_helper.LayerObjectHelper( helper = paddle.fluid.layer_helper.LayerHelper('batch_norm')
'batch_norm')
param_dtype = input.dtype if input.dtype != 'float16' else 'float32' param_dtype = input.dtype if input.dtype != 'float16' else 'float32'
saved_mean = helper.create_variable_for_type_inference( saved_mean = helper.create_variable_for_type_inference(
...@@ -1150,8 +1149,7 @@ class SuperSyncBatchNorm(paddle.nn.SyncBatchNorm): ...@@ -1150,8 +1149,7 @@ class SuperSyncBatchNorm(paddle.nn.SyncBatchNorm):
"Variance": [self._variance] "Variance": [self._variance]
} }
helper = paddle.fluid.dygraph.layer_object_helper.LayerObjectHelper( helper = paddle.fluid.layer_helper.LayerHelper('sync_batch_norm')
'sync_batch_norm')
saved_mean = helper.create_variable_for_type_inference( saved_mean = helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True) dtype=self._dtype, stop_gradient=True)
...@@ -1211,9 +1209,9 @@ class SuperInstanceNorm2D(paddle.nn.InstanceNorm2D): ...@@ -1211,9 +1209,9 @@ class SuperInstanceNorm2D(paddle.nn.InstanceNorm2D):
bias_attr=None, bias_attr=None,
data_format='NCHW', data_format='NCHW',
name=None): name=None):
super(SuperInstanceNorm2D, self).__init__(num_features, epsilon, super(SuperInstanceNorm2D,
momentum, weight_attr, self).__init__(num_features, epsilon, momentum, weight_attr,
bias_attr, data_format, name) bias_attr, data_format, name)
self.cur_config = None self.cur_config = None
def forward(self, input): def forward(self, input):
...@@ -1319,8 +1317,7 @@ class SuperLayerNorm(paddle.nn.LayerNorm): ...@@ -1319,8 +1317,7 @@ class SuperLayerNorm(paddle.nn.LayerNorm):
"begin_norm_axis": begin_norm_axis "begin_norm_axis": begin_norm_axis
} }
helper = paddle.fluid.dygraph.layer_object_helper.LayerObjectHelper( helper = paddle.fluid.layer_helper.LayerHelper('layer_norm')
'layer_norm')
dtype = input.dtype dtype = input.dtype
mean_out = helper.create_variable_for_type_inference( mean_out = helper.create_variable_for_type_inference(
...@@ -1399,17 +1396,17 @@ class SuperEmbedding(paddle.nn.Embedding): ...@@ -1399,17 +1396,17 @@ class SuperEmbedding(paddle.nn.Embedding):
sparse=False, sparse=False,
weight_attr=None, weight_attr=None,
name=None): name=None):
super(SuperEmbedding, self).__init__(num_embeddings, embedding_dim, super(SuperEmbedding,
padding_idx, sparse, weight_attr, self).__init__(num_embeddings, embedding_dim, padding_idx, sparse,
name) weight_attr, name)
self.candidate_config = candidate_config self.candidate_config = candidate_config
self.cur_config = None self.cur_config = None
self.expand_ratio = candidate_config[ self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None 'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self._embedding_dim self.base_output_dim = self._embedding_dim
if self.expand_ratio != None: if self.expand_ratio != None:
self.base_output_dim = int(self._embedding_dim / self.base_output_dim = int(
max(self.expand_ratio)) self._embedding_dim / max(self.expand_ratio))
def forward(self, input, expand_ratio=None, channel=None): def forward(self, input, expand_ratio=None, channel=None):
""" """
......
...@@ -233,7 +233,7 @@ class reshape2(PruneWorker): ...@@ -233,7 +233,7 @@ class reshape2(PruneWorker):
assert self._valid_reshape2( assert self._valid_reshape2(
shape), "we don't support the shape {} in pruning".format(shape) shape), "we don't support the shape {} in pruning".format(shape)
# assert self._valid_pruned_axis(shape, pruned_axis), "we don't support pruned axis is {} when shape is changing from {} to {}".format(pruned_axis, in_shape, out_shape) # assert self._valid_pruned_axis(shape, pruned_axis), "we don't support pruned axis is {} when shape is changing from {} to {}".format(pruned_axis, in_shape, out_shape)
self.append_pruned_vars(xshape_var, pruned_axis + 1, transforms) # self.append_pruned_vars(xshape_var, pruned_axis + 1, transforms)
if var in self.op.inputs("X"): if var in self.op.inputs("X"):
if (len(out_shape) > len(in_shape)): if (len(out_shape) > len(in_shape)):
#self.op.set_attr('shape', #self.op.set_attr('shape',
...@@ -254,6 +254,10 @@ class reshape2(PruneWorker): ...@@ -254,6 +254,10 @@ class reshape2(PruneWorker):
#self.op.set_attr('shape', #self.op.set_attr('shape',
# [0, 0, int(shape[2] * 0.875), shape[3]]) # [0, 0, int(shape[2] * 0.875), shape[3]])
transform = {"repeat": out_shape[pruned_axis + 1]} transform = {"repeat": out_shape[pruned_axis + 1]}
elif len(in_shape) == 1 and len(
out_shape) == 4 and out_shape[pruned_axis] == in_shape[0]:
transform = {}
self.append_pruned_vars(in_var, 0, transforms)
else: else:
transform = {} transform = {}
self._visit_and_search(in_var, pruned_axis, self._visit_and_search(in_var, pruned_axis,
......
...@@ -50,6 +50,14 @@ class Pruner(): ...@@ -50,6 +50,14 @@ class Pruner():
self.pruned_weights = False self.pruned_weights = False
def _update_reshape_op(self, param: VarWrapper, op: OpWrapper, new_shape):
if op.type() == 'reshape2':
_param_shape = param.shape()
_shape_attr = op.attr('shape')
if len(_param_shape) == 1 and _param_shape[0] == _shape_attr[1]:
_shape_attr[1] = new_shape[0]
op.set_attr("shape", _shape_attr)
def prune(self, def prune(self,
program, program,
scope, scope,
...@@ -111,8 +119,8 @@ class Pruner(): ...@@ -111,8 +119,8 @@ class Pruner():
merge_pruned_params[param][pruned_axis].append(pruned_idx) merge_pruned_params[param][pruned_axis].append(pruned_idx)
for param_name in merge_pruned_params: for param_name in merge_pruned_params:
for pruned_axis in merge_pruned_params[param_name]: for pruned_axis in merge_pruned_params[param_name]:
pruned_idx = np.concatenate(merge_pruned_params[param_name][ pruned_idx = np.concatenate(
pruned_axis]) merge_pruned_params[param_name][pruned_axis])
param = graph.var(param_name) param = graph.var(param_name)
_groups = 1 _groups = 1
if not lazy: if not lazy:
...@@ -138,6 +146,7 @@ class Pruner(): ...@@ -138,6 +146,7 @@ class Pruner():
param_shape_backup[param.name()] = origin_shape param_shape_backup[param.name()] = origin_shape
new_shape = list(param.shape()) new_shape = list(param.shape())
new_shape[pruned_axis] -= len(pruned_idx) new_shape[pruned_axis] -= len(pruned_idx)
self._update_reshape_op(param, op, new_shape)
param.set_shape(new_shape) param.set_shape(new_shape)
if not only_graph and (_groups == 1 or pruned_axis != 1): if not only_graph and (_groups == 1 or pruned_axis != 1):
...@@ -159,8 +168,8 @@ class Pruner(): ...@@ -159,8 +168,8 @@ class Pruner():
except IndexError as e: except IndexError as e:
_logger.error( _logger.error(
"Pruning {} with shape {} on axis {}, but get [{}]; ". "Pruning {} with shape {} on axis {}, but get [{}]; ".
format(param.name(), format(param.name(), param_t.shape(), pruned_axis,
param_t.shape(), pruned_axis, e)) e))
graph.infer_shape() graph.infer_shape()
self.pruned_weights = (not only_graph) self.pruned_weights = (not only_graph)
......
...@@ -25,6 +25,7 @@ class TestWalker(unittest.TestCase): ...@@ -25,6 +25,7 @@ class TestWalker(unittest.TestCase):
x = np.random.uniform(-1, 1, x_shape).astype('float32') x = np.random.uniform(-1, 1, x_shape).astype('float32')
pruner = L1NormFilterPruner(net, [paddle.to_tensor(x)]) pruner = L1NormFilterPruner(net, [paddle.to_tensor(x)])
pruner.prune_vars({"conv2d_0.w_0": 0.2}, 0) pruner.prune_vars({"conv2d_0.w_0": 0.2}, 0)
net(paddle.to_tensor(x))
self.assertTrue(net.linear.weight.shape == [5400, 5]) self.assertTrue(net.linear.weight.shape == [5400, 5])
......
...@@ -32,8 +32,8 @@ class TestEagerDygraph2Program(unittest.TestCase): ...@@ -32,8 +32,8 @@ class TestEagerDygraph2Program(unittest.TestCase):
def prepare_inputs(self): def prepare_inputs(self):
self.inputs = [3, 28, 28] self.inputs = [3, 28, 28]
self.ops = [ self.ops = [
'assign_value', 'reshape2', 'conv2d', 'elementwise_add', 'pool2d', 'assign_value', 'reshape2', 'conv2d', 'reshape2', 'elementwise_add',
'reshape2', 'matmul_v2', 'elementwise_add' 'pool2d', 'reshape2', 'matmul_v2', 'elementwise_add'
] ]
def prepare_layer(self): def prepare_layer(self):
...@@ -51,8 +51,8 @@ class TestEagerDygraph2Program2(TestEagerDygraph2Program): ...@@ -51,8 +51,8 @@ class TestEagerDygraph2Program2(TestEagerDygraph2Program):
def prepare_inputs(self): def prepare_inputs(self):
self.inputs = [[3, 28, 28]] self.inputs = [[3, 28, 28]]
self.ops = [ self.ops = [
'assign_value', 'reshape2', 'conv2d', 'elementwise_add', 'pool2d', 'assign_value', 'reshape2', 'conv2d', 'reshape2', 'elementwise_add',
'reshape2', 'matmul_v2', 'elementwise_add' 'pool2d', 'reshape2', 'matmul_v2', 'elementwise_add'
] ]
...@@ -60,8 +60,8 @@ class TestEagerDygraph2Program3(TestEagerDygraph2Program): ...@@ -60,8 +60,8 @@ class TestEagerDygraph2Program3(TestEagerDygraph2Program):
def prepare_inputs(self): def prepare_inputs(self):
self.inputs = paddle.randn([3, 28, 28]) self.inputs = paddle.randn([3, 28, 28])
self.ops = [ self.ops = [
'reshape2', 'conv2d', 'elementwise_add', 'pool2d', 'reshape2', 'reshape2', 'conv2d', 'reshape2', 'elementwise_add', 'pool2d',
'matmul_v2', 'elementwise_add' 'reshape2', 'matmul_v2', 'elementwise_add'
] ]
...@@ -69,8 +69,8 @@ class TestEagerDygraph2Program4(TestEagerDygraph2Program): ...@@ -69,8 +69,8 @@ class TestEagerDygraph2Program4(TestEagerDygraph2Program):
def prepare_inputs(self): def prepare_inputs(self):
self.inputs = [paddle.randn([3, 28, 28])] self.inputs = [paddle.randn([3, 28, 28])]
self.ops = [ self.ops = [
'reshape2', 'conv2d', 'elementwise_add', 'pool2d', 'reshape2', 'reshape2', 'conv2d', 'reshape2', 'elementwise_add', 'pool2d',
'matmul_v2', 'elementwise_add' 'reshape2', 'matmul_v2', 'elementwise_add'
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册