未验证 提交 f4c1f48c 编写于 作者: M minghaoBD 提交者: GitHub

[cherry-pick][unstructured_prune]Resume training (#958) (#960)

上级 853e1d03
...@@ -304,6 +304,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -304,6 +304,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
self.cur_iteration = configs.get('resume_iteration') self.cur_iteration = configs.get('resume_iteration')
assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin." assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin."
self._need_prune_once = False
self._prepare_training_hyper_parameters() self._prepare_training_hyper_parameters()
def _prepare_training_hyper_parameters(self): def _prepare_training_hyper_parameters(self):
...@@ -330,6 +331,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -330,6 +331,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# pop out used ratios to resume training # pop out used ratios to resume training
for i in range(self.cur_iteration): for i in range(self.cur_iteration):
self._need_prune_once = True
if len(self. if len(self.
ratios_stack) > 0 and i % self.ratio_increment_period == 0: ratios_stack) > 0 and i % self.ratio_increment_period == 0:
self.ratio = self.ratios_stack.pop() self.ratio = self.ratios_stack.pop()
...@@ -344,7 +346,8 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -344,7 +346,8 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# Update the threshold and masks only when a new ratio has been set. # Update the threshold and masks only when a new ratio has been set.
# This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period. # This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period.
if ori_ratio != self.ratio: if ori_ratio != self.ratio or self._need_prune_once:
self.update_threshold() self.update_threshold()
self._update_masks() self._update_masks()
self._need_prune_once = False
self.cur_iteration += 1 self.cur_iteration += 1
...@@ -350,6 +350,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -350,6 +350,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
self.cur_iteration = configs.get('resume_iteration') self.cur_iteration = configs.get('resume_iteration')
assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin." assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin."
self._need_prune_once = False
self._prepare_training_hyper_parameters() self._prepare_training_hyper_parameters()
def _prepare_training_hyper_parameters(self): def _prepare_training_hyper_parameters(self):
...@@ -376,6 +377,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -376,6 +377,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# pop out used ratios to resume training # pop out used ratios to resume training
for i in range(self.cur_iteration): for i in range(self.cur_iteration):
self._need_prune_once = True
if len(self. if len(self.
ratios_stack) > 0 and i % self.ratio_increment_period == 0: ratios_stack) > 0 and i % self.ratio_increment_period == 0:
self.ratio = self.ratios_stack.pop() self.ratio = self.ratios_stack.pop()
...@@ -393,7 +395,8 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -393,7 +395,8 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# Update the threshold and masks only when a new ratio has been set. # Update the threshold and masks only when a new ratio has been set.
# This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period. # This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period.
if ori_ratio != self.ratio: if ori_ratio != self.ratio or self._need_prune_once:
self.update_threshold() self.update_threshold()
self._update_masks() self._update_masks()
self._need_prune_once = False
self.cur_iteration += 1 self.cur_iteration += 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册