提交 67a45929 编写于 作者: W wanghaoshuang

Add some comments.

上级 9ff3159e
......@@ -52,6 +52,9 @@ class AutoPruner(object):
"""
Search a group of ratios used to prune program.
Args:
program(Program): The program to be pruned.
scope(Scope): The scope to be pruned.
place(fluid.Place): The device place of parameters.
params(list<str>): The names of parameters to be pruned.
init_ratios(list<float>|float): Init ratios used to pruned parameters in `params`.
List means ratios used for pruning each parameter in `params`.
......@@ -61,9 +64,20 @@ class AutoPruner(object):
pruned_flops(float): The percent of FLOPS to be pruned. Default: None.
pruned_latency(float): The percent of latency to be pruned. Default: None.
server_addr(tuple): A tuple of server ip and server port for controller server.
search_strategy(str): The search strategy. Default: 'sa'.
init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching.
max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`.
The length of `max_ratios` should be equal to length of params when `max_ratios` is a list.
If it is a scalar, it will used for all the parameters in `params`.
min_ratios(float|list<float>): Min ratios used to pruned parameters in `params`. List means min ratios for each parameter in `params`.
The length of `min_ratios` should be equal to length of params when `min_ratios` is a list.
If it is a scalar, it will used for all the parameters in `params`.
key(str): Identity used in communication between controller server and clients.
is_server(bool): Whether current host is controller server. Default: True.
"""
# step1: Create controller server. And start server if current host match server_ip.
self._program = program
self._scope = scope
......@@ -150,6 +164,13 @@ class AutoPruner(object):
1 - self._pruned_flops)
def prune(self, program):
"""
Prune program with latest tokens generated by controller.
Args:
program(fluid.Program): The program to be pruned.
Returns:
Program: The pruned program.
"""
self._current_ratios = self._next_ratios()
pruned_program = self._pruner.prune(
program,
......@@ -163,6 +184,11 @@ class AutoPruner(object):
return pruned_program
def reward(self, score):
"""
Return reward of current pruned program.
Args:
score(float): The score of pruned program.
"""
self._restore(self._scope)
self._param_backup = {}
tokens = self._ratios2tokens(self._current_ratios)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册