未验证 提交 587dd15b 编写于 作者: W whs 提交者: GitHub

Support for pruning training graph with auxiliary parameters. (#604) (#617)

上级 2a4270b0
......@@ -9,8 +9,8 @@ import functools
import math
import time
import numpy as np
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
sys.path.append(
os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
import paddleslim
from paddleslim.common import get_logger
from paddleslim.analysis import dygraph_flops as flops
......
......@@ -9,8 +9,8 @@ import functools
import math
import time
import numpy as np
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
sys.path.append(
os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
import paddleslim
from paddleslim.common import get_logger
import paddle.vision.models as models
......
......@@ -9,8 +9,8 @@ import functools
import math
import time
import numpy as np
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
sys.path.append(
os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
import paddleslim
from paddleslim.common import get_logger
from paddleslim.analysis import dygraph_flops as flops
......
......@@ -2,6 +2,7 @@ import os
import logging
import numpy as np
import pickle
import copy
import paddle
from ..common import get_logger
from .var_group import *
......@@ -17,6 +18,7 @@ _logger = get_logger(__name__, logging.INFO)
CONV_OP_TYPE = paddle.nn.Conv2D
FILTER_DIM = [0]
CONV_WEIGHT_NAME = "weight"
SKIP_LAYERS = (paddle.nn.Conv2DTranspose, paddle.nn.layer.conv.Conv2DTranspose)
class Status():
......@@ -64,9 +66,9 @@ class FilterPruner(Pruner):
# 1. depthwise conv2d layer
self.skip_vars = []
for sub_layer in model.sublayers():
if isinstance(
sub_layer,
paddle.nn.layer.conv.Conv2D) and sub_layer._groups > 1:
if isinstance(sub_layer, SKIP_LAYERS) or (isinstance(
sub_layer, paddle.nn.layer.conv.Conv2D) and
sub_layer._groups > 1):
for param in sub_layer.parameters():
self.skip_vars.append(param.name)
......@@ -236,7 +238,7 @@ class FilterPruner(Pruner):
target_vars=None,
skip_vars=None):
sensitivities = self._status.sensitivies
baseline = eval_func()
baseline = None
ratios = np.arange(0.1, 1, step=0.1)
for group in self.var_group.groups:
var_name = group[0][0]
......@@ -255,6 +257,8 @@ class FilterPruner(Pruner):
_logger.debug("{}, {} has computed.".format(var_name,
ratio))
continue
if baseline is None:
baseline = eval_func()
plan = self.prune_var(var_name, dims, ratio, apply="lazy")
pruned_metric = eval_func()
loss = (baseline - pruned_metric) / baseline
......@@ -342,14 +346,15 @@ class FilterPruner(Pruner):
for _name in group_dict:
# Varibales can be pruned on multiple axies.
for _item in group_dict[_name]:
src_mask = copy.deepcopy(mask)
dims = _item['pruned_dims']
transforms = _item['transforms']
var_shape = _item['var'].shape
if isinstance(dims, int):
dims = [dims]
for trans in transforms:
mask = self._transform_mask(mask, trans)
current_mask = mask
src_mask = self._transform_mask(src_mask, trans)
current_mask = src_mask
assert len(current_mask) == var_shape[dims[
0]], f"The length of current_mask must be equal to the size of dimension to be pruned on. But get: len(current_mask): {len(current_mask)}; var_shape: {var_shape}; dims: {dims}; var name: {_name}; len(mask): {len(mask)}"
plan.add(_name, PruningMask(dims, current_mask, pruned_ratio))
......@@ -360,6 +365,7 @@ class FilterPruner(Pruner):
return plan
def _transform_mask(self, mask, transform):
if "src_start" in transform:
src_start = transform['src_start']
src_end = transform['src_end']
target_start = transform['target_start']
......@@ -367,6 +373,7 @@ class FilterPruner(Pruner):
target_len = transform['target_len']
stride = transform['stride']
mask = mask[src_start:src_end]
mask = mask.repeat(stride) if stride > 1 else mask
dst_mask = np.ones([target_len])
......@@ -377,4 +384,10 @@ class FilterPruner(Pruner):
# if we pruning input channels by 50%(from 4 to 2), the output channel should be 50% * 4 * 8.
expand = int((target_end - target_start) / len(mask))
dst_mask[target_start:target_end] = list(mask) * expand
elif "stride" in transform:
stride = transform['stride']
mask = mask.repeat(stride) if stride > 1 else mask
return mask
else:
return mask
return dst_mask
......@@ -20,9 +20,6 @@ class FPGMFilterPruner(FilterPruner):
if _item['pruned_dims'] == [0]:
value = _item['value']
pruned_dims = _item['pruned_dims']
assert (pruned_dims == [0])
dist_sum_list = []
for out_i in range(value.shape[0]):
dist_sum = self.get_distance_sum(value, out_i)
......
......@@ -77,7 +77,8 @@ class PruningPlan():
for _mask in self._masks[var_name]:
if pruning_mask.dims == _mask.dims:
_mask.mask = list(
np.array(_mask.mask) | np.array(pruning_mask.mask))
np.array(_mask.mask).astype(np.int64) & np.array(
pruning_mask.mask).astype(np.int64))
else:
self._masks[var_name].append(pruning_mask)
self._dims[var_name].append(pruning_mask.dims)
......@@ -171,7 +172,7 @@ class PruningPlan():
paddle.to_tensor(value))
_logger.debug("Backup values of {} into buffers.".
format(param.name))
bool_mask = mask.astype(bool)
bool_mask = np.array(mask).astype(bool)
pruned_value = np.apply_along_axis(
lambda data: data[bool_mask], dims[0], value)
p = t_value._place()
......
......@@ -44,12 +44,9 @@ class VarGroup():
def _parse_model(self, model, inputs):
_logger.debug("Parsing model with input: {}".format(inputs))
model.eval()
# model can be in training mode, because some model contains auxiliary parameters for training.
program = dygraph2program(model, inputs=inputs)
graph = GraphWrapper(program)
visited = {}
for name, param in model.named_parameters():
group = collect_convs([param.name], graph,
......
......@@ -55,10 +55,15 @@ def collect_convs(params, graph, visited={}):
if not isinstance(graph, GraphWrapper):
graph = GraphWrapper(graph)
groups = []
for param in params:
for _param in params:
pruned_params = []
param = graph.var(param)
param = graph.var(_param)
if param is None:
_logger.warning(
f"Cann't found relative variables of {_param} because {_param} is not in target program or model. Please make sure {_param} is in your program if you are using static API of PaddlePaddle. And make sure your model in correctly mode and contains {_param} if you are using dynamic API of PaddlePaddle."
)
groups.append([])
continue
target_op = param.outputs()[0]
if target_op.type() == 'conditional_block':
for op in param.outputs():
......
......@@ -344,15 +344,10 @@ class activation(PruneWorker):
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.outputs(self.output_name):
in_var = self.op.inputs(self.input_name)[0]
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, pruned_idx)
if var in self.op.inputs(self.input_name):
out_var = self.op.outputs(self.output_name)[0]
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
self._visit_and_search(out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
......@@ -364,16 +359,10 @@ class default_walker(PruneWorker):
if var in self.op.all_outputs():
for in_var in self.op.all_inputs():
if len(in_var.shape()) == len(var.shape()):
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, pruned_idx)
for out_var in self.op.all_outputs():
if len(out_var.shape()) == len(var.shape()):
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
self._visit_and_search(out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
......
......@@ -98,8 +98,8 @@ class Pruner():
visited)[0] # [(name, axis, pruned_idx)]
if group is None or len(group) == 0:
continue
assert ((not self.pruned_weights),
"The weights have been pruned once.")
assert (
not self.pruned_weights), "The weights have been pruned once."
group_values = []
for name, axis, pruned_idx in group:
var = scope.find_var(name)
......
import sys
sys.path.append("../../")
import unittest
import numpy as np
import paddle
from paddleslim.dygraph import L1NormFilterPruner
from paddle.nn import Conv2D, Linear, Layer
class Net(Layer):
def __init__(self):
super(Net, self).__init__()
self.conv1 = Conv2D(3, 8, 3)
self.linear = Linear(8 * 30 * 30, 5)
def forward(self, x):
tmp = self.conv1(x)
tmp = paddle.flatten(tmp, 1)
return self.linear(tmp)
class TestWalker(unittest.TestCase):
def runTest(self):
x_shape = (1, 3, 32, 32)
net = Net()
x = np.random.uniform(-1, 1, x_shape).astype('float32')
pruner = L1NormFilterPruner(net, [paddle.to_tensor(x)])
pruner.prune_vars({"conv2d_0.w_0": 0.2}, [0])
self.assertTrue(net.linear.weight.shape == [5400, 5])
if __name__ == '__main__':
unittest.main()
......@@ -31,7 +31,7 @@ class Model(paddle.nn.Layer):
super(Model, self).__init__()
self.conv = paddle.nn.Conv2D(
in_channels=1, out_channels=256, kernel_size=3, stride=1, padding=1)
self.pool2d_avg = paddle.nn.Pool2D(pool_type='avg', global_pooling=True)
self.pool2d_avg = paddle.nn.AdaptiveAvgPool2D([1, 1])
self.out = paddle.nn.Linear(256, 10)
def forward(self, inputs):
......
......@@ -8,13 +8,13 @@ from paddleslim.dygraph.pruning_plan import PruningPlan, PruningMask
class TestPruningPlan(unittest.TestCase):
def testAdd(self):
plan = PruningPlan()
mask = PruningMask([0], [0, 0, 1], 0.33)
mask = PruningMask([0], [0, 1, 1], 0.33)
plan.add("a", mask)
mask = PruningMask([0], [0, 1, 0], 0.33)
plan.add("a", mask)
a_mask = plan.masks["a"]
self.assertTrue(len(a_mask) == 1)
self.assertTrue(a_mask[0].mask == [0, 1, 1])
self.assertTrue(a_mask[0].mask == [0, 1, 0])
self.assertTrue(a_mask[0].dims == [0])
......
......@@ -42,7 +42,8 @@ class TestPrune(StaticCase):
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
collected_groups = collect_convs(
["conv1_weights", "conv2_weights", "conv3_weights"], main_program)
["conv1_weights", "conv2_weights", "conv3_weights", "dummy"],
main_program)
while [] in collected_groups:
collected_groups.remove([])
print(collected_groups)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册