未验证 提交 5b0346c5 编写于 作者: M minghaoBD 提交者: GitHub

Support FC (#1551)

* support fc in pruning

* support matmul and matmul_v2 pruning

* support ffn pruning with gelu activations in the middle

* support reshape2, transpose and split in transformer block

* support pattern: qkv gemm -> batched gemm -> out linear

* prune all fc layer

* support fc pruning

* almost done with few hardcode

* remove hardcode

* fix UT

* fix UT

* avoid setting attributes for reshape in fc pruning

* fix UT

* fix UT
上级 78efe1d9
......@@ -13,6 +13,7 @@ from paddleslim.analysis import flops
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
import models
from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
_logger = get_logger(__name__, level=logging.INFO)
......@@ -37,6 +38,12 @@ def eval(args):
val_dataset = paddle.vision.datasets.MNIST(mode='test')
class_dim = 10
image_shape = "1,28,28"
elif args.data == "cifar10":
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
val_dataset = paddle.vision.datasets.Cifar10(
mode="test", backend="cv2", transform=transform)
class_dim = 10
image_shape = "3, 32, 32"
elif args.data == "imagenet":
import imagenet_reader as reader
val_dataset = reader.ImageNetDataset(mode='val')
......
......@@ -130,6 +130,6 @@ def dygraph_flops(model, inputs, dtypes=None, only_conv=False, detail=False):
detail(bool): Whether to return detail of each convolution layer.
"""
program = dygraph2program(model, inputs)
program = dygraph2program(model, inputs, dtypes=dtypes)
graph = GraphWrapper(program)
return _graph_flops(graph, only_conv=only_conv, detail=detail)
......@@ -58,14 +58,29 @@ class FilterPruner(Pruner):
"""
def __init__(self, model, inputs, sen_file=None, opt=None,
skip_leaves=True):
def __init__(self,
model,
inputs,
sen_file=None,
opt=None,
skip_leaves=True,
prune_type='conv',
num_head=-1,
input_dtype='float32'):
super(FilterPruner, self).__init__(model, inputs, opt=opt)
self.num_head = num_head
self.prune_type = prune_type
self._status = Status(sen_file)
self.skip_leaves = skip_leaves
# sensitive and collections are just used in filter pruning
self.collections = DygraphPruningCollections(
model, inputs, skip_leaves=self.skip_leaves)
model,
inputs,
skip_leaves=self.skip_leaves,
prune_type=prune_type,
input_dtype=input_dtype)
# skip vars in:
# 1. depthwise conv2d layer
......@@ -294,7 +309,7 @@ class FilterPruner(Pruner):
if self.plan is not None:
self.plan.restore(self.model, opt=self.opt)
def cal_mask(self, pruned_ratio, collection):
def cal_mask(self, pruned_ratio, collection, num_head=-1):
raise NotImplemented("cal_mask is not implemented")
def prune_var(self, var_name, pruned_axis, pruned_ratio, apply="impretive"):
......@@ -332,13 +347,12 @@ class FilterPruner(Pruner):
f"Pruning variable [{var_name}] and its relatives {list(collection.variables())}"
)
mask = self.cal_mask(pruned_ratio, collection)
mask = self.cal_mask(pruned_ratio, collection, num_head=self.num_head)
for _detail in collection.all_pruning_details():
# Varibales can be pruned on multiple axies.
src_mask = copy.deepcopy(mask)
var_shape = _detail.var.shape()
for tran in _detail.transform:
src_mask = self._transform_mask(src_mask, tran)
current_mask = src_mask
groups = _detail.op.attr('groups')
......@@ -352,7 +366,11 @@ class FilterPruner(Pruner):
if apply == "lazy":
plan.apply(self.model, lazy=True)
elif apply == "impretive":
plan.apply(self.model, lazy=False, opt=self.opt)
plan.apply(
self.model,
lazy=False,
opt=self.opt,
prune_type=self.prune_type)
return plan
def _transform_mask(self, mask, transform):
......@@ -372,6 +390,17 @@ class FilterPruner(Pruner):
return mask
elif "repeat" in transform and "tile" in transform:
return np.tile(mask.repeat(transform["repeat"]), transform["tile"])
elif "squeeze" in transform:
squeeze = transform['squeeze']
tmp = mask.reshape((-1, squeeze))
tmp = np.max(tmp, axis=1)
tmp[tmp > 0] = 1
return tmp
elif "repeat" in transform:
repeat = transform['repeat']
tmp = np.repeat(mask[:, np.newaxis], repeat, axis=1)
return tmp.flatten()
else:
return mask
return dst_mask
......@@ -17,7 +17,7 @@ class FPGMFilterPruner(FilterPruner):
super(FPGMFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves)
def cal_mask(self, pruned_ratio, collection):
def cal_mask(self, pruned_ratio, collection, num_head=-1):
var_name = collection.master_name
pruned_axis = collection.master_axis
value = collection.values[var_name]
......
......@@ -12,15 +12,42 @@ _logger = get_logger(__name__, logging.INFO)
class L1NormFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None, opt=None,
skip_leaves=True):
def __init__(self,
model,
inputs,
sen_file=None,
opt=None,
skip_leaves=True,
prune_type='conv',
input_dtype='float32',
num_head=-1):
super(L1NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves)
model,
inputs,
sen_file=sen_file,
opt=opt,
skip_leaves=skip_leaves,
prune_type=prune_type,
input_dtype=input_dtype,
num_head=num_head)
def cal_mask(self, pruned_ratio, collection):
def cal_mask(self, pruned_ratio, collection, num_head=-1):
var_name = collection.master_name
pruned_axis = collection.master_axis
value = collection.values[var_name]
if (value.shape[1] == 3 * value.shape[0] or
value.shape[1] == value.shape[0]) and num_head != -1:
num_head = num_head
assert value.size % (
value.shape[0] * num_head
) == 0, "weight shape must be divisible by num_head."
_logger.debug(
"fused-qkv or query/key/value weight detected, we will prune on num_head: {} -> {}."
.format(num_head, int(num_head * (1 - pruned_ratio))))
else:
num_head = -1
groups = 1
for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int))
......@@ -30,6 +57,10 @@ class L1NormFilterPruner(FilterPruner):
groups = _groups
break
if num_head != -1:
k = int(value.size / value.shape[0] / num_head)
value = value.reshape((value.shape[0], num_head, k))
reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
l1norm = np.mean(np.abs(value), axis=tuple(reduce_dims))
if groups > 1:
......@@ -46,4 +77,9 @@ class L1NormFilterPruner(FilterPruner):
if groups > 1:
mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0
if num_head != -1:
mask = np.repeat(mask[:, np.newaxis], k, axis=1)
return mask.flatten()
return mask.reshape(mask_shape)
......@@ -12,15 +12,42 @@ _logger = get_logger(__name__, logging.INFO)
class L2NormFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None, opt=None,
skip_leaves=True):
def __init__(self,
model,
inputs,
sen_file=None,
opt=None,
skip_leaves=True,
prune_type='conv',
input_dtype="float32",
num_head=-1):
super(L2NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves)
model,
inputs,
sen_file=sen_file,
opt=opt,
skip_leaves=skip_leaves,
prune_type=prune_type,
input_dtype=input_dtype,
num_head=num_head)
def cal_mask(self, pruned_ratio, collection):
def cal_mask(self, pruned_ratio, collection, num_head=-1):
var_name = collection.master_name
pruned_axis = collection.master_axis
value = collection.values[var_name]
if (value.shape[1] == 3 * value.shape[0] or
value.shape[1] == value.shape[0]) and num_head != -1:
num_head = num_head
assert value.size % (
value.shape[0] * num_head
) == 0, "weight shape must be divisible by num_head"
_logger.debug(
"fused-qkv or query/key/value weight detected, we will prune on num_head: {} -> {}"
.format(num_head, int(num_head * (1 - pruned_ratio))))
else:
num_head = -1
groups = 1
for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int))
......@@ -30,6 +57,10 @@ class L2NormFilterPruner(FilterPruner):
groups = _groups
break
if num_head != -1:
k = int(value.size / value.shape[0] / num_head)
value = value.reshape((value.shape[0], num_head, k))
reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims)))
if groups > 1:
......@@ -46,4 +77,9 @@ class L2NormFilterPruner(FilterPruner):
if groups > 1:
mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0
if num_head != -1:
mask = np.repeat(mask[:, np.newaxis], k, axis=1)
return mask.flatten()
return mask.reshape(mask_shape)
......@@ -19,7 +19,7 @@ class Pruner(object):
opt(paddle.optimizer.Optimizer): The model's optimizer. Default: None.
"""
def __init__(self, model, inputs, opt=None):
def __init__(self, model, inputs, opt=None, prune_type='conv'):
self.model = model
self.inputs = inputs
self._var_shapes = {}
......@@ -27,6 +27,7 @@ class Pruner(object):
self._var_shapes[var.name] = var.shape
self.plan = None
self.opt = opt
self.prune_type = prune_type
def status(self, data=None, eval_func=None, status_file=None):
raise NotImplemented("status is not implemented")
......@@ -54,6 +55,10 @@ class Pruner(object):
if apply == "lazy":
global_plan.apply(self.model, lazy=True)
elif apply == "impretive":
global_plan.apply(self.model, lazy=False, opt=self.opt)
global_plan.apply(
self.model,
lazy=False,
opt=self.opt,
prune_type=self.prune_type)
self.plan = global_plan
return global_plan
......@@ -102,7 +102,6 @@ class PruningPlan():
var_tmp = v.get(param_name)
#NOTE: var_tmp.shape == [1] is used to skip variables like beta1_pow_acc in Adam optimizer. Its shape is [1] and there's no need to prune this one-value variable.
if var_tmp is None or var_tmp.shape == [1]:
if var_tmp is not None: print(var_tmp.name, var_tmp.shape)
continue
t_value = var_tmp.value().get_tensor()
value = np.array(t_value).astype("float32")
......@@ -162,11 +161,11 @@ class PruningPlan():
t_value.set(np.array(t_backup).astype("float32"), place)
del sub_layer._buffers[backup_name]
def apply(self, model, lazy=False, opt=None):
def apply(self, model, lazy=False, opt=None, prune_type='conv'):
if lazy:
self.lazy_apply(model)
else:
self.imperative_apply(model, opt)
self.imperative_apply(model, opt, prune_type=prune_type)
def lazy_apply(self, model):
for name, sub_layer in model.named_sublayers():
......@@ -203,12 +202,11 @@ class PruningPlan():
t_value.set(value * expand_mask, place)
def imperative_apply(self, model, opt=None):
def imperative_apply(self, model, opt=None, prune_type='conv'):
"""
Pruning values of variable imperatively. It is valid when pruning
on one dimension.
"""
for name, sub_layer in model.named_sublayers(include_self=True):
for param in sub_layer.parameters(include_sublayers=False):
if param.name in self._masks:
......@@ -240,11 +238,9 @@ class PruningPlan():
format(param.name))
# save optimizer accumulators into layer buffer
self._buffer_opt(param.name, sub_layer, opt)
pruned_value = np.apply_along_axis(
lambda data: data[bool_mask], dims, value)
self._prune_opt(param.name, dims, bool_mask, opt)
p = t_value._place()
if p.is_cpu_place():
place = paddle.CPUPlace()
......
......@@ -20,17 +20,38 @@ class DygraphPruningCollections(PruningCollections):
- skip_leaves(bool): Whether to skip the last convolution layers.
"""
def __init__(self, model, inputs, skip_leaves=True):
def __init__(self,
model,
inputs,
skip_leaves=True,
prune_type='conv',
input_dtype='float32'):
assert prune_type in ['conv', 'fc'
], "Please select conv or fc as your prune type."
_logger.debug("Parsing model with input: {}".format(inputs))
# model can be in training mode, because some model contains auxiliary parameters for training.
program = dygraph2program(model, inputs=inputs)
#TODO(minghaoBD): support dictionary input
if isinstance(inputs[0], int):
dtypes = [input_dtype]
elif isinstance(inputs[0], list):
dtypes = [input_dtype] * len(inputs)
else:
dtypes = [input_dtype]
program = dygraph2program(model, inputs=inputs, dtypes=dtypes)
graph = GraphWrapper(program)
params = [
_param.name for _param in model.parameters()
if len(_param.shape) == 4
]
if prune_type == 'conv':
params = [
_param.name for _param in model.parameters()
if len(_param.shape) == 4
]
elif prune_type == 'fc':
params = [
_param.name for _param in model.parameters()
if len(_param.shape) == 2
]
self._collections = self.create_pruning_collections(
params, graph, skip_leaves=skip_leaves)
params, graph, skip_leaves=skip_leaves, prune_type=prune_type)
_logger.info("Found {} collections.".format(len(self._collections)))
_name2values = {}
......
......@@ -142,7 +142,8 @@ class PruningCollections(object):
graph,
skip_stranger=True,
skip_vars=None,
skip_leaves=True):
skip_leaves=True,
prune_type='conv'):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
......@@ -186,6 +187,12 @@ class PruningCollections(object):
visited = {}
collections = []
unsupported_warnings = set()
if prune_type == 'conv':
prune_axis = 0
elif prune_type == 'fc':
prune_axis = 1
for _param in params:
pruned_params = []
param = graph.var(_param)
......@@ -216,7 +223,7 @@ class PruningCollections(object):
worker.skip_vars = skip_vars
try:
visited_backup = copy.deepcopy(worker.visited)
worker.prune(param, pruned_axis=0, pruned_idx=[])
worker.prune(param, pruned_axis=prune_axis, transforms=[])
except UnsupportOpError as e:
visited.clear()
visited.update(visited_backup)
......@@ -225,15 +232,16 @@ class PruningCollections(object):
if len(pruned_params) != 0:
collection = PruningCollection(master=({
"name": param.name(),
"axis": 0
"axis": prune_axis,
}))
for _param, _axis, _transform, _op in pruned_params:
collection.add(
PruningDetails(_param, _axis, _transform, _op))
tmp = PruningDetails(_param, _axis, _transform, _op)
collection.add(tmp)
collections.append(collection)
for warn in unsupported_warnings:
_logger.warning(warn)
self._collections = collections
return self._collections
......
此差异已折叠。
......@@ -188,6 +188,13 @@ class Pruner():
for idx in src:
idx = idx * repeat
target.extend(range(idx, idx + repeat))
elif "squeeze" in trans:
repeat = trans['repeat']
targets_set = set()
for idx in src:
targets_set.add(idx / repeat)
target = list(targets_set)
src = target
ret.append((name, axis, src))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册