未验证 提交 5b0346c5 编写于 作者: M minghaoBD 提交者: GitHub

Support FC (#1551)

* support fc in pruning

* support matmul and matmul_v2 pruning

* support ffn pruning with gelu activations in the middle

* support reshape2, transpose and split in transformer block

* support pattern: qkv gemm -> batched gemm -> out linear

* prune all fc layer

* support fc pruning

* almost done with few hardcode

* remove hardcode

* fix UT

* fix UT

* avoid setting attributes for reshape in fc pruning

* fix UT

* fix UT
上级 78efe1d9
...@@ -13,6 +13,7 @@ from paddleslim.analysis import flops ...@@ -13,6 +13,7 @@ from paddleslim.analysis import flops
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir) sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
import models import models
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
...@@ -37,6 +38,12 @@ def eval(args): ...@@ -37,6 +38,12 @@ def eval(args):
val_dataset = paddle.vision.datasets.MNIST(mode='test') val_dataset = paddle.vision.datasets.MNIST(mode='test')
class_dim = 10 class_dim = 10
image_shape = "1,28,28" image_shape = "1,28,28"
elif args.data == "cifar10":
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
val_dataset = paddle.vision.datasets.Cifar10(
mode="test", backend="cv2", transform=transform)
class_dim = 10
image_shape = "3, 32, 32"
elif args.data == "imagenet": elif args.data == "imagenet":
import imagenet_reader as reader import imagenet_reader as reader
val_dataset = reader.ImageNetDataset(mode='val') val_dataset = reader.ImageNetDataset(mode='val')
......
...@@ -130,6 +130,6 @@ def dygraph_flops(model, inputs, dtypes=None, only_conv=False, detail=False): ...@@ -130,6 +130,6 @@ def dygraph_flops(model, inputs, dtypes=None, only_conv=False, detail=False):
detail(bool): Whether to return detail of each convolution layer. detail(bool): Whether to return detail of each convolution layer.
""" """
program = dygraph2program(model, inputs) program = dygraph2program(model, inputs, dtypes=dtypes)
graph = GraphWrapper(program) graph = GraphWrapper(program)
return _graph_flops(graph, only_conv=only_conv, detail=detail) return _graph_flops(graph, only_conv=only_conv, detail=detail)
...@@ -58,14 +58,29 @@ class FilterPruner(Pruner): ...@@ -58,14 +58,29 @@ class FilterPruner(Pruner):
""" """
def __init__(self, model, inputs, sen_file=None, opt=None, def __init__(self,
skip_leaves=True): model,
inputs,
sen_file=None,
opt=None,
skip_leaves=True,
prune_type='conv',
num_head=-1,
input_dtype='float32'):
super(FilterPruner, self).__init__(model, inputs, opt=opt) super(FilterPruner, self).__init__(model, inputs, opt=opt)
self.num_head = num_head
self.prune_type = prune_type
self._status = Status(sen_file) self._status = Status(sen_file)
self.skip_leaves = skip_leaves self.skip_leaves = skip_leaves
# sensitive and collections are just used in filter pruning # sensitive and collections are just used in filter pruning
self.collections = DygraphPruningCollections( self.collections = DygraphPruningCollections(
model, inputs, skip_leaves=self.skip_leaves) model,
inputs,
skip_leaves=self.skip_leaves,
prune_type=prune_type,
input_dtype=input_dtype)
# skip vars in: # skip vars in:
# 1. depthwise conv2d layer # 1. depthwise conv2d layer
...@@ -294,7 +309,7 @@ class FilterPruner(Pruner): ...@@ -294,7 +309,7 @@ class FilterPruner(Pruner):
if self.plan is not None: if self.plan is not None:
self.plan.restore(self.model, opt=self.opt) self.plan.restore(self.model, opt=self.opt)
def cal_mask(self, pruned_ratio, collection): def cal_mask(self, pruned_ratio, collection, num_head=-1):
raise NotImplemented("cal_mask is not implemented") raise NotImplemented("cal_mask is not implemented")
def prune_var(self, var_name, pruned_axis, pruned_ratio, apply="impretive"): def prune_var(self, var_name, pruned_axis, pruned_ratio, apply="impretive"):
...@@ -332,13 +347,12 @@ class FilterPruner(Pruner): ...@@ -332,13 +347,12 @@ class FilterPruner(Pruner):
f"Pruning variable [{var_name}] and its relatives {list(collection.variables())}" f"Pruning variable [{var_name}] and its relatives {list(collection.variables())}"
) )
mask = self.cal_mask(pruned_ratio, collection) mask = self.cal_mask(pruned_ratio, collection, num_head=self.num_head)
for _detail in collection.all_pruning_details(): for _detail in collection.all_pruning_details():
# Varibales can be pruned on multiple axies. # Varibales can be pruned on multiple axies.
src_mask = copy.deepcopy(mask) src_mask = copy.deepcopy(mask)
var_shape = _detail.var.shape() var_shape = _detail.var.shape()
for tran in _detail.transform: for tran in _detail.transform:
src_mask = self._transform_mask(src_mask, tran) src_mask = self._transform_mask(src_mask, tran)
current_mask = src_mask current_mask = src_mask
groups = _detail.op.attr('groups') groups = _detail.op.attr('groups')
...@@ -352,7 +366,11 @@ class FilterPruner(Pruner): ...@@ -352,7 +366,11 @@ class FilterPruner(Pruner):
if apply == "lazy": if apply == "lazy":
plan.apply(self.model, lazy=True) plan.apply(self.model, lazy=True)
elif apply == "impretive": elif apply == "impretive":
plan.apply(self.model, lazy=False, opt=self.opt) plan.apply(
self.model,
lazy=False,
opt=self.opt,
prune_type=self.prune_type)
return plan return plan
def _transform_mask(self, mask, transform): def _transform_mask(self, mask, transform):
...@@ -372,6 +390,17 @@ class FilterPruner(Pruner): ...@@ -372,6 +390,17 @@ class FilterPruner(Pruner):
return mask return mask
elif "repeat" in transform and "tile" in transform: elif "repeat" in transform and "tile" in transform:
return np.tile(mask.repeat(transform["repeat"]), transform["tile"]) return np.tile(mask.repeat(transform["repeat"]), transform["tile"])
elif "squeeze" in transform:
squeeze = transform['squeeze']
tmp = mask.reshape((-1, squeeze))
tmp = np.max(tmp, axis=1)
tmp[tmp > 0] = 1
return tmp
elif "repeat" in transform:
repeat = transform['repeat']
tmp = np.repeat(mask[:, np.newaxis], repeat, axis=1)
return tmp.flatten()
else: else:
return mask return mask
return dst_mask return dst_mask
...@@ -17,7 +17,7 @@ class FPGMFilterPruner(FilterPruner): ...@@ -17,7 +17,7 @@ class FPGMFilterPruner(FilterPruner):
super(FPGMFilterPruner, self).__init__( super(FPGMFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves) model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves)
def cal_mask(self, pruned_ratio, collection): def cal_mask(self, pruned_ratio, collection, num_head=-1):
var_name = collection.master_name var_name = collection.master_name
pruned_axis = collection.master_axis pruned_axis = collection.master_axis
value = collection.values[var_name] value = collection.values[var_name]
......
...@@ -12,15 +12,42 @@ _logger = get_logger(__name__, logging.INFO) ...@@ -12,15 +12,42 @@ _logger = get_logger(__name__, logging.INFO)
class L1NormFilterPruner(FilterPruner): class L1NormFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None, opt=None, def __init__(self,
skip_leaves=True): model,
inputs,
sen_file=None,
opt=None,
skip_leaves=True,
prune_type='conv',
input_dtype='float32',
num_head=-1):
super(L1NormFilterPruner, self).__init__( super(L1NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves) model,
inputs,
sen_file=sen_file,
opt=opt,
skip_leaves=skip_leaves,
prune_type=prune_type,
input_dtype=input_dtype,
num_head=num_head)
def cal_mask(self, pruned_ratio, collection): def cal_mask(self, pruned_ratio, collection, num_head=-1):
var_name = collection.master_name var_name = collection.master_name
pruned_axis = collection.master_axis pruned_axis = collection.master_axis
value = collection.values[var_name] value = collection.values[var_name]
if (value.shape[1] == 3 * value.shape[0] or
value.shape[1] == value.shape[0]) and num_head != -1:
num_head = num_head
assert value.size % (
value.shape[0] * num_head
) == 0, "weight shape must be divisible by num_head."
_logger.debug(
"fused-qkv or query/key/value weight detected, we will prune on num_head: {} -> {}."
.format(num_head, int(num_head * (1 - pruned_ratio))))
else:
num_head = -1
groups = 1 groups = 1
for _detail in collection.all_pruning_details(): for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int)) assert (isinstance(_detail.axis, int))
...@@ -30,6 +57,10 @@ class L1NormFilterPruner(FilterPruner): ...@@ -30,6 +57,10 @@ class L1NormFilterPruner(FilterPruner):
groups = _groups groups = _groups
break break
if num_head != -1:
k = int(value.size / value.shape[0] / num_head)
value = value.reshape((value.shape[0], num_head, k))
reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis] reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
l1norm = np.mean(np.abs(value), axis=tuple(reduce_dims)) l1norm = np.mean(np.abs(value), axis=tuple(reduce_dims))
if groups > 1: if groups > 1:
...@@ -46,4 +77,9 @@ class L1NormFilterPruner(FilterPruner): ...@@ -46,4 +77,9 @@ class L1NormFilterPruner(FilterPruner):
if groups > 1: if groups > 1:
mask = mask.reshape([groups, -1]) mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0 mask[pruned_idx] = 0
if num_head != -1:
mask = np.repeat(mask[:, np.newaxis], k, axis=1)
return mask.flatten()
return mask.reshape(mask_shape) return mask.reshape(mask_shape)
...@@ -12,15 +12,42 @@ _logger = get_logger(__name__, logging.INFO) ...@@ -12,15 +12,42 @@ _logger = get_logger(__name__, logging.INFO)
class L2NormFilterPruner(FilterPruner): class L2NormFilterPruner(FilterPruner):
def __init__(self, model, inputs, sen_file=None, opt=None, def __init__(self,
skip_leaves=True): model,
inputs,
sen_file=None,
opt=None,
skip_leaves=True,
prune_type='conv',
input_dtype="float32",
num_head=-1):
super(L2NormFilterPruner, self).__init__( super(L2NormFilterPruner, self).__init__(
model, inputs, sen_file=sen_file, opt=opt, skip_leaves=skip_leaves) model,
inputs,
sen_file=sen_file,
opt=opt,
skip_leaves=skip_leaves,
prune_type=prune_type,
input_dtype=input_dtype,
num_head=num_head)
def cal_mask(self, pruned_ratio, collection): def cal_mask(self, pruned_ratio, collection, num_head=-1):
var_name = collection.master_name var_name = collection.master_name
pruned_axis = collection.master_axis pruned_axis = collection.master_axis
value = collection.values[var_name] value = collection.values[var_name]
if (value.shape[1] == 3 * value.shape[0] or
value.shape[1] == value.shape[0]) and num_head != -1:
num_head = num_head
assert value.size % (
value.shape[0] * num_head
) == 0, "weight shape must be divisible by num_head"
_logger.debug(
"fused-qkv or query/key/value weight detected, we will prune on num_head: {} -> {}"
.format(num_head, int(num_head * (1 - pruned_ratio))))
else:
num_head = -1
groups = 1 groups = 1
for _detail in collection.all_pruning_details(): for _detail in collection.all_pruning_details():
assert (isinstance(_detail.axis, int)) assert (isinstance(_detail.axis, int))
...@@ -30,6 +57,10 @@ class L2NormFilterPruner(FilterPruner): ...@@ -30,6 +57,10 @@ class L2NormFilterPruner(FilterPruner):
groups = _groups groups = _groups
break break
if num_head != -1:
k = int(value.size / value.shape[0] / num_head)
value = value.reshape((value.shape[0], num_head, k))
reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis] reduce_dims = [i for i in range(len(value.shape)) if i != pruned_axis]
scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims))) scores = np.sqrt(np.sum(np.square(value), axis=tuple(reduce_dims)))
if groups > 1: if groups > 1:
...@@ -46,4 +77,9 @@ class L2NormFilterPruner(FilterPruner): ...@@ -46,4 +77,9 @@ class L2NormFilterPruner(FilterPruner):
if groups > 1: if groups > 1:
mask = mask.reshape([groups, -1]) mask = mask.reshape([groups, -1])
mask[pruned_idx] = 0 mask[pruned_idx] = 0
if num_head != -1:
mask = np.repeat(mask[:, np.newaxis], k, axis=1)
return mask.flatten()
return mask.reshape(mask_shape) return mask.reshape(mask_shape)
...@@ -19,7 +19,7 @@ class Pruner(object): ...@@ -19,7 +19,7 @@ class Pruner(object):
opt(paddle.optimizer.Optimizer): The model's optimizer. Default: None. opt(paddle.optimizer.Optimizer): The model's optimizer. Default: None.
""" """
def __init__(self, model, inputs, opt=None): def __init__(self, model, inputs, opt=None, prune_type='conv'):
self.model = model self.model = model
self.inputs = inputs self.inputs = inputs
self._var_shapes = {} self._var_shapes = {}
...@@ -27,6 +27,7 @@ class Pruner(object): ...@@ -27,6 +27,7 @@ class Pruner(object):
self._var_shapes[var.name] = var.shape self._var_shapes[var.name] = var.shape
self.plan = None self.plan = None
self.opt = opt self.opt = opt
self.prune_type = prune_type
def status(self, data=None, eval_func=None, status_file=None): def status(self, data=None, eval_func=None, status_file=None):
raise NotImplemented("status is not implemented") raise NotImplemented("status is not implemented")
...@@ -54,6 +55,10 @@ class Pruner(object): ...@@ -54,6 +55,10 @@ class Pruner(object):
if apply == "lazy": if apply == "lazy":
global_plan.apply(self.model, lazy=True) global_plan.apply(self.model, lazy=True)
elif apply == "impretive": elif apply == "impretive":
global_plan.apply(self.model, lazy=False, opt=self.opt) global_plan.apply(
self.model,
lazy=False,
opt=self.opt,
prune_type=self.prune_type)
self.plan = global_plan self.plan = global_plan
return global_plan return global_plan
...@@ -102,7 +102,6 @@ class PruningPlan(): ...@@ -102,7 +102,6 @@ class PruningPlan():
var_tmp = v.get(param_name) var_tmp = v.get(param_name)
#NOTE: var_tmp.shape == [1] is used to skip variables like beta1_pow_acc in Adam optimizer. Its shape is [1] and there's no need to prune this one-value variable. #NOTE: var_tmp.shape == [1] is used to skip variables like beta1_pow_acc in Adam optimizer. Its shape is [1] and there's no need to prune this one-value variable.
if var_tmp is None or var_tmp.shape == [1]: if var_tmp is None or var_tmp.shape == [1]:
if var_tmp is not None: print(var_tmp.name, var_tmp.shape)
continue continue
t_value = var_tmp.value().get_tensor() t_value = var_tmp.value().get_tensor()
value = np.array(t_value).astype("float32") value = np.array(t_value).astype("float32")
...@@ -162,11 +161,11 @@ class PruningPlan(): ...@@ -162,11 +161,11 @@ class PruningPlan():
t_value.set(np.array(t_backup).astype("float32"), place) t_value.set(np.array(t_backup).astype("float32"), place)
del sub_layer._buffers[backup_name] del sub_layer._buffers[backup_name]
def apply(self, model, lazy=False, opt=None): def apply(self, model, lazy=False, opt=None, prune_type='conv'):
if lazy: if lazy:
self.lazy_apply(model) self.lazy_apply(model)
else: else:
self.imperative_apply(model, opt) self.imperative_apply(model, opt, prune_type=prune_type)
def lazy_apply(self, model): def lazy_apply(self, model):
for name, sub_layer in model.named_sublayers(): for name, sub_layer in model.named_sublayers():
...@@ -203,12 +202,11 @@ class PruningPlan(): ...@@ -203,12 +202,11 @@ class PruningPlan():
t_value.set(value * expand_mask, place) t_value.set(value * expand_mask, place)
def imperative_apply(self, model, opt=None): def imperative_apply(self, model, opt=None, prune_type='conv'):
""" """
Pruning values of variable imperatively. It is valid when pruning Pruning values of variable imperatively. It is valid when pruning
on one dimension. on one dimension.
""" """
for name, sub_layer in model.named_sublayers(include_self=True): for name, sub_layer in model.named_sublayers(include_self=True):
for param in sub_layer.parameters(include_sublayers=False): for param in sub_layer.parameters(include_sublayers=False):
if param.name in self._masks: if param.name in self._masks:
...@@ -240,11 +238,9 @@ class PruningPlan(): ...@@ -240,11 +238,9 @@ class PruningPlan():
format(param.name)) format(param.name))
# save optimizer accumulators into layer buffer # save optimizer accumulators into layer buffer
self._buffer_opt(param.name, sub_layer, opt) self._buffer_opt(param.name, sub_layer, opt)
pruned_value = np.apply_along_axis( pruned_value = np.apply_along_axis(
lambda data: data[bool_mask], dims, value) lambda data: data[bool_mask], dims, value)
self._prune_opt(param.name, dims, bool_mask, opt) self._prune_opt(param.name, dims, bool_mask, opt)
p = t_value._place() p = t_value._place()
if p.is_cpu_place(): if p.is_cpu_place():
place = paddle.CPUPlace() place = paddle.CPUPlace()
......
...@@ -20,17 +20,38 @@ class DygraphPruningCollections(PruningCollections): ...@@ -20,17 +20,38 @@ class DygraphPruningCollections(PruningCollections):
- skip_leaves(bool): Whether to skip the last convolution layers. - skip_leaves(bool): Whether to skip the last convolution layers.
""" """
def __init__(self, model, inputs, skip_leaves=True): def __init__(self,
model,
inputs,
skip_leaves=True,
prune_type='conv',
input_dtype='float32'):
assert prune_type in ['conv', 'fc'
], "Please select conv or fc as your prune type."
_logger.debug("Parsing model with input: {}".format(inputs)) _logger.debug("Parsing model with input: {}".format(inputs))
# model can be in training mode, because some model contains auxiliary parameters for training. # model can be in training mode, because some model contains auxiliary parameters for training.
program = dygraph2program(model, inputs=inputs) #TODO(minghaoBD): support dictionary input
if isinstance(inputs[0], int):
dtypes = [input_dtype]
elif isinstance(inputs[0], list):
dtypes = [input_dtype] * len(inputs)
else:
dtypes = [input_dtype]
program = dygraph2program(model, inputs=inputs, dtypes=dtypes)
graph = GraphWrapper(program) graph = GraphWrapper(program)
if prune_type == 'conv':
params = [ params = [
_param.name for _param in model.parameters() _param.name for _param in model.parameters()
if len(_param.shape) == 4 if len(_param.shape) == 4
] ]
elif prune_type == 'fc':
params = [
_param.name for _param in model.parameters()
if len(_param.shape) == 2
]
self._collections = self.create_pruning_collections( self._collections = self.create_pruning_collections(
params, graph, skip_leaves=skip_leaves) params, graph, skip_leaves=skip_leaves, prune_type=prune_type)
_logger.info("Found {} collections.".format(len(self._collections))) _logger.info("Found {} collections.".format(len(self._collections)))
_name2values = {} _name2values = {}
......
...@@ -142,7 +142,8 @@ class PruningCollections(object): ...@@ -142,7 +142,8 @@ class PruningCollections(object):
graph, graph,
skip_stranger=True, skip_stranger=True,
skip_vars=None, skip_vars=None,
skip_leaves=True): skip_leaves=True,
prune_type='conv'):
"""Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation. """Collect convolution layers of graph into groups. The layers in the same group is relative on pruning operation.
A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on. A group is a list of tuple with format (param_name, axis) in which `param_name` is the name of parameter and `axis` is the axis to be pruned on.
...@@ -186,6 +187,12 @@ class PruningCollections(object): ...@@ -186,6 +187,12 @@ class PruningCollections(object):
visited = {} visited = {}
collections = [] collections = []
unsupported_warnings = set() unsupported_warnings = set()
if prune_type == 'conv':
prune_axis = 0
elif prune_type == 'fc':
prune_axis = 1
for _param in params: for _param in params:
pruned_params = [] pruned_params = []
param = graph.var(_param) param = graph.var(_param)
...@@ -216,7 +223,7 @@ class PruningCollections(object): ...@@ -216,7 +223,7 @@ class PruningCollections(object):
worker.skip_vars = skip_vars worker.skip_vars = skip_vars
try: try:
visited_backup = copy.deepcopy(worker.visited) visited_backup = copy.deepcopy(worker.visited)
worker.prune(param, pruned_axis=0, pruned_idx=[]) worker.prune(param, pruned_axis=prune_axis, transforms=[])
except UnsupportOpError as e: except UnsupportOpError as e:
visited.clear() visited.clear()
visited.update(visited_backup) visited.update(visited_backup)
...@@ -225,15 +232,16 @@ class PruningCollections(object): ...@@ -225,15 +232,16 @@ class PruningCollections(object):
if len(pruned_params) != 0: if len(pruned_params) != 0:
collection = PruningCollection(master=({ collection = PruningCollection(master=({
"name": param.name(), "name": param.name(),
"axis": 0 "axis": prune_axis,
})) }))
for _param, _axis, _transform, _op in pruned_params: for _param, _axis, _transform, _op in pruned_params:
collection.add( tmp = PruningDetails(_param, _axis, _transform, _op)
PruningDetails(_param, _axis, _transform, _op)) collection.add(tmp)
collections.append(collection) collections.append(collection)
for warn in unsupported_warnings: for warn in unsupported_warnings:
_logger.warning(warn) _logger.warning(warn)
self._collections = collections self._collections = collections
return self._collections return self._collections
......
...@@ -41,6 +41,8 @@ OPS_UNCHANGE_SHAPE += [ ...@@ -41,6 +41,8 @@ OPS_UNCHANGE_SHAPE += [
'dropout', 'dropout',
'cast', 'cast',
'hard_swish', 'hard_swish',
'fused_softmax_mask_upper_triangle',
'softmax',
'hard_sigmoid', 'hard_sigmoid',
] ]
...@@ -77,20 +79,21 @@ class PruneWorker(object): ...@@ -77,20 +79,21 @@ class PruneWorker(object):
).split(",") ).split(",")
self.skip_vars = skip_vars self.skip_vars = skip_vars
def prune(self, var, pruned_axis, pruned_idx): def prune(self, var, pruned_axis, transforms):
""" """
Infer the shape of variables related with current operator, predecessor and successor. Infer the shape of variables related with current operator, predecessor and successor.
It will search the graph to find all varibles related with `var` and record the information of pruning. It will search the graph to find all varibles related with `var` and record the information of pruning.
Args: Args:
var(Variable): The root variable of searching. It can be the input or output of current operator. var(Variable): The root variable of searching. It can be the input or output of current operator.
pruned_axis(int): The axis to be pruned of root variable. pruned_axis(int): The axis to be pruned of root variable.
pruned_idx(int): The indices to be pruned in `pruned_axis` of root variable. transforms(list<dict>): The transforms applied the the current variable/mask.
""" """
if var.name() in self.skip_vars: if var.name() in self.skip_vars:
raise UnsupportOpError("Variable {} was skipped.".format(var.name( raise UnsupportOpError("Variable {} was skipped.".format(var.name(
))) )))
if self._visit(var, pruned_axis): if self._visit(var, pruned_axis):
self._prune(var, pruned_axis, pruned_idx) self._prune(var, pruned_axis, transforms)
def _visit(self, var, pruned_axis): def _visit(self, var, pruned_axis):
key = "_".join([str(self.op.idx()), var.name()]) key = "_".join([str(self.op.idx()), var.name()])
...@@ -115,10 +118,10 @@ class PruneWorker(object): ...@@ -115,10 +118,10 @@ class PruneWorker(object):
for op in next_ops: for op in next_ops:
self._prune_op(op, var, axis, transforms) self._prune_op(op, var, axis, transforms)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
raise NotImplementedError('Abstract method.') raise NotImplementedError('Abstract method.')
def _prune_op(self, op, var, pruned_axis, pruned_idx, visited=None): def _prune_op(self, op, var, pruned_axis, transforms, visited=None):
if op.type().endswith("_grad"): if op.type().endswith("_grad"):
return return
if visited is not None: if visited is not None:
...@@ -137,18 +140,144 @@ class PruneWorker(object): ...@@ -137,18 +140,144 @@ class PruneWorker(object):
op.type())) op.type()))
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}\ntrans: {}". _logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}\ntrans: {}".
format(self.op, op, pruned_axis, var.name(), pruned_idx)) format(self.op, op, pruned_axis, var.name(), transforms))
_logger.debug( _logger.debug(
f"visit {op.type()} by var [{var.name()}] on axis [{pruned_axis}];\t visited={self.visited}\n" f"visit {op.type()} by var [{var.name()}] on axis [{pruned_axis}];\t visited={self.visited}\n"
) )
worker = cls(op, self.pruned_params, self.visited, self.skip_stranger) worker = cls(op, self.pruned_params, self.visited, self.skip_stranger)
worker.skip_vars = self.skip_vars worker.skip_vars = self.skip_vars
worker.prune(var, pruned_axis, pruned_idx) worker.prune(var, pruned_axis, transforms)
def append_pruned_vars(self, var, axis, transforms): def append_pruned_vars(self, var, axis, transforms):
self.pruned_params.append((var, axis, transforms, self.op)) self.pruned_params.append((var, axis, transforms, self.op))
@PRUNE_WORKER.register
class reshape2(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(reshape2, self).__init__(op, pruned_params, visited,
skip_stranger)
def _valid_reshape2(self, shape):
# case1: only reshape last several dimensions. e.g. [0,0,1,2] returns True while [1,0,0,1] returns False.
changed = False
for sh in shape:
if sh == 0 and changed:
return False
if sh != 0:
changed = True
return True
#NOTE: we might not need this assertion
def _valid_pruned_axis(self, shape, pruned_axis):
last_zero_index = -1
for i in range(shape):
if shape[i] == 0: last_zero_index = i
if pruned_axis <= last_zero_index:
return pruned_axis
elif pruned_axis > last_zero_index:
return pruned_axis
def _get_idx_after_expanding_dims(self, pruned_axis, transforms,
shorter_shape, longer_shape):
assert len(shorter_shape) < len(
longer_shape
), "length of {} must be smaller than length of {}.".format(
shorter_shape, longer_shape)
dim_old = shorter_shape[pruned_axis]
dim_new = longer_shape[pruned_axis]
k = dim_old / dim_new
assert k == longer_shape[
pruned_axis + 1], "k is {} while longer shape is {}[{}]".format(
k, longer_shape, pruned_axis + 1)
transforms = np.array(transforms, dtype='int32')
pruned_rows = transforms / k
pruned_cols = transforms % k
new_transforms = []
for row in range(dim_new):
prune_this_row = row in pruned_rows
prune_this_row = prune_this_row and (len(prued_rows) == k)
if prune_this_row:
new_transforms.append(row)
return new_transforms
def _get_idx_after_shrinking_dims(self, pruned_axis, transforms,
longer_shape, shorter_shape):
assert len(shorter_shape) < len(
longer_shape
), "length of {} must be smaller than length of {}.".format(
shorter_shape, longer_shape)
dim_old = longer_shape[pruned_axis]
dim_new = shorter_shape[pruned_axis]
k = dim_new / dim_old
assert k == longer_shape[pruned_axis + 1]
new_transforms = []
for row in range(dim_old):
if row in transforms:
new_transforms.expand(
[i for i in range(row * k, (row + 1) * k)])
return new_transforms
def _prune(self, var, pruned_axis, transforms):
in_var = self.op.inputs("X")[0]
out_var = self.op.outputs("Out")[0]
xshape_var = self.op.outputs("XShape")[0]
in_shape = in_var.shape()
out_shape = out_var.shape()
shape = self.op.attr("shape")
assert self._valid_reshape2(
shape), "we don't support the shape {} in pruning".format(shape)
# assert self._valid_pruned_axis(shape, pruned_axis), "we don't support pruned axis is {} when shape is changing from {} to {}".format(pruned_axis, in_shape, out_shape)
self.append_pruned_vars(xshape_var, pruned_axis + 1, transforms)
if var in self.op.inputs("X"):
if (len(out_shape) > len(in_shape)):
#self.op.set_attr('shape',
# [0, 0, int(shape[2] * 0.875), shape[3]])
transform = {"squeeze": out_shape[pruned_axis + 1]}
elif (len(out_shape) < len(in_shape)):
# self.op.set_attr('shape', [0, 0, int(shape[2] * 0.875)])
transform = {"repeat": in_shape[pruned_axis + 1]}
else:
transform = {}
self._visit_and_search(out_var, pruned_axis,
transforms + [transform])
elif var in self.op.outputs("Out"):
if (len(in_shape) > len(out_shape)):
# self.op.set_attr('shape', [0, 0, int(shape[2] * 0.875)])
transform = {"squeeze": in_shape[pruned_axis + 1]}
elif (len(in_shape) < len(in_shape)):
#self.op.set_attr('shape',
# [0, 0, int(shape[2] * 0.875), shape[3]])
transform = {"repeat": out_shape[pruned_axis + 1]}
else:
transform = {}
self._visit_and_search(in_var, pruned_axis,
transforms + [transform])
@PRUNE_WORKER.register
class transpose2(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(transpose2, self).__init__(op, pruned_params, visited,
skip_stranger)
def _prune(self, var, pruned_axis, transforms):
axis = self.op.attr('axis')
in_var = self.op.inputs("X")[0]
out_var = self.op.outputs("Out")[0]
if var in self.op.inputs("X"):
new_pruned_axis = axis[pruned_axis]
self._visit_and_search(out_var, new_pruned_axis, transforms)
elif var in self.op.outputs("Out"):
new_pruned_axis = axis[pruned_axis]
self._visit_and_search(in_var, new_pruned_axis, transforms)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class conv2d(PruneWorker): class conv2d(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger): def __init__(self, op, pruned_params, visited, skip_stranger):
...@@ -168,7 +297,7 @@ class conv2d(PruneWorker): ...@@ -168,7 +297,7 @@ class conv2d(PruneWorker):
return (num_channels == groups and num_channels != 1 and return (num_channels == groups and num_channels != 1 and
num_filters % num_channels == 0) num_filters % num_channels == 0)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
if self._is_depthwise_conv(self.op): if self._is_depthwise_conv(self.op):
_logger.debug(f"Meet conv2d who is depthwise conv2d actually.") _logger.debug(f"Meet conv2d who is depthwise conv2d actually.")
worker = depthwise_conv2d( worker = depthwise_conv2d(
...@@ -176,7 +305,7 @@ class conv2d(PruneWorker): ...@@ -176,7 +305,7 @@ class conv2d(PruneWorker):
self.pruned_params, self.pruned_params,
visited=self.visited, visited=self.visited,
skip_stranger=self.skip_stranger) skip_stranger=self.skip_stranger)
return worker._prune(var, pruned_axis, pruned_idx) return worker._prune(var, pruned_axis, transforms)
data_format = self.op.attr("data_format") data_format = self.op.attr("data_format")
groups = self.op.attr("groups") groups = self.op.attr("groups")
...@@ -187,39 +316,39 @@ class conv2d(PruneWorker): ...@@ -187,39 +316,39 @@ class conv2d(PruneWorker):
assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format( assert pruned_axis == channel_axis, "The Input of conv2d can only be pruned at channel axis, but got {}; var: {}".format(
pruned_axis, var.name()) pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
self.append_pruned_vars(filter_var, 1, pruned_idx) self.append_pruned_vars(filter_var, 1, transforms)
if groups is None or groups == 1: if groups is None or groups == 1:
self._visit_and_search(filter_var, 1, pruned_idx) self._visit_and_search(filter_var, 1, transforms)
elif var in self.op.inputs("Filter"): elif var in self.op.inputs("Filter"):
assert pruned_axis in [0, 1] assert pruned_axis in [0, 1]
self.append_pruned_vars(var, pruned_axis, pruned_idx) self.append_pruned_vars(var, pruned_axis, transforms)
if groups is None or groups == 1 or pruned_axis == 0: if groups is None or groups == 1 or pruned_axis == 0:
self._visit_and_search(var, pruned_axis, pruned_idx) self._visit_and_search(var, pruned_axis, transforms)
if pruned_axis == 0: if pruned_axis == 0:
if len(self.op.inputs("Bias")) > 0: if len(self.op.inputs("Bias")) > 0:
self.append_pruned_vars( self.append_pruned_vars(
self.op.inputs("Bias"), channel_axis, pruned_idx) self.op.inputs("Bias"), channel_axis, transforms)
output_var = self.op.outputs("Output")[0] output_var = self.op.outputs("Output")[0]
self._visit_and_search(output_var, channel_axis, pruned_idx) self._visit_and_search(output_var, channel_axis, transforms)
elif pruned_axis == 1: elif pruned_axis == 1:
input_var = self.op.inputs("Input")[0] input_var = self.op.inputs("Input")[0]
self._visit_and_search(input_var, channel_axis, pruned_idx) self._visit_and_search(input_var, channel_axis, transforms)
elif var in self.op.outputs("Output"): elif var in self.op.outputs("Output"):
assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format( assert pruned_axis == channel_axis, "pruned_axis: {}; var: {}".format(
pruned_axis, var.name()) pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 0) self._visit(filter_var, 0)
self.append_pruned_vars(filter_var, 0, pruned_idx) self.append_pruned_vars(filter_var, 0, transforms)
self._visit_and_search(filter_var, 0, pruned_idx) self._visit_and_search(filter_var, 0, transforms)
if len(self.op.inputs("Bias")) > 0: if len(self.op.inputs("Bias")) > 0:
self.append_pruned_vars( self.append_pruned_vars(
self.op.inputs("Bias")[0], channel_axis, pruned_idx) self.op.inputs("Bias")[0], channel_axis, transforms)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -228,7 +357,7 @@ class conv2d_transpose(PruneWorker): ...@@ -228,7 +357,7 @@ class conv2d_transpose(PruneWorker):
super(conv2d_transpose, self).__init__(op, pruned_params, visited, super(conv2d_transpose, self).__init__(op, pruned_params, visited,
skip_stranger) skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
data_format = self.op.attr("data_format") data_format = self.op.attr("data_format")
channel_axis = 1 channel_axis = 1
if data_format == "NHWC": if data_format == "NHWC":
...@@ -238,8 +367,8 @@ class conv2d_transpose(PruneWorker): ...@@ -238,8 +367,8 @@ class conv2d_transpose(PruneWorker):
pruned_axis, var.name()) pruned_axis, var.name())
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 0) self._visit(filter_var, 0)
self.append_pruned_vars(filter_var, 0, pruned_idx) self.append_pruned_vars(filter_var, 0, transforms)
self._visit_and_search(filter_var, 0, pruned_idx) self._visit_and_search(filter_var, 0, transforms)
elif var in self.op.inputs("Filter"): elif var in self.op.inputs("Filter"):
_logger.warn("Skip pruning output channels of conv2d_transpose!") _logger.warn("Skip pruning output channels of conv2d_transpose!")
...@@ -250,15 +379,15 @@ class conv2d_transpose(PruneWorker): ...@@ -250,15 +379,15 @@ class conv2d_transpose(PruneWorker):
filter_var = self.op.inputs("Filter")[0] filter_var = self.op.inputs("Filter")[0]
self._visit(filter_var, 1) self._visit(filter_var, 1)
self.append_pruned_vars(filter_var, 1, pruned_idx) self.append_pruned_vars(filter_var, 1, transforms)
self._visit_and_search(filter_var, 1, pruned_idx) self._visit_and_search(filter_var, 1, transforms)
if len(self.op.inputs("Bias")) > 0: if len(self.op.inputs("Bias")) > 0:
self.append_pruned_vars( self.append_pruned_vars(
self.op.inputs("Bias")[0], channel_axis, pruned_idx) self.op.inputs("Bias")[0], channel_axis, transforms)
output_var = self.op.outputs("Output")[0] output_var = self.op.outputs("Output")[0]
self._visit_and_search(output_var, channel_axis, pruned_idx) self._visit_and_search(output_var, channel_axis, transforms)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -267,22 +396,22 @@ class batch_norm(PruneWorker): ...@@ -267,22 +396,22 @@ class batch_norm(PruneWorker):
super(batch_norm, self).__init__(op, pruned_params, visited, super(batch_norm, self).__init__(op, pruned_params, visited,
skip_stranger) skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
if (var not in self.op.outputs("Y")) and ( if (var not in self.op.outputs("Y")) and (
var not in self.op.inputs("X")): var not in self.op.inputs("X")):
return return
if var in self.op.outputs("Y"): if var in self.op.outputs("Y"):
in_var = self.op.inputs("X")[0] in_var = self.op.inputs("X")[0]
self._visit_and_search(in_var, pruned_axis, pruned_idx) self._visit_and_search(in_var, pruned_axis, transforms)
for param in ["Scale", "Bias", "Mean", "Variance"]: for param in ["Scale", "Bias", "Mean", "Variance"]:
param_var = self.op.inputs(param)[0] param_var = self.op.inputs(param)[0]
self._visit_and_search(param_var, 0, pruned_idx) self._visit_and_search(param_var, 0, transforms)
self.append_pruned_vars(param_var, 0, pruned_idx) self.append_pruned_vars(param_var, 0, transforms)
out_var = self.op.outputs("Y")[0] out_var = self.op.outputs("Y")[0]
self._visit_and_search(out_var, pruned_axis, pruned_idx) self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -297,7 +426,7 @@ class elementwise_op(PruneWorker): ...@@ -297,7 +426,7 @@ class elementwise_op(PruneWorker):
super(elementwise_op, self).__init__(op, pruned_params, visited, super(elementwise_op, self).__init__(op, pruned_params, visited,
skip_stranger) skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
axis = self.op.attr("axis") axis = self.op.attr("axis")
if axis == -1: if axis == -1:
x = self.op.inputs("X")[0] x = self.op.inputs("X")[0]
...@@ -316,8 +445,8 @@ class elementwise_op(PruneWorker): ...@@ -316,8 +445,8 @@ class elementwise_op(PruneWorker):
# for bias # for bias
if name == "Y" and actual_axis >= 0 and not ( if name == "Y" and actual_axis >= 0 and not (
len(in_var.shape()) == 1 and in_var.shape()[0] == 1): len(in_var.shape()) == 1 and in_var.shape()[0] == 1):
self.append_pruned_vars(in_var, actual_axis, pruned_idx) self.append_pruned_vars(in_var, actual_axis, transforms)
self._visit_and_search(in_var, actual_axis, pruned_idx) self._visit_and_search(in_var, actual_axis, transforms)
else: else:
if var in self.op.inputs("X"): if var in self.op.inputs("X"):
...@@ -331,18 +460,18 @@ class elementwise_op(PruneWorker): ...@@ -331,18 +460,18 @@ class elementwise_op(PruneWorker):
if y_pruned_axis >= 0 and not (len(in_var.shape()) == 1 and if y_pruned_axis >= 0 and not (len(in_var.shape()) == 1 and
in_var.shape()[0] == 1): in_var.shape()[0] == 1):
self.append_pruned_vars(in_var, y_pruned_axis, pruned_idx) self.append_pruned_vars(in_var, y_pruned_axis, transforms)
self._visit_and_search(in_var, y_pruned_axis, pruned_idx) self._visit_and_search(in_var, y_pruned_axis, transforms)
elif var in self.op.inputs("Y"): elif var in self.op.inputs("Y"):
in_var = self.op.inputs("X")[0] in_var = self.op.inputs("X")[0]
if len(in_var.shape()) != len(var.shape()): if len(in_var.shape()) != len(var.shape()):
assert (len(var.shape()) < len(in_var.shape())) assert (len(var.shape()) < len(in_var.shape()))
pruned_axis = pruned_axis + axis pruned_axis = pruned_axis + axis
if pruned_axis <= len(in_var.shape()): if pruned_axis <= len(in_var.shape()):
self._visit_and_search(in_var, pruned_axis, pruned_idx) self._visit_and_search(in_var, pruned_axis, transforms)
out_var = self.op.outputs("Out")[0] out_var = self.op.outputs("Out")[0]
self._visit_and_search(out_var, pruned_axis, pruned_idx) self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -374,13 +503,13 @@ class activation(PruneWorker): ...@@ -374,13 +503,13 @@ class activation(PruneWorker):
self.input_name = "X" self.input_name = "X"
self.output_name = "Out" self.output_name = "Out"
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
if var in self.op.outputs(self.output_name): if var in self.op.outputs(self.output_name):
in_var = self.op.inputs(self.input_name)[0] in_var = self.op.inputs(self.input_name)[0]
self._visit_and_search(in_var, pruned_axis, pruned_idx) self._visit_and_search(in_var, pruned_axis, transforms)
if var in self.op.inputs(self.input_name): if var in self.op.inputs(self.input_name):
out_var = self.op.outputs(self.output_name)[0] out_var = self.op.outputs(self.output_name)[0]
self._visit_and_search(out_var, pruned_axis, pruned_idx) self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -389,14 +518,14 @@ class default_worker(PruneWorker): ...@@ -389,14 +518,14 @@ class default_worker(PruneWorker):
super(default_worker, self).__init__(op, pruned_params, visited, super(default_worker, self).__init__(op, pruned_params, visited,
skip_stranger) skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
if var in self.op.all_outputs(): if var in self.op.all_outputs():
for in_var in self.op.all_inputs(): for in_var in self.op.all_inputs():
if len(in_var.shape()) == len(var.shape()): if len(in_var.shape()) == len(var.shape()):
self._visit_and_search(in_var, pruned_axis, pruned_idx) self._visit_and_search(in_var, pruned_axis, transforms)
for out_var in self.op.all_outputs(): for out_var in self.op.all_outputs():
if len(out_var.shape()) == len(var.shape()): if len(out_var.shape()) == len(var.shape()):
self._visit_and_search(out_var, pruned_axis, pruned_idx) self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -428,6 +557,12 @@ class relu(activation): ...@@ -428,6 +557,12 @@ class relu(activation):
super(relu, self).__init__(op, pruned_params, visited, skip_stranger) super(relu, self).__init__(op, pruned_params, visited, skip_stranger)
@PRUNE_WORKER.register
class gelu(activation):
def __init__(self, op, pruned_params, visited, skip_stranger):
super(gelu, self).__init__(op, pruned_params, visited, skip_stranger)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class leaky_relu(activation): class leaky_relu(activation):
def __init__(self, op, pruned_params, visited, skip_stranger): def __init__(self, op, pruned_params, visited, skip_stranger):
...@@ -458,16 +593,16 @@ class sum(PruneWorker): ...@@ -458,16 +593,16 @@ class sum(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger): def __init__(self, op, pruned_params, visited, skip_stranger):
super(sum, self).__init__(op, pruned_params, visited, skip_stranger) super(sum, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
if var in self.op.outputs("Out"): if var in self.op.outputs("Out"):
for in_var in self.op.inputs("X"): for in_var in self.op.inputs("X"):
self._visit_and_search(in_var, pruned_axis, pruned_idx) self._visit_and_search(in_var, pruned_axis, transforms)
elif var in self.op.inputs("X"): elif var in self.op.inputs("X"):
for in_var in self.op.inputs("X"): for in_var in self.op.inputs("X"):
if in_var != var: if in_var != var:
self._visit_and_search(in_var, pruned_axis, pruned_idx) self._visit_and_search(in_var, pruned_axis, transforms)
out_var = self.op.outputs("Out")[0] out_var = self.op.outputs("Out")[0]
self._visit_and_search(out_var, pruned_axis, pruned_idx) self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -482,7 +617,7 @@ class split(PruneWorker): ...@@ -482,7 +617,7 @@ class split(PruneWorker):
def _prune(self, var, pruned_axis, transforms): def _prune(self, var, pruned_axis, transforms):
if var == self.in_var: if var == self.in_var:
if pruned_axis != self.axis: if pruned_axis != self.axis:
for out_var in self.out_vars: for i, out_var in enumerate(self.out_vars):
self._visit_and_search(out_var, pruned_axis, transforms) self._visit_and_search(out_var, pruned_axis, transforms)
else: else:
raise UnsupportOpError( raise UnsupportOpError(
...@@ -607,13 +742,18 @@ class mul(PruneWorker): ...@@ -607,13 +742,18 @@ class mul(PruneWorker):
}]) }])
elif var == y: elif var == y:
if (pruned_axis < y_num_col_dims) and ( if (pruned_axis < y_num_col_dims) and (
1 < len(x_shape) - x_num_col_dims): 1 < len(x_shape) - x_num_col_dims) and max(x_shape[
x_num_col_dims:]) != np.prod(y_shape[:y_num_col_dims]):
raise UnsupportOpError( raise UnsupportOpError(
"Unsupport pruning y of mul when pruned_axis < y_num_col_dims and 1 < len(x_shape) - x_num_col_dims." "Unsupport pruning y of mul when pruned_axis < y_num_col_dims and 1 < len(x_shape) - x_num_col_dims."
) )
tile = 1 tile = 1
repeat = 1 repeat = 1
self.append_pruned_vars(var, pruned_axis, trans)
self._visit_and_search(var, pruned_axis, trans)
if pruned_axis >= y_num_col_dims: if pruned_axis >= y_num_col_dims:
for i in range(y_num_col_dims, pruned_axis): for i in range(y_num_col_dims, pruned_axis):
tile *= y_shape[i] tile *= y_shape[i]
...@@ -632,13 +772,21 @@ class mul(PruneWorker): ...@@ -632,13 +772,21 @@ class mul(PruneWorker):
tile *= y_shape[i] tile *= y_shape[i]
for i in range(pruned_axis + 1, y_num_col_dims): for i in range(pruned_axis + 1, y_num_col_dims):
repeat *= y_shape[i] repeat *= y_shape[i]
self.append_pruned_vars(x, new_pruned_axis = int(np.argmax(x_shape[
len(x_shape) - 1, trans + [{ x_num_col_dims:])) + x_num_col_dims
self.append_pruned_vars(
x,
# len(x_shape) - 1, trans + [{
new_pruned_axis,
trans + [{
"tile": tile, "tile": tile,
"repeat": repeat "repeat": repeat
}]) }])
self._visit_and_search(x, self._visit_and_search(
len(x_shape) - 1, trans + [{ x,
# len(x_shape) - 1, trans + [{
new_pruned_axis,
trans + [{
"tile": tile, "tile": tile,
"repeat": repeat "repeat": repeat
}]) }])
...@@ -663,7 +811,7 @@ class matmul(PruneWorker): ...@@ -663,7 +811,7 @@ class matmul(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger): def __init__(self, op, pruned_params, visited, skip_stranger):
super(matmul, self).__init__(op, pruned_params, visited, skip_stranger) super(matmul, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
x = self.op.inputs("X")[0] x = self.op.inputs("X")[0]
y = self.op.inputs("Y")[0] y = self.op.inputs("Y")[0]
out = self.op.outputs("Out")[0] out = self.op.outputs("Out")[0]
...@@ -680,40 +828,44 @@ class matmul(PruneWorker): ...@@ -680,40 +828,44 @@ class matmul(PruneWorker):
mappings = [(1, -1, 1), (2, 0, -1)] mappings = [(1, -1, 1), (2, 0, -1)]
elif x_shape_len == 2 and y_shape_len == 3: elif x_shape_len == 2 and y_shape_len == 3:
mappings = [(0, -1, 1), (1, 1, -1), (-1, 2, 2)] mappings = [(0, -1, 1), (1, 1, -1), (-1, 2, 2)]
elif x_shape_len == 3 and y_shape_len == 2:
mappings = [(2, 0, -1), (-1, 1, 2)]
elif x_shape_len == 4 and y_shape_len == 4:
mappings = [(1, 1, 1)]
elif x_shape_len >= 3 and y_shape_len >= 3: elif x_shape_len >= 3 and y_shape_len >= 3:
mappings = [(x_shape_len - 2, -1, x_shape_len - 2), mappings = [(x_shape_len - 2, -1, x_shape_len - 2),
(x_shape_len - 1, x_shape_len - 2, -1), (x_shape_len - 1, x_shape_len - 2, -1),
(-1, x_shape_len - 1, x_shape_len - 1)] (-1, x_shape_len - 1, x_shape_len - 1)]
if var == x: if var == x:
for x_i, y_i, out_i in mappings: for x_i, y_i, out_i in mappings:
if pruned_axis == x_i: if pruned_axis == x_i:
if y_i != -1: if y_i != -1:
self.append_pruned_vars(y, y_i, pruned_idx) self.append_pruned_vars(y, y_i, transforms)
self._visit_and_search(y, y_i, pruned_idx) self._visit_and_search(y, y_i, transforms)
if out_i != -1: if out_i != -1:
#self.append_pruned_vars(out, out_i, pruned_idx) self._visit_and_search(out, out_i, transforms)
self._visit_and_search(out, out_i, pruned_idx)
break break
if var == y: if var == y:
for x_i, y_i, out_i in mappings: for x_i, y_i, out_i in mappings:
if pruned_axis == y_i: if pruned_axis == y_i:
if x_i != -1: if x_i != -1:
self.append_pruned_vars(x, x_i, pruned_idx) self.append_pruned_vars(x, x_i, transforms)
self._visit_and_search(x, x_i, pruned_idx) self._visit_and_search(x, x_i, transforms)
if 'w_' in var.name():
self.append_pruned_vars(var, y_i, transforms)
self._visit_and_search(var, y_i, transforms)
if out_i != -1: if out_i != -1:
#self.append_pruned_vars(out, out_i, pruned_idx) self._visit_and_search(out, out_i, transforms)
self._visit_and_search(out, out_i, pruned_idx)
break break
if var == out: if var == out:
for x_i, y_i, out_i in mappings: for x_i, y_i, out_i in mappings:
if pruned_axis == out_i: if pruned_axis == out_i:
if x_i != -1: if x_i != -1:
self.append_pruned_vars(x, x_i, pruned_idx) self.append_pruned_vars(x, x_i, transforms)
self._visit_and_search(x, x_i, pruned_idx) self._visit_and_search(x, x_i, transforms)
if y_i != -1: if y_i != -1:
self.append_pruned_vars(y, y_i, pruned_idx) self.append_pruned_vars(y, y_i, transforms)
self._visit_and_search(y, y_i, pruned_idx) self._visit_and_search(y, y_i, transforms)
break break
...@@ -729,13 +881,13 @@ class scale(PruneWorker): ...@@ -729,13 +881,13 @@ class scale(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger): def __init__(self, op, pruned_params, visited, skip_stranger):
super(scale, self).__init__(op, pruned_params, visited, skip_stranger) super(scale, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
if var in self.op.inputs("X"): if var in self.op.inputs("X"):
out_var = self.op.outputs("Out")[0] out_var = self.op.outputs("Out")[0]
self._visit_and_search(out_var, pruned_axis, pruned_idx) self._visit_and_search(out_var, pruned_axis, transforms)
elif var in self.op.outputs("Out"): elif var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0] in_var = self.op.inputs("X")[0]
self._visit_and_search(in_var, pruned_axis, pruned_idx) self._visit_and_search(in_var, pruned_axis, transforms)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -744,10 +896,10 @@ class momentum(PruneWorker): ...@@ -744,10 +896,10 @@ class momentum(PruneWorker):
super(momentum, self).__init__(op, pruned_params, visited, super(momentum, self).__init__(op, pruned_params, visited,
skip_stranger) skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
if var in self.op.inputs("Param"): if var in self.op.inputs("Param"):
velocity_var = self.op.inputs("Velocity")[0] velocity_var = self.op.inputs("Velocity")[0]
self.append_pruned_vars(velocity_var, pruned_axis, pruned_idx) self.append_pruned_vars(velocity_var, pruned_axis, transforms)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -755,12 +907,12 @@ class adam(PruneWorker): ...@@ -755,12 +907,12 @@ class adam(PruneWorker):
def __init__(self, op, pruned_params, visited, skip_stranger): def __init__(self, op, pruned_params, visited, skip_stranger):
super(adam, self).__init__(op, pruned_params, visited, skip_stranger) super(adam, self).__init__(op, pruned_params, visited, skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
if var in self.op.inputs("Param"): if var in self.op.inputs("Param"):
moment1_var = self.op.inputs("Moment1")[0] moment1_var = self.op.inputs("Moment1")[0]
self.append_pruned_vars(moment1_var, pruned_axis, pruned_idx) self.append_pruned_vars(moment1_var, pruned_axis, transforms)
moment2_var = self.op.inputs("Moment2")[0] moment2_var = self.op.inputs("Moment2")[0]
self.append_pruned_vars(moment2_var, pruned_axis, pruned_idx) self.append_pruned_vars(moment2_var, pruned_axis, transforms)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -769,22 +921,22 @@ class affine_channel(PruneWorker): ...@@ -769,22 +921,22 @@ class affine_channel(PruneWorker):
super(affine_channel, self).__init__(op, pruned_params, visited, super(affine_channel, self).__init__(op, pruned_params, visited,
skip_stranger) skip_stranger)
def _prune(self, var, pruned_axis, pruned_idx): def _prune(self, var, pruned_axis, transforms):
if (var not in self.op.outputs("Out")) and ( if (var not in self.op.outputs("Out")) and (
var not in self.op.inputs("X")): var not in self.op.inputs("X")):
return return
if var in self.op.outputs("Out"): if var in self.op.outputs("Out"):
in_var = self.op.inputs("X")[0] in_var = self.op.inputs("X")[0]
self._visit_and_search(in_var, pruned_axis, pruned_idx) self._visit_and_search(in_var, pruned_axis, transforms)
for param in ["Scale", "Bias"]: for param in ["Scale", "Bias"]:
param_var = self.op.inputs(param)[0] param_var = self.op.inputs(param)[0]
self._visit_and_search(param_var, 0, pruned_idx) self._visit_and_search(param_var, 0, transforms)
self.append_pruned_vars(param_var, 0, pruned_idx) self.append_pruned_vars(param_var, 0, transforms)
out_var = self.op.outputs("Out")[0] out_var = self.op.outputs("Out")[0]
self._visit_and_search(out_var, pruned_axis, pruned_idx) self._visit_and_search(out_var, pruned_axis, transforms)
@PRUNE_WORKER.register @PRUNE_WORKER.register
...@@ -794,9 +946,15 @@ class flatten_contiguous_range(PruneWorker): ...@@ -794,9 +946,15 @@ class flatten_contiguous_range(PruneWorker):
visited, skip_stranger) visited, skip_stranger)
def _prune(self, var, pruned_axis, transforms): def _prune(self, var, pruned_axis, transforms):
start_axis = self.op.attr("start_axis") start_axis = self.op.attr("start_axis")
stop_axis = self.op.attr("stop_axis") stop_axis = self.op.attr("stop_axis")
if var in self.op.outputs("Out"):
out_var = self.op.outputs("Out")[0]
in_var = self.op.inputs("X")[0]
assert pruned_axis == start_axis and pruned_axis == int(
np.argmax(in_var.shape()[1:]) + 1)
self._visit_and_search(in_var, pruned_axis, transforms)
if var in self.op.inputs("X"): if var in self.op.inputs("X"):
out_var = self.op.outputs("Out")[0] out_var = self.op.outputs("Out")[0]
in_var = self.op.inputs("X")[0] in_var = self.op.inputs("X")[0]
......
...@@ -188,6 +188,13 @@ class Pruner(): ...@@ -188,6 +188,13 @@ class Pruner():
for idx in src: for idx in src:
idx = idx * repeat idx = idx * repeat
target.extend(range(idx, idx + repeat)) target.extend(range(idx, idx + repeat))
elif "squeeze" in trans:
repeat = trans['repeat']
targets_set = set()
for idx in src:
targets_set.add(idx / repeat)
target = list(targets_set)
src = target src = target
ret.append((name, axis, src)) ret.append((name, axis, src))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册