提交 74bdbd72 编写于 作者: W wanghaoshuang

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleSlim into bert

......@@ -25,9 +25,6 @@ class Registry(object):
return self._module_dict.get(key, None)
def _register_module(self, module_class):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, but receive {}.'.format(
type(module_class)))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}.'.format(
......
......@@ -52,12 +52,11 @@ def default_idx_selector(group, ratio):
list: pruned indexes
"""
assert (isinstance(graph, GraphWrapper))
name, axis, score = group[
0] # sort channels by the first convolution's score
sorted_idx = score.argsort()
pruned_num = len(sorted_idx) * ratio
pruned_num = int(round(len(sorted_idx) * ratio))
pruned_idx = sorted_idx[:pruned_num]
idxs = []
......@@ -94,7 +93,6 @@ def optimal_threshold(group, ratio):
list: pruned indexes
"""
assert (isinstance(graph, GraphWrapper))
name, axis, score = group[
0] # sort channels by the first convolution's score
......
......@@ -71,8 +71,11 @@ class PruneWorker(object):
if visited is not None:
self.visited = visited
cls = PRUNE_WORKER.get(op.type())
assert cls is not None, "The walker of {} is not registered.".format(
op.type())
if cls is None:
_logger.warn(
"{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.".
format(op.type()))
cls = PRUNE_WORKER.get("default_walker")
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format(
self.op, op, pruned_axis, var.name()))
walker = cls(op,
......@@ -236,6 +239,7 @@ class elementwise_mul(elementwise_op):
super(elementwise_mul, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class activation(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(activation, self).__init__(op, pruned_params, visited)
......@@ -256,6 +260,27 @@ class activation(PruneWorker):
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class default_walker(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(default_walker, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.all_outputs():
for in_var in self.op.inputs():
if len(in_var.shape()) == len(var.shape()):
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
for out_var in self.op.all_outputs():
if len(out_var.shape()) == len(var.shape()):
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class uniform_random_batch_size_like(activation):
def __init__(self, op, pruned_params, visited):
......
......@@ -41,8 +41,6 @@ class Pruner():
def __init__(self,
criterion="l1_norm",
idx_selector="default_idx_selector"):
self.criterion = criterion
self.channel_sortor = channel_sortor
if isinstance(criterion, str):
self.criterion = CRITERION.get(criterion)
else:
......@@ -98,7 +96,7 @@ class Pruner():
param_v = graph.var(param)
pruned_num = int(round(param_v.shape()[0] * ratio))
pruned_idx = [0] * pruned_num
for name, aixs in group:
for name, axis in group:
pruned_params.append((name, axis, pruned_idx))
else:
......@@ -109,10 +107,10 @@ class Pruner():
values = np.array(scope.find_var(name).get_tensor())
group_values.append((name, values, axis))
scores = self.criterion(group_with_values,
scores = self.criterion(group_values,
graph) # [(name, axis, score)]
pruned_params = self.idx_selector(scores)
pruned_params.extend(self.idx_selector(scores, ratio))
merge_pruned_params = {}
for param, pruned_axis, pruned_idx in pruned_params:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册