From f4c1f48c28ca9660efe149098e212be165ee7adb Mon Sep 17 00:00:00 2001 From: minghaoBD <79566150+minghaoBD@users.noreply.github.com> Date: Fri, 31 Dec 2021 10:42:10 +0800 Subject: [PATCH] [cherry-pick][unstructured_prune]Resume training (#958) (#960) --- paddleslim/dygraph/prune/unstructured_pruner.py | 5 ++++- paddleslim/prune/unstructured_pruner.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/paddleslim/dygraph/prune/unstructured_pruner.py b/paddleslim/dygraph/prune/unstructured_pruner.py index 4e7ef182..deaa7e3a 100644 --- a/paddleslim/dygraph/prune/unstructured_pruner.py +++ b/paddleslim/dygraph/prune/unstructured_pruner.py @@ -304,6 +304,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): self.cur_iteration = configs.get('resume_iteration') assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin." + self._need_prune_once = False self._prepare_training_hyper_parameters() def _prepare_training_hyper_parameters(self): @@ -330,6 +331,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): # pop out used ratios to resume training for i in range(self.cur_iteration): + self._need_prune_once = True if len(self. ratios_stack) > 0 and i % self.ratio_increment_period == 0: self.ratio = self.ratios_stack.pop() @@ -344,7 +346,8 @@ class GMPUnstructuredPruner(UnstructuredPruner): # Update the threshold and masks only when a new ratio has been set. # This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period. - if ori_ratio != self.ratio: + if ori_ratio != self.ratio or self._need_prune_once: self.update_threshold() self._update_masks() + self._need_prune_once = False self.cur_iteration += 1 diff --git a/paddleslim/prune/unstructured_pruner.py b/paddleslim/prune/unstructured_pruner.py index cf9064ae..bf19fcfc 100644 --- a/paddleslim/prune/unstructured_pruner.py +++ b/paddleslim/prune/unstructured_pruner.py @@ -350,6 +350,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): self.cur_iteration = configs.get('resume_iteration') assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin." + self._need_prune_once = False self._prepare_training_hyper_parameters() def _prepare_training_hyper_parameters(self): @@ -376,6 +377,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): # pop out used ratios to resume training for i in range(self.cur_iteration): + self._need_prune_once = True if len(self. ratios_stack) > 0 and i % self.ratio_increment_period == 0: self.ratio = self.ratios_stack.pop() @@ -393,7 +395,8 @@ class GMPUnstructuredPruner(UnstructuredPruner): # Update the threshold and masks only when a new ratio has been set. # This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period. - if ori_ratio != self.ratio: + if ori_ratio != self.ratio or self._need_prune_once: self.update_threshold() self._update_masks() + self._need_prune_once = False self.cur_iteration += 1 -- GitLab