未验证 提交 587dd15b 编写于 作者: W whs 提交者: GitHub

Support for pruning training graph with auxiliary parameters. (#604) (#617)

上级 2a4270b0
...@@ -9,8 +9,8 @@ import functools ...@@ -9,8 +9,8 @@ import functools
import math import math
import time import time
import numpy as np import numpy as np
sys.path[0] = os.path.join( sys.path.append(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir) os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
import paddleslim import paddleslim
from paddleslim.common import get_logger from paddleslim.common import get_logger
from paddleslim.analysis import dygraph_flops as flops from paddleslim.analysis import dygraph_flops as flops
......
...@@ -9,8 +9,8 @@ import functools ...@@ -9,8 +9,8 @@ import functools
import math import math
import time import time
import numpy as np import numpy as np
sys.path[0] = os.path.join( sys.path.append(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir) os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
import paddleslim import paddleslim
from paddleslim.common import get_logger from paddleslim.common import get_logger
import paddle.vision.models as models import paddle.vision.models as models
......
...@@ -9,8 +9,8 @@ import functools ...@@ -9,8 +9,8 @@ import functools
import math import math
import time import time
import numpy as np import numpy as np
sys.path[0] = os.path.join( sys.path.append(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir) os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
import paddleslim import paddleslim
from paddleslim.common import get_logger from paddleslim.common import get_logger
from paddleslim.analysis import dygraph_flops as flops from paddleslim.analysis import dygraph_flops as flops
......
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import logging import logging
import numpy as np import numpy as np
import pickle import pickle
import copy
import paddle import paddle
from ..common import get_logger from ..common import get_logger
from .var_group import * from .var_group import *
...@@ -17,6 +18,7 @@ _logger = get_logger(__name__, logging.INFO) ...@@ -17,6 +18,7 @@ _logger = get_logger(__name__, logging.INFO)
CONV_OP_TYPE = paddle.nn.Conv2D CONV_OP_TYPE = paddle.nn.Conv2D
FILTER_DIM = [0] FILTER_DIM = [0]
CONV_WEIGHT_NAME = "weight" CONV_WEIGHT_NAME = "weight"
SKIP_LAYERS = (paddle.nn.Conv2DTranspose, paddle.nn.layer.conv.Conv2DTranspose)
class Status(): class Status():
...@@ -64,9 +66,9 @@ class FilterPruner(Pruner): ...@@ -64,9 +66,9 @@ class FilterPruner(Pruner):
# 1. depthwise conv2d layer # 1. depthwise conv2d layer
self.skip_vars = [] self.skip_vars = []
for sub_layer in model.sublayers(): for sub_layer in model.sublayers():
if isinstance( if isinstance(sub_layer, SKIP_LAYERS) or (isinstance(
sub_layer, sub_layer, paddle.nn.layer.conv.Conv2D) and
paddle.nn.layer.conv.Conv2D) and sub_layer._groups > 1: sub_layer._groups > 1):
for param in sub_layer.parameters(): for param in sub_layer.parameters():
self.skip_vars.append(param.name) self.skip_vars.append(param.name)
...@@ -236,7 +238,7 @@ class FilterPruner(Pruner): ...@@ -236,7 +238,7 @@ class FilterPruner(Pruner):
target_vars=None, target_vars=None,
skip_vars=None): skip_vars=None):
sensitivities = self._status.sensitivies sensitivities = self._status.sensitivies
baseline = eval_func() baseline = None
ratios = np.arange(0.1, 1, step=0.1) ratios = np.arange(0.1, 1, step=0.1)
for group in self.var_group.groups: for group in self.var_group.groups:
var_name = group[0][0] var_name = group[0][0]
...@@ -255,6 +257,8 @@ class FilterPruner(Pruner): ...@@ -255,6 +257,8 @@ class FilterPruner(Pruner):
_logger.debug("{}, {} has computed.".format(var_name, _logger.debug("{}, {} has computed.".format(var_name,
ratio)) ratio))
continue continue
if baseline is None:
baseline = eval_func()
plan = self.prune_var(var_name, dims, ratio, apply="lazy") plan = self.prune_var(var_name, dims, ratio, apply="lazy")
pruned_metric = eval_func() pruned_metric = eval_func()
loss = (baseline - pruned_metric) / baseline loss = (baseline - pruned_metric) / baseline
...@@ -342,14 +346,15 @@ class FilterPruner(Pruner): ...@@ -342,14 +346,15 @@ class FilterPruner(Pruner):
for _name in group_dict: for _name in group_dict:
# Varibales can be pruned on multiple axies. # Varibales can be pruned on multiple axies.
for _item in group_dict[_name]: for _item in group_dict[_name]:
src_mask = copy.deepcopy(mask)
dims = _item['pruned_dims'] dims = _item['pruned_dims']
transforms = _item['transforms'] transforms = _item['transforms']
var_shape = _item['var'].shape var_shape = _item['var'].shape
if isinstance(dims, int): if isinstance(dims, int):
dims = [dims] dims = [dims]
for trans in transforms: for trans in transforms:
mask = self._transform_mask(mask, trans) src_mask = self._transform_mask(src_mask, trans)
current_mask = mask current_mask = src_mask
assert len(current_mask) == var_shape[dims[ assert len(current_mask) == var_shape[dims[
0]], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; dims: {dims}; var name: {_name}; len(mask): {len(mask)}" 0]], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; dims: {dims}; var name: {_name}; len(mask): {len(mask)}"
plan.add(_name, PruningMask(dims, current_mask, pruned_ratio)) plan.add(_name, PruningMask(dims, current_mask, pruned_ratio))
...@@ -360,21 +365,29 @@ class FilterPruner(Pruner): ...@@ -360,21 +365,29 @@ class FilterPruner(Pruner):
return plan return plan
def _transform_mask(self, mask, transform): def _transform_mask(self, mask, transform):
src_start = transform['src_start'] if "src_start" in transform:
src_end = transform['src_end'] src_start = transform['src_start']
target_start = transform['target_start'] src_end = transform['src_end']
target_end = transform['target_end'] target_start = transform['target_start']
target_len = transform['target_len'] target_end = transform['target_end']
stride = transform['stride'] target_len = transform['target_len']
mask = mask[src_start:src_end] stride = transform['stride']
mask = mask.repeat(stride) if stride > 1 else mask mask = mask[src_start:src_end]
dst_mask = np.ones([target_len]) mask = mask.repeat(stride) if stride > 1 else mask
# for depthwise conv2d with:
# input shape: (1, 4, 32, 32) dst_mask = np.ones([target_len])
# filter shape: (32, 1, 3, 3) # for depthwise conv2d with:
# groups: 4 # input shape: (1, 4, 32, 32)
# if we pruning input channels by 50%(from 4 to 2), the output channel should be 50% * 4 * 8. # filter shape: (32, 1, 3, 3)
expand = int((target_end - target_start) / len(mask)) # groups: 4
dst_mask[target_start:target_end] = list(mask) * expand # if we pruning input channels by 50%(from 4 to 2), the output channel should be 50% * 4 * 8.
expand = int((target_end - target_start) / len(mask))
dst_mask[target_start:target_end] = list(mask) * expand
elif "stride" in transform:
stride = transform['stride']
mask = mask.repeat(stride) if stride > 1 else mask
return mask
else:
return mask
return dst_mask return dst_mask
...@@ -20,9 +20,6 @@ class FPGMFilterPruner(FilterPruner): ...@@ -20,9 +20,6 @@ class FPGMFilterPruner(FilterPruner):
if _item['pruned_dims'] == [0]: if _item['pruned_dims'] == [0]:
value = _item['value'] value = _item['value']
pruned_dims = _item['pruned_dims'] pruned_dims = _item['pruned_dims']
assert (pruned_dims == [0])
dist_sum_list = [] dist_sum_list = []
for out_i in range(value.shape[0]): for out_i in range(value.shape[0]):
dist_sum = self.get_distance_sum(value, out_i) dist_sum = self.get_distance_sum(value, out_i)
......
...@@ -77,7 +77,8 @@ class PruningPlan(): ...@@ -77,7 +77,8 @@ class PruningPlan():
for _mask in self._masks[var_name]: for _mask in self._masks[var_name]:
if pruning_mask.dims == _mask.dims: if pruning_mask.dims == _mask.dims:
_mask.mask = list( _mask.mask = list(
np.array(_mask.mask) | np.array(pruning_mask.mask)) np.array(_mask.mask).astype(np.int64) & np.array(
pruning_mask.mask).astype(np.int64))
else: else:
self._masks[var_name].append(pruning_mask) self._masks[var_name].append(pruning_mask)
self._dims[var_name].append(pruning_mask.dims) self._dims[var_name].append(pruning_mask.dims)
...@@ -171,7 +172,7 @@ class PruningPlan(): ...@@ -171,7 +172,7 @@ class PruningPlan():
paddle.to_tensor(value)) paddle.to_tensor(value))
_logger.debug("Backup values of {} into buffers.". _logger.debug("Backup values of {} into buffers.".
format(param.name)) format(param.name))
bool_mask = mask.astype(bool) bool_mask = np.array(mask).astype(bool)
pruned_value = np.apply_along_axis( pruned_value = np.apply_along_axis(
lambda data: data[bool_mask], dims[0], value) lambda data: data[bool_mask], dims[0], value)
p = t_value._place() p = t_value._place()
......
...@@ -44,12 +44,9 @@ class VarGroup(): ...@@ -44,12 +44,9 @@ class VarGroup():
def _parse_model(self, model, inputs): def _parse_model(self, model, inputs):
_logger.debug("Parsing model with input: {}".format(inputs)) _logger.debug("Parsing model with input: {}".format(inputs))
# model can be in training mode, because some model contains auxiliary parameters for training.
model.eval()
program = dygraph2program(model, inputs=inputs) program = dygraph2program(model, inputs=inputs)
graph = GraphWrapper(program) graph = GraphWrapper(program)
visited = {} visited = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
group = collect_convs([param.name], graph, group = collect_convs([param.name], graph,
......
...@@ -55,10 +55,15 @@ def collect_convs(params, graph, visited={}): ...@@ -55,10 +55,15 @@ def collect_convs(params, graph, visited={}):
if not isinstance(graph, GraphWrapper): if not isinstance(graph, GraphWrapper):
graph = GraphWrapper(graph) graph = GraphWrapper(graph)
groups = [] groups = []
for param in params: for _param in params:
pruned_params = [] pruned_params = []
param = graph.var(param) param = graph.var(_param)
if param is None:
_logger.warning(
f"Cann't found relative variables of {_param} because {_param} is not in target program or model. Please make sure {_param} is in your program if you are using static API of PaddlePaddle. And make sure your model in correctly mode and contains {_param} if you are using dynamic API of PaddlePaddle."
)
groups.append([])
continue
target_op = param.outputs()[0] target_op = param.outputs()[0]
if target_op.type() == 'conditional_block': if target_op.type() == 'conditional_block':
for op in param.outputs(): for op in param.outputs():
......
...@@ -344,15 +344,10 @@ class activation(PruneWorker): ...@@ -344,15 +344,10 @@ class activation(PruneWorker):
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs(self.output_name): if var in self.op.outputs(self.output_name):
in_var = self.op.inputs(self.input_name)[0] in_var = self.op.inputs(self.input_name)[0]
pre_ops = in_var.inputs() self._visit_and_search(in_var, pruned_axis, pruned_idx)
for op in pre_ops: if var in self.op.inputs(self.input_name):
self._prune_op(op, in_var, pruned_axis, pruned_idx) out_var = self.op.outputs(self.output_name)[0]
self._visit_and_search(out_var, pruned_axis, pruned_idx)
out_var = self.op.outputs(self.output_name)[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -364,16 +359,10 @@ class default_walker(PruneWorker): ...@@ -364,16 +359,10 @@ class default_walker(PruneWorker):
if var in self.op.all_outputs(): if var in self.op.all_outputs():
for in_var in self.op.all_inputs(): for in_var in self.op.all_inputs():
if len(in_var.shape()) == len(var.shape()): if len(in_var.shape()) == len(var.shape()):
pre_ops = in_var.inputs() self._visit_and_search(in_var, pruned_axis, pruned_idx)
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
for out_var in self.op.all_outputs(): for out_var in self.op.all_outputs():
if len(out_var.shape()) == len(var.shape()): if len(out_var.shape()) == len(var.shape()):
self._visit(out_var, pruned_axis) self._visit_and_search(out_var, pruned_axis, pruned_idx)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
......
...@@ -98,8 +98,8 @@ class Pruner(): ...@@ -98,8 +98,8 @@ class Pruner():
visited)[0] # [(name, axis, pruned_idx)] visited)[0] # [(name, axis, pruned_idx)]
if group is None or len(group) == 0: if group is None or len(group) == 0:
continue continue
assert ((not self.pruned_weights), assert (
"The weights have been pruned once.") not self.pruned_weights), "The weights have been pruned once."
group_values = [] group_values = []
for name, axis, pruned_idx in group: for name, axis, pruned_idx in group:
var = scope.find_var(name) var = scope.find_var(name)
......
import sys
sys.path.append("../../")
import unittest
import numpy as np
import paddle
from paddleslim.dygraph import L1NormFilterPruner
from paddle.nn import Conv2D, Linear, Layer
class Net(Layer):
def __init__(self):
super(Net, self).__init__()
self.conv1 = Conv2D(3, 8, 3)
self.linear = Linear(8 * 30 * 30, 5)
def forward(self, x):
tmp = self.conv1(x)
tmp = paddle.flatten(tmp, 1)
return self.linear(tmp)
class TestWalker(unittest.TestCase):
def runTest(self):
x_shape = (1, 3, 32, 32)
net = Net()
x = np.random.uniform(-1, 1, x_shape).astype('float32')
pruner = L1NormFilterPruner(net, [paddle.to_tensor(x)])
pruner.prune_vars({"conv2d_0.w_0": 0.2}, [0])
self.assertTrue(net.linear.weight.shape == [5400, 5])
if __name__ == '__main__':
unittest.main()
...@@ -31,7 +31,7 @@ class Model(paddle.nn.Layer): ...@@ -31,7 +31,7 @@ class Model(paddle.nn.Layer):
super(Model, self).__init__() super(Model, self).__init__()
self.conv = paddle.nn.Conv2D( self.conv = paddle.nn.Conv2D(
in_channels=1, out_channels=256, kernel_size=3, stride=1, padding=1) in_channels=1, out_channels=256, kernel_size=3, stride=1, padding=1)
self.pool2d_avg = paddle.nn.Pool2D(pool_type='avg', global_pooling=True) self.pool2d_avg = paddle.nn.AdaptiveAvgPool2D([1, 1])
self.out = paddle.nn.Linear(256, 10) self.out = paddle.nn.Linear(256, 10)
def forward(self, inputs): def forward(self, inputs):
......
...@@ -8,13 +8,13 @@ from paddleslim.dygraph.pruning_plan import PruningPlan, PruningMask ...@@ -8,13 +8,13 @@ from paddleslim.dygraph.pruning_plan import PruningPlan, PruningMask
class TestPruningPlan(unittest.TestCase): class TestPruningPlan(unittest.TestCase):
def testAdd(self): def testAdd(self):
plan = PruningPlan() plan = PruningPlan()
mask = PruningMask([0], [0, 0, 1], 0.33) mask = PruningMask([0], [0, 1, 1], 0.33)
plan.add("a", mask) plan.add("a", mask)
mask = PruningMask([0], [0, 1, 0], 0.33) mask = PruningMask([0], [0, 1, 0], 0.33)
plan.add("a", mask) plan.add("a", mask)
a_mask = plan.masks["a"] a_mask = plan.masks["a"]
self.assertTrue(len(a_mask) == 1) self.assertTrue(len(a_mask) == 1)
self.assertTrue(a_mask[0].mask == [0, 1, 1]) self.assertTrue(a_mask[0].mask == [0, 1, 0])
self.assertTrue(a_mask[0].dims == [0]) self.assertTrue(a_mask[0].dims == [0])
......
...@@ -42,7 +42,8 @@ class TestPrune(StaticCase): ...@@ -42,7 +42,8 @@ class TestPrune(StaticCase):
conv5 = conv_bn_layer(sum2, 8, 3, "conv5") conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6") conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
collected_groups = collect_convs( collected_groups = collect_convs(
["conv1_weights", "conv2_weights", "conv3_weights"], main_program) ["conv1_weights", "conv2_weights", "conv3_weights", "dummy"],
main_program)
while [] in collected_groups: while [] in collected_groups:
collected_groups.remove([]) collected_groups.remove([])
print(collected_groups) print(collected_groups)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册