未验证 提交 e6ae7640 编写于 作者: Z zhouzj 提交者: GitHub

Automatic set gmp_config (#1103)

* Automatic set gmp_config

* Automatic set gmp_config

* add a note.

* add a note.
上级 b05085ed
......@@ -248,6 +248,21 @@ class AutoCompression:
feed_target_names, fetch_targets)
config_dict = dict(config._asdict())
if config_dict["prune_strategy"] == "gmp" and config_dict[
'gmp_config'] is None:
_logger.info(
"Calculating the iterations per epoch……(It will take some time)")
# NOTE:XXX: This way of calculating the iters needs to be improved.
iters_per_epoch = len(list(self.train_dataloader()))
total_iters = self.train_config.epochs * iters_per_epoch
config_dict['gmp_config'] = {
'stable_iterations': 0,
'pruning_iterations': 0.45 * total_iters,
'tunning_iterations': 0.45 * total_iters,
'resume_iteration': -1,
'pruning_steps': 100,
'initial_ratio': 0.15,
}
### add prune program
self._pruner = None
if 'prune' in strategy:
......@@ -280,13 +295,14 @@ class AutoCompression:
test_program_info)
if self.train_config.sparse_model:
from ..prune.unstructured_pruner import UnstructuredPruner
# NOTE: The initialization parameter of this pruner doesn't work, it is only used to call the 'set_static_masks' function
self._pruner = UnstructuredPruner(
train_program_info.program,
mode='ratio',
ratio=0.75,
prune_params_type='conv1x1_only',
place=self._places)
self._pruner.set_static_masks()
self._pruner.set_static_masks() # Fixed model sparsity
self._exe.run(train_program_info.startup_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册