未验证 提交 d5f6d39b 编写于 作者: W whs 提交者: GitHub

Format en api of prune, one-shot and common modules (#101)

上级 6ac0720c
......@@ -4,21 +4,16 @@
contain the root `toctree` directive.
API Documents
========
==============
.. toctree::
:maxdepth: 1
modules.rst
paddleslim.analysis.rst
paddleslim.common.rst
paddleslim.core.rst
paddleslim.prune.rst
paddleslim.dist.rst
paddleslim.models.rst
paddleslim.nas.one_shot.rst
paddleslim.quant.rst
paddleslim.nas.rst
paddleslim.nas.one_shot.rst
paddleslim.nas.search_space.rst
paddleslim.pantheon.rst
paddleslim.prune.rst
paddleslim.quant.rst
paddleslim.rst
paddleslim.analysis package
===========================
paddleslim\.analysis package
============================
.. automodule:: paddleslim.analysis
:members:
......@@ -9,24 +9,24 @@ paddleslim.analysis package
Submodules
----------
paddleslim.analysis.flops module
--------------------------------
paddleslim\.analysis\.flops module
----------------------------------
.. automodule:: paddleslim.analysis.flops
:members:
:undoc-members:
:show-inheritance:
paddleslim.analysis.latency module
----------------------------------
paddleslim\.analysis\.latency module
------------------------------------
.. automodule:: paddleslim.analysis.latency
:members:
:undoc-members:
:show-inheritance:
paddleslim.analysis.model\_size module
--------------------------------------
paddleslim\.analysis\.model\_size module
----------------------------------------
.. automodule:: paddleslim.analysis.model_size
:members:
......
paddleslim.common package
=========================
paddleslim\.common package
==========================
.. automodule:: paddleslim.common
:members:
......@@ -9,56 +9,56 @@ paddleslim.common package
Submodules
----------
paddleslim.common.cached\_reader module
---------------------------------------
paddleslim\.common\.cached\_reader module
-----------------------------------------
.. automodule:: paddleslim.common.cached_reader
:members:
:undoc-members:
:show-inheritance:
paddleslim.common.controller module
-----------------------------------
paddleslim\.common\.controller module
-------------------------------------
.. automodule:: paddleslim.common.controller
:members:
:undoc-members:
:show-inheritance:
paddleslim.common.controller\_client module
-------------------------------------------
paddleslim\.common\.controller\_client module
---------------------------------------------
.. automodule:: paddleslim.common.controller_client
:members:
:undoc-members:
:show-inheritance:
paddleslim.common.controller\_server module
-------------------------------------------
paddleslim\.common\.controller\_server module
---------------------------------------------
.. automodule:: paddleslim.common.controller_server
:members:
:undoc-members:
:show-inheritance:
paddleslim.common.lock module
-----------------------------
paddleslim\.common\.lock module
-------------------------------
.. automodule:: paddleslim.common.lock
:members:
:undoc-members:
:show-inheritance:
paddleslim.common.log\_helper module
------------------------------------
paddleslim\.common\.log\_helper module
--------------------------------------
.. automodule:: paddleslim.common.log_helper
:members:
:undoc-members:
:show-inheritance:
paddleslim.common.sa\_controller module
---------------------------------------
paddleslim\.common\.sa\_controller module
-----------------------------------------
.. automodule:: paddleslim.common.sa_controller
:members:
......
paddleslim.core package
=======================
paddleslim\.core package
========================
.. automodule:: paddleslim.core
:members:
......@@ -9,16 +9,16 @@ paddleslim.core package
Submodules
----------
paddleslim.core.graph\_wrapper module
-------------------------------------
paddleslim\.core\.graph\_wrapper module
---------------------------------------
.. automodule:: paddleslim.core.graph_wrapper
:members:
:undoc-members:
:show-inheritance:
paddleslim.core.registry module
-------------------------------
paddleslim\.core\.registry module
---------------------------------
.. automodule:: paddleslim.core.registry
:members:
......
paddleslim.dist package
=======================
paddleslim\.dist package
========================
.. automodule:: paddleslim.dist
:members:
......@@ -9,8 +9,8 @@ paddleslim.dist package
Submodules
----------
paddleslim.dist.single\_distiller module
----------------------------------------
paddleslim\.dist\.single\_distiller module
------------------------------------------
.. automodule:: paddleslim.dist.single_distiller
:members:
......
paddleslim.models package
=========================
paddleslim\.models package
==========================
.. automodule:: paddleslim.models
:members:
......@@ -9,40 +9,40 @@ paddleslim.models package
Submodules
----------
paddleslim.models.classification\_models module
-----------------------------------------------
paddleslim\.models\.classification\_models module
-------------------------------------------------
.. automodule:: paddleslim.models.classification_models
:members:
:undoc-members:
:show-inheritance:
paddleslim.models.mobilenet module
----------------------------------
paddleslim\.models\.mobilenet module
------------------------------------
.. automodule:: paddleslim.models.mobilenet
:members:
:undoc-members:
:show-inheritance:
paddleslim.models.mobilenet\_v2 module
--------------------------------------
paddleslim\.models\.mobilenet\_v2 module
----------------------------------------
.. automodule:: paddleslim.models.mobilenet_v2
:members:
:undoc-members:
:show-inheritance:
paddleslim.models.resnet module
-------------------------------
paddleslim\.models\.resnet module
---------------------------------
.. automodule:: paddleslim.models.resnet
:members:
:undoc-members:
:show-inheritance:
paddleslim.models.util module
-----------------------------
paddleslim\.models\.util module
-------------------------------
.. automodule:: paddleslim.models.util
:members:
......
paddleslim.nas.one\_shot package
================================
paddleslim\.nas\.one\_shot package
==================================
.. automodule:: paddleslim.nas.one_shot
:members:
......@@ -9,16 +9,16 @@ paddleslim.nas.one\_shot package
Submodules
----------
paddleslim.nas.one\_shot.one\_shot\_nas module
----------------------------------------------
paddleslim\.nas\.one\_shot\.one\_shot\_nas module
-------------------------------------------------
.. automodule:: paddleslim.nas.one_shot.one_shot_nas
:members:
:undoc-members:
:show-inheritance:
paddleslim.nas.one\_shot.super\_mnasnet module
----------------------------------------------
paddleslim\.nas\.one\_shot\.super\_mnasnet module
-------------------------------------------------
.. automodule:: paddleslim.nas.one_shot.super_mnasnet
:members:
......
paddleslim.nas package
======================
paddleslim\.nas package
=======================
.. automodule:: paddleslim.nas
:members:
......@@ -17,8 +17,8 @@ Subpackages
Submodules
----------
paddleslim.nas.sa\_nas module
-----------------------------
paddleslim\.nas\.sa\_nas module
-------------------------------
.. automodule:: paddleslim.nas.sa_nas
:members:
......
paddleslim.nas.search\_space package
====================================
paddleslim\.nas\.search\_space package
======================================
.. automodule:: paddleslim.nas.search_space
:members:
......@@ -9,96 +9,96 @@ paddleslim.nas.search\_space package
Submodules
----------
paddleslim.nas.search\_space.base\_layer module
-----------------------------------------------
paddleslim\.nas\.search\_space\.base\_layer module
--------------------------------------------------
.. automodule:: paddleslim.nas.search_space.base_layer
:members:
:undoc-members:
:show-inheritance:
paddleslim.nas.search\_space.combine\_search\_space module
----------------------------------------------------------
paddleslim\.nas\.search\_space\.combine\_search\_space module
-------------------------------------------------------------
.. automodule:: paddleslim.nas.search_space.combine_search_space
:members:
:undoc-members:
:show-inheritance:
paddleslim.nas.search\_space.inception\_block module
----------------------------------------------------
paddleslim\.nas\.search\_space\.inception\_block module
-------------------------------------------------------
.. automodule:: paddleslim.nas.search_space.inception_block
:members:
:undoc-members:
:show-inheritance:
paddleslim.nas.search\_space.mobilenet\_block module
----------------------------------------------------
paddleslim\.nas\.search\_space\.mobilenet\_block module
-------------------------------------------------------
.. automodule:: paddleslim.nas.search_space.mobilenet_block
:members:
:undoc-members:
:show-inheritance:
paddleslim.nas.search\_space.mobilenetv1 module
-----------------------------------------------
paddleslim\.nas\.search\_space\.mobilenetv1 module
--------------------------------------------------
.. automodule:: paddleslim.nas.search_space.mobilenetv1
:members:
:undoc-members:
:show-inheritance:
paddleslim.nas.search\_space.mobilenetv2 module
-----------------------------------------------
paddleslim\.nas\.search\_space\.mobilenetv2 module
--------------------------------------------------
.. automodule:: paddleslim.nas.search_space.mobilenetv2
:members:
:undoc-members:
:show-inheritance:
paddleslim.nas.search\_space.resnet module
------------------------------------------
paddleslim\.nas\.search\_space\.resnet module
---------------------------------------------
.. automodule:: paddleslim.nas.search_space.resnet
:members:
:undoc-members:
:show-inheritance:
paddleslim.nas.search\_space.resnet\_block module
-------------------------------------------------
paddleslim\.nas\.search\_space\.resnet\_block module
----------------------------------------------------
.. automodule:: paddleslim.nas.search_space.resnet_block
:members:
:undoc-members:
:show-inheritance:
paddleslim.nas.search\_space.search\_space\_base module
-------------------------------------------------------
paddleslim\.nas\.search\_space\.search\_space\_base module
----------------------------------------------------------
.. automodule:: paddleslim.nas.search_space.search_space_base
:members:
:undoc-members:
:show-inheritance:
paddleslim.nas.search\_space.search\_space\_factory module
----------------------------------------------------------
paddleslim\.nas\.search\_space\.search\_space\_factory module
-------------------------------------------------------------
.. automodule:: paddleslim.nas.search_space.search_space_factory
:members:
:undoc-members:
:show-inheritance:
paddleslim.nas.search\_space.search\_space\_registry module
-----------------------------------------------------------
paddleslim\.nas\.search\_space\.search\_space\_registry module
--------------------------------------------------------------
.. automodule:: paddleslim.nas.search_space.search_space_registry
:members:
:undoc-members:
:show-inheritance:
paddleslim.nas.search\_space.utils module
-----------------------------------------
paddleslim\.nas\.search\_space\.utils module
--------------------------------------------
.. automodule:: paddleslim.nas.search_space.utils
:members:
......
paddleslim.pantheon package
===========================
paddleslim\.pantheon package
============================
.. automodule:: paddleslim.pantheon
:members:
......@@ -9,24 +9,24 @@ paddleslim.pantheon package
Submodules
----------
paddleslim.pantheon.student module
----------------------------------
paddleslim\.pantheon\.student module
------------------------------------
.. automodule:: paddleslim.pantheon.student
:members:
:undoc-members:
:show-inheritance:
paddleslim.pantheon.teacher module
----------------------------------
paddleslim\.pantheon\.teacher module
------------------------------------
.. automodule:: paddleslim.pantheon.teacher
:members:
:undoc-members:
:show-inheritance:
paddleslim.pantheon.utils module
--------------------------------
paddleslim\.pantheon\.utils module
----------------------------------
.. automodule:: paddleslim.pantheon.utils
:members:
......
paddleslim.prune package
========================
paddleslim\.prune package
=========================
.. automodule:: paddleslim.prune
:members:
......@@ -9,48 +9,48 @@ paddleslim.prune package
Submodules
----------
paddleslim.prune.auto\_pruner module
------------------------------------
paddleslim\.prune\.auto\_pruner module
--------------------------------------
.. automodule:: paddleslim.prune.auto_pruner
:members:
:undoc-members:
:show-inheritance:
paddleslim.prune.prune\_io module
---------------------------------
paddleslim\.prune\.prune\_io module
-----------------------------------
.. automodule:: paddleslim.prune.prune_io
:members:
:undoc-members:
:show-inheritance:
paddleslim.prune.prune\_walker module
-------------------------------------
paddleslim\.prune\.prune\_walker module
---------------------------------------
.. automodule:: paddleslim.prune.prune_walker
:members:
:undoc-members:
:show-inheritance:
paddleslim.prune.pruner module
------------------------------
paddleslim\.prune\.pruner module
--------------------------------
.. automodule:: paddleslim.prune.pruner
:members:
:undoc-members:
:show-inheritance:
paddleslim.prune.sensitive module
---------------------------------
paddleslim\.prune\.sensitive module
-----------------------------------
.. automodule:: paddleslim.prune.sensitive
:members:
:undoc-members:
:show-inheritance:
paddleslim.prune.sensitive\_pruner module
-----------------------------------------
paddleslim\.prune\.sensitive\_pruner module
-------------------------------------------
.. automodule:: paddleslim.prune.sensitive_pruner
:members:
......
paddleslim.quant package
========================
paddleslim\.quant package
=========================
.. automodule:: paddleslim.quant
:members:
......@@ -9,16 +9,16 @@ paddleslim.quant package
Submodules
----------
paddleslim.quant.quant\_embedding module
----------------------------------------
paddleslim\.quant\.quant\_embedding module
------------------------------------------
.. automodule:: paddleslim.quant.quant_embedding
:members:
:undoc-members:
:show-inheritance:
paddleslim.quant.quanter module
-------------------------------
paddleslim\.quant\.quanter module
---------------------------------
.. automodule:: paddleslim.quant.quanter
:members:
......
......@@ -24,8 +24,8 @@ Subpackages
Submodules
----------
paddleslim.version module
-------------------------
paddleslim\.version module
--------------------------
.. automodule:: paddleslim.version
:members:
......
......@@ -25,6 +25,7 @@ _logger = get_logger(__name__, level=logging.INFO)
def cached_reader(reader, sampled_rate, cache_path, cached_id):
"""
Sample partial data from reader and cache them into local file system.
Args:
reader: Iterative data source.
sampled_rate(float): The sampled rate used to sample partial data for evaluation. None means using all data in eval_reader. default: None.
......
......@@ -24,11 +24,9 @@ class EvolutionaryController(object):
"""Abstract controller for all evolutionary searching method.
"""
def __init__(self, *args, **kwargs):
pass
def update(self, tokens, reward):
"""Update the status of controller according current tokens and reward.
Args:
tokens(list<int>): A solution of searching task.
reward(list<int>): The reward of tokens.
......@@ -37,6 +35,7 @@ class EvolutionaryController(object):
def reset(self, range_table, constrain_func=None):
"""Reset the controller.
Args:
range_table(list<int>): It is used to define the searching space of controller.
The tokens[i] generated by controller should be in [0, range_table[i]).
......@@ -47,5 +46,8 @@ class EvolutionaryController(object):
def next_tokens(self):
"""Generate new tokens.
Returns:
list<list>: The next searched tokens.
"""
raise NotImplementedError('Abstract method.')
......@@ -24,6 +24,11 @@ _logger = get_logger(__name__, level=logging.INFO)
class ControllerClient(object):
"""
Controller client.
Args:
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
client_name(str): Current client name, random generate for counting client number. Default: None.
"""
def __init__(self,
......@@ -32,11 +37,6 @@ class ControllerClient(object):
key=None,
client_name=None):
"""
Args:
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
client_name(str): Current client name, random generate for counting client number. Default: None.
"""
self.server_ip = server_ip
self.server_port = server_port
......@@ -46,9 +46,11 @@ class ControllerClient(object):
def update(self, tokens, reward, iter):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
iter(int): The iteration number of current client.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
......
......@@ -26,8 +26,14 @@ _logger = get_logger(__name__, level=logging.INFO)
class ControllerServer(object):
"""
The controller wrapper with a socket server to handle the request of search agent.
"""The controller wrapper with a socket server to handle the request of search agent.
Args:
controller(slim.searcher.Controller): The controller used to generate tokens.
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
which means setting ip automatically
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
search_steps(int|None): The total steps of searching. None means never stopping. Default: None
key(str|None): Config information. Default: None.
"""
def __init__(self,
......@@ -37,13 +43,6 @@ class ControllerServer(object):
search_steps=None,
key=None):
"""
Args:
controller(slim.searcher.Controller): The controller used to generate tokens.
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
which means setting ip automatically
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
search_steps(int|None): The total steps of searching. None means never stopping. Default: None
key(str|None): Config information. Default: None.
"""
self._controller = controller
self._address = address
......@@ -84,6 +83,8 @@ class ControllerServer(object):
return self._ip
def run(self):
"""Start the server.
"""
_logger.info("Controller Server run...")
try:
while ((self._search_steps is None) or
......
......@@ -24,14 +24,19 @@ def get_logger(name, level, fmt='%(asctime)s-%(levelname)s: %(message)s'):
Get logger from logging with given name, level and format without
setting logging basicConfig. For setting basicConfig in paddle
will disable basicConfig setting after import paddle.
Args:
name (str): The logger name.
level (logging.LEVEL): The base level of the logger
fmt (str): Format of logger output
Returns:
logging.Logger: logging logger with given setttings
Examples:
.. code-block:: python
logger = log_helper.get_logger(__name__, logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s')
"""
......
......@@ -29,22 +29,8 @@ _logger = get_logger(__name__, level=logging.INFO)
class SAController(EvolutionaryController):
"""Simulated annealing controller."""
"""Simulated annealing controller.
def __init__(self,
range_table=None,
reduce_rate=0.85,
init_temperature=None,
max_try_times=300,
init_tokens=None,
reward=-1,
max_reward=-1,
iters=0,
best_tokens=None,
constrain_func=None,
checkpoints=None,
searched=None):
"""Initialize.
Args:
range_table(list<int>): Range table.
reduce_rate(float): The decay rate of temperature.
......@@ -59,6 +45,20 @@ class SAController(EvolutionaryController):
checkpoints(str): if checkpoint is None, donnot save checkpoints, else save scene to checkpoints file.
searched(dict<list, float>): remember tokens which are searched.
"""
def __init__(self,
range_table=None,
reduce_rate=0.85,
init_temperature=None,
max_try_times=300,
init_tokens=None,
reward=-1,
max_reward=-1,
iters=0,
best_tokens=None,
constrain_func=None,
checkpoints=None,
searched=None):
super(SAController, self).__init__()
self._range_table = range_table
assert isinstance(self._range_table, tuple) and (
......@@ -92,6 +92,11 @@ class SAController(EvolutionaryController):
@property
def best_tokens(self):
"""Get current best tokens.
Returns:
list<int>: The best tokens.
"""
return self._best_tokens
@property
......@@ -100,14 +105,23 @@ class SAController(EvolutionaryController):
@property
def current_tokens(self):
"""Get tokens generated in current searching step.
Returns:
list<int>: The best tokens.
"""
return self._current_tokens
def update(self, tokens, reward, iter, client_num):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
tokens(list<int>): The tokens generated in current step.
reward(float): The reward of tokens.
iter(int): The current step of searching client.
client_num(int): The total number of searching client.
"""
iter = int(iter)
if iter > self._iter:
......@@ -136,6 +150,12 @@ class SAController(EvolutionaryController):
def next_tokens(self, control_token=None):
"""
Get next tokens.
Args:
control_token: The tokens used to generate next tokens.
Returns:
list<int>: The next tokens.
"""
if control_token:
tokens = control_token[:]
......
......@@ -72,6 +72,7 @@ class VarWrapper(object):
def inputs(self):
"""
Get all the operators that use this variable as output.
Returns:
list<OpWrapper>: A list of operators.
"""
......@@ -84,6 +85,7 @@ class VarWrapper(object):
def outputs(self):
"""
Get all the operators that use this variable as input.
Returns:
list<OpWrapper>: A list of operators.
"""
......@@ -196,10 +198,7 @@ class GraphWrapper(object):
"""
It is a wrapper of paddle.fluid.framework.IrGraph with some special functions
for paddle slim framework.
"""
def __init__(self, program=None, in_nodes=[], out_nodes=[]):
"""
Args:
program(framework.Program): A program with
in_nodes(dict): A dict to indicate the input nodes of the graph.
......@@ -209,6 +208,10 @@ class GraphWrapper(object):
The key is user-defined and human-readable name.
The value is the name of Variable.
"""
def __init__(self, program=None, in_nodes=[], out_nodes=[]):
"""
"""
super(GraphWrapper, self).__init__()
self.program = Program() if program is None else program
self.persistables = {}
......@@ -226,6 +229,7 @@ class GraphWrapper(object):
def all_parameters(self):
"""
Get all the parameters in this graph.
Returns:
list<VarWrapper>: A list of VarWrapper instances.
"""
......@@ -238,6 +242,7 @@ class GraphWrapper(object):
def is_parameter(self, var):
"""
Whether the given variable is parameter.
Args:
var(VarWrapper): The given varibale.
"""
......@@ -246,6 +251,7 @@ class GraphWrapper(object):
def is_persistable(self, var):
"""
Whether the given variable is persistable.
Args:
var(VarWrapper): The given varibale.
"""
......@@ -279,6 +285,7 @@ class GraphWrapper(object):
def clone(self, for_test=False):
"""
Clone a new graph from current graph.
Returns:
(GraphWrapper): The wrapper of a new graph.
"""
......@@ -295,8 +302,10 @@ class GraphWrapper(object):
def pre_ops(self, op):
"""
Get all the previous operators of target operator.
Args:
op(OpWrapper): Target operator..
op(OpWrapper): Target operator.
Returns:
list<OpWrapper>: A list of operators.
"""
......@@ -310,8 +319,10 @@ class GraphWrapper(object):
def next_ops(self, op):
"""
Get all the next operators of target operator.
Args:
op(OpWrapper): Target operator..
op(OpWrapper): Target operator.
Returns:
list<OpWrapper>: A list of operators.
"""
......
......@@ -22,14 +22,16 @@ __all__ = ['OneShotSuperNet', 'OneShotSearch']
def OneShotSearch(model, eval_func, strategy='sa', search_steps=100):
"""
Search a best tokens which represents a sub-network.
Archs:
Args:
model(fluid.dygraph.Layer): A dynamic graph module whose sub-modules should contain
one instance of `OneShotSuperNet` at least.
eval_func(function): A callback function which accept model and tokens as arguments.
strategy(str): The name of strategy used to search. Default: 'sa'.
search_steps(int): The total steps for searching.
Returns:
tokens(list): The best tokens searched.
list<int>: The best tokens searched.
"""
super_net = None
for layer in model.sublayers(include_sublayers=False):
......@@ -52,8 +54,7 @@ def OneShotSearch(model, eval_func, strategy='sa', search_steps=100):
class OneShotSuperNet(fluid.dygraph.Layer):
"""
The base class of super net used in one-shot searching strategy.
"""The base class of super net used in one-shot searching strategy.
A super net is a dygraph layer.
Args:
......@@ -65,14 +66,16 @@ class OneShotSuperNet(fluid.dygraph.Layer):
def init_tokens(self):
"""Get init tokens in search space.
Return:
tokens(list): The init tokens which is a list of integer.
Returns:
lis<int>t: The init tokens which is a list of integer.
"""
raise NotImplementedError('Abstract method.')
def range_table(self):
"""Get range table of current search space.
Return:
Returns:
range_table(tuple): The maximum value and minimum value in each position of tokens
with format `(min_values, max_values)`. The `min_values` is
a list of integers indicating the minimum values while `max_values`
......@@ -81,9 +84,9 @@ class OneShotSuperNet(fluid.dygraph.Layer):
raise NotImplementedError('Abstract method.')
def _forward_impl(self, *inputs, **kwargs):
"""
Defines the computation performed at every call.
"""Defines the computation performed at every call.
Should be overridden by all subclasses.
Args:
inputs(tuple): unpacked tuple arguments
kwargs(dict): unpacked dict arguments
......@@ -93,6 +96,7 @@ class OneShotSuperNet(fluid.dygraph.Layer):
def forward(self, input, tokens=None):
"""
Defines the computation performed at every call.
Args:
input(variable): The input of super net.
tokens(list): The tokens used to generate a sub-network.
......@@ -100,8 +104,9 @@ class OneShotSuperNet(fluid.dygraph.Layer):
Otherwise, it will execute the sub-network generated by tokens.
The `tokens` should be set in searching stage and final training stage.
Default: None.
Returns:
output(varaible): The output of super net.
Varaible: The output of super net.
"""
if tokens == None:
tokens = self._random_tokens()
......
......@@ -31,26 +31,9 @@ _logger = get_logger(__name__, level=logging.INFO)
class AutoPruner(object):
def __init__(self,
program,
scope,
place,
params=[],
init_ratios=None,
pruned_flops=0.5,
pruned_latency=None,
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_times=300,
max_client_num=10,
search_steps=300,
max_ratios=[0.9],
min_ratios=[0],
key="auto_pruner",
is_server=True):
"""
Search a group of ratios used to prune program.
Args:
program(Program): The program to be pruned.
scope(Scope): The scope to be pruned.
......@@ -81,6 +64,25 @@ class AutoPruner(object):
is_server(bool): Whether current host is controller server. Default: True.
"""
def __init__(self,
program,
scope,
place,
params=[],
init_ratios=None,
pruned_flops=0.5,
pruned_latency=None,
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_times=300,
max_client_num=10,
search_steps=300,
max_ratios=[0.9],
min_ratios=[0],
key="auto_pruner",
is_server=True):
self._program = program
self._scope = scope
self._place = place
......@@ -177,10 +179,12 @@ class AutoPruner(object):
def prune(self, program, eval_program=None):
"""
Prune program with latest tokens generated by controller.
Args:
program(fluid.Program): The program to be pruned.
Returns:
Program: The pruned program.
paddle.fluid.Program: The pruned program.
"""
self._current_ratios = self._next_ratios()
pruned_program, _, _ = self._pruner.prune(
......@@ -208,8 +212,9 @@ class AutoPruner(object):
def reward(self, score):
"""
Return reward of current pruned program.
Args:
score(float): The score of pruned program.
float: The score of pruned program.
"""
self._restore(self._scope)
self._param_backup = {}
......
......@@ -19,8 +19,9 @@ def save_model(exe, graph, dirname):
Save weights of model and information of shapes into filesystem.
Args:
- graph(Program|Graph): The graph to be saved.
- dirname(str): The directory that the model saved into.
exe(paddle.fluid.Executor): The executor used to save model.
graph(Program|Graph): The graph to be saved.
dirname(str): The directory that the model saved into.
"""
assert graph is not None and dirname is not None
graph = GraphWrapper(graph) if isinstance(graph, Program) else graph
......@@ -46,8 +47,8 @@ def load_model(exe, graph, dirname):
Load weights of model and information of shapes from filesystem.
Args:
- graph(Program|Graph): The graph to be saved.
- dirname(str): The directory that the model saved into.
graph(Program|Graph): The graph to be updated by loaded information..
dirname(str): The directory that the model will be loaded.
"""
assert graph is not None and dirname is not None
graph = GraphWrapper(graph) if isinstance(graph, Program) else graph
......
......@@ -26,12 +26,14 @@ _logger = get_logger(__name__, level=logging.INFO)
class Pruner():
def __init__(self, criterion="l1_norm"):
"""
"""The pruner used to prune channels of convolution.
Args:
criterion(str): the criterion used to sort channels for pruning.
It only supports 'l1_norm' currently.
criterion(str): the criterion used to sort channels for pruning. It only supports 'l1_norm' currently.
"""
def __init__(self, criterion="l1_norm"):
self.criterion = criterion
def prune(self,
......@@ -44,9 +46,10 @@ class Pruner():
only_graph=False,
param_backup=False,
param_shape_backup=False):
"""
Pruning the given parameters.
"""Pruning the given parameters.
Args:
program(fluid.Program): The program to be pruned.
scope(fluid.Scope): The scope storing paramaters to be pruned.
params(list<str>): A list of parameter names to be pruned.
......@@ -58,10 +61,9 @@ class Pruner():
False means modifying graph and variables in scope. Default: False.
param_backup(bool): Whether to return a dict to backup the values of parameters. Default: False.
param_shape_backup(bool): Whether to return a dict to backup the shapes of parameters. Default: False.
Returns:
Program: The pruned program.
param_backup: A dict to backup the values of parameters.
param_shape_backup: A dict to backup the shapes of parameters.
tuple: ``(pruned_program, param_backup, param_shape_backup)``. ``pruned_program`` is the pruned program. ``param_backup`` is a dict to backup the values of parameters. ``param_shape_backup`` is a dict to backup the shapes of parameters.
"""
self.pruned_list = []
......@@ -131,6 +133,7 @@ class Pruner():
def _cal_pruned_idx(self, param, ratio, axis):
"""
Calculate the index to be pruned on axis by given pruning ratio.
Args:
name(str): The name of parameter to be pruned.
param(np.array): The data of parameter to be pruned.
......@@ -138,6 +141,7 @@ class Pruner():
axis(int): The axis to be used for pruning given parameter.
If it is None, the value in self.pruning_axis will be used.
default: None.
Returns:
list<int>: The indexes to be pruned on axis.
"""
......@@ -151,6 +155,7 @@ class Pruner():
def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False):
"""
Pruning a array by indexes on given axis.
Args:
tensor(numpy.array): The target array to be pruned.
pruned_idx(list<int>): The indexes to be pruned.
......@@ -158,6 +163,7 @@ class Pruner():
lazy(bool): True means setting the pruned elements to zero.
False means remove the pruned elements from memory.
default: False.
Returns:
numpy.array: The pruned array.
"""
......
......@@ -37,6 +37,35 @@ def sensitivity(program,
eval_func,
sensitivities_file=None,
pruned_ratios=None):
"""Compute the sensitivities of convolutions in a model. The sensitivity of a convolution is the losses of accuracy on test dataset in differenct pruned ratios. The sensitivities can be used to get a group of best ratios with some condition.
This function return a dict storing sensitivities as below:
.. code-block:: python
{"weight_0":
{0.1: 0.22,
0.2: 0.33
},
"weight_1":
{0.1: 0.21,
0.2: 0.4
}
}
``weight_0`` is parameter name of convolution. ``sensitivities['weight_0']`` is a dict in which key is pruned ratio and value is the percent of losses.
Args:
program(paddle.fluid.Program): The program to be analysised.
place(fluid.CPUPlace | fluid.CUDAPlace): The device place of filter parameters.
param_names(list): The parameter names of convolutions to be analysised.
eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.fluid.Program` as argument and return a score on test dataset.
sensitivities_file(str): The file to save the sensitivities. It will append the latest computed sensitivities into the file. And the sensitivities in the file would not be computed again. This file can be loaded by `pickle` library.
pruned_ratios(list): The ratios to be pruned. default: ``[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]``.
Returns:
dict: A dict storing sensitivities.
"""
scope = fluid.global_scope()
graph = GraphWrapper(program)
sensitivities = load_sensitivities(sensitivities_file)
......@@ -159,13 +188,13 @@ def flops_sensitivity(program,
def merge_sensitive(sensitivities):
"""
Merge sensitivities.
"""Merge sensitivities.
Args:
sensitivities(list<dict> | list<str>): The sensitivities to be merged. It cann be a list of sensitivities files or dict.
Returns:
sensitivities(dict): A dict with sensitivities.
dict: A dict stroring sensitivities.
"""
assert len(sensitivities) > 0
if not isinstance(sensitivities[0], dict):
......@@ -182,8 +211,13 @@ def merge_sensitive(sensitivities):
def load_sensitivities(sensitivities_file):
"""
Load sensitivities from file.
"""Load sensitivities from file.
Args:
sensitivities_file(str): The file storing sensitivities.
Returns:
dict: A dict stroring sensitivities.
"""
sensitivities = {}
if sensitivities_file and os.path.exists(sensitivities_file):
......@@ -196,8 +230,11 @@ def load_sensitivities(sensitivities_file):
def _save_sensitivities(sensitivities, sensitivities_file):
"""
Save sensitivities into file.
"""Save sensitivities into file.
Args:
sensitivities(dict): The sensitivities to be saved.
sensitivities_file(str): The file to saved sensitivities.
"""
with open(sensitivities_file, 'wb') as f:
pickle.dump(sensitivities, f)
......@@ -217,7 +254,7 @@ def get_ratios_by_loss(sensitivities, loss):
Returns:
ratios(dict): A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned.
dict: A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned.
"""
ratios = {}
for param, losses in sensitivities.items():
......
......@@ -30,10 +30,10 @@ _logger = get_logger(__name__, level=logging.INFO)
class SensitivePruner(object):
def __init__(self, place, eval_func, scope=None, checkpoints=None):
"""
Pruner used to prune parameters iteratively according to sensitivities
of parameters in each step.
Args:
place(fluid.CUDAPlace | fluid.CPUPlace): The device place where
program execute.
......@@ -42,6 +42,8 @@ class SensitivePruner(object):
And it return a score of given program.
scope(fluid.scope): The scope used to execute program.
"""
def __init__(self, place, eval_func, scope=None, checkpoints=None):
self._eval_func = eval_func
self._iter = 0
self._place = place
......@@ -135,12 +137,14 @@ class SensitivePruner(object):
def prune(self, train_program, eval_program, params, pruned_flops):
"""
Pruning parameters of training and evaluation network by sensitivities in current step.
Args:
train_program(fluid.Program): The training program to be pruned.
eval_program(fluid.Program): The evaluation program to be pruned. And it is also used to calculate sensitivities of parameters.
params(list<str>): The parameters to be pruned.
pruned_flops(float): The ratio of FLOPS to be pruned in current step.
Return:
Returns:
tuple: A tuple of pruned training program and pruned evaluation program.
"""
_logger.info("Pruning: {}".format(params))
......@@ -199,9 +203,9 @@ class SensitivePruner(object):
pruned_flops(float): The percent of FLOPS to be pruned.
eval_program(Program): The program whose FLOPS is considered.
Return:
Returns:
ratios(dict): A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned.
dict: A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned.
"""
min_loss = 0.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册