提交 c906852d 编写于 作者: W wanghaoshuang

Fix auto pruner.

上级 ab25d262
......@@ -106,7 +106,7 @@ class AutoPruner(object):
self, _program, self._params, self._pruned_flops,
self._pruned_latency)
init_tokens = self._ratios2tokens(self._init_ratios)
_logger.info("range table: {}".format(self._range_table))
controller = SAController(self._range_table, self._reduce_rate,
self._init_temperature, self._max_try_number,
init_tokens, self._constrain_func)
......@@ -143,10 +143,10 @@ class AutoPruner(object):
def _get_range_table(self, min_ratios, max_ratios):
assert isinstance(min_ratios, list) or isinstance(min_ratios, float)
assert isinstance(max_ratios, list) or isinstance(max_ratios, float)
min_ratios = min_ratios if isinstance(min_ratios,
list) else [min_ratios]
max_ratios = max_ratios if isinstance(max_ratios,
list) else [max_ratios]
min_ratios = min_ratios if isinstance(
min_ratios, list) else [min_ratios] * len(self._params)
max_ratios = max_ratios if isinstance(
max_ratios, list) else [max_ratios] * len(self._params)
min_tokens = self._ratios2tokens(min_ratios)
max_tokens = self._ratios2tokens(max_ratios)
return (min_tokens, max_tokens)
......@@ -163,7 +163,7 @@ class AutoPruner(object):
return flops(pruned_program) < self._base_flops * (
1 - self._pruned_flops)
def prune(self, program):
def prune(self, program, eval_program=None):
"""
Prune program with latest tokens generated by controller.
Args:
......@@ -178,10 +178,21 @@ class AutoPruner(object):
self._params,
self._current_ratios,
place=self._place,
only_graph=False,
param_backup=self._param_backup)
pruned_val_program = None
if eval_program is not None:
pruned_val_program = self._pruner.prune(
program,
self._scope,
self._params,
self._current_ratios,
place=self._place,
only_graph=True)
_logger.info("AutoPruner - pruned ratios: {}".format(
self._current_ratios))
return pruned_program
return pruned_program, pruned_val_program
def reward(self, score):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册