diff --git a/docs/en/api_en/index_en.rst b/docs/en/api_en/index_en.rst index d252687665fdb5ffc3d8789417838454e33e6600..6e15635a74d55337363743128c2018fd3f8474ad 100644 --- a/docs/en/api_en/index_en.rst +++ b/docs/en/api_en/index_en.rst @@ -4,21 +4,16 @@ contain the root `toctree` directive. API Documents -======== +============== .. toctree:: :maxdepth: 1 - modules.rst paddleslim.analysis.rst - paddleslim.common.rst - paddleslim.core.rst + paddleslim.prune.rst paddleslim.dist.rst - paddleslim.models.rst - paddleslim.nas.one_shot.rst + paddleslim.quant.rst paddleslim.nas.rst + paddleslim.nas.one_shot.rst paddleslim.nas.search_space.rst paddleslim.pantheon.rst - paddleslim.prune.rst - paddleslim.quant.rst - paddleslim.rst diff --git a/docs/en/api_en/paddleslim.analysis.rst b/docs/en/api_en/paddleslim.analysis.rst index 9c6ba5dca27a1574fd9c78ded6789e0f07944f74..e9dc6db11dce2d4df4704fe5d43d3c943f50677f 100644 --- a/docs/en/api_en/paddleslim.analysis.rst +++ b/docs/en/api_en/paddleslim.analysis.rst @@ -1,5 +1,5 @@ -paddleslim.analysis package -=========================== +paddleslim\.analysis package +============================ .. automodule:: paddleslim.analysis :members: @@ -9,24 +9,24 @@ paddleslim.analysis package Submodules ---------- -paddleslim.analysis.flops module --------------------------------- +paddleslim\.analysis\.flops module +---------------------------------- .. automodule:: paddleslim.analysis.flops :members: :undoc-members: :show-inheritance: -paddleslim.analysis.latency module ----------------------------------- +paddleslim\.analysis\.latency module +------------------------------------ .. automodule:: paddleslim.analysis.latency :members: :undoc-members: :show-inheritance: -paddleslim.analysis.model\_size module --------------------------------------- +paddleslim\.analysis\.model\_size module +---------------------------------------- .. automodule:: paddleslim.analysis.model_size :members: diff --git a/docs/en/api_en/paddleslim.common.rst b/docs/en/api_en/paddleslim.common.rst index fe498ca83478e2bb9b34760dd04a9d333b5ab22a..a59bd0085952d4b75bb4c336e7e19e21c0353bea 100644 --- a/docs/en/api_en/paddleslim.common.rst +++ b/docs/en/api_en/paddleslim.common.rst @@ -1,5 +1,5 @@ -paddleslim.common package -========================= +paddleslim\.common package +========================== .. automodule:: paddleslim.common :members: @@ -9,56 +9,56 @@ paddleslim.common package Submodules ---------- -paddleslim.common.cached\_reader module ---------------------------------------- +paddleslim\.common\.cached\_reader module +----------------------------------------- .. automodule:: paddleslim.common.cached_reader :members: :undoc-members: :show-inheritance: -paddleslim.common.controller module ------------------------------------ +paddleslim\.common\.controller module +------------------------------------- .. automodule:: paddleslim.common.controller :members: :undoc-members: :show-inheritance: -paddleslim.common.controller\_client module -------------------------------------------- +paddleslim\.common\.controller\_client module +--------------------------------------------- .. automodule:: paddleslim.common.controller_client :members: :undoc-members: :show-inheritance: -paddleslim.common.controller\_server module -------------------------------------------- +paddleslim\.common\.controller\_server module +--------------------------------------------- .. automodule:: paddleslim.common.controller_server :members: :undoc-members: :show-inheritance: -paddleslim.common.lock module ------------------------------ +paddleslim\.common\.lock module +------------------------------- .. automodule:: paddleslim.common.lock :members: :undoc-members: :show-inheritance: -paddleslim.common.log\_helper module ------------------------------------- +paddleslim\.common\.log\_helper module +-------------------------------------- .. automodule:: paddleslim.common.log_helper :members: :undoc-members: :show-inheritance: -paddleslim.common.sa\_controller module ---------------------------------------- +paddleslim\.common\.sa\_controller module +----------------------------------------- .. automodule:: paddleslim.common.sa_controller :members: diff --git a/docs/en/api_en/paddleslim.core.rst b/docs/en/api_en/paddleslim.core.rst index e3cc6cdb0b17e49dc6482316dd92ceac37b60762..38ed2f4d153a4d84833d1dca458292acef350072 100644 --- a/docs/en/api_en/paddleslim.core.rst +++ b/docs/en/api_en/paddleslim.core.rst @@ -1,5 +1,5 @@ -paddleslim.core package -======================= +paddleslim\.core package +======================== .. automodule:: paddleslim.core :members: @@ -9,16 +9,16 @@ paddleslim.core package Submodules ---------- -paddleslim.core.graph\_wrapper module -------------------------------------- +paddleslim\.core\.graph\_wrapper module +--------------------------------------- .. automodule:: paddleslim.core.graph_wrapper :members: :undoc-members: :show-inheritance: -paddleslim.core.registry module -------------------------------- +paddleslim\.core\.registry module +--------------------------------- .. automodule:: paddleslim.core.registry :members: diff --git a/docs/en/api_en/paddleslim.dist.rst b/docs/en/api_en/paddleslim.dist.rst index e7b776bd00f6804de16a7ccc45c2e253b3d2e20f..a886778da14600e33ee0f02d98dce1be4b9a1e89 100644 --- a/docs/en/api_en/paddleslim.dist.rst +++ b/docs/en/api_en/paddleslim.dist.rst @@ -1,5 +1,5 @@ -paddleslim.dist package -======================= +paddleslim\.dist package +======================== .. automodule:: paddleslim.dist :members: @@ -9,8 +9,8 @@ paddleslim.dist package Submodules ---------- -paddleslim.dist.single\_distiller module ----------------------------------------- +paddleslim\.dist\.single\_distiller module +------------------------------------------ .. automodule:: paddleslim.dist.single_distiller :members: diff --git a/docs/en/api_en/paddleslim.models.rst b/docs/en/api_en/paddleslim.models.rst index fb233e48ae168efe45fd455d6b37627f11608f75..958a9682e08ff89be5ccbb1d61fcc36ceab00e55 100644 --- a/docs/en/api_en/paddleslim.models.rst +++ b/docs/en/api_en/paddleslim.models.rst @@ -1,5 +1,5 @@ -paddleslim.models package -========================= +paddleslim\.models package +========================== .. automodule:: paddleslim.models :members: @@ -9,40 +9,40 @@ paddleslim.models package Submodules ---------- -paddleslim.models.classification\_models module ------------------------------------------------ +paddleslim\.models\.classification\_models module +------------------------------------------------- .. automodule:: paddleslim.models.classification_models :members: :undoc-members: :show-inheritance: -paddleslim.models.mobilenet module ----------------------------------- +paddleslim\.models\.mobilenet module +------------------------------------ .. automodule:: paddleslim.models.mobilenet :members: :undoc-members: :show-inheritance: -paddleslim.models.mobilenet\_v2 module --------------------------------------- +paddleslim\.models\.mobilenet\_v2 module +---------------------------------------- .. automodule:: paddleslim.models.mobilenet_v2 :members: :undoc-members: :show-inheritance: -paddleslim.models.resnet module -------------------------------- +paddleslim\.models\.resnet module +--------------------------------- .. automodule:: paddleslim.models.resnet :members: :undoc-members: :show-inheritance: -paddleslim.models.util module ------------------------------ +paddleslim\.models\.util module +------------------------------- .. automodule:: paddleslim.models.util :members: diff --git a/docs/en/api_en/paddleslim.nas.one_shot.rst b/docs/en/api_en/paddleslim.nas.one_shot.rst index a0dcf7af840cf8a8401e8ee10ce24ed677018914..f9ebde94783a33ede7463148b909d3147e9bbb1f 100644 --- a/docs/en/api_en/paddleslim.nas.one_shot.rst +++ b/docs/en/api_en/paddleslim.nas.one_shot.rst @@ -1,5 +1,5 @@ -paddleslim.nas.one\_shot package -================================ +paddleslim\.nas\.one\_shot package +================================== .. automodule:: paddleslim.nas.one_shot :members: @@ -9,16 +9,16 @@ paddleslim.nas.one\_shot package Submodules ---------- -paddleslim.nas.one\_shot.one\_shot\_nas module ----------------------------------------------- +paddleslim\.nas\.one\_shot\.one\_shot\_nas module +------------------------------------------------- .. automodule:: paddleslim.nas.one_shot.one_shot_nas :members: :undoc-members: :show-inheritance: -paddleslim.nas.one\_shot.super\_mnasnet module ----------------------------------------------- +paddleslim\.nas\.one\_shot\.super\_mnasnet module +------------------------------------------------- .. automodule:: paddleslim.nas.one_shot.super_mnasnet :members: diff --git a/docs/en/api_en/paddleslim.nas.rst b/docs/en/api_en/paddleslim.nas.rst index 1fac5029f4e077352eeff465d181afa486a29f47..30eddca0baf1355dbd5872f7fcb97e71698a98c2 100644 --- a/docs/en/api_en/paddleslim.nas.rst +++ b/docs/en/api_en/paddleslim.nas.rst @@ -1,5 +1,5 @@ -paddleslim.nas package -====================== +paddleslim\.nas package +======================= .. automodule:: paddleslim.nas :members: @@ -17,8 +17,8 @@ Subpackages Submodules ---------- -paddleslim.nas.sa\_nas module ------------------------------ +paddleslim\.nas\.sa\_nas module +------------------------------- .. automodule:: paddleslim.nas.sa_nas :members: diff --git a/docs/en/api_en/paddleslim.nas.search_space.rst b/docs/en/api_en/paddleslim.nas.search_space.rst index f07a71cec6d2a4eb3e6d0b4020c359e6422f83dd..f078d4af7cc0d9d77a7ea05e3cff43ce3c112788 100644 --- a/docs/en/api_en/paddleslim.nas.search_space.rst +++ b/docs/en/api_en/paddleslim.nas.search_space.rst @@ -1,5 +1,5 @@ -paddleslim.nas.search\_space package -==================================== +paddleslim\.nas\.search\_space package +====================================== .. automodule:: paddleslim.nas.search_space :members: @@ -9,96 +9,96 @@ paddleslim.nas.search\_space package Submodules ---------- -paddleslim.nas.search\_space.base\_layer module ------------------------------------------------ +paddleslim\.nas\.search\_space\.base\_layer module +-------------------------------------------------- .. automodule:: paddleslim.nas.search_space.base_layer :members: :undoc-members: :show-inheritance: -paddleslim.nas.search\_space.combine\_search\_space module ----------------------------------------------------------- +paddleslim\.nas\.search\_space\.combine\_search\_space module +------------------------------------------------------------- .. automodule:: paddleslim.nas.search_space.combine_search_space :members: :undoc-members: :show-inheritance: -paddleslim.nas.search\_space.inception\_block module ----------------------------------------------------- +paddleslim\.nas\.search\_space\.inception\_block module +------------------------------------------------------- .. automodule:: paddleslim.nas.search_space.inception_block :members: :undoc-members: :show-inheritance: -paddleslim.nas.search\_space.mobilenet\_block module ----------------------------------------------------- +paddleslim\.nas\.search\_space\.mobilenet\_block module +------------------------------------------------------- .. automodule:: paddleslim.nas.search_space.mobilenet_block :members: :undoc-members: :show-inheritance: -paddleslim.nas.search\_space.mobilenetv1 module ------------------------------------------------ +paddleslim\.nas\.search\_space\.mobilenetv1 module +-------------------------------------------------- .. automodule:: paddleslim.nas.search_space.mobilenetv1 :members: :undoc-members: :show-inheritance: -paddleslim.nas.search\_space.mobilenetv2 module ------------------------------------------------ +paddleslim\.nas\.search\_space\.mobilenetv2 module +-------------------------------------------------- .. automodule:: paddleslim.nas.search_space.mobilenetv2 :members: :undoc-members: :show-inheritance: -paddleslim.nas.search\_space.resnet module ------------------------------------------- +paddleslim\.nas\.search\_space\.resnet module +--------------------------------------------- .. automodule:: paddleslim.nas.search_space.resnet :members: :undoc-members: :show-inheritance: -paddleslim.nas.search\_space.resnet\_block module -------------------------------------------------- +paddleslim\.nas\.search\_space\.resnet\_block module +---------------------------------------------------- .. automodule:: paddleslim.nas.search_space.resnet_block :members: :undoc-members: :show-inheritance: -paddleslim.nas.search\_space.search\_space\_base module -------------------------------------------------------- +paddleslim\.nas\.search\_space\.search\_space\_base module +---------------------------------------------------------- .. automodule:: paddleslim.nas.search_space.search_space_base :members: :undoc-members: :show-inheritance: -paddleslim.nas.search\_space.search\_space\_factory module ----------------------------------------------------------- +paddleslim\.nas\.search\_space\.search\_space\_factory module +------------------------------------------------------------- .. automodule:: paddleslim.nas.search_space.search_space_factory :members: :undoc-members: :show-inheritance: -paddleslim.nas.search\_space.search\_space\_registry module ------------------------------------------------------------ +paddleslim\.nas\.search\_space\.search\_space\_registry module +-------------------------------------------------------------- .. automodule:: paddleslim.nas.search_space.search_space_registry :members: :undoc-members: :show-inheritance: -paddleslim.nas.search\_space.utils module ------------------------------------------ +paddleslim\.nas\.search\_space\.utils module +-------------------------------------------- .. automodule:: paddleslim.nas.search_space.utils :members: diff --git a/docs/en/api_en/paddleslim.pantheon.rst b/docs/en/api_en/paddleslim.pantheon.rst index bfaf4948769260fcdb4453efa812ba10c1a9cd21..59f48ce9dc4c653b9724eda46050da768623a0b3 100644 --- a/docs/en/api_en/paddleslim.pantheon.rst +++ b/docs/en/api_en/paddleslim.pantheon.rst @@ -1,5 +1,5 @@ -paddleslim.pantheon package -=========================== +paddleslim\.pantheon package +============================ .. automodule:: paddleslim.pantheon :members: @@ -9,24 +9,24 @@ paddleslim.pantheon package Submodules ---------- -paddleslim.pantheon.student module ----------------------------------- +paddleslim\.pantheon\.student module +------------------------------------ .. automodule:: paddleslim.pantheon.student :members: :undoc-members: :show-inheritance: -paddleslim.pantheon.teacher module ----------------------------------- +paddleslim\.pantheon\.teacher module +------------------------------------ .. automodule:: paddleslim.pantheon.teacher :members: :undoc-members: :show-inheritance: -paddleslim.pantheon.utils module --------------------------------- +paddleslim\.pantheon\.utils module +---------------------------------- .. automodule:: paddleslim.pantheon.utils :members: diff --git a/docs/en/api_en/paddleslim.prune.rst b/docs/en/api_en/paddleslim.prune.rst index af7c896beda266966c8ed8debf0bdf7d61eb1353..ec663d35ecd0f240fd47328654619feff16e805b 100644 --- a/docs/en/api_en/paddleslim.prune.rst +++ b/docs/en/api_en/paddleslim.prune.rst @@ -1,5 +1,5 @@ -paddleslim.prune package -======================== +paddleslim\.prune package +========================= .. automodule:: paddleslim.prune :members: @@ -9,48 +9,48 @@ paddleslim.prune package Submodules ---------- -paddleslim.prune.auto\_pruner module ------------------------------------- +paddleslim\.prune\.auto\_pruner module +-------------------------------------- .. automodule:: paddleslim.prune.auto_pruner :members: :undoc-members: :show-inheritance: -paddleslim.prune.prune\_io module ---------------------------------- +paddleslim\.prune\.prune\_io module +----------------------------------- .. automodule:: paddleslim.prune.prune_io :members: :undoc-members: :show-inheritance: -paddleslim.prune.prune\_walker module -------------------------------------- +paddleslim\.prune\.prune\_walker module +--------------------------------------- .. automodule:: paddleslim.prune.prune_walker :members: :undoc-members: :show-inheritance: -paddleslim.prune.pruner module ------------------------------- +paddleslim\.prune\.pruner module +-------------------------------- .. automodule:: paddleslim.prune.pruner :members: :undoc-members: :show-inheritance: -paddleslim.prune.sensitive module ---------------------------------- +paddleslim\.prune\.sensitive module +----------------------------------- .. automodule:: paddleslim.prune.sensitive :members: :undoc-members: :show-inheritance: -paddleslim.prune.sensitive\_pruner module ------------------------------------------ +paddleslim\.prune\.sensitive\_pruner module +------------------------------------------- .. automodule:: paddleslim.prune.sensitive_pruner :members: diff --git a/docs/en/api_en/paddleslim.quant.rst b/docs/en/api_en/paddleslim.quant.rst index fd2006c1584f4be2417031b36218ea214415d5e9..6051606a65214c4d3e9027472bd26701d99d5f54 100644 --- a/docs/en/api_en/paddleslim.quant.rst +++ b/docs/en/api_en/paddleslim.quant.rst @@ -1,5 +1,5 @@ -paddleslim.quant package -======================== +paddleslim\.quant package +========================= .. automodule:: paddleslim.quant :members: @@ -9,16 +9,16 @@ paddleslim.quant package Submodules ---------- -paddleslim.quant.quant\_embedding module ----------------------------------------- +paddleslim\.quant\.quant\_embedding module +------------------------------------------ .. automodule:: paddleslim.quant.quant_embedding :members: :undoc-members: :show-inheritance: -paddleslim.quant.quanter module -------------------------------- +paddleslim\.quant\.quanter module +--------------------------------- .. automodule:: paddleslim.quant.quanter :members: diff --git a/docs/en/api_en/paddleslim.rst b/docs/en/api_en/paddleslim.rst index 958e7b40b5baa1ae8c838ba971657db9cb4e5dbe..85bf130ce662f5fa165879dfa766b4672e09b9ab 100644 --- a/docs/en/api_en/paddleslim.rst +++ b/docs/en/api_en/paddleslim.rst @@ -24,8 +24,8 @@ Subpackages Submodules ---------- -paddleslim.version module -------------------------- +paddleslim\.version module +-------------------------- .. automodule:: paddleslim.version :members: diff --git a/paddleslim/common/cached_reader.py b/paddleslim/common/cached_reader.py index 55f27054efe55d9df90352b3e707fe51c8996023..c297fe9dfdd56b92b4e46f79ad014f0f4261b022 100644 --- a/paddleslim/common/cached_reader.py +++ b/paddleslim/common/cached_reader.py @@ -25,6 +25,7 @@ _logger = get_logger(__name__, level=logging.INFO) def cached_reader(reader, sampled_rate, cache_path, cached_id): """ Sample partial data from reader and cache them into local file system. + Args: reader: Iterative data source. sampled_rate(float): The sampled rate used to sample partial data for evaluation. None means using all data in eval_reader. default: None. diff --git a/paddleslim/common/controller.py b/paddleslim/common/controller.py index 8c30f49c3aec27a326417554bac3163789342ff6..d06b6b88dbd32d5efaf763cd4da36d3709e54470 100644 --- a/paddleslim/common/controller.py +++ b/paddleslim/common/controller.py @@ -24,11 +24,9 @@ class EvolutionaryController(object): """Abstract controller for all evolutionary searching method. """ - def __init__(self, *args, **kwargs): - pass - def update(self, tokens, reward): """Update the status of controller according current tokens and reward. + Args: tokens(list): A solution of searching task. reward(list): The reward of tokens. @@ -37,6 +35,7 @@ class EvolutionaryController(object): def reset(self, range_table, constrain_func=None): """Reset the controller. + Args: range_table(list): It is used to define the searching space of controller. The tokens[i] generated by controller should be in [0, range_table[i]). @@ -47,5 +46,8 @@ class EvolutionaryController(object): def next_tokens(self): """Generate new tokens. + + Returns: + list: The next searched tokens. """ raise NotImplementedError('Abstract method.') diff --git a/paddleslim/common/controller_client.py b/paddleslim/common/controller_client.py index c60bcfaffdb03771588011f11dd8f875d491dcdf..2c192c1d457b4c640de1590d037719f6101fd479 100644 --- a/paddleslim/common/controller_client.py +++ b/paddleslim/common/controller_client.py @@ -24,6 +24,11 @@ _logger = get_logger(__name__, level=logging.INFO) class ControllerClient(object): """ Controller client. + Args: + server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None. + server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0. + key(str): The key used to identify legal agent for controller server. Default: "light-nas" + client_name(str): Current client name, random generate for counting client number. Default: None. """ def __init__(self, @@ -32,11 +37,6 @@ class ControllerClient(object): key=None, client_name=None): """ - Args: - server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None. - server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0. - key(str): The key used to identify legal agent for controller server. Default: "light-nas" - client_name(str): Current client name, random generate for counting client number. Default: None. """ self.server_ip = server_ip self.server_port = server_port @@ -46,9 +46,11 @@ class ControllerClient(object): def update(self, tokens, reward, iter): """ Update the controller according to latest tokens and reward. + Args: tokens(list): The tokens generated in last step. reward(float): The reward of tokens. + iter(int): The iteration number of current client. """ socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) socket_client.connect((self.server_ip, self.server_port)) diff --git a/paddleslim/common/controller_server.py b/paddleslim/common/controller_server.py index 5e1ef737a690788d562df7978128a06812bed1e4..008639c0a6a372f4f5d13dffc7f433b919486547 100644 --- a/paddleslim/common/controller_server.py +++ b/paddleslim/common/controller_server.py @@ -26,8 +26,14 @@ _logger = get_logger(__name__, level=logging.INFO) class ControllerServer(object): - """ - The controller wrapper with a socket server to handle the request of search agent. + """The controller wrapper with a socket server to handle the request of search agent. + Args: + controller(slim.searcher.Controller): The controller used to generate tokens. + address(tuple): The address of current server binding with format (ip, port). Default: ('', 0). + which means setting ip automatically + max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100. + search_steps(int|None): The total steps of searching. None means never stopping. Default: None + key(str|None): Config information. Default: None. """ def __init__(self, @@ -37,13 +43,6 @@ class ControllerServer(object): search_steps=None, key=None): """ - Args: - controller(slim.searcher.Controller): The controller used to generate tokens. - address(tuple): The address of current server binding with format (ip, port). Default: ('', 0). - which means setting ip automatically - max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100. - search_steps(int|None): The total steps of searching. None means never stopping. Default: None - key(str|None): Config information. Default: None. """ self._controller = controller self._address = address @@ -84,6 +83,8 @@ class ControllerServer(object): return self._ip def run(self): + """Start the server. + """ _logger.info("Controller Server run...") try: while ((self._search_steps is None) or diff --git a/paddleslim/common/log_helper.py b/paddleslim/common/log_helper.py index 18000ce4ec6c472914de49a053e960c02cfd8e32..e0b38e893d74fee1f5ac2b5410498295994253f8 100644 --- a/paddleslim/common/log_helper.py +++ b/paddleslim/common/log_helper.py @@ -24,15 +24,20 @@ def get_logger(name, level, fmt='%(asctime)s-%(levelname)s: %(message)s'): Get logger from logging with given name, level and format without setting logging basicConfig. For setting basicConfig in paddle will disable basicConfig setting after import paddle. + Args: name (str): The logger name. level (logging.LEVEL): The base level of the logger fmt (str): Format of logger output + Returns: logging.Logger: logging logger with given setttings + Examples: - .. code-block:: python - logger = log_helper.get_logger(__name__, logging.INFO, + + .. code-block:: python + + logger = log_helper.get_logger(__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') """ diff --git a/paddleslim/common/sa_controller.py b/paddleslim/common/sa_controller.py index b1034762a5bd8999acd0c8e4a40249ba68590415..aa150b1bdd979a415ffb1a2a74310e1bd4fa854a 100644 --- a/paddleslim/common/sa_controller.py +++ b/paddleslim/common/sa_controller.py @@ -29,7 +29,22 @@ _logger = get_logger(__name__, level=logging.INFO) class SAController(EvolutionaryController): - """Simulated annealing controller.""" + """Simulated annealing controller. + + Args: + range_table(list): Range table. + reduce_rate(float): The decay rate of temperature. + init_temperature(float): Init temperature. + max_try_times(int): max try times before get legal tokens. Default: 300. + init_tokens(list): The initial tokens. Default: None. + reward(float): The reward of current tokens. Default: -1. + max_reward(float): The max reward in the search of sanas, in general, best tokens get max reward. Default: -1. + iters(int): The iteration of sa controller. Default: 0. + best_tokens(list): The best tokens in the search of sanas, in general, best tokens get max reward. Default: None. + constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. + checkpoints(str): if checkpoint is None, donnot save checkpoints, else save scene to checkpoints file. + searched(dict): remember tokens which are searched. + """ def __init__(self, range_table=None, @@ -44,21 +59,6 @@ class SAController(EvolutionaryController): constrain_func=None, checkpoints=None, searched=None): - """Initialize. - Args: - range_table(list): Range table. - reduce_rate(float): The decay rate of temperature. - init_temperature(float): Init temperature. - max_try_times(int): max try times before get legal tokens. Default: 300. - init_tokens(list): The initial tokens. Default: None. - reward(float): The reward of current tokens. Default: -1. - max_reward(float): The max reward in the search of sanas, in general, best tokens get max reward. Default: -1. - iters(int): The iteration of sa controller. Default: 0. - best_tokens(list): The best tokens in the search of sanas, in general, best tokens get max reward. Default: None. - constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. - checkpoints(str): if checkpoint is None, donnot save checkpoints, else save scene to checkpoints file. - searched(dict): remember tokens which are searched. - """ super(SAController, self).__init__() self._range_table = range_table assert isinstance(self._range_table, tuple) and ( @@ -92,6 +92,11 @@ class SAController(EvolutionaryController): @property def best_tokens(self): + """Get current best tokens. + + Returns: + list: The best tokens. + """ return self._best_tokens @property @@ -100,14 +105,23 @@ class SAController(EvolutionaryController): @property def current_tokens(self): + """Get tokens generated in current searching step. + + Returns: + list: The best tokens. + """ + return self._current_tokens def update(self, tokens, reward, iter, client_num): """ Update the controller according to latest tokens and reward. + Args: - tokens(list): The tokens generated in last step. + tokens(list): The tokens generated in current step. reward(float): The reward of tokens. + iter(int): The current step of searching client. + client_num(int): The total number of searching client. """ iter = int(iter) if iter > self._iter: @@ -136,6 +150,12 @@ class SAController(EvolutionaryController): def next_tokens(self, control_token=None): """ Get next tokens. + + Args: + control_token: The tokens used to generate next tokens. + + Returns: + list: The next tokens. """ if control_token: tokens = control_token[:] diff --git a/paddleslim/core/graph_wrapper.py b/paddleslim/core/graph_wrapper.py index 1feedd121ccf8c08c7962c66acc022423fd69cc9..d865865284cc5598a8940558bd28a8a95e1539fc 100644 --- a/paddleslim/core/graph_wrapper.py +++ b/paddleslim/core/graph_wrapper.py @@ -72,6 +72,7 @@ class VarWrapper(object): def inputs(self): """ Get all the operators that use this variable as output. + Returns: list: A list of operators. """ @@ -84,6 +85,7 @@ class VarWrapper(object): def outputs(self): """ Get all the operators that use this variable as input. + Returns: list: A list of operators. """ @@ -196,18 +198,19 @@ class GraphWrapper(object): """ It is a wrapper of paddle.fluid.framework.IrGraph with some special functions for paddle slim framework. + + Args: + program(framework.Program): A program with + in_nodes(dict): A dict to indicate the input nodes of the graph. + The key is user-defined and human-readable name. + The value is the name of Variable. + out_nodes(dict): A dict to indicate the input nodes of the graph. + The key is user-defined and human-readable name. + The value is the name of Variable. """ def __init__(self, program=None, in_nodes=[], out_nodes=[]): """ - Args: - program(framework.Program): A program with - in_nodes(dict): A dict to indicate the input nodes of the graph. - The key is user-defined and human-readable name. - The value is the name of Variable. - out_nodes(dict): A dict to indicate the input nodes of the graph. - The key is user-defined and human-readable name. - The value is the name of Variable. """ super(GraphWrapper, self).__init__() self.program = Program() if program is None else program @@ -226,6 +229,7 @@ class GraphWrapper(object): def all_parameters(self): """ Get all the parameters in this graph. + Returns: list: A list of VarWrapper instances. """ @@ -238,6 +242,7 @@ class GraphWrapper(object): def is_parameter(self, var): """ Whether the given variable is parameter. + Args: var(VarWrapper): The given varibale. """ @@ -246,6 +251,7 @@ class GraphWrapper(object): def is_persistable(self, var): """ Whether the given variable is persistable. + Args: var(VarWrapper): The given varibale. """ @@ -279,6 +285,7 @@ class GraphWrapper(object): def clone(self, for_test=False): """ Clone a new graph from current graph. + Returns: (GraphWrapper): The wrapper of a new graph. """ @@ -295,8 +302,10 @@ class GraphWrapper(object): def pre_ops(self, op): """ Get all the previous operators of target operator. + Args: - op(OpWrapper): Target operator.. + op(OpWrapper): Target operator. + Returns: list: A list of operators. """ @@ -310,8 +319,10 @@ class GraphWrapper(object): def next_ops(self, op): """ Get all the next operators of target operator. + Args: - op(OpWrapper): Target operator.. + op(OpWrapper): Target operator. + Returns: list: A list of operators. """ diff --git a/paddleslim/nas/one_shot/one_shot_nas.py b/paddleslim/nas/one_shot/one_shot_nas.py index 20e0d64046a074077e677ea259f206eff26fd07a..444f512ac8daee43e949a3ad3f917c8cf7344724 100644 --- a/paddleslim/nas/one_shot/one_shot_nas.py +++ b/paddleslim/nas/one_shot/one_shot_nas.py @@ -22,14 +22,16 @@ __all__ = ['OneShotSuperNet', 'OneShotSearch'] def OneShotSearch(model, eval_func, strategy='sa', search_steps=100): """ Search a best tokens which represents a sub-network. - Archs: + + Args: model(fluid.dygraph.Layer): A dynamic graph module whose sub-modules should contain one instance of `OneShotSuperNet` at least. eval_func(function): A callback function which accept model and tokens as arguments. strategy(str): The name of strategy used to search. Default: 'sa'. search_steps(int): The total steps for searching. + Returns: - tokens(list): The best tokens searched. + list: The best tokens searched. """ super_net = None for layer in model.sublayers(include_sublayers=False): @@ -52,8 +54,7 @@ def OneShotSearch(model, eval_func, strategy='sa', search_steps=100): class OneShotSuperNet(fluid.dygraph.Layer): - """ - The base class of super net used in one-shot searching strategy. + """The base class of super net used in one-shot searching strategy. A super net is a dygraph layer. Args: @@ -65,14 +66,16 @@ class OneShotSuperNet(fluid.dygraph.Layer): def init_tokens(self): """Get init tokens in search space. - Return: - tokens(list): The init tokens which is a list of integer. + + Returns: + list: The init tokens which is a list of integer. """ raise NotImplementedError('Abstract method.') def range_table(self): """Get range table of current search space. - Return: + + Returns: range_table(tuple): The maximum value and minimum value in each position of tokens with format `(min_values, max_values)`. The `min_values` is a list of integers indicating the minimum values while `max_values` @@ -81,9 +84,9 @@ class OneShotSuperNet(fluid.dygraph.Layer): raise NotImplementedError('Abstract method.') def _forward_impl(self, *inputs, **kwargs): - """ - Defines the computation performed at every call. + """Defines the computation performed at every call. Should be overridden by all subclasses. + Args: inputs(tuple): unpacked tuple arguments kwargs(dict): unpacked dict arguments @@ -93,6 +96,7 @@ class OneShotSuperNet(fluid.dygraph.Layer): def forward(self, input, tokens=None): """ Defines the computation performed at every call. + Args: input(variable): The input of super net. tokens(list): The tokens used to generate a sub-network. @@ -100,8 +104,9 @@ class OneShotSuperNet(fluid.dygraph.Layer): Otherwise, it will execute the sub-network generated by tokens. The `tokens` should be set in searching stage and final training stage. Default: None. + Returns: - output(varaible): The output of super net. + Varaible: The output of super net. """ if tokens == None: tokens = self._random_tokens() diff --git a/paddleslim/prune/auto_pruner.py b/paddleslim/prune/auto_pruner.py index 672ce78d12a3c69a9648543fa8b48dddc87c11f1..55ec0f55e387c5d8729ea85c233e8f2d54f82352 100644 --- a/paddleslim/prune/auto_pruner.py +++ b/paddleslim/prune/auto_pruner.py @@ -31,6 +31,39 @@ _logger = get_logger(__name__, level=logging.INFO) class AutoPruner(object): + """ + Search a group of ratios used to prune program. + + Args: + program(Program): The program to be pruned. + scope(Scope): The scope to be pruned. + place(fluid.Place): The device place of parameters. + params(list): The names of parameters to be pruned. + init_ratios(list|float): Init ratios used to pruned parameters in `params`. + List means ratios used for pruning each parameter in `params`. + The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. + If it is a scalar, all the parameters in `params` will be pruned by uniform ratio. + None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None. + pruned_flops(float): The percent of FLOPS to be pruned. Default: None. + pruned_latency(float): The percent of latency to be pruned. Default: None. + server_addr(tuple): A tuple of server ip and server port for controller server. + init_temperature(float): The init temperature used in simulated annealing search strategy. + reduce_rate(float): The decay rate used in simulated annealing search strategy. + max_try_times(int): The max number of trying to generate legal tokens. + max_client_num(int): The max number of connections of controller server. + search_steps(int): The steps of searching. + max_ratios(float|list): Max ratios used to pruned parameters in `params`. + List means max ratios for each parameter in `params`. + The length of `max_ratios` should be equal to length of params when `max_ratios` is a list. + If it is a scalar, it will used for all the parameters in `params`. + min_ratios(float|list): Min ratios used to pruned parameters in `params`. + List means min ratios for each parameter in `params`. + The length of `min_ratios` should be equal to length of params when `min_ratios` is a list. + If it is a scalar, it will used for all the parameters in `params`. + key(str): Identity used in communication between controller server and clients. + is_server(bool): Whether current host is controller server. Default: True. + """ + def __init__(self, program, scope, @@ -49,37 +82,6 @@ class AutoPruner(object): min_ratios=[0], key="auto_pruner", is_server=True): - """ - Search a group of ratios used to prune program. - Args: - program(Program): The program to be pruned. - scope(Scope): The scope to be pruned. - place(fluid.Place): The device place of parameters. - params(list): The names of parameters to be pruned. - init_ratios(list|float): Init ratios used to pruned parameters in `params`. - List means ratios used for pruning each parameter in `params`. - The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. - If it is a scalar, all the parameters in `params` will be pruned by uniform ratio. - None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None. - pruned_flops(float): The percent of FLOPS to be pruned. Default: None. - pruned_latency(float): The percent of latency to be pruned. Default: None. - server_addr(tuple): A tuple of server ip and server port for controller server. - init_temperature(float): The init temperature used in simulated annealing search strategy. - reduce_rate(float): The decay rate used in simulated annealing search strategy. - max_try_times(int): The max number of trying to generate legal tokens. - max_client_num(int): The max number of connections of controller server. - search_steps(int): The steps of searching. - max_ratios(float|list): Max ratios used to pruned parameters in `params`. - List means max ratios for each parameter in `params`. - The length of `max_ratios` should be equal to length of params when `max_ratios` is a list. - If it is a scalar, it will used for all the parameters in `params`. - min_ratios(float|list): Min ratios used to pruned parameters in `params`. - List means min ratios for each parameter in `params`. - The length of `min_ratios` should be equal to length of params when `min_ratios` is a list. - If it is a scalar, it will used for all the parameters in `params`. - key(str): Identity used in communication between controller server and clients. - is_server(bool): Whether current host is controller server. Default: True. - """ self._program = program self._scope = scope @@ -177,10 +179,12 @@ class AutoPruner(object): def prune(self, program, eval_program=None): """ Prune program with latest tokens generated by controller. + Args: program(fluid.Program): The program to be pruned. + Returns: - Program: The pruned program. + paddle.fluid.Program: The pruned program. """ self._current_ratios = self._next_ratios() pruned_program, _, _ = self._pruner.prune( @@ -208,8 +212,9 @@ class AutoPruner(object): def reward(self, score): """ Return reward of current pruned program. + Args: - score(float): The score of pruned program. + float: The score of pruned program. """ self._restore(self._scope) self._param_backup = {} diff --git a/paddleslim/prune/prune_io.py b/paddleslim/prune/prune_io.py index 5dcd781c1a658ee757441207a1801174784fdfcc..a37ddd33a04eb23ff2a797b765e2d8ff4143560f 100644 --- a/paddleslim/prune/prune_io.py +++ b/paddleslim/prune/prune_io.py @@ -19,8 +19,9 @@ def save_model(exe, graph, dirname): Save weights of model and information of shapes into filesystem. Args: - - graph(Program|Graph): The graph to be saved. - - dirname(str): The directory that the model saved into. + exe(paddle.fluid.Executor): The executor used to save model. + graph(Program|Graph): The graph to be saved. + dirname(str): The directory that the model saved into. """ assert graph is not None and dirname is not None graph = GraphWrapper(graph) if isinstance(graph, Program) else graph @@ -46,8 +47,8 @@ def load_model(exe, graph, dirname): Load weights of model and information of shapes from filesystem. Args: - - graph(Program|Graph): The graph to be saved. - - dirname(str): The directory that the model saved into. + graph(Program|Graph): The graph to be updated by loaded information.. + dirname(str): The directory that the model will be loaded. """ assert graph is not None and dirname is not None graph = GraphWrapper(graph) if isinstance(graph, Program) else graph diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 98abe950bf505e89ffeeac1bf286d4ca71da36b4..e5038aad6221ddd735b3646ae878dcd3e35cd435 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -26,12 +26,14 @@ _logger = get_logger(__name__, level=logging.INFO) class Pruner(): + """The pruner used to prune channels of convolution. + + Args: + criterion(str): the criterion used to sort channels for pruning. It only supports 'l1_norm' currently. + + """ + def __init__(self, criterion="l1_norm"): - """ - Args: - criterion(str): the criterion used to sort channels for pruning. - It only supports 'l1_norm' currently. - """ self.criterion = criterion def prune(self, @@ -44,9 +46,10 @@ class Pruner(): only_graph=False, param_backup=False, param_shape_backup=False): - """ - Pruning the given parameters. + """Pruning the given parameters. + Args: + program(fluid.Program): The program to be pruned. scope(fluid.Scope): The scope storing paramaters to be pruned. params(list): A list of parameter names to be pruned. @@ -58,10 +61,9 @@ class Pruner(): False means modifying graph and variables in scope. Default: False. param_backup(bool): Whether to return a dict to backup the values of parameters. Default: False. param_shape_backup(bool): Whether to return a dict to backup the shapes of parameters. Default: False. + Returns: - Program: The pruned program. - param_backup: A dict to backup the values of parameters. - param_shape_backup: A dict to backup the shapes of parameters. + tuple: ``(pruned_program, param_backup, param_shape_backup)``. ``pruned_program`` is the pruned program. ``param_backup`` is a dict to backup the values of parameters. ``param_shape_backup`` is a dict to backup the shapes of parameters. """ self.pruned_list = [] @@ -131,6 +133,7 @@ class Pruner(): def _cal_pruned_idx(self, param, ratio, axis): """ Calculate the index to be pruned on axis by given pruning ratio. + Args: name(str): The name of parameter to be pruned. param(np.array): The data of parameter to be pruned. @@ -138,6 +141,7 @@ class Pruner(): axis(int): The axis to be used for pruning given parameter. If it is None, the value in self.pruning_axis will be used. default: None. + Returns: list: The indexes to be pruned on axis. """ @@ -151,6 +155,7 @@ class Pruner(): def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False): """ Pruning a array by indexes on given axis. + Args: tensor(numpy.array): The target array to be pruned. pruned_idx(list): The indexes to be pruned. @@ -158,6 +163,7 @@ class Pruner(): lazy(bool): True means setting the pruned elements to zero. False means remove the pruned elements from memory. default: False. + Returns: numpy.array: The pruned array. """ diff --git a/paddleslim/prune/sensitive.py b/paddleslim/prune/sensitive.py index 981fd46b6c1f90c417a3d90d6fb91a9a2ba33006..608bc83fa42574d060d07d1d758d535c3027850e 100644 --- a/paddleslim/prune/sensitive.py +++ b/paddleslim/prune/sensitive.py @@ -37,6 +37,35 @@ def sensitivity(program, eval_func, sensitivities_file=None, pruned_ratios=None): + """Compute the sensitivities of convolutions in a model. The sensitivity of a convolution is the losses of accuracy on test dataset in differenct pruned ratios. The sensitivities can be used to get a group of best ratios with some condition. + This function return a dict storing sensitivities as below: + + .. code-block:: python + + {"weight_0": + {0.1: 0.22, + 0.2: 0.33 + }, + "weight_1": + {0.1: 0.21, + 0.2: 0.4 + } + } + + ``weight_0`` is parameter name of convolution. ``sensitivities['weight_0']`` is a dict in which key is pruned ratio and value is the percent of losses. + + + Args: + program(paddle.fluid.Program): The program to be analysised. + place(fluid.CPUPlace | fluid.CUDAPlace): The device place of filter parameters. + param_names(list): The parameter names of convolutions to be analysised. + eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.fluid.Program` as argument and return a score on test dataset. + sensitivities_file(str): The file to save the sensitivities. It will append the latest computed sensitivities into the file. And the sensitivities in the file would not be computed again. This file can be loaded by `pickle` library. + pruned_ratios(list): The ratios to be pruned. default: ``[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]``. + + Returns: + dict: A dict storing sensitivities. + """ scope = fluid.global_scope() graph = GraphWrapper(program) sensitivities = load_sensitivities(sensitivities_file) @@ -159,13 +188,13 @@ def flops_sensitivity(program, def merge_sensitive(sensitivities): - """ - Merge sensitivities. + """Merge sensitivities. + Args: sensitivities(list | list): The sensitivities to be merged. It cann be a list of sensitivities files or dict. Returns: - sensitivities(dict): A dict with sensitivities. + dict: A dict stroring sensitivities. """ assert len(sensitivities) > 0 if not isinstance(sensitivities[0], dict): @@ -182,8 +211,13 @@ def merge_sensitive(sensitivities): def load_sensitivities(sensitivities_file): - """ - Load sensitivities from file. + """Load sensitivities from file. + + Args: + sensitivities_file(str): The file storing sensitivities. + + Returns: + dict: A dict stroring sensitivities. """ sensitivities = {} if sensitivities_file and os.path.exists(sensitivities_file): @@ -196,8 +230,11 @@ def load_sensitivities(sensitivities_file): def _save_sensitivities(sensitivities, sensitivities_file): - """ - Save sensitivities into file. + """Save sensitivities into file. + + Args: + sensitivities(dict): The sensitivities to be saved. + sensitivities_file(str): The file to saved sensitivities. """ with open(sensitivities_file, 'wb') as f: pickle.dump(sensitivities, f) @@ -217,7 +254,7 @@ def get_ratios_by_loss(sensitivities, loss): Returns: - ratios(dict): A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned. + dict: A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned. """ ratios = {} for param, losses in sensitivities.items(): diff --git a/paddleslim/prune/sensitive_pruner.py b/paddleslim/prune/sensitive_pruner.py index afbf9ff47eb05a7a6488b9ccbf3571fe2e553d1c..c216482271129a222ad128c14052855a34bafe80 100644 --- a/paddleslim/prune/sensitive_pruner.py +++ b/paddleslim/prune/sensitive_pruner.py @@ -30,18 +30,20 @@ _logger = get_logger(__name__, level=logging.INFO) class SensitivePruner(object): + """ + Pruner used to prune parameters iteratively according to sensitivities + of parameters in each step. + + Args: + place(fluid.CUDAPlace | fluid.CPUPlace): The device place where + program execute. + eval_func(function): A callback function used to evaluate pruned + program. The argument of this function is pruned program. + And it return a score of given program. + scope(fluid.scope): The scope used to execute program. + """ + def __init__(self, place, eval_func, scope=None, checkpoints=None): - """ - Pruner used to prune parameters iteratively according to sensitivities - of parameters in each step. - Args: - place(fluid.CUDAPlace | fluid.CPUPlace): The device place where - program execute. - eval_func(function): A callback function used to evaluate pruned - program. The argument of this function is pruned program. - And it return a score of given program. - scope(fluid.scope): The scope used to execute program. - """ self._eval_func = eval_func self._iter = 0 self._place = place @@ -135,12 +137,14 @@ class SensitivePruner(object): def prune(self, train_program, eval_program, params, pruned_flops): """ Pruning parameters of training and evaluation network by sensitivities in current step. + Args: train_program(fluid.Program): The training program to be pruned. eval_program(fluid.Program): The evaluation program to be pruned. And it is also used to calculate sensitivities of parameters. params(list): The parameters to be pruned. pruned_flops(float): The ratio of FLOPS to be pruned in current step. - Return: + + Returns: tuple: A tuple of pruned training program and pruned evaluation program. """ _logger.info("Pruning: {}".format(params)) @@ -199,9 +203,9 @@ class SensitivePruner(object): pruned_flops(float): The percent of FLOPS to be pruned. eval_program(Program): The program whose FLOPS is considered. - Return: + Returns: - ratios(dict): A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned. + dict: A group of ratios. The key of dict is name of parameters while the value is the ratio to be pruned. """ min_loss = 0.