diff --git a/paddleslim/dygraph/prune/unstructured_pruner.py b/paddleslim/dygraph/prune/unstructured_pruner.py index 4e7ef182a7073f62f763338e3dd2bc3656ef32eb..deaa7e3ae969e23a945b8bf17aafed54a5bd9f1e 100644 --- a/paddleslim/dygraph/prune/unstructured_pruner.py +++ b/paddleslim/dygraph/prune/unstructured_pruner.py @@ -304,6 +304,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): self.cur_iteration = configs.get('resume_iteration') assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin." + self._need_prune_once = False self._prepare_training_hyper_parameters() def _prepare_training_hyper_parameters(self): @@ -330,6 +331,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): # pop out used ratios to resume training for i in range(self.cur_iteration): + self._need_prune_once = True if len(self. ratios_stack) > 0 and i % self.ratio_increment_period == 0: self.ratio = self.ratios_stack.pop() @@ -344,7 +346,8 @@ class GMPUnstructuredPruner(UnstructuredPruner): # Update the threshold and masks only when a new ratio has been set. # This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period. - if ori_ratio != self.ratio: + if ori_ratio != self.ratio or self._need_prune_once: self.update_threshold() self._update_masks() + self._need_prune_once = False self.cur_iteration += 1 diff --git a/paddleslim/prune/unstructured_pruner.py b/paddleslim/prune/unstructured_pruner.py index cf9064aee5d490a2ebf2a8fe287f0cd9cd38a546..bf19fcfc535233fd7d93c0e346415efd0554ab52 100644 --- a/paddleslim/prune/unstructured_pruner.py +++ b/paddleslim/prune/unstructured_pruner.py @@ -350,6 +350,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): self.cur_iteration = configs.get('resume_iteration') assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin." + self._need_prune_once = False self._prepare_training_hyper_parameters() def _prepare_training_hyper_parameters(self): @@ -376,6 +377,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): # pop out used ratios to resume training for i in range(self.cur_iteration): + self._need_prune_once = True if len(self. ratios_stack) > 0 and i % self.ratio_increment_period == 0: self.ratio = self.ratios_stack.pop() @@ -393,7 +395,8 @@ class GMPUnstructuredPruner(UnstructuredPruner): # Update the threshold and masks only when a new ratio has been set. # This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period. - if ori_ratio != self.ratio: + if ori_ratio != self.ratio or self._need_prune_once: self.update_threshold() self._update_masks() + self._need_prune_once = False self.cur_iteration += 1