diff --git a/paddleslim/common/controller_server.py b/paddleslim/common/controller_server.py index ac24df86030aae8cb286452b6bd6eeb7b5c80741..e4705a887727bf444b3ba285165d27df59a1ed57 100644 --- a/paddleslim/common/controller_server.py +++ b/paddleslim/common/controller_server.py @@ -117,8 +117,6 @@ class ControllerServer(object): reward = messages[2] tokens = [int(token) for token in tokens.split(",")] self._controller.update(tokens, float(reward)) - #tokens = self._controller.next_tokens() - #tokens = ",".join([str(token) for token in tokens]) response = "ok" conn.send(response.encode()) _logger.debug("send message to {}: [{}]".format(addr, diff --git a/paddleslim/prune/auto_pruner.py b/paddleslim/prune/auto_pruner.py index b144251a0a9a294094f7101f30958486abcf0543..fba8c11170f3fbf2eddbe15942dc642ad448658b 100644 --- a/paddleslim/prune/auto_pruner.py +++ b/paddleslim/prune/auto_pruner.py @@ -96,8 +96,10 @@ class AutoPruner(object): self._pruner = Pruner() if self._pruned_flops: self._base_flops = flops(program) - _logger.info("AutoPruner - base flops: {};".format( - self._base_flops)) + self._max_flops = self._base_flops * (1 - self._pruned_flops) + _logger.info( + "AutoPruner - base flops: {}; pruned_flops: {}; max_flops: {}". + format(self._base_flops, self._pruned_flops, self._max_flops)) if self._pruned_latency: self._base_latency = latency(program) @@ -160,8 +162,15 @@ class AutoPruner(object): ratios, place=self._place, only_graph=True) - return flops(pruned_program) < self._base_flops * ( - 1 - self._pruned_flops) + current_flops = flops(pruned_program) + result = current_flops < self._max_flops + if not result: + _logger.info("Failed try ratios: {}; flops: {}; max_flops: {}". + format(ratios, current_flops, self._max_flops)) + else: + _logger.info("Success try ratios: {}; flops: {}; max_flops: {}". + format(ratios, current_flops, self._max_flops)) + return result def prune(self, program, eval_program=None): """