diff --git a/paddleslim/common/__init__.py b/paddleslim/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3bdeabc208673cf327f414710e88b1f4950a7a52 --- /dev/null +++ b/paddleslim/common/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import controller +from controller import * +import sa_controller +from sa_controller import * +import log_helper +from log_helper import * + +__all__ = [] +__all__ += controller.__all__ +__all__ += sa_controller.__all__ diff --git a/paddleslim/common/controller.py b/paddleslim/common/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..651d37c32d04a0bf7af1c247072da68103588190 --- /dev/null +++ b/paddleslim/common/controller.py @@ -0,0 +1,53 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The controller used to search hyperparameters or neural architecture""" + +import copy +import math +import numpy as np + +__all__ = ['EvolutionaryController'] + +class EvolutionaryController(object): + """Abstract controller for all evolutionary searching method. + """ + + def __init__(self, *args, **kwargs): + pass + + def update(self, tokens, reward): + """Update the status of controller according current tokens and reward. + Args: + tokens(list): A solution of searching task. + reward(list): The reward of tokens. + """ + raise NotImplementedError('Abstract method.') + + def reset(self, range_table, constrain_func=None): + """Reset the controller. + Args: + range_table(list): It is used to define the searching space of controller. + The tokens[i] generated by controller should be in [0, range_table[i]). + constrain_func(function): It is used to check whether tokens meet the constraint. + None means there is no constraint. Default: None. + """ + raise NotImplementedError('Abstract method.') + + def next_tokens(self): + """Generate new tokens. + """ + raise NotImplementedError('Abstract method.') + + + diff --git a/paddleslim/common/log_helper.py b/paddleslim/common/log_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..1088761e0284181bc485f5ee1824e1cbd9c7eb81 --- /dev/null +++ b/paddleslim/common/log_helper.py @@ -0,0 +1,48 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import logging + +__all__ = ['get_logger'] + + +def get_logger(name, level, fmt=None): + """ + Get logger from logging with given name, level and format without + setting logging basicConfig. For setting basicConfig in paddle + will disable basicConfig setting after import paddle. + Args: + name (str): The logger name. + level (logging.LEVEL): The base level of the logger + fmt (str): Format of logger output + Returns: + logging.Logger: logging logger with given setttings + Examples: + .. code-block:: python + logger = log_helper.get_logger(__name__, logging.INFO, + fmt='%(asctime)s-%(levelname)s: %(message)s') + """ + + logger = logging.getLogger(name) + logger.setLevel(level) + handler = logging.StreamHandler() + + if fmt: + formatter = logging.Formatter(fmt=fmt) + handler.setFormatter(formatter) + + logger.addHandler(handler) + return logger diff --git a/paddleslim/common/sa_controller.py b/paddleslim/common/sa_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d6caa1438e133e526a22041170adcd77cf3593 --- /dev/null +++ b/paddleslim/common/sa_controller.py @@ -0,0 +1,109 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The controller used to search hyperparameters or neural architecture""" + +import copy +import math +import logging +import numpy as np +from .controller import EvolutionaryController +from log_helper import get_logger + +__all__ = ["SAController"] + +_logger = get_logger(__name__, level=logging.INFO) + +class SAController(EvolutionaryController): + """Simulated annealing controller.""" + + def __init__(self, + range_table=None, + reduce_rate=0.85, + init_temperature=1024, + max_iter_number=300, + init_tokens=None, + constrain_func=None): + """Initialize. + Args: + range_table(list): Range table. + reduce_rate(float): The decay rate of temperature. + init_temperature(float): Init temperature. + max_iter_number(int): max iteration number. + init_tokens(list): The initial tokens. + constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. + """ + super(SAController, self).__init__() + self._range_table = range_table + assert isinstance(self._range_table, tuple) and (len(self._range_table) == 2) + self._reduce_rate = reduce_rate + self._init_temperature = init_temperature + self._max_iter_number = max_iter_number + self._reward = -1 + self._tokens = init_tokens + self._constrain_func = constrain_func + self._max_reward = -1 + self._best_tokens = None + self._iter = 0 + + def __getstate__(self): + d = {} + for key in self.__dict__: + if key != "_constrain_func": + d[key] = self.__dict__[key] + return d + + def update(self, tokens, reward): + """ + Update the controller according to latest tokens and reward. + Args: + tokens(list): The tokens generated in last step. + reward(float): The reward of tokens. + """ + self._iter += 1 + temperature = self._init_temperature * self._reduce_rate**self._iter + if (reward > self._reward) or (np.random.random() <= math.exp( + (reward - self._reward) / temperature)): + self._reward = reward + self._tokens = tokens + if reward > self._max_reward: + self._max_reward = reward + self._best_tokens = tokens + _logger.info("iter: {}; max_reward: {}; best_tokens: {}".format( + self._iter, self._max_reward, self._best_tokens)) + _logger.info("current_reward: {}; current tokens: {}".format( + self._reward, self._tokens)) + + def next_tokens(self, control_token=None): + """ + Get next tokens. + """ + if control_token: + tokens = control_token[:] + else: + tokens = self._tokens + new_tokens = tokens[:] + index = int(len(self._range_table[0]) * np.random.random()) + new_tokens[index] = np.random.randint(self._range_table[0][index], self._range_table[1][index]+1) + _logger.info("change index[{}] from {} to {}".format(index, tokens[ + index], new_tokens[index])) + if self._constrain_func is None: + return new_tokens + for _ in range(self._max_iter_number): + if not self._constrain_func(new_tokens): + index = int(len(self._range_table) * np.random.random()) + new_tokens = tokens[:] + new_tokens[index] = np.random.randint(self._range_table[index]) + else: + break + return new_tokens diff --git a/paddleslim/prune/__init__.py b/paddleslim/prune/__init__.py index 454c23d69d22fe3ef0786dddf9be3dc8ad353c24..bb615b9dfca03ed2b289f902f6d75c73543f6fb2 100644 --- a/paddleslim/prune/__init__.py +++ b/paddleslim/prune/__init__.py @@ -15,6 +15,13 @@ import pruner from pruner import * import auto_pruner from auto_pruner import * +import controller_server +from controller_server import * +import controller_client +from controller_client import * + __all__ = [] __all__ += pruner.__all__ __all__ += auto_pruner.__all__ +__all__ += controller_server.__all__ +__all__ += controller_client.__all__ diff --git a/paddleslim/prune/auto_pruner.py b/paddleslim/prune/auto_pruner.py index a5ce2a4d67375916fc6a12cadcefc71188f8f3dc..178bcacf42476e94f189f1632a68ba1716ec389a 100644 --- a/paddleslim/prune/auto_pruner.py +++ b/paddleslim/prune/auto_pruner.py @@ -24,12 +24,21 @@ __all__ = ["AutoPruner"] class AutoPruner(object): def __init__(self, + program, params=[], init_ratios=None, pruned_flops=0.5, pruned_latency=None, - server_addr=("", ""), - search_strategy="sa"): + server_addr=("", 0), + init_temperature=100, + reduce_rate=0.85, + max_iter_number=300, + max_client_num=10, + search_steps=300, + max_ratios=[0.9], + min_ratios=[0], + key="auto_pruner" + ): """ Search a group of ratios used to prune program. Args: @@ -45,71 +54,86 @@ class AutoPruner(object): search_strategy(str): The search strategy. Default: 'sa'. """ # step1: Create controller server. And start server if current host match server_ip. - self._controller_server = ControllerServer( - addr=(server_ip, server_port), search_strategy="sa") + + self._program = program self._params = params self._init_ratios = init_ratios self._pruned_flops = pruned_flops self._pruned_latency = pruned_latency + self._reduce_rate = reduce_rate + self._init_temperature = init_temperature + self._max_try_number = max_try_number + + assert isinstance(self._max_ratios, float) or isinstance(self._max_ratios) + self._range_table = self._get_range_table(min_ratios, max_ratios) + self._pruner = Pruner() - self._controller_agent = None - self._base_flops = None - self._base_latency = None + if self._pruned_flops: + self._base_flops = flops(program) + if self._pruned_latency: + self._base_latency = latency(program) + if self._init_ratios is None: + self._init_ratios = self._get_init_ratios( + self,_program, self._params, self._pruned_flops, + self._pruned_latency) + init_tokens = self._ratios2tokens(self._init_ratios) + - def prune(self, program, scope, place): + controller = SAController(self._range_table, + self._reduce_rate, + self._init_temperature, + self._max_try_number, + init_tokens, + self._constrain_func) - if self._controller_agent is None: - self._controller_agent = PrunerAgent( - addr=self._controller_server.addr, self._range_table) - if self._init_ratios is None: - self._init_ratios = self._get_init_ratios( - program, self._params, self._pruned_flops, - self._pruned_latency) - self._current_ratios = self._init_ratios - else: - self._current_ratios = self._controller_agent.next_ratios() + self._controller_server = ControllerServer( + controller=controller, + addr=server_addr, + max_client_num, + search_steps, + key=key) - if self._base_flops == None: - self._base_flops = flops(program) - for i in range(self._max_try_num): - pruned_program = self._pruner.prune( - program, - scope, - self._params, - self._current_ratios, - only_graph=True) - if flops(pruned_program) < self._base_flops * ( - 1 - self._pruned_flops): - break - self._current_ratios = self._controller_agent.illegal_ratios( - self._current_ratios) + self._controller_client = ControllerClient(server_addr, key=key) + + self._iter = 0 + def _get_init_ratios(self, program, params, pruned_flops, pruned_latency): + pass + + def _get_range_table(self, min_ratios, max_ratios): + assert isinstance(min_ratios, list) or isinstance(min_ratios, float) + assert isinstance(max_ratios, list) or isinstance(max_ratios, float) + min_ratios = min_ratios if isinstance(min_ratios, list) else [min_ratios] + max_ratios = max_ratios if isinstance(max_ratios, list) else [max_ratios] + min_tokens = self._ratios2tokens(min_ratios) + max_tokens = self._ratios2tokens(max_ratios) + return (min_tokens, max_tokens) + + def _constrain_func(self, tokens): + ratios = self._tokens2ratios(tokens) + + pruned_program = self._pruner.prune( + program, + scope, + self._params, + self._current_ratios, + only_graph=True) + return flops(pruned_program) < self._base_flops + + def prune(self, program, scope, place): + self._current_ratios = self._next_ratios() pruned_program = self._pruner.prune(program, scope, self._params, self._current_ratios) return pruned_program def reward(self, score): - self._controller_agent.reward(self._current_ratios, score) + tokens = self.ratios2tokens(self._current_ratios) + self._controller_client.reward(tokens, score) + self._iter += 1 - -class PrunerAgent(object): - """ - The agent used to talk with controller server. - """ - - def __init__(self, server_attr=("", ""), range_table): - self._range_table = range_table - self._controller_client = ControllerClient(server_attr) - self._controller_client.send_range_table(range_table) - - def next_ratios(self): + def _next_ratios(self): tokens = self._controller_client.next_tokens() - self._tokens2ratios(tokens) - - def illegal_ratios(self, ratios): - tokens = self._ratios2tokens(ratios) - tokens = self._controller_client.illegal_tokens(tokens) return self._tokens2ratios(tokens) def _ratios2tokens(self, ratios): diff --git a/paddleslim/prune/controller_client.py b/paddleslim/prune/controller_client.py new file mode 100644 index 0000000000000000000000000000000000000000..50e987413c24c1ead03a10bed16ba51cfea1e4f8 --- /dev/null +++ b/paddleslim/prune/controller_client.py @@ -0,0 +1,62 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket + +__all__ = ['ControllerClient'] + + +class ControllerClient(object): + """ + Controller client. + """ + + def __init__(self, server_ip=None, server_port=None, key=None): + """ + Args: + server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None. + server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0. + key(str): The key used to identify legal agent for controller server. Default: "light-nas" + """ + self.server_ip = server_ip + self.server_port = server_port + self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._key = key + + def update(self, tokens, reward): + """ + Update the controller according to latest tokens and reward. + Args: + tokens(list): The tokens generated in last step. + reward(float): The reward of tokens. + """ + socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_client.connect((self.server_ip, self.server_port)) + tokens = ",".join([str(token) for token in tokens]) + socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) + .encode()) + tokens = socket_client.recv(1024).decode() + tokens = [int(token) for token in tokens.strip("\n").split(",")] + return tokens + + def next_tokens(self): + """ + Get next tokens. + """ + socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_client.connect((self.server_ip, self.server_port)) + socket_client.send("next_tokens".encode()) + tokens = socket_client.recv(1024).decode() + tokens = [int(token) for token in tokens.strip("\n").split(",")] + return tokens diff --git a/paddleslim/prune/controller_server.py b/paddleslim/prune/controller_server.py new file mode 100644 index 0000000000000000000000000000000000000000..22bbd1e3d35b0df8b3bb834faff6f078a40a859e --- /dev/null +++ b/paddleslim/prune/controller_server.py @@ -0,0 +1,99 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket +from threading import Thread + +__all__ = ['ControllerServer'] + +class ControllerServer(object): + """ + The controller wrapper with a socket server to handle the request of search agent. + """ + + def __init__(self, + controller=None, + address=('', 0), + max_client_num=100, + search_steps=None, + key=None): + """ + Args: + controller(slim.searcher.Controller): The controller used to generate tokens. + address(tuple): The address of current server binding with format (ip, port). Default: ('', 0). + which means setting ip automatically + max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100. + search_steps(int): The total steps of searching. None means never stopping. Default: None + """ + self._controller = controller + self._address = address + self._max_client_num = max_client_num + self._search_steps = search_steps + self._closed = False + self._port = address[1] + self._ip = address[0] + self._key = key + + def start(self): + self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket_server.bind(self._address) + self._socket_server.listen(self._max_client_num) + self._port = self._socket_server.getsockname()[1] + self._ip = self._socket_server.getsockname()[0] + _logger.info("listen on: [{}:{}]".format(self._ip, self._port)) + thread = Thread(target=self.run) + thread.start() + return str(thread) + + def close(self): + """Close the server.""" + self._closed = True + + def port(self): + """Get the port.""" + return self._port + + def ip(self): + """Get the ip.""" + return self._ip + + def run(self): + _logger.info("Controller Server run...") + while ((self._search_steps is None) or + (self._controller._iter < + (self._search_steps))) and not self._closed: + conn, addr = self._socket_server.accept() + message = conn.recv(1024).decode() + if message.strip("\n") == "next_tokens": + tokens = self._controller.next_tokens() + tokens = ",".join([str(token) for token in tokens]) + conn.send(tokens.encode()) + else: + _logger.info("recv message from {}: [{}]".format(addr, message)) + messages = message.strip('\n').split("\t") + if (len(messages) < 3) or (messages[0] != self._key): + _logger.info("recv noise from {}: [{}]".format(addr, + message)) + continue + tokens = messages[1] + reward = messages[2] + tokens = [int(token) for token in tokens.split(",")] + self._controller.update(tokens, float(reward)) + tokens = self._controller.next_tokens() + tokens = ",".join([str(token) for token in tokens]) + conn.send(tokens.encode()) + _logger.info("send message to {}: [{}]".format(addr, tokens)) + conn.close() + self._socket_server.close() + _logger.info("server closed!")