From b352e44fbeee4624ab7e7aff75aefc8392bf4811 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Fri, 7 Feb 2020 08:46:38 +0800 Subject: [PATCH] Fix unitest of auto pruner. --- paddleslim/common/sa_controller.py | 6 +--- paddleslim/prune/auto_pruner.py | 49 ++---------------------------- tests/test_auto_prune.py | 14 +++------ tests/test_prune.py | 5 +-- 4 files changed, 12 insertions(+), 62 deletions(-) diff --git a/paddleslim/common/sa_controller.py b/paddleslim/common/sa_controller.py index b1034762..733ce321 100644 --- a/paddleslim/common/sa_controller.py +++ b/paddleslim/common/sa_controller.py @@ -41,7 +41,6 @@ class SAController(EvolutionaryController): max_reward=-1, iters=0, best_tokens=None, - constrain_func=None, checkpoints=None, searched=None): """Initialize. @@ -55,7 +54,6 @@ class SAController(EvolutionaryController): max_reward(float): The max reward in the search of sanas, in general, best tokens get max reward. Default: -1. iters(int): The iteration of sa controller. Default: 0. best_tokens(list): The best tokens in the search of sanas, in general, best tokens get max reward. Default: None. - constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. checkpoints(str): if checkpoint is None, donnot save checkpoints, else save scene to checkpoints file. searched(dict): remember tokens which are searched. """ @@ -75,7 +73,6 @@ class SAController(EvolutionaryController): else: self._init_temperature = 1.0 - self._constrain_func = constrain_func self._max_reward = max_reward self._best_tokens = best_tokens self._iter = iters @@ -86,8 +83,7 @@ class SAController(EvolutionaryController): def __getstate__(self): d = {} for key in self.__dict__: - if key != "_constrain_func": - d[key] = self.__dict__[key] + d[key] = self.__dict__[key] return d @property diff --git a/paddleslim/prune/auto_pruner.py b/paddleslim/prune/auto_pruner.py index 672ce78d..178d131a 100644 --- a/paddleslim/prune/auto_pruner.py +++ b/paddleslim/prune/auto_pruner.py @@ -32,13 +32,10 @@ _logger = get_logger(__name__, level=logging.INFO) class AutoPruner(object): def __init__(self, - program, scope, place, - params=[], + params=None, init_ratios=None, - pruned_flops=0.5, - pruned_latency=None, server_addr=("", 0), init_temperature=100, reduce_rate=0.85, @@ -52,7 +49,6 @@ class AutoPruner(object): """ Search a group of ratios used to prune program. Args: - program(Program): The program to be pruned. scope(Scope): The scope to be pruned. place(fluid.Place): The device place of parameters. params(list): The names of parameters to be pruned. @@ -61,8 +57,6 @@ class AutoPruner(object): The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. If it is a scalar, all the parameters in `params` will be pruned by uniform ratio. None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None. - pruned_flops(float): The percent of FLOPS to be pruned. Default: None. - pruned_latency(float): The percent of latency to be pruned. Default: None. server_addr(tuple): A tuple of server ip and server port for controller server. init_temperature(float): The init temperature used in simulated annealing search strategy. reduce_rate(float): The decay rate used in simulated annealing search strategy. @@ -81,13 +75,11 @@ class AutoPruner(object): is_server(bool): Whether current host is controller server. Default: True. """ - self._program = program self._scope = scope self._place = place self._params = params self._init_ratios = init_ratios - self._pruned_flops = pruned_flops - self._pruned_latency = pruned_latency + assert (params is not None and init_ratios is not None) self._reduce_rate = reduce_rate self._init_temperature = init_temperature self._max_try_times = max_try_times @@ -96,24 +88,11 @@ class AutoPruner(object): self._range_table = self._get_range_table(min_ratios, max_ratios) self._pruner = Pruner() - if self._pruned_flops: - self._base_flops = flops(program) - self._max_flops = self._base_flops * (1 - self._pruned_flops) - _logger.info( - "AutoPruner - base flops: {}; pruned_flops: {}; max_flops: {}". - format(self._base_flops, self._pruned_flops, self._max_flops)) - if self._pruned_latency: - self._base_latency = latency(program) - - if self._init_ratios is None: - self._init_ratios = self._get_init_ratios( - self, _program, self._params, self._pruned_flops, - self._pruned_latency) init_tokens = self._ratios2tokens(self._init_ratios) _logger.info("range table: {}".format(self._range_table)) controller = SAController(self._range_table, self._reduce_rate, self._init_temperature, self._max_try_times, - init_tokens, self._constrain_func) + init_tokens) server_ip, server_port = server_addr if server_ip == None or server_ip == "": @@ -141,9 +120,6 @@ class AutoPruner(object): def _get_host_ip(self): return socket.gethostbyname(socket.gethostname()) - def _get_init_ratios(self, program, params, pruned_flops, pruned_latency): - pass - def _get_range_table(self, min_ratios, max_ratios): assert isinstance(min_ratios, list) or isinstance(min_ratios, float) assert isinstance(max_ratios, list) or isinstance(max_ratios, float) @@ -155,25 +131,6 @@ class AutoPruner(object): max_tokens = self._ratios2tokens(max_ratios) return (min_tokens, max_tokens) - def _constrain_func(self, tokens): - ratios = self._tokens2ratios(tokens) - pruned_program, _, _ = self._pruner.prune( - self._program, - self._scope, - self._params, - ratios, - place=self._place, - only_graph=True) - current_flops = flops(pruned_program) - result = current_flops < self._max_flops - if not result: - _logger.info("Failed try ratios: {}; flops: {}; max_flops: {}". - format(ratios, current_flops, self._max_flops)) - else: - _logger.info("Success try ratios: {}; flops: {}; max_flops: {}". - format(ratios, current_flops, self._max_flops)) - return result - def prune(self, program, eval_program=None): """ Prune program with latest tokens generated by controller. diff --git a/tests/test_auto_prune.py b/tests/test_auto_prune.py index c9cdc72c..471552a3 100644 --- a/tests/test_auto_prune.py +++ b/tests/test_auto_prune.py @@ -51,19 +51,15 @@ class TestPrune(unittest.TestCase): scope = fluid.Scope() exe.run(startup_program, scope=scope) - pruned_flops = 0.5 pruner = AutoPruner( - main_program, scope, place, params=["conv4_weights"], init_ratios=[0.5], - pruned_flops=0.5, - pruned_latency=None, server_addr=("", 0), init_temperature=100, reduce_rate=0.85, - max_try_number=300, + max_try_times=300, max_client_num=10, search_steps=2, max_ratios=[0.9], @@ -71,12 +67,12 @@ class TestPrune(unittest.TestCase): key="auto_pruner") base_flops = flops(main_program) - program = pruner.prune(main_program) - self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops)) + program, _ = pruner.prune(main_program) + self.assertTrue(flops(program) <= base_flops) pruner.reward(1) - program = pruner.prune(main_program) - self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops)) + program, _ = pruner.prune(main_program) + self.assertTrue(flops(program) <= base_flops) pruner.reward(1) diff --git a/tests/test_prune.py b/tests/test_prune.py index 60fe603c..3d8bc50c 100644 --- a/tests/test_prune.py +++ b/tests/test_prune.py @@ -15,7 +15,7 @@ import sys sys.path.append("../") import unittest import paddle.fluid as fluid -from paddleslim.prune.walk_pruner import Pruner +from paddleslim.prune import Pruner from layers import conv_bn_layer @@ -72,7 +72,8 @@ class TestPrune(unittest.TestCase): for param in main_program.global_block().all_parameters(): if "weights" in param.name: - print("param: {}; param shape: {}".format(param.name, param.shape)) + print("param: {}; param shape: {}".format(param.name, + param.shape)) self.assertTrue(param.shape == shapes[param.name]) -- GitLab