diff --git a/paddleslim/dygraph/prune/unstructured_pruner.py b/paddleslim/dygraph/prune/unstructured_pruner.py index 4e7ef182a7073f62f763338e3dd2bc3656ef32eb..deaa7e3ae969e23a945b8bf17aafed54a5bd9f1e 100644 --- a/paddleslim/dygraph/prune/unstructured_pruner.py +++ b/paddleslim/dygraph/prune/unstructured_pruner.py @@ -304,6 +304,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): self.cur_iteration = configs.get('resume_iteration') assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin." + self._need_prune_once = False self._prepare_training_hyper_parameters() def _prepare_training_hyper_parameters(self): @@ -330,6 +331,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): # pop out used ratios to resume training for i in range(self.cur_iteration): + self._need_prune_once = True if len(self. ratios_stack) > 0 and i % self.ratio_increment_period == 0: self.ratio = self.ratios_stack.pop() @@ -344,7 +346,8 @@ class GMPUnstructuredPruner(UnstructuredPruner): # Update the threshold and masks only when a new ratio has been set. # This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period. - if ori_ratio != self.ratio: + if ori_ratio != self.ratio or self._need_prune_once: self.update_threshold() self._update_masks() + self._need_prune_once = False self.cur_iteration += 1 diff --git a/paddleslim/prune/unstructured_pruner.py b/paddleslim/prune/unstructured_pruner.py index 2f3c840f6fa34cfddb9422b78ec4cfb4d029ef09..ca2fb2fd87e6be44373bcfd5dd93c19377303d61 100644 --- a/paddleslim/prune/unstructured_pruner.py +++ b/paddleslim/prune/unstructured_pruner.py @@ -349,6 +349,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): self.cur_iteration = configs.get('resume_iteration') assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin." + self._need_prune_once = False self._prepare_training_hyper_parameters() def _prepare_training_hyper_parameters(self): @@ -375,6 +376,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): # pop out used ratios to resume training for i in range(self.cur_iteration): + self._need_prune_once = True if len(self. ratios_stack) > 0 and i % self.ratio_increment_period == 0: self.ratio = self.ratios_stack.pop() @@ -392,7 +394,8 @@ class GMPUnstructuredPruner(UnstructuredPruner): # Update the threshold and masks only when a new ratio has been set. # This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period. - if ori_ratio != self.ratio: + if ori_ratio != self.ratio or self._need_prune_once: self.update_threshold() self._update_masks() + self._need_prune_once = False self.cur_iteration += 1