提交 b352e44f 编写于 作者: W wanghaoshuang

Fix unitest of auto pruner.

上级 d80ed89f
...@@ -41,7 +41,6 @@ class SAController(EvolutionaryController): ...@@ -41,7 +41,6 @@ class SAController(EvolutionaryController):
max_reward=-1, max_reward=-1,
iters=0, iters=0,
best_tokens=None, best_tokens=None,
constrain_func=None,
checkpoints=None, checkpoints=None,
searched=None): searched=None):
"""Initialize. """Initialize.
...@@ -55,7 +54,6 @@ class SAController(EvolutionaryController): ...@@ -55,7 +54,6 @@ class SAController(EvolutionaryController):
max_reward(float): The max reward in the search of sanas, in general, best tokens get max reward. Default: -1. max_reward(float): The max reward in the search of sanas, in general, best tokens get max reward. Default: -1.
iters(int): The iteration of sa controller. Default: 0. iters(int): The iteration of sa controller. Default: 0.
best_tokens(list<int>): The best tokens in the search of sanas, in general, best tokens get max reward. Default: None. best_tokens(list<int>): The best tokens in the search of sanas, in general, best tokens get max reward. Default: None.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
checkpoints(str): if checkpoint is None, donnot save checkpoints, else save scene to checkpoints file. checkpoints(str): if checkpoint is None, donnot save checkpoints, else save scene to checkpoints file.
searched(dict<list, float>): remember tokens which are searched. searched(dict<list, float>): remember tokens which are searched.
""" """
...@@ -75,7 +73,6 @@ class SAController(EvolutionaryController): ...@@ -75,7 +73,6 @@ class SAController(EvolutionaryController):
else: else:
self._init_temperature = 1.0 self._init_temperature = 1.0
self._constrain_func = constrain_func
self._max_reward = max_reward self._max_reward = max_reward
self._best_tokens = best_tokens self._best_tokens = best_tokens
self._iter = iters self._iter = iters
...@@ -86,8 +83,7 @@ class SAController(EvolutionaryController): ...@@ -86,8 +83,7 @@ class SAController(EvolutionaryController):
def __getstate__(self): def __getstate__(self):
d = {} d = {}
for key in self.__dict__: for key in self.__dict__:
if key != "_constrain_func": d[key] = self.__dict__[key]
d[key] = self.__dict__[key]
return d return d
@property @property
......
...@@ -32,13 +32,10 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -32,13 +32,10 @@ _logger = get_logger(__name__, level=logging.INFO)
class AutoPruner(object): class AutoPruner(object):
def __init__(self, def __init__(self,
program,
scope, scope,
place, place,
params=[], params=None,
init_ratios=None, init_ratios=None,
pruned_flops=0.5,
pruned_latency=None,
server_addr=("", 0), server_addr=("", 0),
init_temperature=100, init_temperature=100,
reduce_rate=0.85, reduce_rate=0.85,
...@@ -52,7 +49,6 @@ class AutoPruner(object): ...@@ -52,7 +49,6 @@ class AutoPruner(object):
""" """
Search a group of ratios used to prune program. Search a group of ratios used to prune program.
Args: Args:
program(Program): The program to be pruned.
scope(Scope): The scope to be pruned. scope(Scope): The scope to be pruned.
place(fluid.Place): The device place of parameters. place(fluid.Place): The device place of parameters.
params(list<str>): The names of parameters to be pruned. params(list<str>): The names of parameters to be pruned.
...@@ -61,8 +57,6 @@ class AutoPruner(object): ...@@ -61,8 +57,6 @@ class AutoPruner(object):
The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. The length of `init_ratios` should be equal to length of params when `init_ratios` is a list.
If it is a scalar, all the parameters in `params` will be pruned by uniform ratio. If it is a scalar, all the parameters in `params` will be pruned by uniform ratio.
None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None. None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None.
pruned_flops(float): The percent of FLOPS to be pruned. Default: None.
pruned_latency(float): The percent of latency to be pruned. Default: None.
server_addr(tuple): A tuple of server ip and server port for controller server. server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy. init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy. reduce_rate(float): The decay rate used in simulated annealing search strategy.
...@@ -81,13 +75,11 @@ class AutoPruner(object): ...@@ -81,13 +75,11 @@ class AutoPruner(object):
is_server(bool): Whether current host is controller server. Default: True. is_server(bool): Whether current host is controller server. Default: True.
""" """
self._program = program
self._scope = scope self._scope = scope
self._place = place self._place = place
self._params = params self._params = params
self._init_ratios = init_ratios self._init_ratios = init_ratios
self._pruned_flops = pruned_flops assert (params is not None and init_ratios is not None)
self._pruned_latency = pruned_latency
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._max_try_times = max_try_times self._max_try_times = max_try_times
...@@ -96,24 +88,11 @@ class AutoPruner(object): ...@@ -96,24 +88,11 @@ class AutoPruner(object):
self._range_table = self._get_range_table(min_ratios, max_ratios) self._range_table = self._get_range_table(min_ratios, max_ratios)
self._pruner = Pruner() self._pruner = Pruner()
if self._pruned_flops:
self._base_flops = flops(program)
self._max_flops = self._base_flops * (1 - self._pruned_flops)
_logger.info(
"AutoPruner - base flops: {}; pruned_flops: {}; max_flops: {}".
format(self._base_flops, self._pruned_flops, self._max_flops))
if self._pruned_latency:
self._base_latency = latency(program)
if self._init_ratios is None:
self._init_ratios = self._get_init_ratios(
self, _program, self._params, self._pruned_flops,
self._pruned_latency)
init_tokens = self._ratios2tokens(self._init_ratios) init_tokens = self._ratios2tokens(self._init_ratios)
_logger.info("range table: {}".format(self._range_table)) _logger.info("range table: {}".format(self._range_table))
controller = SAController(self._range_table, self._reduce_rate, controller = SAController(self._range_table, self._reduce_rate,
self._init_temperature, self._max_try_times, self._init_temperature, self._max_try_times,
init_tokens, self._constrain_func) init_tokens)
server_ip, server_port = server_addr server_ip, server_port = server_addr
if server_ip == None or server_ip == "": if server_ip == None or server_ip == "":
...@@ -141,9 +120,6 @@ class AutoPruner(object): ...@@ -141,9 +120,6 @@ class AutoPruner(object):
def _get_host_ip(self): def _get_host_ip(self):
return socket.gethostbyname(socket.gethostname()) return socket.gethostbyname(socket.gethostname())
def _get_init_ratios(self, program, params, pruned_flops, pruned_latency):
pass
def _get_range_table(self, min_ratios, max_ratios): def _get_range_table(self, min_ratios, max_ratios):
assert isinstance(min_ratios, list) or isinstance(min_ratios, float) assert isinstance(min_ratios, list) or isinstance(min_ratios, float)
assert isinstance(max_ratios, list) or isinstance(max_ratios, float) assert isinstance(max_ratios, list) or isinstance(max_ratios, float)
...@@ -155,25 +131,6 @@ class AutoPruner(object): ...@@ -155,25 +131,6 @@ class AutoPruner(object):
max_tokens = self._ratios2tokens(max_ratios) max_tokens = self._ratios2tokens(max_ratios)
return (min_tokens, max_tokens) return (min_tokens, max_tokens)
def _constrain_func(self, tokens):
ratios = self._tokens2ratios(tokens)
pruned_program, _, _ = self._pruner.prune(
self._program,
self._scope,
self._params,
ratios,
place=self._place,
only_graph=True)
current_flops = flops(pruned_program)
result = current_flops < self._max_flops
if not result:
_logger.info("Failed try ratios: {}; flops: {}; max_flops: {}".
format(ratios, current_flops, self._max_flops))
else:
_logger.info("Success try ratios: {}; flops: {}; max_flops: {}".
format(ratios, current_flops, self._max_flops))
return result
def prune(self, program, eval_program=None): def prune(self, program, eval_program=None):
""" """
Prune program with latest tokens generated by controller. Prune program with latest tokens generated by controller.
......
...@@ -51,19 +51,15 @@ class TestPrune(unittest.TestCase): ...@@ -51,19 +51,15 @@ class TestPrune(unittest.TestCase):
scope = fluid.Scope() scope = fluid.Scope()
exe.run(startup_program, scope=scope) exe.run(startup_program, scope=scope)
pruned_flops = 0.5
pruner = AutoPruner( pruner = AutoPruner(
main_program,
scope, scope,
place, place,
params=["conv4_weights"], params=["conv4_weights"],
init_ratios=[0.5], init_ratios=[0.5],
pruned_flops=0.5,
pruned_latency=None,
server_addr=("", 0), server_addr=("", 0),
init_temperature=100, init_temperature=100,
reduce_rate=0.85, reduce_rate=0.85,
max_try_number=300, max_try_times=300,
max_client_num=10, max_client_num=10,
search_steps=2, search_steps=2,
max_ratios=[0.9], max_ratios=[0.9],
...@@ -71,12 +67,12 @@ class TestPrune(unittest.TestCase): ...@@ -71,12 +67,12 @@ class TestPrune(unittest.TestCase):
key="auto_pruner") key="auto_pruner")
base_flops = flops(main_program) base_flops = flops(main_program)
program = pruner.prune(main_program) program, _ = pruner.prune(main_program)
self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops)) self.assertTrue(flops(program) <= base_flops)
pruner.reward(1) pruner.reward(1)
program = pruner.prune(main_program) program, _ = pruner.prune(main_program)
self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops)) self.assertTrue(flops(program) <= base_flops)
pruner.reward(1) pruner.reward(1)
......
...@@ -15,7 +15,7 @@ import sys ...@@ -15,7 +15,7 @@ import sys
sys.path.append("../") sys.path.append("../")
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddleslim.prune.walk_pruner import Pruner from paddleslim.prune import Pruner
from layers import conv_bn_layer from layers import conv_bn_layer
...@@ -72,7 +72,8 @@ class TestPrune(unittest.TestCase): ...@@ -72,7 +72,8 @@ class TestPrune(unittest.TestCase):
for param in main_program.global_block().all_parameters(): for param in main_program.global_block().all_parameters():
if "weights" in param.name: if "weights" in param.name:
print("param: {}; param shape: {}".format(param.name, param.shape)) print("param: {}; param shape: {}".format(param.name,
param.shape))
self.assertTrue(param.shape == shapes[param.name]) self.assertTrue(param.shape == shapes[param.name])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册