diff --git a/paddleslim/analysis/sensitive.py b/paddleslim/analysis/sensitive.py index 498de1d6022eb628d99ecccf847758c3b5aea3ab..7fbfde15647df5f5b4e7dd00d922003edafaaf6f 100644 --- a/paddleslim/analysis/sensitive.py +++ b/paddleslim/analysis/sensitive.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import numpy as np from ..core import GraphWrapper +from ..common import get_logger + +_logger = get_logger(__name__, level=logging.INFO) __all__ = ["sensitivity"] diff --git a/paddleslim/common/__init__.py b/paddleslim/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..98b314ab6d144924bff6b68e3fb176ce73583f5c --- /dev/null +++ b/paddleslim/common/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import controller +from controller import * +import sa_controller +from sa_controller import * +import log_helper +from log_helper import * +import controller_server +from controller_server import * +import controller_client +from controller_client import * +import lock_utils +from lock_utils import * + +__all__ = [] +__all__ += controller.__all__ +__all__ += sa_controller.__all__ +__all__ += controller_server.__all__ +__all__ += controller_client.__all__ +__all__ += lock_utils.__all__ diff --git a/paddleslim/common/controller.py b/paddleslim/common/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..8c30f49c3aec27a326417554bac3163789342ff6 --- /dev/null +++ b/paddleslim/common/controller.py @@ -0,0 +1,51 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The controller used to search hyperparameters or neural architecture""" + +import copy +import math +import numpy as np + +__all__ = ['EvolutionaryController'] + + +class EvolutionaryController(object): + """Abstract controller for all evolutionary searching method. + """ + + def __init__(self, *args, **kwargs): + pass + + def update(self, tokens, reward): + """Update the status of controller according current tokens and reward. + Args: + tokens(list): A solution of searching task. + reward(list): The reward of tokens. + """ + raise NotImplementedError('Abstract method.') + + def reset(self, range_table, constrain_func=None): + """Reset the controller. + Args: + range_table(list): It is used to define the searching space of controller. + The tokens[i] generated by controller should be in [0, range_table[i]). + constrain_func(function): It is used to check whether tokens meet the constraint. + None means there is no constraint. Default: None. + """ + raise NotImplementedError('Abstract method.') + + def next_tokens(self): + """Generate new tokens. + """ + raise NotImplementedError('Abstract method.') diff --git a/paddleslim/common/controller_client.py b/paddleslim/common/controller_client.py new file mode 100644 index 0000000000000000000000000000000000000000..5dcbd7bb64bf4460371d523a0f745e2490a7b3a0 --- /dev/null +++ b/paddleslim/common/controller_client.py @@ -0,0 +1,66 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import socket +from log_helper import get_logger + +__all__ = ['ControllerClient'] + +_logger = get_logger(__name__, level=logging.INFO) + + +class ControllerClient(object): + """ + Controller client. + """ + + def __init__(self, server_ip=None, server_port=None, key=None): + """ + Args: + server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None. + server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0. + key(str): The key used to identify legal agent for controller server. Default: "light-nas" + """ + self.server_ip = server_ip + self.server_port = server_port + self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._key = key + + def update(self, tokens, reward): + """ + Update the controller according to latest tokens and reward. + Args: + tokens(list): The tokens generated in last step. + reward(float): The reward of tokens. + """ + socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_client.connect((self.server_ip, self.server_port)) + tokens = ",".join([str(token) for token in tokens]) + socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) + .encode()) + tokens = socket_client.recv(1024).decode() + tokens = [int(token) for token in tokens.strip("\n").split(",")] + return tokens + + def next_tokens(self): + """ + Get next tokens. + """ + socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_client.connect((self.server_ip, self.server_port)) + socket_client.send("next_tokens".encode()) + tokens = socket_client.recv(1024).decode() + tokens = [int(token) for token in tokens.strip("\n").split(",")] + return tokens diff --git a/paddleslim/common/controller_server.py b/paddleslim/common/controller_server.py new file mode 100644 index 0000000000000000000000000000000000000000..10883d5988652a8c3f738416eae3dd768ba74e67 --- /dev/null +++ b/paddleslim/common/controller_server.py @@ -0,0 +1,128 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import socket +from .log_helper import get_logger +from threading import Thread +from .lock import lock, unlock + +__all__ = ['ControllerServer'] + +_logger = get_logger(__name__, level=logging.INFO) + + +class ControllerServer(object): + """ + The controller wrapper with a socket server to handle the request of search agent. + """ + + def __init__(self, + controller=None, + address=('', 0), + max_client_num=100, + search_steps=None, + key=None): + """ + Args: + controller(slim.searcher.Controller): The controller used to generate tokens. + address(tuple): The address of current server binding with format (ip, port). Default: ('', 0). + which means setting ip automatically + max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100. + search_steps(int): The total steps of searching. None means never stopping. Default: None + """ + self._controller = controller + self._address = address + self._max_client_num = max_client_num + self._search_steps = search_steps + self._closed = False + self._port = address[1] + self._ip = address[0] + self._key = key + self._socket_file = "./controller_server.socket" + + def start(self): + open(self._socket_file, 'a').close() + socket_file = open(self._socket_file, 'r+') + lock(socket_file) + tid = socket_file.readline() + if tid == '': + _logger.info("start controller server...") + tid = self._start() + socket_file.write("tid: {}\nip: {}\nport: {}\n".format( + tid, self._ip, self._port)) + _logger.info("started controller server...") + unlock(socket_file) + socket_file.close() + + def _start(self): + self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket_server.bind(self._address) + self._socket_server.listen(self._max_client_num) + self._port = self._socket_server.getsockname()[1] + self._ip = self._socket_server.getsockname()[0] + _logger.info("ControllerServer - listen on: [{}:{}]".format( + self._ip, self._port)) + thread = Thread(target=self.run) + thread.start() + return str(thread) + + def close(self): + """Close the server.""" + self._closed = True + os.remove(self._socket_file) + _logger.info("server closed!") + + def port(self): + """Get the port.""" + return self._port + + def ip(self): + """Get the ip.""" + return self._ip + + def run(self): + _logger.info("Controller Server run...") + try: + while ((self._search_steps is None) or + (self._controller._iter < + (self._search_steps))) and not self._closed: + conn, addr = self._socket_server.accept() + message = conn.recv(1024).decode() + if message.strip("\n") == "next_tokens": + tokens = self._controller.next_tokens() + tokens = ",".join([str(token) for token in tokens]) + conn.send(tokens.encode()) + else: + _logger.debug("recv message from {}: [{}]".format(addr, + message)) + messages = message.strip('\n').split("\t") + if (len(messages) < 3) or (messages[0] != self._key): + _logger.debug("recv noise from {}: [{}]".format( + addr, message)) + continue + tokens = messages[1] + reward = messages[2] + tokens = [int(token) for token in tokens.split(",")] + self._controller.update(tokens, float(reward)) + tokens = self._controller.next_tokens() + tokens = ",".join([str(token) for token in tokens]) + conn.send(tokens.encode()) + _logger.debug("send message to {}: [{}]".format(addr, + tokens)) + conn.close() + finally: + self._socket_server.close() + self.close() diff --git a/paddleslim/common/lock_utils.py b/paddleslim/common/lock_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9daf4f3f6e842609a39fd286dfa49eb705c631a7 --- /dev/null +++ b/paddleslim/common/lock_utils.py @@ -0,0 +1,38 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +__all__ = ['lock', 'unlock'] + +if os.name == 'nt': + + def lock(file): + raise NotImplementedError('Windows is not supported.') + + def unlock(file): + raise NotImplementedError('Windows is not supported.') + +elif os.name == 'posix': + from fcntl import flock, LOCK_EX, LOCK_UN + + def lock(file): + """Lock the file in local file system.""" + flock(file.fileno(), LOCK_EX) + + def unlock(file): + """Unlock the file in local file system.""" + flock(file.fileno(), LOCK_UN) +else: + raise RuntimeError("File Locker only support NT and Posix platforms!") diff --git a/paddleslim/common/log_helper.py b/paddleslim/common/log_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..1088761e0284181bc485f5ee1824e1cbd9c7eb81 --- /dev/null +++ b/paddleslim/common/log_helper.py @@ -0,0 +1,48 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import logging + +__all__ = ['get_logger'] + + +def get_logger(name, level, fmt=None): + """ + Get logger from logging with given name, level and format without + setting logging basicConfig. For setting basicConfig in paddle + will disable basicConfig setting after import paddle. + Args: + name (str): The logger name. + level (logging.LEVEL): The base level of the logger + fmt (str): Format of logger output + Returns: + logging.Logger: logging logger with given setttings + Examples: + .. code-block:: python + logger = log_helper.get_logger(__name__, logging.INFO, + fmt='%(asctime)s-%(levelname)s: %(message)s') + """ + + logger = logging.getLogger(name) + logger.setLevel(level) + handler = logging.StreamHandler() + + if fmt: + formatter = logging.Formatter(fmt=fmt) + handler.setFormatter(formatter) + + logger.addHandler(handler) + return logger diff --git a/paddleslim/common/sa_controller.py b/paddleslim/common/sa_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..b619b818a3208d740c1ddb6753cf5931f3d058f5 --- /dev/null +++ b/paddleslim/common/sa_controller.py @@ -0,0 +1,113 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The controller used to search hyperparameters or neural architecture""" + +import copy +import math +import logging +import numpy as np +from .controller import EvolutionaryController +from log_helper import get_logger + +__all__ = ["SAController"] + +_logger = get_logger(__name__, level=logging.INFO) + + +class SAController(EvolutionaryController): + """Simulated annealing controller.""" + + def __init__(self, + range_table=None, + reduce_rate=0.85, + init_temperature=1024, + max_iter_number=300, + init_tokens=None, + constrain_func=None): + """Initialize. + Args: + range_table(list): Range table. + reduce_rate(float): The decay rate of temperature. + init_temperature(float): Init temperature. + max_iter_number(int): max iteration number. + init_tokens(list): The initial tokens. + constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. + """ + super(SAController, self).__init__() + self._range_table = range_table + assert isinstance(self._range_table, tuple) and ( + len(self._range_table) == 2) + self._reduce_rate = reduce_rate + self._init_temperature = init_temperature + self._max_iter_number = max_iter_number + self._reward = -1 + self._tokens = init_tokens + self._constrain_func = constrain_func + self._max_reward = -1 + self._best_tokens = None + self._iter = 0 + + def __getstate__(self): + d = {} + for key in self.__dict__: + if key != "_constrain_func": + d[key] = self.__dict__[key] + return d + + def update(self, tokens, reward): + """ + Update the controller according to latest tokens and reward. + Args: + tokens(list): The tokens generated in last step. + reward(float): The reward of tokens. + """ + self._iter += 1 + temperature = self._init_temperature * self._reduce_rate**self._iter + if (reward > self._reward) or (np.random.random() <= math.exp( + (reward - self._reward) / temperature)): + self._reward = reward + self._tokens = tokens + if reward > self._max_reward: + self._max_reward = reward + self._best_tokens = tokens + _logger.info( + "Controller - iter: {}; current_reward: {}; current tokens: {}". + format(self._iter, self._reward, self._tokens)) + + def next_tokens(self, control_token=None): + """ + Get next tokens. + """ + if control_token: + tokens = control_token[:] + else: + tokens = self._tokens + new_tokens = tokens[:] + index = int(len(self._range_table[0]) * np.random.random()) + new_tokens[index] = np.random.randint(self._range_table[0][index], + self._range_table[1][index] + 1) + _logger.debug("change index[{}] from {} to {}".format(index, tokens[ + index], new_tokens[index])) + if self._constrain_func is None: + return new_tokens + for _ in range(self._max_iter_number): + if not self._constrain_func(new_tokens): + index = int(len(self._range_table[0]) * np.random.random()) + new_tokens = tokens[:] + new_tokens[index] = np.random.randint( + self._range_table[0][index], + self._range_table[1][index] + 1) + else: + break + return new_tokens diff --git a/paddleslim/nas/__init__.py b/paddleslim/nas/__init__.py index 2f5509144f53529ae717b72bbb7252b4b06a0048..f11948f6bcbdd3d52e334bed3b06510e226825bc 100644 --- a/paddleslim/nas/__init__.py +++ b/paddleslim/nas/__init__.py @@ -14,6 +14,9 @@ import search_space from search_space import * +import sa_nas +from sa_nas import * __all__ = [] __all__ += search_space.__all__ +__all__ += sa_nas.__all__ diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py new file mode 100644 index 0000000000000000000000000000000000000000..cfc747b0ec9f977dc4e41d2fb128b29823cfd3a3 --- /dev/null +++ b/paddleslim/nas/sa_nas.py @@ -0,0 +1,145 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket +import logging +import numpy as np +import paddle.fluid as fluid +from ..core import VarWrapper, OpWrapper, GraphWrapper +from ..common import SAController +from ..common import get_logger +from ..analysis import flops + +from ..common import ControllerServer +from ..common import ControllerClient +from .search_space import SearchSpaceFactory + +__all__ = ["SANAS"] + +_logger = get_logger(__name__, level=logging.INFO) + + +class SANAS(object): + def __init__(self, + configs, + max_flops=None, + max_latency=None, + server_addr=("", 0), + init_temperature=100, + reduce_rate=0.85, + max_try_number=300, + max_client_num=10, + search_steps=300, + key="sa_nas", + is_server=True): + """ + Search a group of ratios used to prune program. + Args: + configs(list): A list of search space configuration with format (key, input_size, output_size, block_num). + `key` is the name of search space with data type str. `input_size` and `output_size` are + input size and output size of searched sub-network. `block_num` is the number of blocks in searched network. + max_flops(int): The max flops of searched network. None means no constrains. Default: None. + max_latency(float): The max latency of searched network. None means no constrains. Default: None. + server_addr(tuple): A tuple of server ip and server port for controller server. + init_temperature(float): The init temperature used in simulated annealing search strategy. + reduce_rate(float): The decay rate used in simulated annealing search strategy. + max_try_number(int): The max number of trying to generate legal tokens. + max_client_num(int): The max number of connections of controller server. + search_steps(int): The steps of searching. + key(str): Identity used in communication between controller server and clients. + is_server(bool): Whether current host is controller server. Default: True. + """ + + self._reduce_rate = reduce_rate + self._init_temperature = init_temperature + self._max_try_number = max_try_number + self._is_server = is_server + self._max_flops = max_flops + self._max_latency = max_latency + + self._configs = configs + + factory = SearchSpaceFactory() + self._search_space = factory.get_search_space(configs) + init_tokens = self._search_space.init_tokens() + range_table = self._search_space.range_table() + range_table = (len(range_table) * [0], range_table) + + print range_table + + controller = SAController(range_table, self._reduce_rate, + self._init_temperature, self._max_try_number, + init_tokens, self._constrain_func) + + server_ip, server_port = server_addr + if server_ip == None or server_ip == "": + server_ip = self._get_host_ip() + + self._controller_server = ControllerServer( + controller=controller, + address=(server_ip, server_port), + max_client_num=max_client_num, + search_steps=search_steps, + key=key) + + # create controller server + if self._is_server: + self._controller_server.start() + + self._controller_client = ControllerClient( + self._controller_server.ip(), + self._controller_server.port(), + key=key) + + self._iter = 0 + + def _get_host_ip(self): + return socket.gethostbyname(socket.gethostname()) + + def _constrain_func(self, tokens): + if (self._max_flops is None) and (self._max_latency is None): + return True + archs = self._search_space.token2arch(tokens) + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + i = 0 + for config, arch in zip(self._configs, archs): + input_size = config[1]["input_size"] + input = fluid.data( + name="data_{}".format(i), + shape=[None, 3, input_size, input_size], + dtype="float32") + output = arch(input) + i += 1 + return flops(main_program) < self._max_flops + + def next_archs(self): + """ + Get next network architectures. + Returns: + list: A list of functions that define networks. + """ + self._current_tokens = self._controller_client.next_tokens() + archs = self._search_space.token2arch(self._current_tokens) + return archs + + def reward(self, score): + """ + Return reward of current searched network. + Args: + score(float): The score of current searched network. + """ + self._controller_client.update(self._current_tokens, score) + self._iter += 1 diff --git a/paddleslim/nas/search_space/mobilenetv2.py b/paddleslim/nas/search_space/mobilenetv2.py index 90e0b2a0e704a2f44a8031b05d475419bc534677..28d8a7ea03bc94618b9b5575f837f09879d309c8 100644 --- a/paddleslim/nas/search_space/mobilenetv2.py +++ b/paddleslim/nas/search_space/mobilenetv2.py @@ -52,7 +52,6 @@ class MobileNetV2Space(SearchSpaceBase): self.scale = scale self.class_dim = class_dim - def init_tokens(self): """ The initial token send to controller. @@ -71,10 +70,11 @@ class MobileNetV2Space(SearchSpaceBase): 4, 9, 0, 0] # 6, 320, 1 # yapf: enable - if self.block_num < 5: + if self.block_num < 5: self.token_len = 1 + (self.block_num - 1) * 4 else: - self.token_len = 1 + (self.block_num + 2 * (self.block_num - 5)) * 4 + self.token_len = 1 + (self.block_num + 2 * + (self.block_num - 5)) * 4 return init_token_base[:self.token_len] @@ -92,6 +92,7 @@ class MobileNetV2Space(SearchSpaceBase): 5, 10, 6, 2, 5, 10, 6, 2, 5, 12, 6, 2] + range_table_base = list(np.array(range_table_base) - 1) # yapf: enable return range_table_base[:self.token_len] @@ -107,24 +108,36 @@ class MobileNetV2Space(SearchSpaceBase): tokens = self.init_tokens() bottleneck_params_list = [] - if self.block_num >= 1: bottleneck_params_list.append((1, self.head_num[tokens[0]], 1, 1, 3)) - if self.block_num >= 2: bottleneck_params_list.append((self.multiply[tokens[1]], self.filter_num1[tokens[2]], - self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) - if self.block_num >= 3: bottleneck_params_list.append((self.multiply[tokens[5]], self.filter_num1[tokens[6]], - self.repeat[tokens[7]], 2, self.k_size[tokens[8]])) - if self.block_num >= 4: bottleneck_params_list.append((self.multiply[tokens[9]], self.filter_num2[tokens[10]], - self.repeat[tokens[11]], 2, self.k_size[tokens[12]])) - if self.block_num >= 5: - bottleneck_params_list.append((self.multiply[tokens[13]], self.filter_num3[tokens[14]], - self.repeat[tokens[15]], 2, self.k_size[tokens[16]])) - bottleneck_params_list.append((self.multiply[tokens[17]], self.filter_num3[tokens[18]], - self.repeat[tokens[19]], 1, self.k_size[tokens[20]])) - if self.block_num >= 6: - bottleneck_params_list.append((self.multiply[tokens[21]], self.filter_num5[tokens[22]], - self.repeat[tokens[23]], 2, self.k_size[tokens[24]])) - bottleneck_params_list.append((self.multiply[tokens[25]], self.filter_num6[tokens[26]], - self.repeat[tokens[27]], 1, self.k_size[tokens[28]])) - + if self.block_num >= 1: + bottleneck_params_list.append( + (1, self.head_num[tokens[0]], 1, 1, 3)) + if self.block_num >= 2: + bottleneck_params_list.append( + (self.multiply[tokens[1]], self.filter_num1[tokens[2]], + self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) + if self.block_num >= 3: + bottleneck_params_list.append( + (self.multiply[tokens[5]], self.filter_num1[tokens[6]], + self.repeat[tokens[7]], 2, self.k_size[tokens[8]])) + if self.block_num >= 4: + bottleneck_params_list.append( + (self.multiply[tokens[9]], self.filter_num2[tokens[10]], + self.repeat[tokens[11]], 2, self.k_size[tokens[12]])) + if self.block_num >= 5: + bottleneck_params_list.append( + (self.multiply[tokens[13]], self.filter_num3[tokens[14]], + self.repeat[tokens[15]], 2, self.k_size[tokens[16]])) + bottleneck_params_list.append( + (self.multiply[tokens[17]], self.filter_num3[tokens[18]], + self.repeat[tokens[19]], 1, self.k_size[tokens[20]])) + if self.block_num >= 6: + bottleneck_params_list.append( + (self.multiply[tokens[21]], self.filter_num5[tokens[22]], + self.repeat[tokens[23]], 2, self.k_size[tokens[24]])) + bottleneck_params_list.append( + (self.multiply[tokens[25]], self.filter_num6[tokens[26]], + self.repeat[tokens[27]], 1, self.k_size[tokens[28]])) + def net_arch(input): #conv1 # all padding is 'SAME' in the conv2d, can compute the actual padding automatic. @@ -182,15 +195,15 @@ class MobileNetV2Space(SearchSpaceBase): return fluid.layers.elementwise_add(input, data_residual) def _inverted_residual_unit(self, - input, - num_in_filter, - num_filters, - ifshortcut, - stride, - filter_size, - expansion_factor, - reduction_ratio=4, - name=None): + input, + num_in_filter, + num_filters, + ifshortcut, + stride, + filter_size, + expansion_factor, + reduction_ratio=4, + name=None): """Build inverted residual unit. Args: input(Variable), input. diff --git a/paddleslim/nas/search_space/resnet.py b/paddleslim/nas/search_space/resnet.py index a6ac5817ce89190987d67f1eda644fa2aef79037..7ed404e5e145c9f173aee95823c8d6ac6a47dfdb 100644 --- a/paddleslim/nas/search_space/resnet.py +++ b/paddleslim/nas/search_space/resnet.py @@ -25,19 +25,25 @@ from .search_space_registry import SEARCHSPACE __all__ = ["ResNetSpace"] + @SEARCHSPACE.register class ResNetSpace(SearchSpaceBase): - def __init__(self, input_size, output_size, block_num, scale=1.0, class_dim=1000): + def __init__(self, + input_size, + output_size, + block_num, + scale=1.0, + class_dim=1000): super(ResNetSpace, self).__init__(input_size, output_size, block_num) pass def init_tokens(self): - return [0,0,0,0,0,0] + return [0, 0, 0, 0, 0, 0] def range_table(self): - return [3,3,3,3,3,3] + return [2, 2, 2, 2, 2, 2] - def token2arch(self,tokens=None): + def token2arch(self, tokens=None): if tokens is None: self.init_tokens() @@ -54,5 +60,3 @@ class ResNetSpace(SearchSpaceBase): return input return net_arch - - diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 30341f63407aa1b0cc52ec5b43eadead27aec2ab..b653c2b5a0d46b3de58cc5a609dd20dba7363c6a 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -14,9 +14,9 @@ import numpy as np import paddle.fluid as fluid -from core import VarWrapper, OpWrapper, GraphWrapper +from ..core import VarWrapper, OpWrapper, GraphWrapper -__all__ = ["prune"] +__all__ = ["Pruner"] class Pruner(): diff --git a/paddleslim/quant/__init__.py b/paddleslim/quant/__init__.py index f97b9f1f28fdf4d812dd37f90b01750d1d475e6a..5f5f9a300630abac32a9c0301328e344da082c55 100644 --- a/paddleslim/quant/__init__.py +++ b/paddleslim/quant/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .quanter import quant_aware, quant_post, convert from .quant_embedding import quant_embedding diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py new file mode 100644 index 0000000000000000000000000000000000000000..0db22772d712951ed895f2d2e897142d6ce3c377 --- /dev/null +++ b/paddleslim/quant/quanter.py @@ -0,0 +1,205 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import paddle +import paddle.fluid as fluid +from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass +from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass +from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass +from paddle.fluid.contrib.slim.quantization import TransformForMobilePass +from paddle.fluid import core + +WEIGHT_QUANTIZATION_TYPES=['abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'] +ACTIVATION_QUANTIZATION_TYPES=['abs_max','range_abs_max', 'moving_average_abs_max'] +VALID_DTYPES = ['int8'] + +_quant_config_default = { + # weight quantize type, default is 'abs_max' + 'weight_quantize_type': 'abs_max', + # activation quantize type, default is 'abs_max' + 'activation_quantize_type': 'abs_max', + # weight quantize bit num, default is 8 + 'weight_bits': 8, + # activation quantize bit num, default is 8 + 'activation_bits': 8, + # ops of name_scope in not_quant_pattern list, will not be quantized + 'not_quant_pattern': ['skip_quant'], + # ops of type in quantize_op_types, will be quantized + 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], + # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' + 'dtype': 'int8', + # window size for 'range_abs_max' quantization. defaulf is 10000 + 'window_size': 10000, + # The decay coefficient of moving average, default is 0.9 + 'moving_rate': 0.9, + # if set quant_weight_only True, then only quantize parameters of layers which need to be quantized, + # and activations will not be quantized. + 'quant_weight_only': False +} + + +def _parse_configs(user_config): + """ + check user configs is valid, and set default value if user not config. + Args: + user_config(dict):the config of user. + Return: + configs(dict): final configs will be used. + """ + + configs = copy.deepcopy(_quant_config_default) + configs.update(user_config) + + # check configs is valid + assert configs['weight_quantize_type'] in WEIGHT_QUANTIZATION_TYPES, \ + "Unknown weight_quantize_type: '%s'. It can only be " + " ".join(WEIGHT_QUANTIZATION_TYPES) + + assert configs['activation_quantize_type'] in ACTIVATION_QUANTIZATION_TYPES, \ + "Unknown activation_quantize_type: '%s'. It can only be " + " ".join(ACTIVATION_QUANTIZATION_TYPES) + + assert isinstance(configs['weight_bits'], int), \ + "weight_bits must be int value." + + assert (configs['weight_bits'] >= 1 and configs['weight_bits'] <= 16), \ + "weight_bits should be between 1 and 16." + + assert isinstance(configs['activation_bits'], int), \ + "activation_bits must be int value." + + assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \ + "activation_bits should be between 1 and 16." + + assert isinstance(configs['not_quant_pattern'], list), \ + "not_quant_pattern must be a list" + + assert isinstance(configs['quantize_op_types'], list), \ + "quantize_op_types must be a list" + + assert isinstance(configs['dtype'], str), \ + "dtype must be a str." + + assert (configs['dtype'] in VALID_DTYPES), \ + "dtype can only be " + " ".join(VALID_DTYPES) + + assert isinstance(configs['window_size'], int), \ + "window_size must be int value, window size for 'range_abs_max' quantization, default is 10000." + + assert isinstance(configs['moving_rate'], float), \ + "moving_rate must be float value, The decay coefficient of moving average, default is 0.9." + + assert isinstance(configs['quant_weight_only'], bool), \ + "quant_weight_only must be bool value, if set quant_weight_only True, " \ + "then only quantize parameters of layers which need to be quantized, " \ + " and activations will not be quantized." + + return configs + + +def quant_aware(program, place, config, scope=None, for_test=False): + """ + add trainable quantization ops in program. + Args: + program(fluid.Program): program + scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope(). + place(fluid.CPUPlace or fluid.CUDAPlace): place + config(dict): configs for quantization, default values are in quant_config_default dict. + for_test: if program is test program, for_test should be set True, else False. + Return: + fluid.Program: user can finetune this quantization program to enhance the accuracy. + """ + + scope = fluid.global_scope() if not scope else scope + assert isinstance(config, dict), "config must be dict" + + assert 'weight_quantize_type' in config.keys( + ), 'weight_quantize_type must be configured' + assert 'activation_quantize_type' in config.keys( + ), 'activation_quantize_type must be configured' + + config = _parse_configs(config) + main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) + + transform_pass = QuantizationTransformPass( + scope=scope, + place=place, + weight_bits=config['weight_bits'], + activation_bits=config['activation_bits'], + activation_quantize_type=config['activation_quantize_type'], + weight_quantize_type=config['weight_quantize_type'], + window_size=config['window_size'], + moving_rate=config['moving_rate'], + quantizable_op_type=config['quantize_op_types'], + skip_pattern=config['not_quant_pattern']) + + transform_pass.apply(main_graph) + + if for_test: + quant_program = main_graph.to_program() + else: + quant_program = fluid.CompiledProgram(main_graph.graph) + return quant_program + + +def quant_post(program, place, config, scope=None): + """ + add quantization ops in program. the program returned is not trainable. + Args: + program(fluid.Program): program + scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope(). + place(fluid.CPUPlace or fluid.CUDAPlace): place + config(dict): configs for quantization, default values are in quant_config_default dict. + for_test: is for test program. + Return: + fluid.Program: the quantization program is not trainable. + """ + pass + + +def convert(program, scope, place, config, save_int8=False): + """ + add quantization ops in program. the program returned is not trainable. + Args: + program(fluid.Program): program + scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope() + place(fluid.CPUPlace or fluid.CUDAPlace): place + config(dict): configs for quantization, default values are in quant_config_default dict. + save_int8: is export int8 freezed program. + Return: + fluid.Program: freezed program which can be used for inference. + parameters is float32 type, but it's value in int8 range. + fluid.Program: freezed int8 program which can be used for inference. + if save_int8 is False, this value is None. + """ + + test_graph = IrGraph(core.Graph(program.desc), for_test=True) + + # Freeze the graph after training by adjusting the quantize + # operators' order for the inference. + freeze_pass = QuantizationFreezePass( + scope=scope, + place=place, + weight_quantize_type=config['weight_quantize_type']) + freeze_pass.apply(test_graph) + freezed_program = test_graph.to_program() + + if save_int8: + convert_int8_pass = ConvertToInt8Pass( + scope=fluid.global_scope(), place=place) + convert_int8_pass.apply(test_graph) + freezed_program_int8 = test_graph.to_program() + return freezed_program, freezed_program_int8 + else: + return freezed_program diff --git a/tests/test_sa_nas.py b/tests/test_sa_nas.py new file mode 100644 index 0000000000000000000000000000000000000000..c1bcd08dadf87e24f31af1a525f67aa9a92bd26e --- /dev/null +++ b/tests/test_sa_nas.py @@ -0,0 +1,59 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle.fluid as fluid +from paddleslim.nas import SANAS +from paddleslim.nas import SearchSpaceFactory +from paddleslim.analysis import flops + + +class TestSANAS(unittest.TestCase): + def test_nas(self): + + factory = SearchSpaceFactory() + config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5} + config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2} + configs = [('MobileNetV2Space', config0), ('ResNetSpace', config1)] + + space = factory.get_search_space([('MobileNetV2Space', config0)]) + origin_arch = space.token2arch()[0] + + main_program = fluid.Program() + s_program = fluid.Program() + with fluid.program_guard(main_program, s_program): + input = fluid.data( + name="input", shape=[None, 3, 224, 224], dtype="float32") + origin_arch(input) + base_flops = flops(main_program) + + search_steps = 3 + sa_nas = SANAS( + configs, max_flops=base_flops, search_steps=search_steps) + + for i in range(search_steps): + archs = sa_nas.next_archs() + main_program = fluid.Program() + s_program = fluid.Program() + with fluid.program_guard(main_program, s_program): + input = fluid.data( + name="input", shape=[None, 3, 224, 224], dtype="float32") + archs[0](input) + sa_nas.reward(1) + self.assertTrue(flops(main_program) < base_flops) + + +if __name__ == '__main__': + unittest.main()