“c1349d98aa48060b449c4eea4dfc95a2989ad203”上不存在“paddle/fluid/lite/core/executor.h”
未验证 提交 5b0346c5 编写于 作者: M minghaoBD 提交者: GitHub

Support FC (#1551)

* support fc in pruning

* support matmul and matmul_v2 pruning

* support ffn pruning with gelu activations in the middle

* support reshape2, transpose and split in transformer block

* support pattern: qkv gemm -> batched gemm -> out linear

* prune all fc layer

* support fc pruning

* almost done with few hardcode

* remove hardcode

* fix UT

* fix UT

* avoid setting attributes for reshape in fc pruning

* fix UT

* fix UT
上级 78efe1d9
......@@ -13,6 +13,7 @@ from paddleslim.analysis import flops
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
import models
from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
_logger = get_logger(__name__, level=logging.INFO)
......@@ -37,6 +38,12 @@ def eval(args):
val_dataset = paddle.vision.datasets.MNIST(mode='test')
class_dim = 10
image_shape = "1,28,28"
elif args.data == "cifar10":
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
val_dataset = paddle.vision.datasets.Cifar10(
mode="test", backend="cv2", transform=transform)
class_dim = 10
image_shape = "3, 32, 32"
elif args.data == "imagenet":
import imagenet_reader as reader
val_dataset = reader.ImageNetDataset(mode='val')
......
......@@ -130,6 +130,6 @@ def dygraph_flops(model, inputs, dtypes=None, only_conv=False, detail=False):
detail(bool): Whether to return detail of each convolution layer.
"""
program = dygraph2program(model, inputs)
program = dygraph2program(model, inputs, dtypes=dtypes)
graph = GraphWrapper(program)
return _graph_flops(graph, only_conv=only_conv, detail=detail)
......@@ -58,14 +58,29 @@ class FilterPruner(Pruner):
"""
def __init__(self, model, inputs, sen_file=None, opt=None,
skip_leaves=True):
def __init__(self,
model,
inputs,
sen_file=None,
opt=None,
skip_leaves=True,
prune_type='conv',
num_head=-1,
input_dtype='float32'):
super(FilterPruner, self).__init__(model, inputs, opt=opt)
self.num_head = num_head
self.prune_type = prune_type
self._status = Status(sen_file)
self.skip_leaves = skip_leaves
# sensitive and collections are just used in filter pruning
self.collections = DygraphPruningCollections(
model, inputs, skip_leaves=self.skip_leaves)
model,
inputs,
skip_leaves=self.skip_leaves,
prune_type=prune_type,
input_dtype=input_dtype)
# skip vars in:
# 1. depthwise conv2d layer
......@@ -294,7 +309,7 @@ class FilterPruner(Pruner):
if self.plan is not None:
self.plan.restore(self.model, opt=self.opt)
def cal_mask(self, pruned_ratio, collection):
def cal_mask(self, pruned_ratio, collection, num_head=-1):
raise NotImplemented("cal_mask is not implemented")
def prune_var(self, var_name, pruned_axis, pruned_ratio, apply="impretive"):
......@@ -332,13 +347,12 @@ class FilterPruner(Pruner):
f"Pruning variable [{var_name}] and its relatives {list(collection.variables())}"
)
mask = self.cal_mask(pruned_ratio, collection)
mask = self.cal_mask(pruned_ratio, collection, num_head=self.num_head)
for _detail in collection.all_pruning_details():
# Varibales can be pruned on multiple axies.
src_mask = copy.deepcopy(mask)
var_shape = _detail.var.shape()
for tran in _detail.transform:
src_mask = self._transform_mask(src_mask, tran)
current_mask = src_mask
groups = _detail.op.attr('groups')
......@@ -352,7 +366,11 @@ class FilterPruner(Pruner):
if apply == "lazy":
plan.apply(self.model, lazy=True)
elif apply == "impretive":
plan.apply(self.model, lazy=False, opt=self.opt)
plan.apply(
self.model,
lazy=False,
opt=self.opt,
prune_type=self.prune_type)
return plan
def _transform_mask(self, mask, transform):
......@@ -372,6 +390,17 @@ class FilterPruner(Pruner):
return mask
elif "repeat" in transform and "tile" in transform:
return np.tile(mask.repeat(transform["repeat"]), transform["tile"])
elif "squeeze" in transform:
squeeze = transform['squeeze']
tmp = mask.reshape((-1, squeeze))
tmp = np.max(tmp, axis=1)
tmp[tmp > 0] = 1
return tmp
elif "repeat" in transform:
repeat = transform['repeat']
tmp = np.repeat(mask[:, np.newaxis], repeat, axis=1)
return tmp.flatten()
else:
return mask
return dst_mask
......@@ -17,7 +17,7 @@ class FPGMFilterPruner(FilterPruner):
super(FPGMFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves)
def cal_mask(self, pruned_ratio, collection):
def cal_mask(self, pruned_ratio, collection, num_head=-1):
var_name = collection.master_name
pruned_axis = collection.master_axis
value = collection.values[var_name]
......
......@@ -12,15 +12,42 @@ _logger = get_logger(__name__, logging.INFO)
class L1NormFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None, opt=None,
skip_leaves=True):
def __init__(self,
model,
inputs,
sen_file=None,
opt=None,
skip_leaves=True,
prune_type='conv',
input_dtype='float32',
num_head=-1):
super(L1NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves)
model,
inputs,
sen_file=sen_file,
opt=opt,
skip_leaves=skip_leaves,
prune_type=prune_type,
input_dtype=input_dtype,
num_head=num_head)
def cal_mask(self, pruned_ratio, collection):
def cal_mask(self, pruned_ratio, collection, num_head=-1):
var_name = collection.master_name
pruned_axis = collection.master_axis
value = collection.values[var_name]
if (value.shape[1] == 3 * value.shape[0] or
value.shape[1] == value.shape[0]) and num_head != -1:
num_head = num_head
assert value.size % (
value.shape[0] * num_head
) == 0, "weight shape must be divisible by num_head."
_logger.debug(
"fused-qkv or query/key/value weight detected, we will prune on num_head: {} -> {}."
.format(num_head, int(num_head * (1 - pruned_ratio))))
else:
num_head = -1
groups = 1
for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int))
......@@ -30,6 +57,10 @@ class L1NormFilterPruner(FilterPruner):
groups = _groups
break
if num_head != -1:
k = int(value.size / value.shape[0] / num_head)
value = value.reshape((value.shape[0], num_head, k))
reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
l1norm = np.mean(np.abs(value), axis=tuple(reduce_dims))
if groups > 1:
......@@ -46,4 +77,9 @@ class L1NormFilterPruner(FilterPruner):
if groups > 1:
mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0
if num_head != -1:
mask = np.repeat(mask[:, np.newaxis], k, axis=1)
return mask.flatten()
return mask.reshape(mask_shape)
......@@ -12,15 +12,42 @@ _logger = get_logger(__name__, logging.INFO)
class L2NormFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None, opt=None,
skip_leaves=True):
def __init__(self,
model,
inputs,
sen_file=None,
opt=None,
skip_leaves=True,
prune_type='conv',
input_dtype="float32",
num_head=-1):
super(L2NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves)
model,
inputs,
sen_file=sen_file,
opt=opt,
skip_leaves=skip_leaves,
prune_type=prune_type,
input_dtype=input_dtype,
num_head=num_head)
def cal_mask(self, pruned_ratio, collection):
def cal_mask(self, pruned_ratio, collection, num_head=-1):
var_name = collection.master_name
pruned_axis = collection.master_axis
value = collection.values[var_name]
if (value.shape[1] == 3 * value.shape[0] or
value.shape[1] == value.shape[0]) and num_head != -1:
num_head = num_head
assert value.size % (
value.shape[0] * num_head
) == 0, "weight shape must be divisible by num_head"
_logger.debug(
"fused-qkv or query/key/value weight detected, we will prune on num_head: {} -> {}"
.format(num_head, int(num_head * (1 - pruned_ratio))))
else:
num_head = -1
groups = 1
for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int))
......@@ -30,6 +57,10 @@ class L2NormFilterPruner(FilterPruner):
groups = _groups
break
if num_head != -1:
k = int(value.size / value.shape[0] / num_head)
value = value.reshape((value.shape[0], num_head, k))
reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims)))
if groups > 1:
......@@ -46,4 +77,9 @@ class L2NormFilterPruner(FilterPruner):
if groups > 1:
mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0
if num_head != -1:
mask = np.repeat(mask[:, np.newaxis], k, axis=1)
return mask.flatten()
return mask.reshape(mask_shape)
......@@ -19,7 +19,7 @@ class Pruner(object):
opt(paddle.optimizer.Optimizer): The model's optimizer. Default: None.
"""
def __init__(self, model, inputs, opt=None):
def __init__(self, model, inputs, opt=None, prune_type='conv'):
self.model = model
self.inputs = inputs
self._var_shapes = {}
......@@ -27,6 +27,7 @@ class Pruner(object):
self._var_shapes[var.name] = var.shape
self.plan = None
self.opt = opt
self.prune_type = prune_type
def status(self, data=None, eval_func=None, status_file=None):
raise NotImplemented("status is not implemented")
......@@ -54,6 +55,10 @@ class Pruner(object):
if apply == "lazy":
global_plan.apply(self.model, lazy=True)
elif apply == "impretive":
global_plan.apply(self.model, lazy=False, opt=self.opt)
global_plan.apply(
self.model,
lazy=False,
opt=self.opt,
prune_type=self.prune_type)
self.plan = global_plan
return global_plan
......@@ -102,7 +102,6 @@ class PruningPlan():
var_tmp = v.get(param_name)
#NOTE: var_tmp.shape == [1] is used to skip variables like beta1_pow_acc in Adam optimizer. Its shape is [1] and there's no need to prune this one-value variable.
if var_tmp is None or var_tmp.shape == [1]:
if var_tmp is not None: print(var_tmp.name, var_tmp.shape)
continue
t_value = var_tmp.value().get_tensor()
value = np.array(t_value).astype("float32")
......@@ -162,11 +161,11 @@ class PruningPlan():
t_value.set(np.array(t_backup).astype("float32"), place)
del sub_layer._buffers[backup_name]
def apply(self, model, lazy=False, opt=None):
def apply(self, model, lazy=False, opt=None, prune_type='conv'):
if lazy:
self.lazy_apply(model)
else:
self.imperative_apply(model, opt)
self.imperative_apply(model, opt, prune_type=prune_type)
def lazy_apply(self, model):
for name, sub_layer in model.named_sublayers():
......@@ -203,12 +202,11 @@ class PruningPlan():
t_value.set(value * expand_mask, place)
def imperative_apply(self, model, opt=None):
def imperative_apply(self, model, opt=None, prune_type='conv'):
"""
Pruning values of variable imperatively. It is valid when pruning
on one dimension.
"""
for name, sub_layer in model.named_sublayers(include_self=True):
for param in sub_layer.parameters(include_sublayers=False):
if param.name in self._masks:
......@@ -240,11 +238,9 @@ class PruningPlan():
format(param.name))
# save optimizer accumulators into layer buffer
self._buffer_opt(param.name, sub_layer, opt)
pruned_value = np.apply_along_axis(
lambda data: data[bool_mask], dims, value)
self._prune_opt(param.name, dims, bool_mask, opt)
p = t_value._place()
if p.is_cpu_place():
place = paddle.CPUPlace()
......
......@@ -20,17 +20,38 @@ class DygraphPruningCollections(PruningCollections):
- skip_leaves(bool): Whether to skip the last convolution layers.
"""
def __init__(self, model, inputs, skip_leaves=True):
def __init__(self,
model,
inputs,
skip_leaves=True,
prune_type='conv',
input_dtype='float32'):
assert prune_type in ['conv', 'fc'
], "Please select conv or fc as your prune type."
_logger.debug("Parsing model with input: {}".format(inputs))
# model can be in training mode, because some model contains auxiliary parameters for training.
program = dygraph2program(model, inputs=inputs)
#TODO(minghaoBD): support dictionary input
if isinstance(inputs[0], int):
dtypes = [input_dtype]
elif isinstance(inputs[0], list):
dtypes = [input_dtype] * len(inputs)
else:
dtypes = [input_dtype]
program = dygraph2program(model, inputs=inputs, dtypes=dtypes)
graph = GraphWrapper(program)
params = [
_param.name for _param in model.parameters()
if len(_param.shape) == 4
]
if prune_type == 'conv':
params = [
_param.name for _param in model.parameters()
if len(_param.shape) == 4
]
elif prune_type == 'fc':
params = [
_param.name for _param in model.parameters()
if len(_param.shape) == 2
]
self._collections = self.create_pruning_collections(
params, graph, skip_leaves=skip_leaves)
params, graph, skip_leaves=skip_leaves, prune_type=prune_type)
_logger.info("Found {} collections.".format(len(self._collections)))
_name2values = {}
......
......@@ -142,7 +142,8 @@ class PruningCollections(object):
graph,
skip_stranger=True,
skip_vars=None,
skip_leaves=True):
skip_leaves=True,
prune_type='conv'):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
......@@ -186,6 +187,12 @@ class PruningCollections(object):
visited = {}
collections = []
unsupported_warnings = set()
if prune_type == 'conv':
prune_axis = 0
elif prune_type == 'fc':
prune_axis = 1
for _param in params:
pruned_params = []
param = graph.var(_param)
......@@ -216,7 +223,7 @@ class PruningCollections(object):
worker.skip_vars = skip_vars
try:
visited_backup = copy.deepcopy(worker.visited)
worker.prune(param, pruned_axis=0, pruned_idx=[])
worker.prune(param, pruned_axis=prune_axis, transforms=[])
except UnsupportOpError as e:
visited.clear()
visited.update(visited_backup)
......@@ -225,15 +232,16 @@ class PruningCollections(object):
if len(pruned_params) != 0:
collection = PruningCollection(master=({
"name": param.name(),
"axis": 0
"axis": prune_axis,
}))
for _param, _axis, _transform, _op in pruned_params:
collection.add(
PruningDetails(_param, _axis, _transform, _op))
tmp = PruningDetails(_param, _axis, _transform, _op)
collection.add(tmp)
collections.append(collection)
for warn in unsupported_warnings:
_logger.warning(warn)
self._collections = collections
return self._collections
......
......@@ -41,6 +41,8 @@ OPS_UNCHANGE_SHAPE += [
'dropout',
'cast',
'hard_swish',
'fused_softmax_mask_upper_triangle',
'softmax',
'hard_sigmoid',
]
......@@ -77,20 +79,21 @@ class PruneWorker(object):
).split(",")
self.skip_vars = skip_vars
def prune(self, var, pruned_axis, pruned_idx):
def prune(self, var, pruned_axis, transforms):
"""
Infer the shape of variables related with current operator, predecessor and successor.
It will search the graph to find all varibles related with `var` and record the information of pruning.
Args:
var(Variable): The root variable of searching. It can be the input or output of current operator.
pruned_axis(int): The axis to be pruned of root variable.
pruned_idx(int): The indices to be pruned in `pruned_axis` of root variable.
transforms(list<dict>): The transforms applied the the current variable/mask.
"""
if var.name() in self.skip_vars:
raise UnsupportOpError("Variable {} was skipped.".format(var.name(
)))
if self._visit(var, pruned_axis):
self._prune(var, pruned_axis, pruned_idx)
self._prune(var, pruned_axis, transforms)
def _visit(self, var, pruned_axis):
key = "_".join([str(self.op.idx()), var.name()])
......@@ -115,10 +118,10 @@ class PruneWorker(object):
for op in next_ops:
self._prune_op(op, var, axis, transforms)
def _prune(self, var, pruned_axis, pruned_idx):
def _prune(self, var, pruned_axis, transforms):
raise NotImplementedError('Abstract method.')
def _prune_op(self, op, var, pruned_axis, pruned_idx, visited=None):
def _prune_op(self, op, var, pruned_axis, transforms, visited=None):
if op.type().endswith("_grad"):
return
if visited is not None:
......@@ -137,18 +140,144 @@ class PruneWorker(object):
op.type()))
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}\ntrans: {}".
format(self.op, op, pruned_axis, var.name(), pruned_idx))
format(self.op, op, pruned_axis, var.name(), transforms))
_logger.debug(
f"visit {op.type()} by var [{var.name()}] on axis [{pruned_axis}];\t visited={self.visited}\n"
)
worker = cls(op, self.pruned_params, self.visited, self.skip_stranger)
worker.skip_vars = self.skip_vars
worker.prune(var, pruned_axis, pruned_idx)
worker.prune(var, pruned_axis, transforms)
def append_pruned_vars(self, var, axis, transforms):
self.pruned_params.append((var, axis, transforms, self.op))
@PRUNE_WORKER.register
class reshape2(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(reshape2, self).__init__(op, pruned_params, visited,
skip_stranger)
def _valid_reshape2(self, shape):
# case1: only reshape last several dimensions. e.g. [0,0,1,2] returns True while [1,0,0,1] returns False.
changed = False
for sh in shape:
if sh == 0 and changed:
return False
if sh != 0:
changed = True
return True
#NOTE: we might not need this assertion
def _valid_pruned_axis(self, shape, pruned_axis):
last_zero_index = -1
for i in range(shape):
if shape[i] == 0: last_zero_index = i
if pruned_axis <= last_zero_index:
return pruned_axis
elif pruned_axis > last_zero_index:
return pruned_axis
def _get_idx_after_expanding_dims(self, pruned_axis, transforms,
shorter_shape, longer_shape):
assert len(shorter_shape) < len(
longer_shape
), "length of {} must be smaller than length of {}.".format(
shorter_shape, longer_shape)
dim_old = shorter_shape[pruned_axis]
dim_new = longer_shape[pruned_axis]
k = dim_old / dim_new
assert k == longer_shape[
pruned_axis + 1], "k is {} while longer shape is {}[{}]".format(
k, longer_shape, pruned_axis + 1)
transforms = np.array(transforms, dtype='int32')
pruned_rows = transforms / k
pruned_cols = transforms % k
new_transforms = []
for row in range(dim_new):
prune_this_row = row in pruned_rows
prune_this_row = prune_this_row and (len(prued_rows) == k)
if prune_this_row:
new_transforms.append(row)
return new_transforms
def _get_idx_after_shrinking_dims(self, pruned_axis, transforms,
longer_shape, shorter_shape):
assert len(shorter_shape) < len(
longer_shape
), "length of {} must be smaller than length of {}.".format(
shorter_shape, longer_shape)
dim_old = longer_shape[pruned_axis]
dim_new = shorter_shape[pruned_axis]
k = dim_new / dim_old
assert k == longer_shape[pruned_axis + 1]
new_transforms = []
for row in range(dim_old):
if row in transforms:
new_transforms.expand(
[i for i in range(row * k, (row + 1) * k)])
return new_transforms
def _prune(self, var, pruned_axis, transforms):
in_var = self.op.inputs("X")[0]
out_var = self.op.outputs("Out")[0]
xshape_var = self.op.outputs("XShape")[0]
in_shape = in_var.shape()
out_shape = out_var.shape()
shape = self.op.attr("shape")
assert self._valid_reshape2(
shape), "we don't support the shape {} in pruning".format(shape)
# assert self._valid_pruned_axis(shape, pruned_axis), "we don't support pruned axis is {} when shape is changing from {} to {}".format(pruned_axis, in_shape, out_shape)
self.append_pruned_vars(xshape_var, pruned_axis + 1, transforms)
if var in self.op.inputs("X"):
if (len(out_shape) > len(in_shape)):
#self.op.set_attr('shape',
# [0, 0, int(shape[2] * 0.875), shape[3]])
transform = {"squeeze": out_shape[pruned_axis + 1]}
elif (len(out_shape) < len(in_shape)):
# self.op.set_attr('shape', [0, 0, int(shape[2] * 0.875)])
transform = {"repeat": in_shape[pruned_axis + 1]}
else:
transform = {}
self._visit_and_search(out_var, pruned_axis,
transforms + [transform])
elif var in self.op.outputs("Out"):
if (len(in_shape) > len(out_shape)):
# self.op.set_attr('shape', [0, 0, int(shape[2] * 0.875)])
transform = {"squeeze": in_shape[pruned_axis + 1]}
elif (len(in_shape) < len(in_shape)):
#self.op.set_attr('shape',
# [0, 0, int(shape[2] * 0.875), shape[3]])
transform = {"repeat": out_shape[pruned_axis + 1]}
else:
transform = {}
self._visit_and_search(in_var, pruned_axis,
transforms + [transform])
@PRUNE_WORKER.register
class transpose2(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(transpose2, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, transforms):
axis = self.op.attr('axis')
in_var = self.op.inputs("X")[0]
out_var = self.op.outputs("Out")[0]
if var in self.op.inputs("X"):
new_pruned_axis = axis[pruned_axis]
self._visit_and_search(out_var, new_pruned_axis, transforms)
elif var in self.op.outputs("Out"):
new_pruned_axis = axis[pruned_axis]
self._visit_and_search(in_var, new_pruned_axis, transforms)
@PRUNE_WORKER.register
class conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
......@@ -168,7 +297,7 @@ class conv2d(PruneWorker):
return (num_channels == groups and num_channels != 1 and
num_filters % num_channels == 0)
def _prune(self, var, pruned_axis, pruned_idx):
def _prune(self, var, pruned_axis, transforms):
if self._is_depthwise_conv(self.op):
_logger.debug(f"Meet conv2d who is depthwise conv2d actually.")
worker = depthwise_conv2d(
......@@ -176,7 +305,7 @@ class conv2d(PruneWorker):
self.pruned_params,
visited=self.visited,
skip_stranger=self.skip_stranger)
return worker._prune(var, pruned_axis, pruned_idx)
return worker._prune(var, pruned_axis, transforms)
data_format = self.op.attr("data_format")
groups = self.op.attr("groups")
......@@ -187,39 +316,39 @@ class conv2d(PruneWorker):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format(
pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0]
self.append_pruned_vars(filter_var, 1, pruned_idx)
self.append_pruned_vars(filter_var, 1, transforms)
if groups is None or groups == 1:
self._visit_and_search(filter_var, 1, pruned_idx)
self._visit_and_search(filter_var, 1, transforms)
elif var in self.op.inputs("Filter"):
assert pruned_axis in [0, 1]
self.append_pruned_vars(var, pruned_axis, pruned_idx)
self.append_pruned_vars(var, pruned_axis, transforms)
if groups is None or groups == 1 or pruned_axis == 0:
self._visit_and_search(var, pruned_axis, pruned_idx)
self._visit_and_search(var, pruned_axis, transforms)
if pruned_axis == 0:
if len(self.op.inputs("Bias")) > 0:
self.append_pruned_vars(
self.op.inputs("Bias"), channel_axis, pruned_idx)
self.op.inputs("Bias"), channel_axis, transforms)
output_var = self.op.outputs("Output")[0]
self._visit_and_search(output_var, channel_axis, pruned_idx)
self._visit_and_search(output_var, channel_axis, transforms)
elif pruned_axis == 1:
input_var = self.op.inputs("Input")[0]
self._visit_and_search(input_var, channel_axis, pruned_idx)
self._visit_and_search(input_var, channel_axis, transforms)
elif var in self.op.outputs("Output"):
assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format(
pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 0)
self.append_pruned_vars(filter_var, 0, pruned_idx)
self._visit_and_search(filter_var, 0, pruned_idx)
self.append_pruned_vars(filter_var, 0, transforms)
self._visit_and_search(filter_var, 0, transforms)
if len(self.op.inputs("Bias")) > 0:
self.append_pruned_vars(
self.op.inputs("Bias")[0], channel_axis, pruned_idx)
self.op.inputs("Bias")[0], channel_axis, transforms)
@PRUNE_WORKER.register
......@@ -228,7 +357,7 @@ class conv2d_transpose(PruneWorker):
super(conv2d_transpose, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
def _prune(self, var, pruned_axis, transforms):
data_format = self.op.attr("data_format")
channel_axis = 1
if data_format == "NHWC":
......@@ -238,8 +367,8 @@ class conv2d_transpose(PruneWorker):
pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 0)
self.append_pruned_vars(filter_var, 0, pruned_idx)
self._visit_and_search(filter_var, 0, pruned_idx)
self.append_pruned_vars(filter_var, 0, transforms)
self._visit_and_search(filter_var, 0, transforms)
elif var in self.op.inputs("Filter"):
_logger.warn("Skip pruning output channels of conv2d_transpose!")
......@@ -250,15 +379,15 @@ class conv2d_transpose(PruneWorker):
filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 1)
self.append_pruned_vars(filter_var, 1, pruned_idx)
self.append_pruned_vars(filter_var, 1, transforms)
self._visit_and_search(filter_var, 1, pruned_idx)
self._visit_and_search(filter_var, 1, transforms)
if len(self.op.inputs("Bias")) > 0:
self.append_pruned_vars(
self.op.inputs("Bias")[0], channel_axis, pruned_idx)
self.op.inputs("Bias")[0], channel_axis, transforms)
output_var = self.op.outputs("Output")[0]
self._visit_and_search(output_var, channel_axis, pruned_idx)
self._visit_and_search(output_var, channel_axis, transforms)
@PRUNE_WORKER.register
......@@ -267,22 +396,22 @@ class batch_norm(PruneWorker):
super(batch_norm, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
def _prune(self, var, pruned_axis, transforms):
if (var not in self.op.outputs("Y")) and (
var not in self.op.inputs("X")):
return
if var in self.op.outputs("Y"):
in_var = self.op.inputs("X")[0]
self._visit_and_search(in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, transforms)
for param in ["Scale", "Bias", "Mean", "Variance"]:
param_var = self.op.inputs(param)[0]
self._visit_and_search(param_var, 0, pruned_idx)
self.append_pruned_vars(param_var, 0, pruned_idx)
self._visit_and_search(param_var, 0, transforms)
self.append_pruned_vars(param_var, 0, transforms)
out_var = self.op.outputs("Y")[0]
self._visit_and_search(out_var, pruned_axis, pruned_idx)
self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register
......@@ -297,7 +426,7 @@ class elementwise_op(PruneWorker):
super(elementwise_op, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
def _prune(self, var, pruned_axis, transforms):
axis = self.op.attr("axis")
if axis == -1:
x = self.op.inputs("X")[0]
......@@ -316,8 +445,8 @@ class elementwise_op(PruneWorker):
# for bias
if name == "Y" and actual_axis >= 0 and not (
len(in_var.shape()) == 1 and in_var.shape()[0] == 1):
self.append_pruned_vars(in_var, actual_axis, pruned_idx)
self._visit_and_search(in_var, actual_axis, pruned_idx)
self.append_pruned_vars(in_var, actual_axis, transforms)
self._visit_and_search(in_var, actual_axis, transforms)
else:
if var in self.op.inputs("X"):
......@@ -331,18 +460,18 @@ class elementwise_op(PruneWorker):
if y_pruned_axis >= 0 and not (len(in_var.shape()) == 1 and
in_var.shape()[0] == 1):
self.append_pruned_vars(in_var, y_pruned_axis, pruned_idx)
self._visit_and_search(in_var, y_pruned_axis, pruned_idx)
self.append_pruned_vars(in_var, y_pruned_axis, transforms)
self._visit_and_search(in_var, y_pruned_axis, transforms)
elif var in self.op.inputs("Y"):
in_var = self.op.inputs("X")[0]
if len(in_var.shape()) != len(var.shape()):
assert (len(var.shape()) < len(in_var.shape()))
pruned_axis = pruned_axis + axis
if pruned_axis <= len(in_var.shape()):
self._visit_and_search(in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, transforms)
out_var = self.op.outputs("Out")[0]
self._visit_and_search(out_var, pruned_axis, pruned_idx)
self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register
......@@ -374,13 +503,13 @@ class activation(PruneWorker):
self.input_name = "X"
self.output_name = "Out"
def _prune(self, var, pruned_axis, pruned_idx):
def _prune(self, var, pruned_axis, transforms):
if var in self.op.outputs(self.output_name):
in_var = self.op.inputs(self.input_name)[0]
self._visit_and_search(in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, transforms)
if var in self.op.inputs(self.input_name):
out_var = self.op.outputs(self.output_name)[0]
self._visit_and_search(out_var, pruned_axis, pruned_idx)
self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register
......@@ -389,14 +518,14 @@ class default_worker(PruneWorker):
super(default_worker, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
def _prune(self, var, pruned_axis, transforms):
if var in self.op.all_outputs():
for in_var in self.op.all_inputs():
if len(in_var.shape()) == len(var.shape()):
self._visit_and_search(in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, transforms)
for out_var in self.op.all_outputs():
if len(out_var.shape()) == len(var.shape()):
self._visit_and_search(out_var, pruned_axis, pruned_idx)
self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register
......@@ -428,6 +557,12 @@ class relu(activation):
super(relu, self).__init__(op, pruned_params, visited, skip_stranger)
@PRUNE_WORKER.register
class gelu(activation):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(gelu, self).__init__(op, pruned_params, visited, skip_stranger)
@PRUNE_WORKER.register
class leaky_relu(activation):
def __init__(self, op, pruned_params, visited, skip_stranger):
......@@ -458,16 +593,16 @@ class sum(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(sum, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
def _prune(self, var, pruned_axis, transforms):
if var in self.op.outputs("Out"):
for in_var in self.op.inputs("X"):
self._visit_and_search(in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, transforms)
elif var in self.op.inputs("X"):
for in_var in self.op.inputs("X"):
if in_var != var:
self._visit_and_search(in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, transforms)
out_var = self.op.outputs("Out")[0]
self._visit_and_search(out_var, pruned_axis, pruned_idx)
self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register
......@@ -482,7 +617,7 @@ class split(PruneWorker):
def _prune(self, var, pruned_axis, transforms):
if var == self.in_var:
if pruned_axis != self.axis:
for out_var in self.out_vars:
for i, out_var in enumerate(self.out_vars):
self._visit_and_search(out_var, pruned_axis, transforms)
else:
raise UnsupportOpError(
......@@ -607,13 +742,18 @@ class mul(PruneWorker):
}])
elif var == y:
if (pruned_axis < y_num_col_dims) and (
1 < len(x_shape) - x_num_col_dims):
1 < len(x_shape) - x_num_col_dims) and max(x_shape[
x_num_col_dims:]) != np.prod(y_shape[:y_num_col_dims]):
raise UnsupportOpError(
"Unsupport pruning y of mul when pruned_axis < y_num_col_dims and 1 < len(x_shape) - x_num_col_dims."
)
tile = 1
repeat = 1
self.append_pruned_vars(var, pruned_axis, trans)
self._visit_and_search(var, pruned_axis, trans)
if pruned_axis >= y_num_col_dims:
for i in range(y_num_col_dims, pruned_axis):
tile *= y_shape[i]
......@@ -632,16 +772,24 @@ class mul(PruneWorker):
tile *= y_shape[i]
for i in range(pruned_axis + 1, y_num_col_dims):
repeat *= y_shape[i]
self.append_pruned_vars(x,
len(x_shape) - 1, trans + [{
"tile": tile,
"repeat": repeat
}])
self._visit_and_search(x,
len(x_shape) - 1, trans + [{
"tile": tile,
"repeat": repeat
}])
new_pruned_axis = int(np.argmax(x_shape[
x_num_col_dims:])) + x_num_col_dims
self.append_pruned_vars(
x,
# len(x_shape) - 1, trans + [{
new_pruned_axis,
trans + [{
"tile": tile,
"repeat": repeat
}])
self._visit_and_search(
x,
# len(x_shape) - 1, trans + [{
new_pruned_axis,
trans + [{
"tile": tile,
"repeat": repeat
}])
elif var == out:
if (pruned_axis == 0 and x_num_col_dims != 1) or (
pruned_axis == 1 and (len(y_shape) - y_num_col_dims) != 1):
......@@ -663,7 +811,7 @@ class matmul(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(matmul, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
def _prune(self, var, pruned_axis, transforms):
x = self.op.inputs("X")[0]
y = self.op.inputs("Y")[0]
out = self.op.outputs("Out")[0]
......@@ -680,40 +828,44 @@ class matmul(PruneWorker):
mappings = [(1, -1, 1), (2, 0, -1)]
elif x_shape_len == 2 and y_shape_len == 3:
mappings = [(0, -1, 1), (1, 1, -1), (-1, 2, 2)]
elif x_shape_len == 3 and y_shape_len == 2:
mappings = [(2, 0, -1), (-1, 1, 2)]
elif x_shape_len == 4 and y_shape_len == 4:
mappings = [(1, 1, 1)]
elif x_shape_len >= 3 and y_shape_len >= 3:
mappings = [(x_shape_len - 2, -1, x_shape_len - 2),
(x_shape_len - 1, x_shape_len - 2, -1),
(-1, x_shape_len - 1, x_shape_len - 1)]
if var == x:
for x_i, y_i, out_i in mappings:
if pruned_axis == x_i:
if y_i != -1:
self.append_pruned_vars(y, y_i, pruned_idx)
self._visit_and_search(y, y_i, pruned_idx)
self.append_pruned_vars(y, y_i, transforms)
self._visit_and_search(y, y_i, transforms)
if out_i != -1:
#self.append_pruned_vars(out, out_i, pruned_idx)
self._visit_and_search(out, out_i, pruned_idx)
self._visit_and_search(out, out_i, transforms)
break
if var == y:
for x_i, y_i, out_i in mappings:
if pruned_axis == y_i:
if x_i != -1:
self.append_pruned_vars(x, x_i, pruned_idx)
self._visit_and_search(x, x_i, pruned_idx)
self.append_pruned_vars(x, x_i, transforms)
self._visit_and_search(x, x_i, transforms)
if 'w_' in var.name():
self.append_pruned_vars(var, y_i, transforms)
self._visit_and_search(var, y_i, transforms)
if out_i != -1:
#self.append_pruned_vars(out, out_i, pruned_idx)
self._visit_and_search(out, out_i, pruned_idx)
self._visit_and_search(out, out_i, transforms)
break
if var == out:
for x_i, y_i, out_i in mappings:
if pruned_axis == out_i:
if x_i != -1:
self.append_pruned_vars(x, x_i, pruned_idx)
self._visit_and_search(x, x_i, pruned_idx)
self.append_pruned_vars(x, x_i, transforms)
self._visit_and_search(x, x_i, transforms)
if y_i != -1:
self.append_pruned_vars(y, y_i, pruned_idx)
self._visit_and_search(y, y_i, pruned_idx)
self.append_pruned_vars(y, y_i, transforms)
self._visit_and_search(y, y_i, transforms)
break
......@@ -729,13 +881,13 @@ class scale(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(scale, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
def _prune(self, var, pruned_axis, transforms):
if var in self.op.inputs("X"):
out_var = self.op.outputs("Out")[0]
self._visit_and_search(out_var, pruned_axis, pruned_idx)
self._visit_and_search(out_var, pruned_axis, transforms)
elif var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0]
self._visit_and_search(in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, transforms)
@PRUNE_WORKER.register
......@@ -744,10 +896,10 @@ class momentum(PruneWorker):
super(momentum, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
def _prune(self, var, pruned_axis, transforms):
if var in self.op.inputs("Param"):
velocity_var = self.op.inputs("Velocity")[0]
self.append_pruned_vars(velocity_var, pruned_axis, pruned_idx)
self.append_pruned_vars(velocity_var, pruned_axis, transforms)
@PRUNE_WORKER.register
......@@ -755,12 +907,12 @@ class adam(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(adam, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
def _prune(self, var, pruned_axis, transforms):
if var in self.op.inputs("Param"):
moment1_var = self.op.inputs("Moment1")[0]
self.append_pruned_vars(moment1_var, pruned_axis, pruned_idx)
self.append_pruned_vars(moment1_var, pruned_axis, transforms)
moment2_var = self.op.inputs("Moment2")[0]
self.append_pruned_vars(moment2_var, pruned_axis, pruned_idx)
self.append_pruned_vars(moment2_var, pruned_axis, transforms)
@PRUNE_WORKER.register
......@@ -769,22 +921,22 @@ class affine_channel(PruneWorker):
super(affine_channel, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx):
def _prune(self, var, pruned_axis, transforms):
if (var not in self.op.outputs("Out")) and (
var not in self.op.inputs("X")):
return
if var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0]
self._visit_and_search(in_var, pruned_axis, pruned_idx)
self._visit_and_search(in_var, pruned_axis, transforms)
for param in ["Scale", "Bias"]:
param_var = self.op.inputs(param)[0]
self._visit_and_search(param_var, 0, pruned_idx)
self.append_pruned_vars(param_var, 0, pruned_idx)
self._visit_and_search(param_var, 0, transforms)
self.append_pruned_vars(param_var, 0, transforms)
out_var = self.op.outputs("Out")[0]
self._visit_and_search(out_var, pruned_axis, pruned_idx)
self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register
......@@ -794,9 +946,15 @@ class flatten_contiguous_range(PruneWorker):
visited, skip_stranger)
def _prune(self, var, pruned_axis, transforms):
start_axis = self.op.attr("start_axis")
stop_axis = self.op.attr("stop_axis")
if var in self.op.outputs("Out"):
out_var = self.op.outputs("Out")[0]
in_var = self.op.inputs("X")[0]
assert pruned_axis == start_axis and pruned_axis == int(
np.argmax(in_var.shape()[1:]) + 1)
self._visit_and_search(in_var, pruned_axis, transforms)
if var in self.op.inputs("X"):
out_var = self.op.outputs("Out")[0]
in_var = self.op.inputs("X")[0]
......
......@@ -188,6 +188,13 @@ class Pruner():
for idx in src:
idx = idx * repeat
target.extend(range(idx, idx + repeat))
elif "squeeze" in trans:
repeat = trans['repeat']
targets_set = set()
for idx in src:
targets_set.add(idx / repeat)
target = list(targets_set)
src = target
ret.append((name, axis, src))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册