未验证 提交 d5f6d39b 编写于 作者: W whs 提交者: GitHub

Format en api of prune, one-shot and common modules (#101)

上级 6ac0720c
...@@ -4,21 +4,16 @@ ...@@ -4,21 +4,16 @@
contain the root `toctree` directive. contain the root `toctree` directive.
API Documents API Documents
======== ==============
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
modules.rst
paddleslim.analysis.rst paddleslim.analysis.rst
paddleslim.common.rst paddleslim.prune.rst
paddleslim.core.rst
paddleslim.dist.rst paddleslim.dist.rst
paddleslim.models.rst paddleslim.quant.rst
paddleslim.nas.one_shot.rst
paddleslim.nas.rst paddleslim.nas.rst
paddleslim.nas.one_shot.rst
paddleslim.nas.search_space.rst paddleslim.nas.search_space.rst
paddleslim.pantheon.rst paddleslim.pantheon.rst
paddleslim.prune.rst
paddleslim.quant.rst
paddleslim.rst
paddleslim.analysis package paddleslim\.analysis package
=========================== ============================
.. automodule:: paddleslim.analysis .. automodule:: paddleslim.analysis
:members: :members:
...@@ -9,24 +9,24 @@ paddleslim.analysis package ...@@ -9,24 +9,24 @@ paddleslim.analysis package
Submodules Submodules
---------- ----------
paddleslim.analysis.flops module paddleslim\.analysis\.flops module
-------------------------------- ----------------------------------
.. automodule:: paddleslim.analysis.flops .. automodule:: paddleslim.analysis.flops
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.analysis.latency module paddleslim\.analysis\.latency module
---------------------------------- ------------------------------------
.. automodule:: paddleslim.analysis.latency .. automodule:: paddleslim.analysis.latency
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.analysis.model\_size module paddleslim\.analysis\.model\_size module
-------------------------------------- ----------------------------------------
.. automodule:: paddleslim.analysis.model_size .. automodule:: paddleslim.analysis.model_size
:members: :members:
......
paddleslim.common package paddleslim\.common package
========================= ==========================
.. automodule:: paddleslim.common .. automodule:: paddleslim.common
:members: :members:
...@@ -9,56 +9,56 @@ paddleslim.common package ...@@ -9,56 +9,56 @@ paddleslim.common package
Submodules Submodules
---------- ----------
paddleslim.common.cached\_reader module paddleslim\.common\.cached\_reader module
--------------------------------------- -----------------------------------------
.. automodule:: paddleslim.common.cached_reader .. automodule:: paddleslim.common.cached_reader
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.common.controller module paddleslim\.common\.controller module
----------------------------------- -------------------------------------
.. automodule:: paddleslim.common.controller .. automodule:: paddleslim.common.controller
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.common.controller\_client module paddleslim\.common\.controller\_client module
------------------------------------------- ---------------------------------------------
.. automodule:: paddleslim.common.controller_client .. automodule:: paddleslim.common.controller_client
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.common.controller\_server module paddleslim\.common\.controller\_server module
------------------------------------------- ---------------------------------------------
.. automodule:: paddleslim.common.controller_server .. automodule:: paddleslim.common.controller_server
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.common.lock module paddleslim\.common\.lock module
----------------------------- -------------------------------
.. automodule:: paddleslim.common.lock .. automodule:: paddleslim.common.lock
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.common.log\_helper module paddleslim\.common\.log\_helper module
------------------------------------ --------------------------------------
.. automodule:: paddleslim.common.log_helper .. automodule:: paddleslim.common.log_helper
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.common.sa\_controller module paddleslim\.common\.sa\_controller module
--------------------------------------- -----------------------------------------
.. automodule:: paddleslim.common.sa_controller .. automodule:: paddleslim.common.sa_controller
:members: :members:
......
paddleslim.core package paddleslim\.core package
======================= ========================
.. automodule:: paddleslim.core .. automodule:: paddleslim.core
:members: :members:
...@@ -9,16 +9,16 @@ paddleslim.core package ...@@ -9,16 +9,16 @@ paddleslim.core package
Submodules Submodules
---------- ----------
paddleslim.core.graph\_wrapper module paddleslim\.core\.graph\_wrapper module
------------------------------------- ---------------------------------------
.. automodule:: paddleslim.core.graph_wrapper .. automodule:: paddleslim.core.graph_wrapper
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.core.registry module paddleslim\.core\.registry module
------------------------------- ---------------------------------
.. automodule:: paddleslim.core.registry .. automodule:: paddleslim.core.registry
:members: :members:
......
paddleslim.dist package paddleslim\.dist package
======================= ========================
.. automodule:: paddleslim.dist .. automodule:: paddleslim.dist
:members: :members:
...@@ -9,8 +9,8 @@ paddleslim.dist package ...@@ -9,8 +9,8 @@ paddleslim.dist package
Submodules Submodules
---------- ----------
paddleslim.dist.single\_distiller module paddleslim\.dist\.single\_distiller module
---------------------------------------- ------------------------------------------
.. automodule:: paddleslim.dist.single_distiller .. automodule:: paddleslim.dist.single_distiller
:members: :members:
......
paddleslim.models package paddleslim\.models package
========================= ==========================
.. automodule:: paddleslim.models .. automodule:: paddleslim.models
:members: :members:
...@@ -9,40 +9,40 @@ paddleslim.models package ...@@ -9,40 +9,40 @@ paddleslim.models package
Submodules Submodules
---------- ----------
paddleslim.models.classification\_models module paddleslim\.models\.classification\_models module
----------------------------------------------- -------------------------------------------------
.. automodule:: paddleslim.models.classification_models .. automodule:: paddleslim.models.classification_models
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.models.mobilenet module paddleslim\.models\.mobilenet module
---------------------------------- ------------------------------------
.. automodule:: paddleslim.models.mobilenet .. automodule:: paddleslim.models.mobilenet
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.models.mobilenet\_v2 module paddleslim\.models\.mobilenet\_v2 module
-------------------------------------- ----------------------------------------
.. automodule:: paddleslim.models.mobilenet_v2 .. automodule:: paddleslim.models.mobilenet_v2
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.models.resnet module paddleslim\.models\.resnet module
------------------------------- ---------------------------------
.. automodule:: paddleslim.models.resnet .. automodule:: paddleslim.models.resnet
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.models.util module paddleslim\.models\.util module
----------------------------- -------------------------------
.. automodule:: paddleslim.models.util .. automodule:: paddleslim.models.util
:members: :members:
......
paddleslim.nas.one\_shot package paddleslim\.nas\.one\_shot package
================================ ==================================
.. automodule:: paddleslim.nas.one_shot .. automodule:: paddleslim.nas.one_shot
:members: :members:
...@@ -9,16 +9,16 @@ paddleslim.nas.one\_shot package ...@@ -9,16 +9,16 @@ paddleslim.nas.one\_shot package
Submodules Submodules
---------- ----------
paddleslim.nas.one\_shot.one\_shot\_nas module paddleslim\.nas\.one\_shot\.one\_shot\_nas module
---------------------------------------------- -------------------------------------------------
.. automodule:: paddleslim.nas.one_shot.one_shot_nas .. automodule:: paddleslim.nas.one_shot.one_shot_nas
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.nas.one\_shot.super\_mnasnet module paddleslim\.nas\.one\_shot\.super\_mnasnet module
---------------------------------------------- -------------------------------------------------
.. automodule:: paddleslim.nas.one_shot.super_mnasnet .. automodule:: paddleslim.nas.one_shot.super_mnasnet
:members: :members:
......
paddleslim.nas package paddleslim\.nas package
====================== =======================
.. automodule:: paddleslim.nas .. automodule:: paddleslim.nas
:members: :members:
...@@ -17,8 +17,8 @@ Subpackages ...@@ -17,8 +17,8 @@ Subpackages
Submodules Submodules
---------- ----------
paddleslim.nas.sa\_nas module paddleslim\.nas\.sa\_nas module
----------------------------- -------------------------------
.. automodule:: paddleslim.nas.sa_nas .. automodule:: paddleslim.nas.sa_nas
:members: :members:
......
paddleslim.nas.search\_space package paddleslim\.nas\.search\_space package
==================================== ======================================
.. automodule:: paddleslim.nas.search_space .. automodule:: paddleslim.nas.search_space
:members: :members:
...@@ -9,96 +9,96 @@ paddleslim.nas.search\_space package ...@@ -9,96 +9,96 @@ paddleslim.nas.search\_space package
Submodules Submodules
---------- ----------
paddleslim.nas.search\_space.base\_layer module paddleslim\.nas\.search\_space\.base\_layer module
----------------------------------------------- --------------------------------------------------
.. automodule:: paddleslim.nas.search_space.base_layer .. automodule:: paddleslim.nas.search_space.base_layer
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.nas.search\_space.combine\_search\_space module paddleslim\.nas\.search\_space\.combine\_search\_space module
---------------------------------------------------------- -------------------------------------------------------------
.. automodule:: paddleslim.nas.search_space.combine_search_space .. automodule:: paddleslim.nas.search_space.combine_search_space
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.nas.search\_space.inception\_block module paddleslim\.nas\.search\_space\.inception\_block module
---------------------------------------------------- -------------------------------------------------------
.. automodule:: paddleslim.nas.search_space.inception_block .. automodule:: paddleslim.nas.search_space.inception_block
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.nas.search\_space.mobilenet\_block module paddleslim\.nas\.search\_space\.mobilenet\_block module
---------------------------------------------------- -------------------------------------------------------
.. automodule:: paddleslim.nas.search_space.mobilenet_block .. automodule:: paddleslim.nas.search_space.mobilenet_block
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.nas.search\_space.mobilenetv1 module paddleslim\.nas\.search\_space\.mobilenetv1 module
----------------------------------------------- --------------------------------------------------
.. automodule:: paddleslim.nas.search_space.mobilenetv1 .. automodule:: paddleslim.nas.search_space.mobilenetv1
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.nas.search\_space.mobilenetv2 module paddleslim\.nas\.search\_space\.mobilenetv2 module
----------------------------------------------- --------------------------------------------------
.. automodule:: paddleslim.nas.search_space.mobilenetv2 .. automodule:: paddleslim.nas.search_space.mobilenetv2
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.nas.search\_space.resnet module paddleslim\.nas\.search\_space\.resnet module
------------------------------------------ ---------------------------------------------
.. automodule:: paddleslim.nas.search_space.resnet .. automodule:: paddleslim.nas.search_space.resnet
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.nas.search\_space.resnet\_block module paddleslim\.nas\.search\_space\.resnet\_block module
------------------------------------------------- ----------------------------------------------------
.. automodule:: paddleslim.nas.search_space.resnet_block .. automodule:: paddleslim.nas.search_space.resnet_block
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.nas.search\_space.search\_space\_base module paddleslim\.nas\.search\_space\.search\_space\_base module
------------------------------------------------------- ----------------------------------------------------------
.. automodule:: paddleslim.nas.search_space.search_space_base .. automodule:: paddleslim.nas.search_space.search_space_base
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.nas.search\_space.search\_space\_factory module paddleslim\.nas\.search\_space\.search\_space\_factory module
---------------------------------------------------------- -------------------------------------------------------------
.. automodule:: paddleslim.nas.search_space.search_space_factory .. automodule:: paddleslim.nas.search_space.search_space_factory
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.nas.search\_space.search\_space\_registry module paddleslim\.nas\.search\_space\.search\_space\_registry module
----------------------------------------------------------- --------------------------------------------------------------
.. automodule:: paddleslim.nas.search_space.search_space_registry .. automodule:: paddleslim.nas.search_space.search_space_registry
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.nas.search\_space.utils module paddleslim\.nas\.search\_space\.utils module
----------------------------------------- --------------------------------------------
.. automodule:: paddleslim.nas.search_space.utils .. automodule:: paddleslim.nas.search_space.utils
:members: :members:
......
paddleslim.pantheon package paddleslim\.pantheon package
=========================== ============================
.. automodule:: paddleslim.pantheon .. automodule:: paddleslim.pantheon
:members: :members:
...@@ -9,24 +9,24 @@ paddleslim.pantheon package ...@@ -9,24 +9,24 @@ paddleslim.pantheon package
Submodules Submodules
---------- ----------
paddleslim.pantheon.student module paddleslim\.pantheon\.student module
---------------------------------- ------------------------------------
.. automodule:: paddleslim.pantheon.student .. automodule:: paddleslim.pantheon.student
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.pantheon.teacher module paddleslim\.pantheon\.teacher module
---------------------------------- ------------------------------------
.. automodule:: paddleslim.pantheon.teacher .. automodule:: paddleslim.pantheon.teacher
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.pantheon.utils module paddleslim\.pantheon\.utils module
-------------------------------- ----------------------------------
.. automodule:: paddleslim.pantheon.utils .. automodule:: paddleslim.pantheon.utils
:members: :members:
......
paddleslim.prune package paddleslim\.prune package
======================== =========================
.. automodule:: paddleslim.prune .. automodule:: paddleslim.prune
:members: :members:
...@@ -9,48 +9,48 @@ paddleslim.prune package ...@@ -9,48 +9,48 @@ paddleslim.prune package
Submodules Submodules
---------- ----------
paddleslim.prune.auto\_pruner module paddleslim\.prune\.auto\_pruner module
------------------------------------ --------------------------------------
.. automodule:: paddleslim.prune.auto_pruner .. automodule:: paddleslim.prune.auto_pruner
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.prune.prune\_io module paddleslim\.prune\.prune\_io module
--------------------------------- -----------------------------------
.. automodule:: paddleslim.prune.prune_io .. automodule:: paddleslim.prune.prune_io
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.prune.prune\_walker module paddleslim\.prune\.prune\_walker module
------------------------------------- ---------------------------------------
.. automodule:: paddleslim.prune.prune_walker .. automodule:: paddleslim.prune.prune_walker
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.prune.pruner module paddleslim\.prune\.pruner module
------------------------------ --------------------------------
.. automodule:: paddleslim.prune.pruner .. automodule:: paddleslim.prune.pruner
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.prune.sensitive module paddleslim\.prune\.sensitive module
--------------------------------- -----------------------------------
.. automodule:: paddleslim.prune.sensitive .. automodule:: paddleslim.prune.sensitive
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.prune.sensitive\_pruner module paddleslim\.prune\.sensitive\_pruner module
----------------------------------------- -------------------------------------------
.. automodule:: paddleslim.prune.sensitive_pruner .. automodule:: paddleslim.prune.sensitive_pruner
:members: :members:
......
paddleslim.quant package paddleslim\.quant package
======================== =========================
.. automodule:: paddleslim.quant .. automodule:: paddleslim.quant
:members: :members:
...@@ -9,16 +9,16 @@ paddleslim.quant package ...@@ -9,16 +9,16 @@ paddleslim.quant package
Submodules Submodules
---------- ----------
paddleslim.quant.quant\_embedding module paddleslim\.quant\.quant\_embedding module
---------------------------------------- ------------------------------------------
.. automodule:: paddleslim.quant.quant_embedding .. automodule:: paddleslim.quant.quant_embedding
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
paddleslim.quant.quanter module paddleslim\.quant\.quanter module
------------------------------- ---------------------------------
.. automodule:: paddleslim.quant.quanter .. automodule:: paddleslim.quant.quanter
:members: :members:
......
...@@ -24,8 +24,8 @@ Subpackages ...@@ -24,8 +24,8 @@ Subpackages
Submodules Submodules
---------- ----------
paddleslim.version module paddleslim\.version module
------------------------- --------------------------
.. automodule:: paddleslim.version .. automodule:: paddleslim.version
:members: :members:
......
...@@ -25,6 +25,7 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -25,6 +25,7 @@ _logger = get_logger(__name__, level=logging.INFO)
def cached_reader(reader, sampled_rate, cache_path, cached_id): def cached_reader(reader, sampled_rate, cache_path, cached_id):
""" """
Sample partial data from reader and cache them into local file system. Sample partial data from reader and cache them into local file system.
Args: Args:
reader: Iterative data source. reader: Iterative data source.
sampled_rate(float): The sampled rate used to sample partial data for evaluation. None means using all data in eval_reader. default: None. sampled_rate(float): The sampled rate used to sample partial data for evaluation. None means using all data in eval_reader. default: None.
......
...@@ -24,11 +24,9 @@ class EvolutionaryController(object): ...@@ -24,11 +24,9 @@ class EvolutionaryController(object):
"""Abstract controller for all evolutionary searching method. """Abstract controller for all evolutionary searching method.
""" """
def __init__(self, *args, **kwargs):
pass
def update(self, tokens, reward): def update(self, tokens, reward):
"""Update the status of controller according current tokens and reward. """Update the status of controller according current tokens and reward.
Args: Args:
tokens(list<int>): A solution of searching task. tokens(list<int>): A solution of searching task.
reward(list<int>): The reward of tokens. reward(list<int>): The reward of tokens.
...@@ -37,6 +35,7 @@ class EvolutionaryController(object): ...@@ -37,6 +35,7 @@ class EvolutionaryController(object):
def reset(self, range_table, constrain_func=None): def reset(self, range_table, constrain_func=None):
"""Reset the controller. """Reset the controller.
Args: Args:
range_table(list<int>): It is used to define the searching space of controller. range_table(list<int>): It is used to define the searching space of controller.
The tokens[i] generated by controller should be in [0, range_table[i]). The tokens[i] generated by controller should be in [0, range_table[i]).
...@@ -47,5 +46,8 @@ class EvolutionaryController(object): ...@@ -47,5 +46,8 @@ class EvolutionaryController(object):
def next_tokens(self): def next_tokens(self):
"""Generate new tokens. """Generate new tokens.
Returns:
list<list>: The next searched tokens.
""" """
raise NotImplementedError('Abstract method.') raise NotImplementedError('Abstract method.')
...@@ -24,6 +24,11 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -24,6 +24,11 @@ _logger = get_logger(__name__, level=logging.INFO)
class ControllerClient(object): class ControllerClient(object):
""" """
Controller client. Controller client.
Args:
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
client_name(str): Current client name, random generate for counting client number. Default: None.
""" """
def __init__(self, def __init__(self,
...@@ -32,11 +37,6 @@ class ControllerClient(object): ...@@ -32,11 +37,6 @@ class ControllerClient(object):
key=None, key=None,
client_name=None): client_name=None):
""" """
Args:
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
client_name(str): Current client name, random generate for counting client number. Default: None.
""" """
self.server_ip = server_ip self.server_ip = server_ip
self.server_port = server_port self.server_port = server_port
...@@ -46,9 +46,11 @@ class ControllerClient(object): ...@@ -46,9 +46,11 @@ class ControllerClient(object):
def update(self, tokens, reward, iter): def update(self, tokens, reward, iter):
""" """
Update the controller according to latest tokens and reward. Update the controller according to latest tokens and reward.
Args: Args:
tokens(list<int>): The tokens generated in last step. tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens. reward(float): The reward of tokens.
iter(int): The iteration number of current client.
""" """
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port)) socket_client.connect((self.server_ip, self.server_port))
......
...@@ -26,8 +26,14 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -26,8 +26,14 @@ _logger = get_logger(__name__, level=logging.INFO)
class ControllerServer(object): class ControllerServer(object):
""" """The controller wrapper with a socket server to handle the request of search agent.
The controller wrapper with a socket server to handle the request of search agent. Args:
controller(slim.searcher.Controller): The controller used to generate tokens.
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
which means setting ip automatically
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
search_steps(int|None): The total steps of searching. None means never stopping. Default: None
key(str|None): Config information. Default: None.
""" """
def __init__(self, def __init__(self,
...@@ -37,13 +43,6 @@ class ControllerServer(object): ...@@ -37,13 +43,6 @@ class ControllerServer(object):
search_steps=None, search_steps=None,
key=None): key=None):
""" """
Args:
controller(slim.searcher.Controller): The controller used to generate tokens.
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
which means setting ip automatically
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
search_steps(int|None): The total steps of searching. None means never stopping. Default: None
key(str|None): Config information. Default: None.
""" """
self._controller = controller self._controller = controller
self._address = address self._address = address
...@@ -84,6 +83,8 @@ class ControllerServer(object): ...@@ -84,6 +83,8 @@ class ControllerServer(object):
return self._ip return self._ip
def run(self): def run(self):
"""Start the server.
"""
_logger.info("Controller Server run...") _logger.info("Controller Server run...")
try: try:
while ((self._search_steps is None) or while ((self._search_steps is None) or
......
...@@ -24,14 +24,19 @@ def get_logger(name, level, fmt='%(asctime)s-%(levelname)s: %(message)s'): ...@@ -24,14 +24,19 @@ def get_logger(name, level, fmt='%(asctime)s-%(levelname)s: %(message)s'):
Get logger from logging with given name, level and format without Get logger from logging with given name, level and format without
setting logging basicConfig. For setting basicConfig in paddle setting logging basicConfig. For setting basicConfig in paddle
will disable basicConfig setting after import paddle. will disable basicConfig setting after import paddle.
Args: Args:
name (str): The logger name. name (str): The logger name.
level (logging.LEVEL): The base level of the logger level (logging.LEVEL): The base level of the logger
fmt (str): Format of logger output fmt (str): Format of logger output
Returns: Returns:
logging.Logger: logging logger with given setttings logging.Logger: logging logger with given setttings
Examples: Examples:
.. code-block:: python .. code-block:: python
logger = log_helper.get_logger(__name__, logging.INFO, logger = log_helper.get_logger(__name__, logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s') fmt='%(asctime)s-%(levelname)s: %(message)s')
""" """
......
...@@ -29,22 +29,8 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -29,22 +29,8 @@ _logger = get_logger(__name__, level=logging.INFO)
class SAController(EvolutionaryController): class SAController(EvolutionaryController):
"""Simulated annealing controller.""" """Simulated annealing controller.
def __init__(self,
range_table=None,
reduce_rate=0.85,
init_temperature=None,
max_try_times=300,
init_tokens=None,
reward=-1,
max_reward=-1,
iters=0,
best_tokens=None,
constrain_func=None,
checkpoints=None,
searched=None):
"""Initialize.
Args: Args:
range_table(list<int>): Range table. range_table(list<int>): Range table.
reduce_rate(float): The decay rate of temperature. reduce_rate(float): The decay rate of temperature.
...@@ -59,6 +45,20 @@ class SAController(EvolutionaryController): ...@@ -59,6 +45,20 @@ class SAController(EvolutionaryController):
checkpoints(str): if checkpoint is None, donnot save checkpoints, else save scene to checkpoints file. checkpoints(str): if checkpoint is None, donnot save checkpoints, else save scene to checkpoints file.
searched(dict<list, float>): remember tokens which are searched. searched(dict<list, float>): remember tokens which are searched.
""" """
def __init__(self,
range_table=None,
reduce_rate=0.85,
init_temperature=None,
max_try_times=300,
init_tokens=None,
reward=-1,
max_reward=-1,
iters=0,
best_tokens=None,
constrain_func=None,
checkpoints=None,
searched=None):
super(SAController, self).__init__() super(SAController, self).__init__()
self._range_table = range_table self._range_table = range_table
assert isinstance(self._range_table, tuple) and ( assert isinstance(self._range_table, tuple) and (
...@@ -92,6 +92,11 @@ class SAController(EvolutionaryController): ...@@ -92,6 +92,11 @@ class SAController(EvolutionaryController):
@property @property
def best_tokens(self): def best_tokens(self):
"""Get current best tokens.
Returns:
list<int>: The best tokens.
"""
return self._best_tokens return self._best_tokens
@property @property
...@@ -100,14 +105,23 @@ class SAController(EvolutionaryController): ...@@ -100,14 +105,23 @@ class SAController(EvolutionaryController):
@property @property
def current_tokens(self): def current_tokens(self):
"""Get tokens generated in current searching step.
Returns:
list<int>: The best tokens.
"""
return self._current_tokens return self._current_tokens
def update(self, tokens, reward, iter, client_num): def update(self, tokens, reward, iter, client_num):
""" """
Update the controller according to latest tokens and reward. Update the controller according to latest tokens and reward.
Args: Args:
tokens(list<int>): The tokens generated in last step. tokens(list<int>): The tokens generated in current step.
reward(float): The reward of tokens. reward(float): The reward of tokens.
iter(int): The current step of searching client.
client_num(int): The total number of searching client.
""" """
iter = int(iter) iter = int(iter)
if iter > self._iter: if iter > self._iter:
...@@ -136,6 +150,12 @@ class SAController(EvolutionaryController): ...@@ -136,6 +150,12 @@ class SAController(EvolutionaryController):
def next_tokens(self, control_token=None): def next_tokens(self, control_token=None):
""" """
Get next tokens. Get next tokens.
Args:
control_token: The tokens used to generate next tokens.
Returns:
list<int>: The next tokens.
""" """
if control_token: if control_token:
tokens = control_token[:] tokens = control_token[:]
......
...@@ -72,6 +72,7 @@ class VarWrapper(object): ...@@ -72,6 +72,7 @@ class VarWrapper(object):
def inputs(self): def inputs(self):
""" """
Get all the operators that use this variable as output. Get all the operators that use this variable as output.
Returns: Returns:
list<OpWrapper>: A list of operators. list<OpWrapper>: A list of operators.
""" """
...@@ -84,6 +85,7 @@ class VarWrapper(object): ...@@ -84,6 +85,7 @@ class VarWrapper(object):
def outputs(self): def outputs(self):
""" """
Get all the operators that use this variable as input. Get all the operators that use this variable as input.
Returns: Returns:
list<OpWrapper>: A list of operators. list<OpWrapper>: A list of operators.
""" """
...@@ -196,10 +198,7 @@ class GraphWrapper(object): ...@@ -196,10 +198,7 @@ class GraphWrapper(object):
""" """
It is a wrapper of paddle.fluid.framework.IrGraph with some special functions It is a wrapper of paddle.fluid.framework.IrGraph with some special functions
for paddle slim framework. for paddle slim framework.
"""
def __init__(self, program=None, in_nodes=[], out_nodes=[]):
"""
Args: Args:
program(framework.Program): A program with program(framework.Program): A program with
in_nodes(dict): A dict to indicate the input nodes of the graph. in_nodes(dict): A dict to indicate the input nodes of the graph.
...@@ -209,6 +208,10 @@ class GraphWrapper(object): ...@@ -209,6 +208,10 @@ class GraphWrapper(object):
The key is user-defined and human-readable name. The key is user-defined and human-readable name.
The value is the name of Variable. The value is the name of Variable.
""" """
def __init__(self, program=None, in_nodes=[], out_nodes=[]):
"""
"""
super(GraphWrapper, self).__init__() super(GraphWrapper, self).__init__()
self.program = Program() if program is None else program self.program = Program() if program is None else program
self.persistables = {} self.persistables = {}
...@@ -226,6 +229,7 @@ class GraphWrapper(object): ...@@ -226,6 +229,7 @@ class GraphWrapper(object):
def all_parameters(self): def all_parameters(self):
""" """
Get all the parameters in this graph. Get all the parameters in this graph.
Returns: Returns:
list<VarWrapper>: A list of VarWrapper instances. list<VarWrapper>: A list of VarWrapper instances.
""" """
...@@ -238,6 +242,7 @@ class GraphWrapper(object): ...@@ -238,6 +242,7 @@ class GraphWrapper(object):
def is_parameter(self, var): def is_parameter(self, var):
""" """
Whether the given variable is parameter. Whether the given variable is parameter.
Args: Args:
var(VarWrapper): The given varibale. var(VarWrapper): The given varibale.
""" """
...@@ -246,6 +251,7 @@ class GraphWrapper(object): ...@@ -246,6 +251,7 @@ class GraphWrapper(object):
def is_persistable(self, var): def is_persistable(self, var):
""" """
Whether the given variable is persistable. Whether the given variable is persistable.
Args: Args:
var(VarWrapper): The given varibale. var(VarWrapper): The given varibale.
""" """
...@@ -279,6 +285,7 @@ class GraphWrapper(object): ...@@ -279,6 +285,7 @@ class GraphWrapper(object):
def clone(self, for_test=False): def clone(self, for_test=False):
""" """
Clone a new graph from current graph. Clone a new graph from current graph.
Returns: Returns:
(GraphWrapper): The wrapper of a new graph. (GraphWrapper): The wrapper of a new graph.
""" """
...@@ -295,8 +302,10 @@ class GraphWrapper(object): ...@@ -295,8 +302,10 @@ class GraphWrapper(object):
def pre_ops(self, op): def pre_ops(self, op):
""" """
Get all the previous operators of target operator. Get all the previous operators of target operator.
Args: Args:
op(OpWrapper): Target operator.. op(OpWrapper): Target operator.
Returns: Returns:
list<OpWrapper>: A list of operators. list<OpWrapper>: A list of operators.
""" """
...@@ -310,8 +319,10 @@ class GraphWrapper(object): ...@@ -310,8 +319,10 @@ class GraphWrapper(object):
def next_ops(self, op): def next_ops(self, op):
""" """
Get all the next operators of target operator. Get all the next operators of target operator.
Args: Args:
op(OpWrapper): Target operator.. op(OpWrapper): Target operator.
Returns: Returns:
list<OpWrapper>: A list of operators. list<OpWrapper>: A list of operators.
""" """
......
...@@ -22,14 +22,16 @@ __all__ = ['OneShotSuperNet', 'OneShotSearch'] ...@@ -22,14 +22,16 @@ __all__ = ['OneShotSuperNet', 'OneShotSearch']
def OneShotSearch(model, eval_func, strategy='sa', search_steps=100): def OneShotSearch(model, eval_func, strategy='sa', search_steps=100):
""" """
Search a best tokens which represents a sub-network. Search a best tokens which represents a sub-network.
Archs:
Args:
model(fluid.dygraph.Layer): A dynamic graph module whose sub-modules should contain model(fluid.dygraph.Layer): A dynamic graph module whose sub-modules should contain
one instance of `OneShotSuperNet` at least. one instance of `OneShotSuperNet` at least.
eval_func(function): A callback function which accept model and tokens as arguments. eval_func(function): A callback function which accept model and tokens as arguments.
strategy(str): The name of strategy used to search. Default: 'sa'. strategy(str): The name of strategy used to search. Default: 'sa'.
search_steps(int): The total steps for searching. search_steps(int): The total steps for searching.
Returns: Returns:
tokens(list): The best tokens searched. list<int>: The best tokens searched.
""" """
super_net = None super_net = None
for layer in model.sublayers(include_sublayers=False): for layer in model.sublayers(include_sublayers=False):
...@@ -52,8 +54,7 @@ def OneShotSearch(model, eval_func, strategy='sa', search_steps=100): ...@@ -52,8 +54,7 @@ def OneShotSearch(model, eval_func, strategy='sa', search_steps=100):
class OneShotSuperNet(fluid.dygraph.Layer): class OneShotSuperNet(fluid.dygraph.Layer):
""" """The base class of super net used in one-shot searching strategy.
The base class of super net used in one-shot searching strategy.
A super net is a dygraph layer. A super net is a dygraph layer.
Args: Args:
...@@ -65,14 +66,16 @@ class OneShotSuperNet(fluid.dygraph.Layer): ...@@ -65,14 +66,16 @@ class OneShotSuperNet(fluid.dygraph.Layer):
def init_tokens(self): def init_tokens(self):
"""Get init tokens in search space. """Get init tokens in search space.
Return:
tokens(list): The init tokens which is a list of integer. Returns:
lis<int>t: The init tokens which is a list of integer.
""" """
raise NotImplementedError('Abstract method.') raise NotImplementedError('Abstract method.')
def range_table(self): def range_table(self):
"""Get range table of current search space. """Get range table of current search space.
Return:
Returns:
range_table(tuple): The maximum value and minimum value in each position of tokens range_table(tuple): The maximum value and minimum value in each position of tokens
with format `(min_values, max_values)`. The `min_values` is with format `(min_values, max_values)`. The `min_values` is
a list of integers indicating the minimum values while `max_values` a list of integers indicating the minimum values while `max_values`
...@@ -81,9 +84,9 @@ class OneShotSuperNet(fluid.dygraph.Layer): ...@@ -81,9 +84,9 @@ class OneShotSuperNet(fluid.dygraph.Layer):
raise NotImplementedError('Abstract method.') raise NotImplementedError('Abstract method.')
def _forward_impl(self, *inputs, **kwargs): def _forward_impl(self, *inputs, **kwargs):
""" """Defines the computation performed at every call.
Defines the computation performed at every call.
Should be overridden by all subclasses. Should be overridden by all subclasses.
Args: Args:
inputs(tuple): unpacked tuple arguments inputs(tuple): unpacked tuple arguments
kwargs(dict): unpacked dict arguments kwargs(dict): unpacked dict arguments
...@@ -93,6 +96,7 @@ class OneShotSuperNet(fluid.dygraph.Layer): ...@@ -93,6 +96,7 @@ class OneShotSuperNet(fluid.dygraph.Layer):
def forward(self, input, tokens=None): def forward(self, input, tokens=None):
""" """
Defines the computation performed at every call. Defines the computation performed at every call.
Args: Args:
input(variable): The input of super net. input(variable): The input of super net.
tokens(list): The tokens used to generate a sub-network. tokens(list): The tokens used to generate a sub-network.
...@@ -100,8 +104,9 @@ class OneShotSuperNet(fluid.dygraph.Layer): ...@@ -100,8 +104,9 @@ class OneShotSuperNet(fluid.dygraph.Layer):
Otherwise, it will execute the sub-network generated by tokens. Otherwise, it will execute the sub-network generated by tokens.
The `tokens` should be set in searching stage and final training stage. The `tokens` should be set in searching stage and final training stage.
Default: None. Default: None.
Returns: Returns:
output(varaible): The output of super net. Varaible: The output of super net.
""" """
if tokens == None: if tokens == None:
tokens = self._random_tokens() tokens = self._random_tokens()
......
...@@ -31,26 +31,9 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -31,26 +31,9 @@ _logger = get_logger(__name__, level=logging.INFO)
class AutoPruner(object): class AutoPruner(object):
def __init__(self,
program,
scope,
place,
params=[],
init_ratios=None,
pruned_flops=0.5,
pruned_latency=None,
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_times=300,
max_client_num=10,
search_steps=300,
max_ratios=[0.9],
min_ratios=[0],
key="auto_pruner",
is_server=True):
""" """
Search a group of ratios used to prune program. Search a group of ratios used to prune program.
Args: Args:
program(Program): The program to be pruned. program(Program): The program to be pruned.
scope(Scope): The scope to be pruned. scope(Scope): The scope to be pruned.
...@@ -81,6 +64,25 @@ class AutoPruner(object): ...@@ -81,6 +64,25 @@ class AutoPruner(object):
is_server(bool): Whether current host is controller server. Default: True. is_server(bool): Whether current host is controller server. Default: True.
""" """
def __init__(self,
program,
scope,
place,
params=[],
init_ratios=None,
pruned_flops=0.5,
pruned_latency=None,
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_times=300,
max_client_num=10,
search_steps=300,
max_ratios=[0.9],
min_ratios=[0],
key="auto_pruner",
is_server=True):
self._program = program self._program = program
self._scope = scope self._scope = scope
self._place = place self._place = place
...@@ -177,10 +179,12 @@ class AutoPruner(object): ...@@ -177,10 +179,12 @@ class AutoPruner(object):
def prune(self, program, eval_program=None): def prune(self, program, eval_program=None):
""" """
Prune program with latest tokens generated by controller. Prune program with latest tokens generated by controller.
Args: Args:
program(fluid.Program): The program to be pruned. program(fluid.Program): The program to be pruned.
Returns: Returns:
Program: The pruned program. paddle.fluid.Program: The pruned program.
""" """
self._current_ratios = self._next_ratios() self._current_ratios = self._next_ratios()
pruned_program, _, _ = self._pruner.prune( pruned_program, _, _ = self._pruner.prune(
...@@ -208,8 +212,9 @@ class AutoPruner(object): ...@@ -208,8 +212,9 @@ class AutoPruner(object):
def reward(self, score): def reward(self, score):
""" """
Return reward of current pruned program. Return reward of current pruned program.
Args: Args:
score(float): The score of pruned program. float: The score of pruned program.
""" """
self._restore(self._scope) self._restore(self._scope)
self._param_backup = {} self._param_backup = {}
......
...@@ -19,8 +19,9 @@ def save_model(exe, graph, dirname): ...@@ -19,8 +19,9 @@ def save_model(exe, graph, dirname):
Save weights of model and information of shapes into filesystem. Save weights of model and information of shapes into filesystem.
Args: Args:
- graph(Program|Graph): The graph to be saved. exe(paddle.fluid.Executor): The executor used to save model.
- dirname(str): The directory that the model saved into. graph(Program|Graph): The graph to be saved.
dirname(str): The directory that the model saved into.
""" """
assert graph is not None and dirname is not None assert graph is not None and dirname is not None
graph = GraphWrapper(graph) if isinstance(graph, Program) else graph graph = GraphWrapper(graph) if isinstance(graph, Program) else graph
...@@ -46,8 +47,8 @@ def load_model(exe, graph, dirname): ...@@ -46,8 +47,8 @@ def load_model(exe, graph, dirname):
Load weights of model and information of shapes from filesystem. Load weights of model and information of shapes from filesystem.
Args: Args:
- graph(Program|Graph): The graph to be saved. graph(Program|Graph): The graph to be updated by loaded information..
- dirname(str): The directory that the model saved into. dirname(str): The directory that the model will be loaded.
""" """
assert graph is not None and dirname is not None assert graph is not None and dirname is not None
graph = GraphWrapper(graph) if isinstance(graph, Program) else graph graph = GraphWrapper(graph) if isinstance(graph, Program) else graph
......
...@@ -26,12 +26,14 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -26,12 +26,14 @@ _logger = get_logger(__name__, level=logging.INFO)
class Pruner(): class Pruner():
def __init__(self, criterion="l1_norm"): """The pruner used to prune channels of convolution.
"""
Args: Args:
criterion(str): the criterion used to sort channels for pruning. criterion(str): the criterion used to sort channels for pruning. It only supports 'l1_norm' currently.
It only supports 'l1_norm' currently.
""" """
def __init__(self, criterion="l1_norm"):
self.criterion = criterion self.criterion = criterion
def prune(self, def prune(self,
...@@ -44,9 +46,10 @@ class Pruner(): ...@@ -44,9 +46,10 @@ class Pruner():
only_graph=False, only_graph=False,
param_backup=False, param_backup=False,
param_shape_backup=False): param_shape_backup=False):
""" """Pruning the given parameters.
Pruning the given parameters.
Args: Args:
program(fluid.Program): The program to be pruned. program(fluid.Program): The program to be pruned.
scope(fluid.Scope): The scope storing paramaters to be pruned. scope(fluid.Scope): The scope storing paramaters to be pruned.
params(list<str>): A list of parameter names to be pruned. params(list<str>): A list of parameter names to be pruned.
...@@ -58,10 +61,9 @@ class Pruner(): ...@@ -58,10 +61,9 @@ class Pruner():
False means modifying graph and variables in scope. Default: False. False means modifying graph and variables in scope. Default: False.
param_backup(bool): Whether to return a dict to backup the values of parameters. Default: False. param_backup(bool): Whether to return a dict to backup the values of parameters. Default: False.
param_shape_backup(bool): Whether to return a dict to backup the shapes of parameters. Default: False. param_shape_backup(bool): Whether to return a dict to backup the shapes of parameters. Default: False.
Returns: Returns:
Program: The pruned program. tuple: ``(pruned_program, param_backup, param_shape_backup)``. ``pruned_program`` is the pruned program. ``param_backup`` is a dict to backup the values of parameters. ``param_shape_backup`` is a dict to backup the shapes of parameters.
param_backup: A dict to backup the values of parameters.
param_shape_backup: A dict to backup the shapes of parameters.
""" """
self.pruned_list = [] self.pruned_list = []
...@@ -131,6 +133,7 @@ class Pruner(): ...@@ -131,6 +133,7 @@ class Pruner():
def _cal_pruned_idx(self, param, ratio, axis): def _cal_pruned_idx(self, param, ratio, axis):
""" """
Calculate the index to be pruned on axis by given pruning ratio. Calculate the index to be pruned on axis by given pruning ratio.
Args: Args:
name(str): The name of parameter to be pruned. name(str): The name of parameter to be pruned.
param(np.array): The data of parameter to be pruned. param(np.array): The data of parameter to be pruned.
...@@ -138,6 +141,7 @@ class Pruner(): ...@@ -138,6 +141,7 @@ class Pruner():
axis(int): The axis to be used for pruning given parameter. axis(int): The axis to be used for pruning given parameter.
If it is None, the value in self.pruning_axis will be used. If it is None, the value in self.pruning_axis will be used.
default: None. default: None.
Returns: Returns:
list<int>: The indexes to be pruned on axis. list<int>: The indexes to be pruned on axis.
""" """
...@@ -151,6 +155,7 @@ class Pruner(): ...@@ -151,6 +155,7 @@ class Pruner():
def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False): def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False):
""" """
Pruning a array by indexes on given axis. Pruning a array by indexes on given axis.
Args: Args:
tensor(numpy.array): The target array to be pruned. tensor(numpy.array): The target array to be pruned.
pruned_idx(list<int>): The indexes to be pruned. pruned_idx(list<int>): The indexes to be pruned.
...@@ -158,6 +163,7 @@ class Pruner(): ...@@ -158,6 +163,7 @@ class Pruner():
lazy(bool): True means setting the pruned elements to zero. lazy(bool): True means setting the pruned elements to zero.
False means remove the pruned elements from memory. False means remove the pruned elements from memory.
default: False. default: False.
Returns: Returns:
numpy.array: The pruned array. numpy.array: The pruned array.
""" """
......
...@@ -37,6 +37,35 @@ def sensitivity(program, ...@@ -37,6 +37,35 @@ def sensitivity(program,
eval_func, eval_func,
sensitivities_file=None, sensitivities_file=None,
pruned_ratios=None): pruned_ratios=None):
"""Compute the sensitivities of convolutions in a model. The sensitivity of a convolution is the losses of accuracy on test dataset in differenct pruned ratios. The sensitivities can be used to get a group of best ratios with some condition.
This function return a dict storing sensitivities as below:
.. code-block:: python
{"weight_0":
{0.1: 0.22,
0.2: 0.33
},
"weight_1":
{0.1: 0.21,
0.2: 0.4
}
}
``weight_0`` is parameter name of convolution. ``sensitivities['weight_0']`` is a dict in which key is pruned ratio and value is the percent of losses.
Args:
program(paddle.fluid.Program): The program to be analysised.
place(fluid.CPUPlace | fluid.CUDAPlace): The device place of filter parameters.
param_names(list): The parameter names of convolutions to be analysised.
eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.fluid.Program` as argument and return a score on test dataset.
sensitivities_file(str): The file to save the sensitivities. It will append the latest computed sensitivities into the file. And the sensitivities in the file would not be computed again. This file can be loaded by `pickle` library.
pruned_ratios(list): The ratios to be pruned. default: ``[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]``.
Returns:
dict: A dict storing sensitivities.
"""
scope = fluid.global_scope() scope = fluid.global_scope()
graph = GraphWrapper(program) graph = GraphWrapper(program)
sensitivities = load_sensitivities(sensitivities_file) sensitivities = load_sensitivities(sensitivities_file)
...@@ -159,13 +188,13 @@ def flops_sensitivity(program, ...@@ -159,13 +188,13 @@ def flops_sensitivity(program,
def merge_sensitive(sensitivities): def merge_sensitive(sensitivities):
""" """Merge sensitivities.
Merge sensitivities.
Args: Args:
sensitivities(list<dict> | list<str>): The sensitivities to be merged. It cann be a list of sensitivities files or dict. sensitivities(list<dict> | list<str>): The sensitivities to be merged. It cann be a list of sensitivities files or dict.
Returns: Returns:
sensitivities(dict): A dict with sensitivities. dict: A dict stroring sensitivities.
""" """
assert len(sensitivities) > 0 assert len(sensitivities) > 0
if not isinstance(sensitivities[0], dict): if not isinstance(sensitivities[0], dict):
...@@ -182,8 +211,13 @@ def merge_sensitive(sensitivities): ...@@ -182,8 +211,13 @@ def merge_sensitive(sensitivities):
def load_sensitivities(sensitivities_file): def load_sensitivities(sensitivities_file):
""" """Load sensitivities from file.
Load sensitivities from file.
Args:
sensitivities_file(str): The file storing sensitivities.
Returns:
dict: A dict stroring sensitivities.
""" """
sensitivities = {} sensitivities = {}
if sensitivities_file and os.path.exists(sensitivities_file): if sensitivities_file and os.path.exists(sensitivities_file):
...@@ -196,8 +230,11 @@ def load_sensitivities(sensitivities_file): ...@@ -196,8 +230,11 @@ def load_sensitivities(sensitivities_file):
def _save_sensitivities(sensitivities, sensitivities_file): def _save_sensitivities(sensitivities, sensitivities_file):
""" """Save sensitivities into file.
Save sensitivities into file.
Args:
sensitivities(dict): The sensitivities to be saved.
sensitivities_file(str): The file to saved sensitivities.
""" """
with open(sensitivities_file, 'wb') as f: with open(sensitivities_file, 'wb') as f:
pickle.dump(sensitivities, f) pickle.dump(sensitivities, f)
...@@ -217,7 +254,7 @@ def get_ratios_by_loss(sensitivities, loss): ...@@ -217,7 +254,7 @@ def get_ratios_by_loss(sensitivities, loss):
Returns: Returns:
ratios(dict): A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned. dict: A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned.
""" """
ratios = {} ratios = {}
for param, losses in sensitivities.items(): for param, losses in sensitivities.items():
......
...@@ -30,10 +30,10 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -30,10 +30,10 @@ _logger = get_logger(__name__, level=logging.INFO)
class SensitivePruner(object): class SensitivePruner(object):
def __init__(self, place, eval_func, scope=None, checkpoints=None):
""" """
Pruner used to prune parameters iteratively according to sensitivities Pruner used to prune parameters iteratively according to sensitivities
of parameters in each step. of parameters in each step.
Args: Args:
place(fluid.CUDAPlace | fluid.CPUPlace): The device place where place(fluid.CUDAPlace | fluid.CPUPlace): The device place where
program execute. program execute.
...@@ -42,6 +42,8 @@ class SensitivePruner(object): ...@@ -42,6 +42,8 @@ class SensitivePruner(object):
And it return a score of given program. And it return a score of given program.
scope(fluid.scope): The scope used to execute program. scope(fluid.scope): The scope used to execute program.
""" """
def __init__(self, place, eval_func, scope=None, checkpoints=None):
self._eval_func = eval_func self._eval_func = eval_func
self._iter = 0 self._iter = 0
self._place = place self._place = place
...@@ -135,12 +137,14 @@ class SensitivePruner(object): ...@@ -135,12 +137,14 @@ class SensitivePruner(object):
def prune(self, train_program, eval_program, params, pruned_flops): def prune(self, train_program, eval_program, params, pruned_flops):
""" """
Pruning parameters of training and evaluation network by sensitivities in current step. Pruning parameters of training and evaluation network by sensitivities in current step.
Args: Args:
train_program(fluid.Program): The training program to be pruned. train_program(fluid.Program): The training program to be pruned.
eval_program(fluid.Program): The evaluation program to be pruned. And it is also used to calculate sensitivities of parameters. eval_program(fluid.Program): The evaluation program to be pruned. And it is also used to calculate sensitivities of parameters.
params(list<str>): The parameters to be pruned. params(list<str>): The parameters to be pruned.
pruned_flops(float): The ratio of FLOPS to be pruned in current step. pruned_flops(float): The ratio of FLOPS to be pruned in current step.
Return:
Returns:
tuple: A tuple of pruned training program and pruned evaluation program. tuple: A tuple of pruned training program and pruned evaluation program.
""" """
_logger.info("Pruning: {}".format(params)) _logger.info("Pruning: {}".format(params))
...@@ -199,9 +203,9 @@ class SensitivePruner(object): ...@@ -199,9 +203,9 @@ class SensitivePruner(object):
pruned_flops(float): The percent of FLOPS to be pruned. pruned_flops(float): The percent of FLOPS to be pruned.
eval_program(Program): The program whose FLOPS is considered. eval_program(Program): The program whose FLOPS is considered.
Return: Returns:
ratios(dict): A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned. dict: A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned.
""" """
min_loss = 0. min_loss = 0.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册