From 06175aeaf62b460d46aa619bde2c12e8a40d468e Mon Sep 17 00:00:00 2001 From: minghaoBD <79566150+minghaoBD@users.noreply.github.com> Date: Wed, 22 Dec 2021 21:31:50 +0800 Subject: [PATCH] Resume training (#958) --- paddleslim/dygraph/prune/unstructured_pruner.py | 5 ++++- paddleslim/prune/unstructured_pruner.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/paddleslim/dygraph/prune/unstructured_pruner.py b/paddleslim/dygraph/prune/unstructured_pruner.py index 4e7ef182..deaa7e3a 100644 --- a/paddleslim/dygraph/prune/unstructured_pruner.py +++ b/paddleslim/dygraph/prune/unstructured_pruner.py @@ -304,6 +304,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): self.cur_iteration = configs.get('resume_iteration') assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin." + self._need_prune_once = False self._prepare_training_hyper_parameters() def _prepare_training_hyper_parameters(self): @@ -330,6 +331,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): # pop out used ratios to resume training for i in range(self.cur_iteration): + self._need_prune_once = True if len(self. ratios_stack) > 0 and i % self.ratio_increment_period == 0: self.ratio = self.ratios_stack.pop() @@ -344,7 +346,8 @@ class GMPUnstructuredPruner(UnstructuredPruner): # Update the threshold and masks only when a new ratio has been set. # This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period. - if ori_ratio != self.ratio: + if ori_ratio != self.ratio or self._need_prune_once: self.update_threshold() self._update_masks() + self._need_prune_once = False self.cur_iteration += 1 diff --git a/paddleslim/prune/unstructured_pruner.py b/paddleslim/prune/unstructured_pruner.py index 2f3c840f..ca2fb2fd 100644 --- a/paddleslim/prune/unstructured_pruner.py +++ b/paddleslim/prune/unstructured_pruner.py @@ -349,6 +349,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): self.cur_iteration = configs.get('resume_iteration') assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin." + self._need_prune_once = False self._prepare_training_hyper_parameters() def _prepare_training_hyper_parameters(self): @@ -375,6 +376,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): # pop out used ratios to resume training for i in range(self.cur_iteration): + self._need_prune_once = True if len(self. ratios_stack) > 0 and i % self.ratio_increment_period == 0: self.ratio = self.ratios_stack.pop() @@ -392,7 +394,8 @@ class GMPUnstructuredPruner(UnstructuredPruner): # Update the threshold and masks only when a new ratio has been set. # This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period. - if ori_ratio != self.ratio: + if ori_ratio != self.ratio or self._need_prune_once: self.update_threshold() self._update_masks() + self._need_prune_once = False self.cur_iteration += 1 -- GitLab