diff --git a/paddleslim/prune/auto_pruner.py b/paddleslim/prune/auto_pruner.py index 5dbdb6d4aa064fc6d5534f0ea02fefe19e580899..b144251a0a9a294094f7101f30958486abcf0543 100644 --- a/paddleslim/prune/auto_pruner.py +++ b/paddleslim/prune/auto_pruner.py @@ -106,7 +106,7 @@ class AutoPruner(object): self, _program, self._params, self._pruned_flops, self._pruned_latency) init_tokens = self._ratios2tokens(self._init_ratios) - + _logger.info("range table: {}".format(self._range_table)) controller = SAController(self._range_table, self._reduce_rate, self._init_temperature, self._max_try_number, init_tokens, self._constrain_func) @@ -143,10 +143,10 @@ class AutoPruner(object): def _get_range_table(self, min_ratios, max_ratios): assert isinstance(min_ratios, list) or isinstance(min_ratios, float) assert isinstance(max_ratios, list) or isinstance(max_ratios, float) - min_ratios = min_ratios if isinstance(min_ratios, - list) else [min_ratios] - max_ratios = max_ratios if isinstance(max_ratios, - list) else [max_ratios] + min_ratios = min_ratios if isinstance( + min_ratios, list) else [min_ratios] * len(self._params) + max_ratios = max_ratios if isinstance( + max_ratios, list) else [max_ratios] * len(self._params) min_tokens = self._ratios2tokens(min_ratios) max_tokens = self._ratios2tokens(max_ratios) return (min_tokens, max_tokens) @@ -163,7 +163,7 @@ class AutoPruner(object): return flops(pruned_program) < self._base_flops * ( 1 - self._pruned_flops) - def prune(self, program): + def prune(self, program, eval_program=None): """ Prune program with latest tokens generated by controller. Args: @@ -178,10 +178,21 @@ class AutoPruner(object): self._params, self._current_ratios, place=self._place, + only_graph=False, param_backup=self._param_backup) + pruned_val_program = None + if eval_program is not None: + pruned_val_program = self._pruner.prune( + program, + self._scope, + self._params, + self._current_ratios, + place=self._place, + only_graph=True) + _logger.info("AutoPruner - pruned ratios: {}".format( self._current_ratios)) - return pruned_program + return pruned_program, pruned_val_program def reward(self, score): """