未验证 提交 f4c1f48c 编写于 作者: M minghaoBD 提交者: GitHub

[cherry-pick][unstructured_prune]Resume training (#958) (#960)

上级 853e1d03
......@@ -304,6 +304,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
self.cur_iteration = configs.get('resume_iteration')
assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin."
self._need_prune_once = False
self._prepare_training_hyper_parameters()
def _prepare_training_hyper_parameters(self):
......@@ -330,6 +331,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# pop out used ratios to resume training
for i in range(self.cur_iteration):
self._need_prune_once = True
if len(self.
ratios_stack) > 0 and i % self.ratio_increment_period == 0:
self.ratio = self.ratios_stack.pop()
......@@ -344,7 +346,8 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# Update the threshold and masks only when a new ratio has been set.
# This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period.
if ori_ratio != self.ratio:
if ori_ratio != self.ratio or self._need_prune_once:
self.update_threshold()
self._update_masks()
self._need_prune_once = False
self.cur_iteration += 1
......@@ -350,6 +350,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
self.cur_iteration = configs.get('resume_iteration')
assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin."
self._need_prune_once = False
self._prepare_training_hyper_parameters()
def _prepare_training_hyper_parameters(self):
......@@ -376,6 +377,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# pop out used ratios to resume training
for i in range(self.cur_iteration):
self._need_prune_once = True
if len(self.
ratios_stack) > 0 and i % self.ratio_increment_period == 0:
self.ratio = self.ratios_stack.pop()
......@@ -393,7 +395,8 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# Update the threshold and masks only when a new ratio has been set.
# This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period.
if ori_ratio != self.ratio:
if ori_ratio != self.ratio or self._need_prune_once:
self.update_threshold()
self._update_masks()
self._need_prune_once = False
self.cur_iteration += 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册