未验证 提交 82bfa0a5 编写于 作者: W whs 提交者: GitHub

Enhence auto pruning and fix save/load graph when using py_reader

1. Enhence auto pruning. 
2. Fix save/load graph when using py_reader.
test=develop
上级 aab4d12c
...@@ -477,8 +477,12 @@ class GraphWrapper(object): ...@@ -477,8 +477,12 @@ class GraphWrapper(object):
for var in self.program.list_vars(): for var in self.program.list_vars():
if var.persistable and var.name not in self.persistables: if var.persistable and var.name not in self.persistables:
self.persistables[var.name] = var self.persistables[var.name] = var
persistables = []
for var in self.persistables:
if 'reader' not in var and 'double_buffer' not in var:
persistables.append(self.persistables[var])
io.save_vars(exe.exe, path, vars=self.persistables.values()) io.save_vars(exe.exe, path, vars=persistables)
def load_persistables(self, path, exe): def load_persistables(self, path, exe):
""" """
...@@ -491,8 +495,11 @@ class GraphWrapper(object): ...@@ -491,8 +495,11 @@ class GraphWrapper(object):
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(path, var.name)) return os.path.exists(os.path.join(path, var.name))
io.load_vars( persistables = []
exe.exe, path, vars=self.persistables.values(), predicate=if_exist) for var in self.persistables:
if 'reader' not in var and 'double_buffer' not in var:
persistables.append(self.persistables[var])
io.load_vars(exe.exe, path, vars=persistables, predicate=if_exist)
def update_param_shape(self, scope): def update_param_shape(self, scope):
""" """
......
...@@ -39,7 +39,9 @@ class AutoPruneStrategy(PruneStrategy): ...@@ -39,7 +39,9 @@ class AutoPruneStrategy(PruneStrategy):
max_ratio=0.7, max_ratio=0.7,
metric_name='top1_acc', metric_name='top1_acc',
pruned_params='conv.*_weights', pruned_params='conv.*_weights',
retrain_epoch=0): retrain_epoch=0,
uniform_range=None,
init_tokens=None):
""" """
Args: Args:
pruner(slim.Pruner): The pruner used to prune the parameters. Default: None. pruner(slim.Pruner): The pruner used to prune the parameters. Default: None.
...@@ -52,6 +54,8 @@ class AutoPruneStrategy(PruneStrategy): ...@@ -52,6 +54,8 @@ class AutoPruneStrategy(PruneStrategy):
It should be one of keys in out_nodes of graph wrapper. Default: 'top1_acc' It should be one of keys in out_nodes of graph wrapper. Default: 'top1_acc'
pruned_params(str): The pattern str to match the parameter names to be pruned. Default: 'conv.*_weights' pruned_params(str): The pattern str to match the parameter names to be pruned. Default: 'conv.*_weights'
retrain_epoch(int): The training epochs in each seaching step. Default: 0 retrain_epoch(int): The training epochs in each seaching step. Default: 0
uniform_range(int): The token range in each position of tokens generated by controller. None means getting the range automatically. Default: None.
init_tokens(list<int>): The initial tokens. None means getting the initial tokens automatically. Default: None.
""" """
super(AutoPruneStrategy, self).__init__(pruner, start_epoch, end_epoch, super(AutoPruneStrategy, self).__init__(pruner, start_epoch, end_epoch,
0.0, metric_name, pruned_params) 0.0, metric_name, pruned_params)
...@@ -60,8 +64,9 @@ class AutoPruneStrategy(PruneStrategy): ...@@ -60,8 +64,9 @@ class AutoPruneStrategy(PruneStrategy):
self._controller = controller self._controller = controller
self._metric_name = metric_name self._metric_name = metric_name
self._pruned_param_names = [] self._pruned_param_names = []
self._retrain_epoch = 0 self._retrain_epoch = retrain_epoch
self._uniform_range = uniform_range
self._init_tokens = init_tokens
self._current_tokens = None self._current_tokens = None
def on_compression_begin(self, context): def on_compression_begin(self, context):
...@@ -75,9 +80,18 @@ class AutoPruneStrategy(PruneStrategy): ...@@ -75,9 +80,18 @@ class AutoPruneStrategy(PruneStrategy):
if re.match(self.pruned_params, param.name()): if re.match(self.pruned_params, param.name()):
self._pruned_param_names.append(param.name()) self._pruned_param_names.append(param.name())
self._current_tokens = self._get_init_tokens(context) if self._init_tokens is not None:
self._range_table = copy.deepcopy(self._current_tokens) self._current_tokens = self._init_tokens
else:
self._current_tokens = self._get_init_tokens(context)
if self._uniform_range is not None:
self._range_table = [round(self._uniform_range, 2) / 0.01] * len(
self._pruned_param_names)
else:
self._range_table = copy.deepcopy(self._current_tokens)
_logger.info('init tokens: {}'.format(self._current_tokens))
_logger.info("range_table: {}".format(self._range_table))
constrain_func = functools.partial( constrain_func = functools.partial(
self._constrain_func, context=context) self._constrain_func, context=context)
...@@ -104,14 +118,20 @@ class AutoPruneStrategy(PruneStrategy): ...@@ -104,14 +118,20 @@ class AutoPruneStrategy(PruneStrategy):
context.eval_graph.var(param).set_shape(param_shape_backup[param]) context.eval_graph.var(param).set_shape(param_shape_backup[param])
flops_ratio = (1 - float(flops) / ori_flops) flops_ratio = (1 - float(flops) / ori_flops)
if flops_ratio >= self._min_ratio and flops_ratio <= self._max_ratio: if flops_ratio >= self._min_ratio and flops_ratio <= self._max_ratio:
_logger.info("Success try [{}]; flops: -{}".format(tokens,
flops_ratio))
return True return True
else: else:
_logger.info("Failed try [{}]; flops: -{}".format(tokens,
flops_ratio))
return False return False
def _get_init_tokens(self, context): def _get_init_tokens(self, context):
"""Get initial tokens. """Get initial tokens.
""" """
ratios = self._get_uniform_ratios(context) ratios = self._get_uniform_ratios(context)
_logger.info('Get init ratios: {}'.format(
[round(r, 2) for r in ratios]))
return self._ratios_to_tokens(ratios) return self._ratios_to_tokens(ratios)
def _ratios_to_tokens(self, ratios): def _ratios_to_tokens(self, ratios):
...@@ -171,7 +191,7 @@ class AutoPruneStrategy(PruneStrategy): ...@@ -171,7 +191,7 @@ class AutoPruneStrategy(PruneStrategy):
if context.epoch_id >= self.start_epoch and context.epoch_id <= self.end_epoch and ( if context.epoch_id >= self.start_epoch and context.epoch_id <= self.end_epoch and (
self._retrain_epoch == 0 or self._retrain_epoch == 0 or
(context.epoch_id - self.start_epoch) % self._retrain_epoch == 0): (context.epoch_id - self.start_epoch) % self._retrain_epoch == 0):
self._current_tokens = self._controller.next_tokens() _logger.info("on_epoch_begin")
params = self._pruned_param_names params = self._pruned_param_names
ratios = self._tokens_to_ratios(self._current_tokens) ratios = self._tokens_to_ratios(self._current_tokens)
...@@ -189,7 +209,7 @@ class AutoPruneStrategy(PruneStrategy): ...@@ -189,7 +209,7 @@ class AutoPruneStrategy(PruneStrategy):
context.optimize_graph.update_groups_of_conv() context.optimize_graph.update_groups_of_conv()
context.eval_graph.update_groups_of_conv() context.eval_graph.update_groups_of_conv()
context.optimize_graph.compile( context.optimize_graph.compile(
mem_opt=True) # to update the compiled program mem_opt=False) # to update the compiled program
context.skip_training = (self._retrain_epoch == 0) context.skip_training = (self._retrain_epoch == 0)
def on_epoch_end(self, context): def on_epoch_end(self, context):
...@@ -199,10 +219,13 @@ class AutoPruneStrategy(PruneStrategy): ...@@ -199,10 +219,13 @@ class AutoPruneStrategy(PruneStrategy):
""" """
if context.epoch_id >= self.start_epoch and context.epoch_id < self.end_epoch and ( if context.epoch_id >= self.start_epoch and context.epoch_id < self.end_epoch and (
self._retrain_epoch == 0 or self._retrain_epoch == 0 or
(context.epoch_id - self.start_epoch) % self._retrain_epoch == 0): (context.epoch_id - self.start_epoch + 1
) % self._retrain_epoch == 0):
_logger.info("on_epoch_end")
reward = context.eval_results[self._metric_name][-1] reward = context.eval_results[self._metric_name][-1]
self._controller.update(self._current_tokens, reward) self._controller.update(self._current_tokens, reward)
self._current_tokens = self._controller.next_tokens()
# restore pruned parameters # restore pruned parameters
for param_name in self._param_backup.keys(): for param_name in self._param_backup.keys():
param_t = context.scope.find_var(param_name).get_tensor() param_t = context.scope.find_var(param_name).get_tensor()
...@@ -218,7 +241,7 @@ class AutoPruneStrategy(PruneStrategy): ...@@ -218,7 +241,7 @@ class AutoPruneStrategy(PruneStrategy):
context.optimize_graph.update_groups_of_conv() context.optimize_graph.update_groups_of_conv()
context.eval_graph.update_groups_of_conv() context.eval_graph.update_groups_of_conv()
context.optimize_graph.compile( context.optimize_graph.compile(
mem_opt=True) # to update the compiled program mem_opt=False) # to update the compiled program
elif context.epoch_id == self.end_epoch: # restore graph for final training elif context.epoch_id == self.end_epoch: # restore graph for final training
# restore pruned parameters # restore pruned parameters
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册