未验证 提交 06175aea 编写于 作者: M minghaoBD 提交者: GitHub

Resume training (#958)

上级 1c6c326f
...@@ -304,6 +304,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -304,6 +304,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
self.cur_iteration = configs.get('resume_iteration') self.cur_iteration = configs.get('resume_iteration')
assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin." assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin."
self._need_prune_once = False
self._prepare_training_hyper_parameters() self._prepare_training_hyper_parameters()
def _prepare_training_hyper_parameters(self): def _prepare_training_hyper_parameters(self):
...@@ -330,6 +331,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -330,6 +331,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# pop out used ratios to resume training # pop out used ratios to resume training
for i in range(self.cur_iteration): for i in range(self.cur_iteration):
self._need_prune_once = True
if len(self. if len(self.
ratios_stack) > 0 and i % self.ratio_increment_period == 0: ratios_stack) > 0 and i % self.ratio_increment_period == 0:
self.ratio = self.ratios_stack.pop() self.ratio = self.ratios_stack.pop()
...@@ -344,7 +346,8 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -344,7 +346,8 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# Update the threshold and masks only when a new ratio has been set. # Update the threshold and masks only when a new ratio has been set.
# This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period. # This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period.
if ori_ratio != self.ratio: if ori_ratio != self.ratio or self._need_prune_once:
self.update_threshold() self.update_threshold()
self._update_masks() self._update_masks()
self._need_prune_once = False
self.cur_iteration += 1 self.cur_iteration += 1
...@@ -349,6 +349,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -349,6 +349,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
self.cur_iteration = configs.get('resume_iteration') self.cur_iteration = configs.get('resume_iteration')
assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin." assert self.pruning_iterations / self.pruning_steps > 10, "To guarantee the performance of GMP pruner, pruning iterations must be larger than pruning steps by a margin."
self._need_prune_once = False
self._prepare_training_hyper_parameters() self._prepare_training_hyper_parameters()
def _prepare_training_hyper_parameters(self): def _prepare_training_hyper_parameters(self):
...@@ -375,6 +376,7 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -375,6 +376,7 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# pop out used ratios to resume training # pop out used ratios to resume training
for i in range(self.cur_iteration): for i in range(self.cur_iteration):
self._need_prune_once = True
if len(self. if len(self.
ratios_stack) > 0 and i % self.ratio_increment_period == 0: ratios_stack) > 0 and i % self.ratio_increment_period == 0:
self.ratio = self.ratios_stack.pop() self.ratio = self.ratios_stack.pop()
...@@ -392,7 +394,8 @@ class GMPUnstructuredPruner(UnstructuredPruner): ...@@ -392,7 +394,8 @@ class GMPUnstructuredPruner(UnstructuredPruner):
# Update the threshold and masks only when a new ratio has been set. # Update the threshold and masks only when a new ratio has been set.
# This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period. # This condition check would save training time dramatically since we only update the threshold by the triger of self.ratio_increment_period.
if ori_ratio != self.ratio: if ori_ratio != self.ratio or self._need_prune_once:
self.update_threshold() self.update_threshold()
self._update_masks() self._update_masks()
self._need_prune_once = False
self.cur_iteration += 1 self.cur_iteration += 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册