提交 189ddbab 编写于 作者: W wanghaoshuang

Merge branch 'develop' of http://gitlab.baidu.com/PaddlePaddle/PaddleSlim into fix_prune

...@@ -117,8 +117,6 @@ class ControllerServer(object): ...@@ -117,8 +117,6 @@ class ControllerServer(object):
reward = messages[2] reward = messages[2]
tokens = [int(token) for token in tokens.split(",")] tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward)) self._controller.update(tokens, float(reward))
#tokens = self._controller.next_tokens()
#tokens = ",".join([str(token) for token in tokens])
response = "ok" response = "ok"
conn.send(response.encode()) conn.send(response.encode())
_logger.debug("send message to {}: [{}]".format(addr, _logger.debug("send message to {}: [{}]".format(addr,
......
...@@ -96,8 +96,10 @@ class AutoPruner(object): ...@@ -96,8 +96,10 @@ class AutoPruner(object):
self._pruner = Pruner() self._pruner = Pruner()
if self._pruned_flops: if self._pruned_flops:
self._base_flops = flops(program) self._base_flops = flops(program)
_logger.info("AutoPruner - base flops: {};".format( self._max_flops = self._base_flops * (1 - self._pruned_flops)
self._base_flops)) _logger.info(
"AutoPruner - base flops: {}; pruned_flops: {}; max_flops: {}".
format(self._base_flops, self._pruned_flops, self._max_flops))
if self._pruned_latency: if self._pruned_latency:
self._base_latency = latency(program) self._base_latency = latency(program)
...@@ -160,8 +162,15 @@ class AutoPruner(object): ...@@ -160,8 +162,15 @@ class AutoPruner(object):
ratios, ratios,
place=self._place, place=self._place,
only_graph=True) only_graph=True)
return flops(pruned_program) < self._base_flops * ( current_flops = flops(pruned_program)
1 - self._pruned_flops) result = current_flops < self._max_flops
if not result:
_logger.info("Failed try ratios: {}; flops: {}; max_flops: {}".
format(ratios, current_flops, self._max_flops))
else:
_logger.info("Success try ratios: {}; flops: {}; max_flops: {}".
format(ratios, current_flops, self._max_flops))
return result
def prune(self, program, eval_program=None): def prune(self, program, eval_program=None):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册