From 587dd15b55bc5e71eea3b4caf2dfa61545648c11 Mon Sep 17 00:00:00 2001 From: whs Date: Mon, 25 Jan 2021 11:40:11 +0800 Subject: [PATCH] Support for pruning training graph with auxiliary parameters. (#604) (#617) --- demo/dygraph/pruning/eval.py | 4 +- demo/dygraph/pruning/export_model.py | 4 +- demo/dygraph/pruning/train.py | 4 +- paddleslim/dygraph/filter_pruner.py | 59 +++++++++++++++++----------- paddleslim/dygraph/fpgm_pruner.py | 3 -- paddleslim/dygraph/pruning_plan.py | 5 ++- paddleslim/dygraph/var_group.py | 5 +-- paddleslim/prune/group_param.py | 11 ++++-- paddleslim/prune/prune_walker.py | 23 +++-------- paddleslim/prune/pruner.py | 4 +- requirements.txt | 3 ++ tests/dygraph/test_prune_walker.py | 33 ++++++++++++++++ tests/test_deep_mutual_learning.py | 2 +- tests/test_dygraph_pruning_plan.py | 4 +- tests/test_group_param.py | 3 +- 15 files changed, 103 insertions(+), 64 deletions(-) create mode 100644 tests/dygraph/test_prune_walker.py diff --git a/demo/dygraph/pruning/eval.py b/demo/dygraph/pruning/eval.py index dac721d3..37024f35 100644 --- a/demo/dygraph/pruning/eval.py +++ b/demo/dygraph/pruning/eval.py @@ -9,8 +9,8 @@ import functools import math import time import numpy as np -sys.path[0] = os.path.join( - os.path.dirname("__file__"), os.path.pardir, os.path.pardir) +sys.path.append( + os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir)) import paddleslim from paddleslim.common import get_logger from paddleslim.analysis import dygraph_flops as flops diff --git a/demo/dygraph/pruning/export_model.py b/demo/dygraph/pruning/export_model.py index fb9245c6..99f6afdb 100644 --- a/demo/dygraph/pruning/export_model.py +++ b/demo/dygraph/pruning/export_model.py @@ -9,8 +9,8 @@ import functools import math import time import numpy as np -sys.path[0] = os.path.join( - os.path.dirname("__file__"), os.path.pardir, os.path.pardir) +sys.path.append( + os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir)) import paddleslim from paddleslim.common import get_logger import paddle.vision.models as models diff --git a/demo/dygraph/pruning/train.py b/demo/dygraph/pruning/train.py index ed3ccac6..64a93666 100644 --- a/demo/dygraph/pruning/train.py +++ b/demo/dygraph/pruning/train.py @@ -9,8 +9,8 @@ import functools import math import time import numpy as np -sys.path[0] = os.path.join( - os.path.dirname("__file__"), os.path.pardir, os.path.pardir) +sys.path.append( + os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir)) import paddleslim from paddleslim.common import get_logger from paddleslim.analysis import dygraph_flops as flops diff --git a/paddleslim/dygraph/filter_pruner.py b/paddleslim/dygraph/filter_pruner.py index 0cb22d0a..a9c20958 100644 --- a/paddleslim/dygraph/filter_pruner.py +++ b/paddleslim/dygraph/filter_pruner.py @@ -2,6 +2,7 @@ import os import logging import numpy as np import pickle +import copy import paddle from ..common import get_logger from .var_group import * @@ -17,6 +18,7 @@ _logger = get_logger(__name__, logging.INFO) CONV_OP_TYPE = paddle.nn.Conv2D FILTER_DIM = [0] CONV_WEIGHT_NAME = "weight" +SKIP_LAYERS = (paddle.nn.Conv2DTranspose, paddle.nn.layer.conv.Conv2DTranspose) class Status(): @@ -64,9 +66,9 @@ class FilterPruner(Pruner): # 1. depthwise conv2d layer self.skip_vars = [] for sub_layer in model.sublayers(): - if isinstance( - sub_layer, - paddle.nn.layer.conv.Conv2D) and sub_layer._groups > 1: + if isinstance(sub_layer, SKIP_LAYERS) or (isinstance( + sub_layer, paddle.nn.layer.conv.Conv2D) and + sub_layer._groups > 1): for param in sub_layer.parameters(): self.skip_vars.append(param.name) @@ -236,7 +238,7 @@ class FilterPruner(Pruner): target_vars=None, skip_vars=None): sensitivities = self._status.sensitivies - baseline = eval_func() + baseline = None ratios = np.arange(0.1, 1, step=0.1) for group in self.var_group.groups: var_name = group[0][0] @@ -255,6 +257,8 @@ class FilterPruner(Pruner): _logger.debug("{}, {} has computed.".format(var_name, ratio)) continue + if baseline is None: + baseline = eval_func() plan = self.prune_var(var_name, dims, ratio, apply="lazy") pruned_metric = eval_func() loss = (baseline - pruned_metric) / baseline @@ -342,14 +346,15 @@ class FilterPruner(Pruner): for _name in group_dict: # Varibales can be pruned on multiple axies. for _item in group_dict[_name]: + src_mask = copy.deepcopy(mask) dims = _item['pruned_dims'] transforms = _item['transforms'] var_shape = _item['var'].shape if isinstance(dims, int): dims = [dims] for trans in transforms: - mask = self._transform_mask(mask, trans) - current_mask = mask + src_mask = self._transform_mask(src_mask, trans) + current_mask = src_mask assert len(current_mask) == var_shape[dims[ 0]], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; dims: {dims}; var name: {_name}; len(mask): {len(mask)}" plan.add(_name, PruningMask(dims, current_mask, pruned_ratio)) @@ -360,21 +365,29 @@ class FilterPruner(Pruner): return plan def _transform_mask(self, mask, transform): - src_start = transform['src_start'] - src_end = transform['src_end'] - target_start = transform['target_start'] - target_end = transform['target_end'] - target_len = transform['target_len'] - stride = transform['stride'] - mask = mask[src_start:src_end] - mask = mask.repeat(stride) if stride > 1 else mask - - dst_mask = np.ones([target_len]) - # for depthwise conv2d with: - # input shape: (1, 4, 32, 32) - # filter shape: (32, 1, 3, 3) - # groups: 4 - # if we pruning input channels by 50%(from 4 to 2), the output channel should be 50% * 4 * 8. - expand = int((target_end - target_start) / len(mask)) - dst_mask[target_start:target_end] = list(mask) * expand + if "src_start" in transform: + src_start = transform['src_start'] + src_end = transform['src_end'] + target_start = transform['target_start'] + target_end = transform['target_end'] + target_len = transform['target_len'] + stride = transform['stride'] + mask = mask[src_start:src_end] + + mask = mask.repeat(stride) if stride > 1 else mask + + dst_mask = np.ones([target_len]) + # for depthwise conv2d with: + # input shape: (1, 4, 32, 32) + # filter shape: (32, 1, 3, 3) + # groups: 4 + # if we pruning input channels by 50%(from 4 to 2), the output channel should be 50% * 4 * 8. + expand = int((target_end - target_start) / len(mask)) + dst_mask[target_start:target_end] = list(mask) * expand + elif "stride" in transform: + stride = transform['stride'] + mask = mask.repeat(stride) if stride > 1 else mask + return mask + else: + return mask return dst_mask diff --git a/paddleslim/dygraph/fpgm_pruner.py b/paddleslim/dygraph/fpgm_pruner.py index 45bc48f9..1bff3424 100644 --- a/paddleslim/dygraph/fpgm_pruner.py +++ b/paddleslim/dygraph/fpgm_pruner.py @@ -20,9 +20,6 @@ class FPGMFilterPruner(FilterPruner): if _item['pruned_dims'] == [0]: value = _item['value'] pruned_dims = _item['pruned_dims'] - - assert (pruned_dims == [0]) - dist_sum_list = [] for out_i in range(value.shape[0]): dist_sum = self.get_distance_sum(value, out_i) diff --git a/paddleslim/dygraph/pruning_plan.py b/paddleslim/dygraph/pruning_plan.py index 2e35437b..185d0194 100644 --- a/paddleslim/dygraph/pruning_plan.py +++ b/paddleslim/dygraph/pruning_plan.py @@ -77,7 +77,8 @@ class PruningPlan(): for _mask in self._masks[var_name]: if pruning_mask.dims == _mask.dims: _mask.mask = list( - np.array(_mask.mask) | np.array(pruning_mask.mask)) + np.array(_mask.mask).astype(np.int64) & np.array( + pruning_mask.mask).astype(np.int64)) else: self._masks[var_name].append(pruning_mask) self._dims[var_name].append(pruning_mask.dims) @@ -171,7 +172,7 @@ class PruningPlan(): paddle.to_tensor(value)) _logger.debug("Backup values of {} into buffers.". format(param.name)) - bool_mask = mask.astype(bool) + bool_mask = np.array(mask).astype(bool) pruned_value = np.apply_along_axis( lambda data: data[bool_mask], dims[0], value) p = t_value._place() diff --git a/paddleslim/dygraph/var_group.py b/paddleslim/dygraph/var_group.py index 6ee02d7d..894de662 100644 --- a/paddleslim/dygraph/var_group.py +++ b/paddleslim/dygraph/var_group.py @@ -44,12 +44,9 @@ class VarGroup(): def _parse_model(self, model, inputs): _logger.debug("Parsing model with input: {}".format(inputs)) - - model.eval() + # model can be in training mode, because some model contains auxiliary parameters for training. program = dygraph2program(model, inputs=inputs) - graph = GraphWrapper(program) - visited = {} for name, param in model.named_parameters(): group = collect_convs([param.name], graph, diff --git a/paddleslim/prune/group_param.py b/paddleslim/prune/group_param.py index f8f8db2a..9a406e31 100644 --- a/paddleslim/prune/group_param.py +++ b/paddleslim/prune/group_param.py @@ -55,10 +55,15 @@ def collect_convs(params, graph, visited={}): if not isinstance(graph, GraphWrapper): graph = GraphWrapper(graph) groups = [] - for param in params: + for _param in params: pruned_params = [] - param = graph.var(param) - + param = graph.var(_param) + if param is None: + _logger.warning( + f"Cann't found relative variables of {_param} because {_param} is not in target program or model. Please make sure {_param} is in your program if you are using static API of PaddlePaddle. And make sure your model in correctly mode and contains {_param} if you are using dynamic API of PaddlePaddle." + ) + groups.append([]) + continue target_op = param.outputs()[0] if target_op.type() == 'conditional_block': for op in param.outputs(): diff --git a/paddleslim/prune/prune_walker.py b/paddleslim/prune/prune_walker.py index 59b8b5ac..85439a49 100644 --- a/paddleslim/prune/prune_walker.py +++ b/paddleslim/prune/prune_walker.py @@ -344,15 +344,10 @@ class activation(PruneWorker): def _prune(self, var, pruned_axis, pruned_idx): if var in self.op.outputs(self.output_name): in_var = self.op.inputs(self.input_name)[0] - pre_ops = in_var.inputs() - for op in pre_ops: - self._prune_op(op, in_var, pruned_axis, pruned_idx) - - out_var = self.op.outputs(self.output_name)[0] - self._visit(out_var, pruned_axis) - next_ops = out_var.outputs() - for op in next_ops: - self._prune_op(op, out_var, pruned_axis, pruned_idx) + self._visit_and_search(in_var, pruned_axis, pruned_idx) + if var in self.op.inputs(self.input_name): + out_var = self.op.outputs(self.output_name)[0] + self._visit_and_search(out_var, pruned_axis, pruned_idx) @PRUNE_WORKER.register @@ -364,16 +359,10 @@ class default_walker(PruneWorker): if var in self.op.all_outputs(): for in_var in self.op.all_inputs(): if len(in_var.shape()) == len(var.shape()): - pre_ops = in_var.inputs() - for op in pre_ops: - self._prune_op(op, in_var, pruned_axis, pruned_idx) - + self._visit_and_search(in_var, pruned_axis, pruned_idx) for out_var in self.op.all_outputs(): if len(out_var.shape()) == len(var.shape()): - self._visit(out_var, pruned_axis) - next_ops = out_var.outputs() - for op in next_ops: - self._prune_op(op, out_var, pruned_axis, pruned_idx) + self._visit_and_search(out_var, pruned_axis, pruned_idx) @PRUNE_WORKER.register diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 9ebcc5ee..60fd885e 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -98,8 +98,8 @@ class Pruner(): visited)[0] # [(name, axis, pruned_idx)] if group is None or len(group) == 0: continue - assert ((not self.pruned_weights), - "The weights have been pruned once.") + assert ( + not self.pruned_weights), "The weights have been pruned once." group_values = [] for name, axis, pruned_idx in group: var = scope.find_var(name) diff --git a/requirements.txt b/requirements.txt index 8b4fa549..3619953e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,6 @@ #paddlepaddle == 1.6.0rc0 tqdm pyzmq +matplotlib +opencv-python +pillow diff --git a/tests/dygraph/test_prune_walker.py b/tests/dygraph/test_prune_walker.py new file mode 100644 index 00000000..0c18ffc0 --- /dev/null +++ b/tests/dygraph/test_prune_walker.py @@ -0,0 +1,33 @@ +import sys +sys.path.append("../../") +import unittest +import numpy as np +import paddle +from paddleslim.dygraph import L1NormFilterPruner +from paddle.nn import Conv2D, Linear, Layer + + +class Net(Layer): + def __init__(self): + super(Net, self).__init__() + self.conv1 = Conv2D(3, 8, 3) + self.linear = Linear(8 * 30 * 30, 5) + + def forward(self, x): + tmp = self.conv1(x) + tmp = paddle.flatten(tmp, 1) + return self.linear(tmp) + + +class TestWalker(unittest.TestCase): + def runTest(self): + x_shape = (1, 3, 32, 32) + net = Net() + x = np.random.uniform(-1, 1, x_shape).astype('float32') + pruner = L1NormFilterPruner(net, [paddle.to_tensor(x)]) + pruner.prune_vars({"conv2d_0.w_0": 0.2}, [0]) + self.assertTrue(net.linear.weight.shape == [5400, 5]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_deep_mutual_learning.py b/tests/test_deep_mutual_learning.py index 63df036d..e859cbd8 100755 --- a/tests/test_deep_mutual_learning.py +++ b/tests/test_deep_mutual_learning.py @@ -31,7 +31,7 @@ class Model(paddle.nn.Layer): super(Model, self).__init__() self.conv = paddle.nn.Conv2D( in_channels=1, out_channels=256, kernel_size=3, stride=1, padding=1) - self.pool2d_avg = paddle.nn.Pool2D(pool_type='avg', global_pooling=True) + self.pool2d_avg = paddle.nn.AdaptiveAvgPool2D([1, 1]) self.out = paddle.nn.Linear(256, 10) def forward(self, inputs): diff --git a/tests/test_dygraph_pruning_plan.py b/tests/test_dygraph_pruning_plan.py index b4c97afe..88c4d59a 100644 --- a/tests/test_dygraph_pruning_plan.py +++ b/tests/test_dygraph_pruning_plan.py @@ -8,13 +8,13 @@ from paddleslim.dygraph.pruning_plan import PruningPlan, PruningMask class TestPruningPlan(unittest.TestCase): def testAdd(self): plan = PruningPlan() - mask = PruningMask([0], [0, 0, 1], 0.33) + mask = PruningMask([0], [0, 1, 1], 0.33) plan.add("a", mask) mask = PruningMask([0], [0, 1, 0], 0.33) plan.add("a", mask) a_mask = plan.masks["a"] self.assertTrue(len(a_mask) == 1) - self.assertTrue(a_mask[0].mask == [0, 1, 1]) + self.assertTrue(a_mask[0].mask == [0, 1, 0]) self.assertTrue(a_mask[0].dims == [0]) diff --git a/tests/test_group_param.py b/tests/test_group_param.py index 9e1c9f6f..dcdf6eab 100644 --- a/tests/test_group_param.py +++ b/tests/test_group_param.py @@ -42,7 +42,8 @@ class TestPrune(StaticCase): conv5 = conv_bn_layer(sum2, 8, 3, "conv5") conv6 = conv_bn_layer(conv5, 8, 3, "conv6") collected_groups = collect_convs( - ["conv1_weights", "conv2_weights", "conv3_weights"], main_program) + ["conv1_weights", "conv2_weights", "conv3_weights", "dummy"], + main_program) while [] in collected_groups: collected_groups.remove([]) print(collected_groups) -- GitLab