未验证 提交 06175aea 编写于 作者: M minghaoBD 提交者: GitHub

Resume training (#958)

上级 1c6c326f
......@@ -304,6 +304,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
self.cur_iteration = configs.get('resume_iteration')
assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin."
self._need_prune_once = False
self._prepare_training_hyper_parameters()
def _prepare_training_hyper_parameters(self):
......@@ -330,6 +331,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# pop out used ratios to resume training
for i in range(self.cur_iteration):
self._need_prune_once = True
if len(self.
ratios_stack) > 0 and i % self.ratio_increment_period == 0:
self.ratio = self.ratios_stack.pop()
......@@ -344,7 +346,8 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# Update the threshold and masks only when a new ratio has been set.
# This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period.
if ori_ratio != self.ratio:
if ori_ratio != self.ratio or self._need_prune_once:
self.update_threshold()
self._update_masks()
self._need_prune_once = False
self.cur_iteration += 1
......@@ -349,6 +349,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
self.cur_iteration = configs.get('resume_iteration')
assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin."
self._need_prune_once = False
self._prepare_training_hyper_parameters()
def _prepare_training_hyper_parameters(self):
......@@ -375,6 +376,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# pop out used ratios to resume training
for i in range(self.cur_iteration):
self._need_prune_once = True
if len(self.
ratios_stack) > 0 and i % self.ratio_increment_period == 0:
self.ratio = self.ratios_stack.pop()
......@@ -392,7 +394,8 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# Update the threshold and masks only when a new ratio has been set.
# This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period.
if ori_ratio != self.ratio:
if ori_ratio != self.ratio or self._need_prune_once:
self.update_threshold()
self._update_masks()
self._need_prune_once = False
self.cur_iteration += 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册