提交 74bdbd72 编写于 作者: W wanghaoshuang

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleSlim into bert

...@@ -25,9 +25,6 @@ class Registry(object): ...@@ -25,9 +25,6 @@ class Registry(object):
return self._module_dict.get(key, None) return self._module_dict.get(key, None)
def _register_module(self, module_class): def _register_module(self, module_class):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, but receive {}.'.format(
type(module_class)))
module_name = module_class.__name__ module_name = module_class.__name__
if module_name in self._module_dict: if module_name in self._module_dict:
raise KeyError('{} is already registered in {}.'.format( raise KeyError('{} is already registered in {}.'.format(
......
...@@ -52,12 +52,11 @@ def default_idx_selector(group, ratio): ...@@ -52,12 +52,11 @@ def default_idx_selector(group, ratio):
list: pruned indexes list: pruned indexes
""" """
assert (isinstance(graph, GraphWrapper))
name, axis, score = group[ name, axis, score = group[
0] # sort channels by the first convolution's score 0] # sort channels by the first convolution's score
sorted_idx = score.argsort() sorted_idx = score.argsort()
pruned_num = len(sorted_idx) * ratio pruned_num = int(round(len(sorted_idx) * ratio))
pruned_idx = sorted_idx[:pruned_num] pruned_idx = sorted_idx[:pruned_num]
idxs = [] idxs = []
...@@ -94,7 +93,6 @@ def optimal_threshold(group, ratio): ...@@ -94,7 +93,6 @@ def optimal_threshold(group, ratio):
list: pruned indexes list: pruned indexes
""" """
assert (isinstance(graph, GraphWrapper))
name, axis, score = group[ name, axis, score = group[
0] # sort channels by the first convolution's score 0] # sort channels by the first convolution's score
......
...@@ -71,8 +71,11 @@ class PruneWorker(object): ...@@ -71,8 +71,11 @@ class PruneWorker(object):
if visited is not None: if visited is not None:
self.visited = visited self.visited = visited
cls = PRUNE_WORKER.get(op.type()) cls = PRUNE_WORKER.get(op.type())
assert cls is not None, "The walker of {} is not registered.".format( if cls is None:
op.type()) _logger.warn(
"{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.".
format(op.type()))
cls = PRUNE_WORKER.get("default_walker")
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format( _logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format(
self.op, op, pruned_axis, var.name())) self.op, op, pruned_axis, var.name()))
walker = cls(op, walker = cls(op,
...@@ -236,6 +239,7 @@ class elementwise_mul(elementwise_op): ...@@ -236,6 +239,7 @@ class elementwise_mul(elementwise_op):
super(elementwise_mul, self).__init__(op, pruned_params, visited) super(elementwise_mul, self).__init__(op, pruned_params, visited)
@PRUNE_WORKER.register
class activation(PruneWorker): class activation(PruneWorker):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited):
super(activation, self).__init__(op, pruned_params, visited) super(activation, self).__init__(op, pruned_params, visited)
...@@ -256,6 +260,27 @@ class activation(PruneWorker): ...@@ -256,6 +260,27 @@ class activation(PruneWorker):
self._prune_op(op, out_var, pruned_axis, pruned_idx) self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register
class default_walker(PruneWorker):
def __init__(self, op, pruned_params, visited):
super(default_walker, self).__init__(op, pruned_params, visited)
def _prune(self, var, pruned_axis, pruned_idx):
if var in self.op.all_outputs():
for in_var in self.op.inputs():
if len(in_var.shape()) == len(var.shape()):
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
for out_var in self.op.all_outputs():
if len(out_var.shape()) == len(var.shape()):
self._visit(out_var, pruned_axis)
next_ops = out_var.outputs()
for op in next_ops:
self._prune_op(op, out_var, pruned_axis, pruned_idx)
@PRUNE_WORKER.register @PRUNE_WORKER.register
class uniform_random_batch_size_like(activation): class uniform_random_batch_size_like(activation):
def __init__(self, op, pruned_params, visited): def __init__(self, op, pruned_params, visited):
......
...@@ -41,8 +41,6 @@ class Pruner(): ...@@ -41,8 +41,6 @@ class Pruner():
def __init__(self, def __init__(self,
criterion="l1_norm", criterion="l1_norm",
idx_selector="default_idx_selector"): idx_selector="default_idx_selector"):
self.criterion = criterion
self.channel_sortor = channel_sortor
if isinstance(criterion, str): if isinstance(criterion, str):
self.criterion = CRITERION.get(criterion) self.criterion = CRITERION.get(criterion)
else: else:
...@@ -98,7 +96,7 @@ class Pruner(): ...@@ -98,7 +96,7 @@ class Pruner():
param_v = graph.var(param) param_v = graph.var(param)
pruned_num = int(round(param_v.shape()[0] * ratio)) pruned_num = int(round(param_v.shape()[0] * ratio))
pruned_idx = [0] * pruned_num pruned_idx = [0] * pruned_num
for name, aixs in group: for name, axis in group:
pruned_params.append((name, axis, pruned_idx)) pruned_params.append((name, axis, pruned_idx))
else: else:
...@@ -109,10 +107,10 @@ class Pruner(): ...@@ -109,10 +107,10 @@ class Pruner():
values = np.array(scope.find_var(name).get_tensor()) values = np.array(scope.find_var(name).get_tensor())
group_values.append((name, values, axis)) group_values.append((name, values, axis))
scores = self.criterion(group_with_values, scores = self.criterion(group_values,
graph) # [(name, axis, score)] graph) # [(name, axis, score)]
pruned_params = self.idx_selector(scores) pruned_params.extend(self.idx_selector(scores, ratio))
merge_pruned_params = {} merge_pruned_params = {}
for param, pruned_axis, pruned_idx in pruned_params: for param, pruned_axis, pruned_idx in pruned_params:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册