From 82bfa0a5ba5a4e54c8200c12f9f59716d6044e72 Mon Sep 17 00:00:00 2001 From: whs Date: Mon, 10 Jun 2019 17:54:37 +0800 Subject: [PATCH] Enhence auto pruning and fix save/load graph when using py_reader 1. Enhence auto pruning. 2. Fix save/load graph when using py_reader. test=develop --- .../fluid/contrib/slim/graph/graph_wrapper.py | 13 ++++-- .../contrib/slim/prune/auto_prune_strategy.py | 41 +++++++++++++++---- 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py index 689f6441170..b01c98aab9d 100644 --- a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py +++ b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py @@ -477,8 +477,12 @@ class GraphWrapper(object): for var in self.program.list_vars(): if var.persistable and var.name not in self.persistables: self.persistables[var.name] = var + persistables = [] + for var in self.persistables: + if 'reader' not in var and 'double_buffer' not in var: + persistables.append(self.persistables[var]) - io.save_vars(exe.exe, path, vars=self.persistables.values()) + io.save_vars(exe.exe, path, vars=persistables) def load_persistables(self, path, exe): """ @@ -491,8 +495,11 @@ class GraphWrapper(object): def if_exist(var): return os.path.exists(os.path.join(path, var.name)) - io.load_vars( - exe.exe, path, vars=self.persistables.values(), predicate=if_exist) + persistables = [] + for var in self.persistables: + if 'reader' not in var and 'double_buffer' not in var: + persistables.append(self.persistables[var]) + io.load_vars(exe.exe, path, vars=persistables, predicate=if_exist) def update_param_shape(self, scope): """ diff --git a/python/paddle/fluid/contrib/slim/prune/auto_prune_strategy.py b/python/paddle/fluid/contrib/slim/prune/auto_prune_strategy.py index 680b644fdd7..2f29385870d 100644 --- a/python/paddle/fluid/contrib/slim/prune/auto_prune_strategy.py +++ b/python/paddle/fluid/contrib/slim/prune/auto_prune_strategy.py @@ -39,7 +39,9 @@ class AutoPruneStrategy(PruneStrategy): max_ratio=0.7, metric_name='top1_acc', pruned_params='conv.*_weights', - retrain_epoch=0): + retrain_epoch=0, + uniform_range=None, + init_tokens=None): """ Args: pruner(slim.Pruner): The pruner used to prune the parameters. Default: None. @@ -52,6 +54,8 @@ class AutoPruneStrategy(PruneStrategy): It should be one of keys in out_nodes of graph wrapper. Default: 'top1_acc' pruned_params(str): The pattern str to match the parameter names to be pruned. Default: 'conv.*_weights' retrain_epoch(int): The training epochs in each seaching step. Default: 0 + uniform_range(int): The token range in each position of tokens generated by controller. None means getting the range automatically. Default: None. + init_tokens(list): The initial tokens. None means getting the initial tokens automatically. Default: None. """ super(AutoPruneStrategy, self).__init__(pruner, start_epoch, end_epoch, 0.0, metric_name, pruned_params) @@ -60,8 +64,9 @@ class AutoPruneStrategy(PruneStrategy): self._controller = controller self._metric_name = metric_name self._pruned_param_names = [] - self._retrain_epoch = 0 - + self._retrain_epoch = retrain_epoch + self._uniform_range = uniform_range + self._init_tokens = init_tokens self._current_tokens = None def on_compression_begin(self, context): @@ -75,9 +80,18 @@ class AutoPruneStrategy(PruneStrategy): if re.match(self.pruned_params, param.name()): self._pruned_param_names.append(param.name()) - self._current_tokens = self._get_init_tokens(context) - self._range_table = copy.deepcopy(self._current_tokens) + if self._init_tokens is not None: + self._current_tokens = self._init_tokens + else: + self._current_tokens = self._get_init_tokens(context) + if self._uniform_range is not None: + self._range_table = [round(self._uniform_range, 2) / 0.01] * len( + self._pruned_param_names) + else: + self._range_table = copy.deepcopy(self._current_tokens) + _logger.info('init tokens: {}'.format(self._current_tokens)) + _logger.info("range_table: {}".format(self._range_table)) constrain_func = functools.partial( self._constrain_func, context=context) @@ -104,14 +118,20 @@ class AutoPruneStrategy(PruneStrategy): context.eval_graph.var(param).set_shape(param_shape_backup[param]) flops_ratio = (1 - float(flops) / ori_flops) if flops_ratio >= self._min_ratio and flops_ratio <= self._max_ratio: + _logger.info("Success try [{}]; flops: -{}".format(tokens, + flops_ratio)) return True else: + _logger.info("Failed try [{}]; flops: -{}".format(tokens, + flops_ratio)) return False def _get_init_tokens(self, context): """Get initial tokens. """ ratios = self._get_uniform_ratios(context) + _logger.info('Get init ratios: {}'.format( + [round(r, 2) for r in ratios])) return self._ratios_to_tokens(ratios) def _ratios_to_tokens(self, ratios): @@ -171,7 +191,7 @@ class AutoPruneStrategy(PruneStrategy): if context.epoch_id >= self.start_epoch and context.epoch_id <= self.end_epoch and ( self._retrain_epoch == 0 or (context.epoch_id - self.start_epoch) % self._retrain_epoch == 0): - self._current_tokens = self._controller.next_tokens() + _logger.info("on_epoch_begin") params = self._pruned_param_names ratios = self._tokens_to_ratios(self._current_tokens) @@ -189,7 +209,7 @@ class AutoPruneStrategy(PruneStrategy): context.optimize_graph.update_groups_of_conv() context.eval_graph.update_groups_of_conv() context.optimize_graph.compile( - mem_opt=True) # to update the compiled program + mem_opt=False) # to update the compiled program context.skip_training = (self._retrain_epoch == 0) def on_epoch_end(self, context): @@ -199,10 +219,13 @@ class AutoPruneStrategy(PruneStrategy): """ if context.epoch_id >= self.start_epoch and context.epoch_id < self.end_epoch and ( self._retrain_epoch == 0 or - (context.epoch_id - self.start_epoch) % self._retrain_epoch == 0): + (context.epoch_id - self.start_epoch + 1 + ) % self._retrain_epoch == 0): + _logger.info("on_epoch_end") reward = context.eval_results[self._metric_name][-1] self._controller.update(self._current_tokens, reward) + self._current_tokens = self._controller.next_tokens() # restore pruned parameters for param_name in self._param_backup.keys(): param_t = context.scope.find_var(param_name).get_tensor() @@ -218,7 +241,7 @@ class AutoPruneStrategy(PruneStrategy): context.optimize_graph.update_groups_of_conv() context.eval_graph.update_groups_of_conv() context.optimize_graph.compile( - mem_opt=True) # to update the compiled program + mem_opt=False) # to update the compiled program elif context.epoch_id == self.end_epoch: # restore graph for final training # restore pruned parameters -- GitLab