提交 b352e44f 编写于 作者: W wanghaoshuang

Fix unitest of auto pruner.

上级 d80ed89f
......@@ -41,7 +41,6 @@ class SAController(EvolutionaryController):
max_reward=-1,
iters=0,
best_tokens=None,
constrain_func=None,
checkpoints=None,
searched=None):
"""Initialize.
......@@ -55,7 +54,6 @@ class SAController(EvolutionaryController):
max_reward(float): The max reward in the search of sanas, in general, best tokens get max reward. Default: -1.
iters(int): The iteration of sa controller. Default: 0.
best_tokens(list<int>): The best tokens in the search of sanas, in general, best tokens get max reward. Default: None.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
checkpoints(str): if checkpoint is None, donnot save checkpoints, else save scene to checkpoints file.
searched(dict<list, float>): remember tokens which are searched.
"""
......@@ -75,7 +73,6 @@ class SAController(EvolutionaryController):
else:
self._init_temperature = 1.0
self._constrain_func = constrain_func
self._max_reward = max_reward
self._best_tokens = best_tokens
self._iter = iters
......@@ -86,8 +83,7 @@ class SAController(EvolutionaryController):
def __getstate__(self):
d = {}
for key in self.__dict__:
if key != "_constrain_func":
d[key] = self.__dict__[key]
d[key] = self.__dict__[key]
return d
@property
......
......@@ -32,13 +32,10 @@ _logger = get_logger(__name__, level=logging.INFO)
class AutoPruner(object):
def __init__(self,
program,
scope,
place,
params=[],
params=None,
init_ratios=None,
pruned_flops=0.5,
pruned_latency=None,
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
......@@ -52,7 +49,6 @@ class AutoPruner(object):
"""
Search a group of ratios used to prune program.
Args:
program(Program): The program to be pruned.
scope(Scope): The scope to be pruned.
place(fluid.Place): The device place of parameters.
params(list<str>): The names of parameters to be pruned.
......@@ -61,8 +57,6 @@ class AutoPruner(object):
The length of `init_ratios` should be equal to length of params when `init_ratios` is a list.
If it is a scalar, all the parameters in `params` will be pruned by uniform ratio.
None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None.
pruned_flops(float): The percent of FLOPS to be pruned. Default: None.
pruned_latency(float): The percent of latency to be pruned. Default: None.
server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy.
......@@ -81,13 +75,11 @@ class AutoPruner(object):
is_server(bool): Whether current host is controller server. Default: True.
"""
self._program = program
self._scope = scope
self._place = place
self._params = params
self._init_ratios = init_ratios
self._pruned_flops = pruned_flops
self._pruned_latency = pruned_latency
assert (params is not None and init_ratios is not None)
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_try_times = max_try_times
......@@ -96,24 +88,11 @@ class AutoPruner(object):
self._range_table = self._get_range_table(min_ratios, max_ratios)
self._pruner = Pruner()
if self._pruned_flops:
self._base_flops = flops(program)
self._max_flops = self._base_flops * (1 - self._pruned_flops)
_logger.info(
"AutoPruner - base flops: {}; pruned_flops: {}; max_flops: {}".
format(self._base_flops, self._pruned_flops, self._max_flops))
if self._pruned_latency:
self._base_latency = latency(program)
if self._init_ratios is None:
self._init_ratios = self._get_init_ratios(
self, _program, self._params, self._pruned_flops,
self._pruned_latency)
init_tokens = self._ratios2tokens(self._init_ratios)
_logger.info("range table: {}".format(self._range_table))
controller = SAController(self._range_table, self._reduce_rate,
self._init_temperature, self._max_try_times,
init_tokens, self._constrain_func)
init_tokens)
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
......@@ -141,9 +120,6 @@ class AutoPruner(object):
def _get_host_ip(self):
return socket.gethostbyname(socket.gethostname())
def _get_init_ratios(self, program, params, pruned_flops, pruned_latency):
pass
def _get_range_table(self, min_ratios, max_ratios):
assert isinstance(min_ratios, list) or isinstance(min_ratios, float)
assert isinstance(max_ratios, list) or isinstance(max_ratios, float)
......@@ -155,25 +131,6 @@ class AutoPruner(object):
max_tokens = self._ratios2tokens(max_ratios)
return (min_tokens, max_tokens)
def _constrain_func(self, tokens):
ratios = self._tokens2ratios(tokens)
pruned_program, _, _ = self._pruner.prune(
self._program,
self._scope,
self._params,
ratios,
place=self._place,
only_graph=True)
current_flops = flops(pruned_program)
result = current_flops < self._max_flops
if not result:
_logger.info("Failed try ratios: {}; flops: {}; max_flops: {}".
format(ratios, current_flops, self._max_flops))
else:
_logger.info("Success try ratios: {}; flops: {}; max_flops: {}".
format(ratios, current_flops, self._max_flops))
return result
def prune(self, program, eval_program=None):
"""
Prune program with latest tokens generated by controller.
......
......@@ -51,19 +51,15 @@ class TestPrune(unittest.TestCase):
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
pruned_flops = 0.5
pruner = AutoPruner(
main_program,
scope,
place,
params=["conv4_weights"],
init_ratios=[0.5],
pruned_flops=0.5,
pruned_latency=None,
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_number=300,
max_try_times=300,
max_client_num=10,
search_steps=2,
max_ratios=[0.9],
......@@ -71,12 +67,12 @@ class TestPrune(unittest.TestCase):
key="auto_pruner")
base_flops = flops(main_program)
program = pruner.prune(main_program)
self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops))
program, _ = pruner.prune(main_program)
self.assertTrue(flops(program) <= base_flops)
pruner.reward(1)
program = pruner.prune(main_program)
self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops))
program, _ = pruner.prune(main_program)
self.assertTrue(flops(program) <= base_flops)
pruner.reward(1)
......
......@@ -15,7 +15,7 @@ import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.prune.walk_pruner import Pruner
from paddleslim.prune import Pruner
from layers import conv_bn_layer
......@@ -72,7 +72,8 @@ class TestPrune(unittest.TestCase):
for param in main_program.global_block().all_parameters():
if "weights" in param.name:
print("param: {}; param shape: {}".format(param.name, param.shape))
print("param: {}; param shape: {}".format(param.name,
param.shape))
self.assertTrue(param.shape == shapes[param.name])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册