未验证 提交 5b0346c5 编写于 作者: M minghaoBD 提交者: GitHub

Support FC (#1551)

* support fc in pruning

* support matmul and matmul_v2 pruning

* support ffn pruning with gelu activations in the middle

* support reshape2, transpose and split in transformer block

* support pattern: qkv gemm -> batched gemm -> out linear

* prune all fc layer

* support fc pruning

* almost done with few hardcode

* remove hardcode

* fix UT

* fix UT

* avoid setting attributes for reshape in fc pruning

* fix UT

* fix UT
上级 78efe1d9
...@@ -13,6 +13,7 @@ from paddleslim.analysis import flops ...@@ -13,6 +13,7 @@ from paddleslim.analysis import flops
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir) sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
import models import models
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
...@@ -37,6 +38,12 @@ def eval(args): ...@@ -37,6 +38,12 @@ def eval(args):
val_dataset = paddle.vision.datasets.MNIST(mode='test') val_dataset = paddle.vision.datasets.MNIST(mode='test')
class_dim = 10 class_dim = 10
image_shape = "1,28,28" image_shape = "1,28,28"
elif args.data == "cifar10":
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
val_dataset = paddle.vision.datasets.Cifar10(
mode="test", backend="cv2", transform=transform)
class_dim = 10
image_shape = "3, 32, 32"
elif args.data == "imagenet": elif args.data == "imagenet":
import imagenet_reader as reader import imagenet_reader as reader
val_dataset = reader.ImageNetDataset(mode='val') val_dataset = reader.ImageNetDataset(mode='val')
......
...@@ -130,6 +130,6 @@ def dygraph_flops(model, inputs, dtypes=None, only_conv=False, detail=False): ...@@ -130,6 +130,6 @@ def dygraph_flops(model, inputs, dtypes=None, only_conv=False, detail=False):
detail(bool): Whether to return detail of each convolution layer. detail(bool): Whether to return detail of each convolution layer.
""" """
program = dygraph2program(model, inputs) program = dygraph2program(model, inputs, dtypes=dtypes)
graph = GraphWrapper(program) graph = GraphWrapper(program)
return _graph_flops(graph, only_conv=only_conv, detail=detail) return _graph_flops(graph, only_conv=only_conv, detail=detail)
...@@ -58,14 +58,29 @@ class FilterPruner(Pruner): ...@@ -58,14 +58,29 @@ class FilterPruner(Pruner):
""" """
def __init__(self, model, inputs, sen_file=None, opt=None, def __init__(self,
skip_leaves=True): model,
inputs,
sen_file=None,
opt=None,
skip_leaves=True,
prune_type='conv',
num_head=-1,
input_dtype='float32'):
super(FilterPruner, self).__init__(model, inputs, opt=opt) super(FilterPruner, self).__init__(model, inputs, opt=opt)
self.num_head = num_head
self.prune_type = prune_type
self._status = Status(sen_file) self._status = Status(sen_file)
self.skip_leaves = skip_leaves self.skip_leaves = skip_leaves
# sensitive and collections are just used in filter pruning # sensitive and collections are just used in filter pruning
self.collections = DygraphPruningCollections( self.collections = DygraphPruningCollections(
model, inputs, skip_leaves=self.skip_leaves) model,
inputs,
skip_leaves=self.skip_leaves,
prune_type=prune_type,
input_dtype=input_dtype)
# skip vars in: # skip vars in:
# 1. depthwise conv2d layer # 1. depthwise conv2d layer
...@@ -294,7 +309,7 @@ class FilterPruner(Pruner): ...@@ -294,7 +309,7 @@ class FilterPruner(Pruner):
if self.plan is not None: if self.plan is not None:
self.plan.restore(self.model, opt=self.opt) self.plan.restore(self.model, opt=self.opt)
def cal_mask(self, pruned_ratio, collection): def cal_mask(self, pruned_ratio, collection, num_head=-1):
raise NotImplemented("cal_mask is not implemented") raise NotImplemented("cal_mask is not implemented")
def prune_var(self, var_name, pruned_axis, pruned_ratio, apply="impretive"): def prune_var(self, var_name, pruned_axis, pruned_ratio, apply="impretive"):
...@@ -332,13 +347,12 @@ class FilterPruner(Pruner): ...@@ -332,13 +347,12 @@ class FilterPruner(Pruner):
f"Pruning variable [{var_name}] and its relatives {list(collection.variables())}" f"Pruning variable [{var_name}] and its relatives {list(collection.variables())}"
) )
mask = self.cal_mask(pruned_ratio, collection) mask = self.cal_mask(pruned_ratio, collection, num_head=self.num_head)
for _detail in collection.all_pruning_details(): for _detail in collection.all_pruning_details():
# Varibales can be pruned on multiple axies. # Varibales can be pruned on multiple axies.
src_mask = copy.deepcopy(mask) src_mask = copy.deepcopy(mask)
var_shape = _detail.var.shape() var_shape = _detail.var.shape()
for tran in _detail.transform: for tran in _detail.transform:
src_mask = self._transform_mask(src_mask, tran) src_mask = self._transform_mask(src_mask, tran)
current_mask = src_mask current_mask = src_mask
groups = _detail.op.attr('groups') groups = _detail.op.attr('groups')
...@@ -352,7 +366,11 @@ class FilterPruner(Pruner): ...@@ -352,7 +366,11 @@ class FilterPruner(Pruner):
if apply == "lazy": if apply == "lazy":
plan.apply(self.model, lazy=True) plan.apply(self.model, lazy=True)
elif apply == "impretive": elif apply == "impretive":
plan.apply(self.model, lazy=False, opt=self.opt) plan.apply(
self.model,
lazy=False,
opt=self.opt,
prune_type=self.prune_type)
return plan return plan
def _transform_mask(self, mask, transform): def _transform_mask(self, mask, transform):
...@@ -372,6 +390,17 @@ class FilterPruner(Pruner): ...@@ -372,6 +390,17 @@ class FilterPruner(Pruner):
return mask return mask
elif "repeat" in transform and "tile" in transform: elif "repeat" in transform and "tile" in transform:
return np.tile(mask.repeat(transform["repeat"]), transform["tile"]) return np.tile(mask.repeat(transform["repeat"]), transform["tile"])
elif "squeeze" in transform:
squeeze = transform['squeeze']
tmp = mask.reshape((-1, squeeze))
tmp = np.max(tmp, axis=1)
tmp[tmp > 0] = 1
return tmp
elif "repeat" in transform:
repeat = transform['repeat']
tmp = np.repeat(mask[:, np.newaxis], repeat, axis=1)
return tmp.flatten()
else: else:
return mask return mask
return dst_mask return dst_mask
...@@ -17,7 +17,7 @@ class FPGMFilterPruner(FilterPruner): ...@@ -17,7 +17,7 @@ class FPGMFilterPruner(FilterPruner):
super(FPGMFilterPruner, self).__init__( super(FPGMFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves) model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves)
def cal_mask(self, pruned_ratio, collection): def cal_mask(self, pruned_ratio, collection, num_head=-1):
var_name = collection.master_name var_name = collection.master_name
pruned_axis = collection.master_axis pruned_axis = collection.master_axis
value = collection.values[var_name] value = collection.values[var_name]
......
...@@ -12,15 +12,42 @@ _logger = get_logger(__name__, logging.INFO) ...@@ -12,15 +12,42 @@ _logger = get_logger(__name__, logging.INFO)
class L1NormFilterPruner(FilterPruner): class L1NormFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None, opt=None, def __init__(self,
skip_leaves=True): model,
inputs,
sen_file=None,
opt=None,
skip_leaves=True,
prune_type='conv',
input_dtype='float32',
num_head=-1):
super(L1NormFilterPruner, self).__init__( super(L1NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves) model,
inputs,
sen_file=sen_file,
opt=opt,
skip_leaves=skip_leaves,
prune_type=prune_type,
input_dtype=input_dtype,
num_head=num_head)
def cal_mask(self, pruned_ratio, collection): def cal_mask(self, pruned_ratio, collection, num_head=-1):
var_name = collection.master_name var_name = collection.master_name
pruned_axis = collection.master_axis pruned_axis = collection.master_axis
value = collection.values[var_name] value = collection.values[var_name]
if (value.shape[1] == 3 * value.shape[0] or
value.shape[1] == value.shape[0]) and num_head != -1:
num_head = num_head
assert value.size % (
value.shape[0] * num_head
) == 0, "weight shape must be divisible by num_head."
_logger.debug(
"fused-qkv or query/key/value weight detected, we will prune on num_head: {} -> {}."
.format(num_head, int(num_head * (1 - pruned_ratio))))
else:
num_head = -1
groups = 1 groups = 1
for _detail in collection.all_pruning_details(): for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int)) assert (isinstance(_detail.axis, int))
...@@ -30,6 +57,10 @@ class L1NormFilterPruner(FilterPruner): ...@@ -30,6 +57,10 @@ class L1NormFilterPruner(FilterPruner):
groups = _groups groups = _groups
break break
if num_head != -1:
k = int(value.size / value.shape[0] / num_head)
value = value.reshape((value.shape[0], num_head, k))
reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis] reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
l1norm = np.mean(np.abs(value), axis=tuple(reduce_dims)) l1norm = np.mean(np.abs(value), axis=tuple(reduce_dims))
if groups > 1: if groups > 1:
...@@ -46,4 +77,9 @@ class L1NormFilterPruner(FilterPruner): ...@@ -46,4 +77,9 @@ class L1NormFilterPruner(FilterPruner):
if groups > 1: if groups > 1:
mask = mask.reshape([groups, -1]) mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0 mask[pruned_idx] = 0
if num_head != -1:
mask = np.repeat(mask[:, np.newaxis], k, axis=1)
return mask.flatten()
return mask.reshape(mask_shape) return mask.reshape(mask_shape)
...@@ -12,15 +12,42 @@ _logger = get_logger(__name__, logging.INFO) ...@@ -12,15 +12,42 @@ _logger = get_logger(__name__, logging.INFO)
class L2NormFilterPruner(FilterPruner): class L2NormFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None, opt=None, def __init__(self,
skip_leaves=True): model,
inputs,
sen_file=None,
opt=None,
skip_leaves=True,
prune_type='conv',
input_dtype="float32",
num_head=-1):
super(L2NormFilterPruner, self).__init__( super(L2NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves) model,
inputs,
sen_file=sen_file,
opt=opt,
skip_leaves=skip_leaves,
prune_type=prune_type,
input_dtype=input_dtype,
num_head=num_head)
def cal_mask(self, pruned_ratio, collection): def cal_mask(self, pruned_ratio, collection, num_head=-1):
var_name = collection.master_name var_name = collection.master_name
pruned_axis = collection.master_axis pruned_axis = collection.master_axis
value = collection.values[var_name] value = collection.values[var_name]
if (value.shape[1] == 3 * value.shape[0] or
value.shape[1] == value.shape[0]) and num_head != -1:
num_head = num_head
assert value.size % (
value.shape[0] * num_head
) == 0, "weight shape must be divisible by num_head"
_logger.debug(
"fused-qkv or query/key/value weight detected, we will prune on num_head: {} -> {}"
.format(num_head, int(num_head * (1 - pruned_ratio))))
else:
num_head = -1
groups = 1 groups = 1
for _detail in collection.all_pruning_details(): for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int)) assert (isinstance(_detail.axis, int))
...@@ -30,6 +57,10 @@ class L2NormFilterPruner(FilterPruner): ...@@ -30,6 +57,10 @@ class L2NormFilterPruner(FilterPruner):
groups = _groups groups = _groups
break break
if num_head != -1:
k = int(value.size / value.shape[0] / num_head)
value = value.reshape((value.shape[0], num_head, k))
reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis] reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims))) scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims)))
if groups > 1: if groups > 1:
...@@ -46,4 +77,9 @@ class L2NormFilterPruner(FilterPruner): ...@@ -46,4 +77,9 @@ class L2NormFilterPruner(FilterPruner):
if groups > 1: if groups > 1:
mask = mask.reshape([groups, -1]) mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0 mask[pruned_idx] = 0
if num_head != -1:
mask = np.repeat(mask[:, np.newaxis], k, axis=1)
return mask.flatten()
return mask.reshape(mask_shape) return mask.reshape(mask_shape)
...@@ -19,7 +19,7 @@ class Pruner(object): ...@@ -19,7 +19,7 @@ class Pruner(object):
opt(paddle.optimizer.Optimizer): The model's optimizer. Default: None. opt(paddle.optimizer.Optimizer): The model's optimizer. Default: None.
""" """
def __init__(self, model, inputs, opt=None): def __init__(self, model, inputs, opt=None, prune_type='conv'):
self.model = model self.model = model
self.inputs = inputs self.inputs = inputs
self._var_shapes = {} self._var_shapes = {}
...@@ -27,6 +27,7 @@ class Pruner(object): ...@@ -27,6 +27,7 @@ class Pruner(object):
self._var_shapes[var.name] = var.shape self._var_shapes[var.name] = var.shape
self.plan = None self.plan = None
self.opt = opt self.opt = opt
self.prune_type = prune_type
def status(self, data=None, eval_func=None, status_file=None): def status(self, data=None, eval_func=None, status_file=None):
raise NotImplemented("status is not implemented") raise NotImplemented("status is not implemented")
...@@ -54,6 +55,10 @@ class Pruner(object): ...@@ -54,6 +55,10 @@ class Pruner(object):
if apply == "lazy": if apply == "lazy":
global_plan.apply(self.model, lazy=True) global_plan.apply(self.model, lazy=True)
elif apply == "impretive": elif apply == "impretive":
global_plan.apply(self.model, lazy=False, opt=self.opt) global_plan.apply(
self.model,
lazy=False,
opt=self.opt,
prune_type=self.prune_type)
self.plan = global_plan self.plan = global_plan
return global_plan return global_plan
...@@ -102,7 +102,6 @@ class PruningPlan(): ...@@ -102,7 +102,6 @@ class PruningPlan():
var_tmp = v.get(param_name) var_tmp = v.get(param_name)
#NOTE: var_tmp.shape == [1] is used to skip variables like beta1_pow_acc in Adam optimizer. Its shape is [1] and there's no need to prune this one-value variable. #NOTE: var_tmp.shape == [1] is used to skip variables like beta1_pow_acc in Adam optimizer. Its shape is [1] and there's no need to prune this one-value variable.
if var_tmp is None or var_tmp.shape == [1]: if var_tmp is None or var_tmp.shape == [1]:
if var_tmp is not None: print(var_tmp.name, var_tmp.shape)
continue continue
t_value = var_tmp.value().get_tensor() t_value = var_tmp.value().get_tensor()
value = np.array(t_value).astype("float32") value = np.array(t_value).astype("float32")
...@@ -162,11 +161,11 @@ class PruningPlan(): ...@@ -162,11 +161,11 @@ class PruningPlan():
t_value.set(np.array(t_backup).astype("float32"), place) t_value.set(np.array(t_backup).astype("float32"), place)
del sub_layer._buffers[backup_name] del sub_layer._buffers[backup_name]
def apply(self, model, lazy=False, opt=None): def apply(self, model, lazy=False, opt=None, prune_type='conv'):
if lazy: if lazy:
self.lazy_apply(model) self.lazy_apply(model)
else: else:
self.imperative_apply(model, opt) self.imperative_apply(model, opt, prune_type=prune_type)
def lazy_apply(self, model): def lazy_apply(self, model):
for name, sub_layer in model.named_sublayers(): for name, sub_layer in model.named_sublayers():
...@@ -203,12 +202,11 @@ class PruningPlan(): ...@@ -203,12 +202,11 @@ class PruningPlan():
t_value.set(value * expand_mask, place) t_value.set(value * expand_mask, place)
def imperative_apply(self, model, opt=None): def imperative_apply(self, model, opt=None, prune_type='conv'):
""" """
Pruning values of variable imperatively. It is valid when pruning Pruning values of variable imperatively. It is valid when pruning
on one dimension. on one dimension.
""" """
for name, sub_layer in model.named_sublayers(include_self=True): for name, sub_layer in model.named_sublayers(include_self=True):
for param in sub_layer.parameters(include_sublayers=False): for param in sub_layer.parameters(include_sublayers=False):
if param.name in self._masks: if param.name in self._masks:
...@@ -240,11 +238,9 @@ class PruningPlan(): ...@@ -240,11 +238,9 @@ class PruningPlan():
format(param.name)) format(param.name))
# save optimizer accumulators into layer buffer # save optimizer accumulators into layer buffer
self._buffer_opt(param.name, sub_layer, opt) self._buffer_opt(param.name, sub_layer, opt)
pruned_value = np.apply_along_axis( pruned_value = np.apply_along_axis(
lambda data: data[bool_mask], dims, value) lambda data: data[bool_mask], dims, value)
self._prune_opt(param.name, dims, bool_mask, opt) self._prune_opt(param.name, dims, bool_mask, opt)
p = t_value._place() p = t_value._place()
if p.is_cpu_place(): if p.is_cpu_place():
place = paddle.CPUPlace() place = paddle.CPUPlace()
......
...@@ -20,17 +20,38 @@ class DygraphPruningCollections(PruningCollections): ...@@ -20,17 +20,38 @@ class DygraphPruningCollections(PruningCollections):
- skip_leaves(bool): Whether to skip the last convolution layers. - skip_leaves(bool): Whether to skip the last convolution layers.
""" """
def __init__(self, model, inputs, skip_leaves=True): def __init__(self,
model,
inputs,
skip_leaves=True,
prune_type='conv',
input_dtype='float32'):
assert prune_type in ['conv', 'fc'
], "Please select conv or fc as your prune type."
_logger.debug("Parsing model with input: {}".format(inputs)) _logger.debug("Parsing model with input: {}".format(inputs))
# model can be in training mode, because some model contains auxiliary parameters for training. # model can be in training mode, because some model contains auxiliary parameters for training.
program = dygraph2program(model, inputs=inputs) #TODO(minghaoBD): support dictionary input
if isinstance(inputs[0], int):
dtypes = [input_dtype]
elif isinstance(inputs[0], list):
dtypes = [input_dtype] * len(inputs)
else:
dtypes = [input_dtype]
program = dygraph2program(model, inputs=inputs, dtypes=dtypes)
graph = GraphWrapper(program) graph = GraphWrapper(program)
if prune_type == 'conv':
params = [ params = [
_param.name for _param in model.parameters() _param.name for _param in model.parameters()
if len(_param.shape) == 4 if len(_param.shape) == 4
] ]
elif prune_type == 'fc':
params = [
_param.name for _param in model.parameters()
if len(_param.shape) == 2
]
self._collections = self.create_pruning_collections( self._collections = self.create_pruning_collections(
params, graph, skip_leaves=skip_leaves) params, graph, skip_leaves=skip_leaves, prune_type=prune_type)
_logger.info("Found {} collections.".format(len(self._collections))) _logger.info("Found {} collections.".format(len(self._collections)))
_name2values = {} _name2values = {}
......
...@@ -142,7 +142,8 @@ class PruningCollections(object): ...@@ -142,7 +142,8 @@ class PruningCollections(object):
graph, graph,
skip_stranger=True, skip_stranger=True,
skip_vars=None, skip_vars=None,
skip_leaves=True): skip_leaves=True,
prune_type='conv'):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation. """Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on. A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
...@@ -186,6 +187,12 @@ class PruningCollections(object): ...@@ -186,6 +187,12 @@ class PruningCollections(object):
visited = {} visited = {}
collections = [] collections = []
unsupported_warnings = set() unsupported_warnings = set()
if prune_type == 'conv':
prune_axis = 0
elif prune_type == 'fc':
prune_axis = 1
for _param in params: for _param in params:
pruned_params = [] pruned_params = []
param = graph.var(_param) param = graph.var(_param)
...@@ -216,7 +223,7 @@ class PruningCollections(object): ...@@ -216,7 +223,7 @@ class PruningCollections(object):
worker.skip_vars = skip_vars worker.skip_vars = skip_vars
try: try:
visited_backup = copy.deepcopy(worker.visited) visited_backup = copy.deepcopy(worker.visited)
worker.prune(param, pruned_axis=0, pruned_idx=[]) worker.prune(param, pruned_axis=prune_axis, transforms=[])
except UnsupportOpError as e: except UnsupportOpError as e:
visited.clear() visited.clear()
visited.update(visited_backup) visited.update(visited_backup)
...@@ -225,15 +232,16 @@ class PruningCollections(object): ...@@ -225,15 +232,16 @@ class PruningCollections(object):
if len(pruned_params) != 0: if len(pruned_params) != 0:
collection = PruningCollection(master=({ collection = PruningCollection(master=({
"name": param.name(), "name": param.name(),
"axis": 0 "axis": prune_axis,
})) }))
for _param, _axis, _transform, _op in pruned_params: for _param, _axis, _transform, _op in pruned_params:
collection.add( tmp = PruningDetails(_param, _axis, _transform, _op)
PruningDetails(_param, _axis, _transform, _op)) collection.add(tmp)
collections.append(collection) collections.append(collection)
for warn in unsupported_warnings: for warn in unsupported_warnings:
_logger.warning(warn) _logger.warning(warn)
self._collections = collections self._collections = collections
return self._collections return self._collections
......
此差异已折叠。
...@@ -188,6 +188,13 @@ class Pruner(): ...@@ -188,6 +188,13 @@ class Pruner():
for idx in src: for idx in src:
idx = idx * repeat idx = idx * repeat
target.extend(range(idx, idx + repeat)) target.extend(range(idx, idx + repeat))
elif "squeeze" in trans:
repeat = trans['repeat']
targets_set = set()
for idx in src:
targets_set.add(idx / repeat)
target = list(targets_set)
src = target src = target
ret.append((name, axis, src)) ret.append((name, axis, src))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册