diff --git a/paddleslim/common/__init__.py b/paddleslim/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..98b314ab6d144924bff6b68e3fb176ce73583f5c --- /dev/null +++ b/paddleslim/common/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import controller +from controller import * +import sa_controller +from sa_controller import * +import log_helper +from log_helper import * +import controller_server +from controller_server import * +import controller_client +from controller_client import * +import lock_utils +from lock_utils import * + +__all__ = [] +__all__ += controller.__all__ +__all__ += sa_controller.__all__ +__all__ += controller_server.__all__ +__all__ += controller_client.__all__ +__all__ += lock_utils.__all__ diff --git a/paddleslim/common/controller.py b/paddleslim/common/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..8c30f49c3aec27a326417554bac3163789342ff6 --- /dev/null +++ b/paddleslim/common/controller.py @@ -0,0 +1,51 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The controller used to search hyperparameters or neural architecture""" + +import copy +import math +import numpy as np + +__all__ = ['EvolutionaryController'] + + +class EvolutionaryController(object): + """Abstract controller for all evolutionary searching method. + """ + + def __init__(self, *args, **kwargs): + pass + + def update(self, tokens, reward): + """Update the status of controller according current tokens and reward. + Args: + tokens(list): A solution of searching task. + reward(list): The reward of tokens. + """ + raise NotImplementedError('Abstract method.') + + def reset(self, range_table, constrain_func=None): + """Reset the controller. + Args: + range_table(list): It is used to define the searching space of controller. + The tokens[i] generated by controller should be in [0, range_table[i]). + constrain_func(function): It is used to check whether tokens meet the constraint. + None means there is no constraint. Default: None. + """ + raise NotImplementedError('Abstract method.') + + def next_tokens(self): + """Generate new tokens. + """ + raise NotImplementedError('Abstract method.') diff --git a/paddleslim/common/controller_client.py b/paddleslim/common/controller_client.py new file mode 100644 index 0000000000000000000000000000000000000000..5dcbd7bb64bf4460371d523a0f745e2490a7b3a0 --- /dev/null +++ b/paddleslim/common/controller_client.py @@ -0,0 +1,66 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import socket +from log_helper import get_logger + +__all__ = ['ControllerClient'] + +_logger = get_logger(__name__, level=logging.INFO) + + +class ControllerClient(object): + """ + Controller client. + """ + + def __init__(self, server_ip=None, server_port=None, key=None): + """ + Args: + server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None. + server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0. + key(str): The key used to identify legal agent for controller server. Default: "light-nas" + """ + self.server_ip = server_ip + self.server_port = server_port + self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._key = key + + def update(self, tokens, reward): + """ + Update the controller according to latest tokens and reward. + Args: + tokens(list): The tokens generated in last step. + reward(float): The reward of tokens. + """ + socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_client.connect((self.server_ip, self.server_port)) + tokens = ",".join([str(token) for token in tokens]) + socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) + .encode()) + tokens = socket_client.recv(1024).decode() + tokens = [int(token) for token in tokens.strip("\n").split(",")] + return tokens + + def next_tokens(self): + """ + Get next tokens. + """ + socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_client.connect((self.server_ip, self.server_port)) + socket_client.send("next_tokens".encode()) + tokens = socket_client.recv(1024).decode() + tokens = [int(token) for token in tokens.strip("\n").split(",")] + return tokens diff --git a/paddleslim/common/controller_server.py b/paddleslim/common/controller_server.py new file mode 100644 index 0000000000000000000000000000000000000000..10883d5988652a8c3f738416eae3dd768ba74e67 --- /dev/null +++ b/paddleslim/common/controller_server.py @@ -0,0 +1,128 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import socket +from .log_helper import get_logger +from threading import Thread +from .lock import lock, unlock + +__all__ = ['ControllerServer'] + +_logger = get_logger(__name__, level=logging.INFO) + + +class ControllerServer(object): + """ + The controller wrapper with a socket server to handle the request of search agent. + """ + + def __init__(self, + controller=None, + address=('', 0), + max_client_num=100, + search_steps=None, + key=None): + """ + Args: + controller(slim.searcher.Controller): The controller used to generate tokens. + address(tuple): The address of current server binding with format (ip, port). Default: ('', 0). + which means setting ip automatically + max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100. + search_steps(int): The total steps of searching. None means never stopping. Default: None + """ + self._controller = controller + self._address = address + self._max_client_num = max_client_num + self._search_steps = search_steps + self._closed = False + self._port = address[1] + self._ip = address[0] + self._key = key + self._socket_file = "./controller_server.socket" + + def start(self): + open(self._socket_file, 'a').close() + socket_file = open(self._socket_file, 'r+') + lock(socket_file) + tid = socket_file.readline() + if tid == '': + _logger.info("start controller server...") + tid = self._start() + socket_file.write("tid: {}\nip: {}\nport: {}\n".format( + tid, self._ip, self._port)) + _logger.info("started controller server...") + unlock(socket_file) + socket_file.close() + + def _start(self): + self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket_server.bind(self._address) + self._socket_server.listen(self._max_client_num) + self._port = self._socket_server.getsockname()[1] + self._ip = self._socket_server.getsockname()[0] + _logger.info("ControllerServer - listen on: [{}:{}]".format( + self._ip, self._port)) + thread = Thread(target=self.run) + thread.start() + return str(thread) + + def close(self): + """Close the server.""" + self._closed = True + os.remove(self._socket_file) + _logger.info("server closed!") + + def port(self): + """Get the port.""" + return self._port + + def ip(self): + """Get the ip.""" + return self._ip + + def run(self): + _logger.info("Controller Server run...") + try: + while ((self._search_steps is None) or + (self._controller._iter < + (self._search_steps))) and not self._closed: + conn, addr = self._socket_server.accept() + message = conn.recv(1024).decode() + if message.strip("\n") == "next_tokens": + tokens = self._controller.next_tokens() + tokens = ",".join([str(token) for token in tokens]) + conn.send(tokens.encode()) + else: + _logger.debug("recv message from {}: [{}]".format(addr, + message)) + messages = message.strip('\n').split("\t") + if (len(messages) < 3) or (messages[0] != self._key): + _logger.debug("recv noise from {}: [{}]".format( + addr, message)) + continue + tokens = messages[1] + reward = messages[2] + tokens = [int(token) for token in tokens.split(",")] + self._controller.update(tokens, float(reward)) + tokens = self._controller.next_tokens() + tokens = ",".join([str(token) for token in tokens]) + conn.send(tokens.encode()) + _logger.debug("send message to {}: [{}]".format(addr, + tokens)) + conn.close() + finally: + self._socket_server.close() + self.close() diff --git a/paddleslim/common/lock_utils.py b/paddleslim/common/lock_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9daf4f3f6e842609a39fd286dfa49eb705c631a7 --- /dev/null +++ b/paddleslim/common/lock_utils.py @@ -0,0 +1,38 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +__all__ = ['lock', 'unlock'] + +if os.name == 'nt': + + def lock(file): + raise NotImplementedError('Windows is not supported.') + + def unlock(file): + raise NotImplementedError('Windows is not supported.') + +elif os.name == 'posix': + from fcntl import flock, LOCK_EX, LOCK_UN + + def lock(file): + """Lock the file in local file system.""" + flock(file.fileno(), LOCK_EX) + + def unlock(file): + """Unlock the file in local file system.""" + flock(file.fileno(), LOCK_UN) +else: + raise RuntimeError("File Locker only support NT and Posix platforms!") diff --git a/paddleslim/common/log_helper.py b/paddleslim/common/log_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..1088761e0284181bc485f5ee1824e1cbd9c7eb81 --- /dev/null +++ b/paddleslim/common/log_helper.py @@ -0,0 +1,48 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import logging + +__all__ = ['get_logger'] + + +def get_logger(name, level, fmt=None): + """ + Get logger from logging with given name, level and format without + setting logging basicConfig. For setting basicConfig in paddle + will disable basicConfig setting after import paddle. + Args: + name (str): The logger name. + level (logging.LEVEL): The base level of the logger + fmt (str): Format of logger output + Returns: + logging.Logger: logging logger with given setttings + Examples: + .. code-block:: python + logger = log_helper.get_logger(__name__, logging.INFO, + fmt='%(asctime)s-%(levelname)s: %(message)s') + """ + + logger = logging.getLogger(name) + logger.setLevel(level) + handler = logging.StreamHandler() + + if fmt: + formatter = logging.Formatter(fmt=fmt) + handler.setFormatter(formatter) + + logger.addHandler(handler) + return logger diff --git a/paddleslim/common/sa_controller.py b/paddleslim/common/sa_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..b619b818a3208d740c1ddb6753cf5931f3d058f5 --- /dev/null +++ b/paddleslim/common/sa_controller.py @@ -0,0 +1,113 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The controller used to search hyperparameters or neural architecture""" + +import copy +import math +import logging +import numpy as np +from .controller import EvolutionaryController +from log_helper import get_logger + +__all__ = ["SAController"] + +_logger = get_logger(__name__, level=logging.INFO) + + +class SAController(EvolutionaryController): + """Simulated annealing controller.""" + + def __init__(self, + range_table=None, + reduce_rate=0.85, + init_temperature=1024, + max_iter_number=300, + init_tokens=None, + constrain_func=None): + """Initialize. + Args: + range_table(list): Range table. + reduce_rate(float): The decay rate of temperature. + init_temperature(float): Init temperature. + max_iter_number(int): max iteration number. + init_tokens(list): The initial tokens. + constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. + """ + super(SAController, self).__init__() + self._range_table = range_table + assert isinstance(self._range_table, tuple) and ( + len(self._range_table) == 2) + self._reduce_rate = reduce_rate + self._init_temperature = init_temperature + self._max_iter_number = max_iter_number + self._reward = -1 + self._tokens = init_tokens + self._constrain_func = constrain_func + self._max_reward = -1 + self._best_tokens = None + self._iter = 0 + + def __getstate__(self): + d = {} + for key in self.__dict__: + if key != "_constrain_func": + d[key] = self.__dict__[key] + return d + + def update(self, tokens, reward): + """ + Update the controller according to latest tokens and reward. + Args: + tokens(list): The tokens generated in last step. + reward(float): The reward of tokens. + """ + self._iter += 1 + temperature = self._init_temperature * self._reduce_rate**self._iter + if (reward > self._reward) or (np.random.random() <= math.exp( + (reward - self._reward) / temperature)): + self._reward = reward + self._tokens = tokens + if reward > self._max_reward: + self._max_reward = reward + self._best_tokens = tokens + _logger.info( + "Controller - iter: {}; current_reward: {}; current tokens: {}". + format(self._iter, self._reward, self._tokens)) + + def next_tokens(self, control_token=None): + """ + Get next tokens. + """ + if control_token: + tokens = control_token[:] + else: + tokens = self._tokens + new_tokens = tokens[:] + index = int(len(self._range_table[0]) * np.random.random()) + new_tokens[index] = np.random.randint(self._range_table[0][index], + self._range_table[1][index] + 1) + _logger.debug("change index[{}] from {} to {}".format(index, tokens[ + index], new_tokens[index])) + if self._constrain_func is None: + return new_tokens + for _ in range(self._max_iter_number): + if not self._constrain_func(new_tokens): + index = int(len(self._range_table[0]) * np.random.random()) + new_tokens = tokens[:] + new_tokens[index] = np.random.randint( + self._range_table[0][index], + self._range_table[1][index] + 1) + else: + break + return new_tokens diff --git a/paddleslim/nas/__init__.py b/paddleslim/nas/__init__.py index 2f5509144f53529ae717b72bbb7252b4b06a0048..f11948f6bcbdd3d52e334bed3b06510e226825bc 100644 --- a/paddleslim/nas/__init__.py +++ b/paddleslim/nas/__init__.py @@ -14,6 +14,9 @@ import search_space from search_space import * +import sa_nas +from sa_nas import * __all__ = [] __all__ += search_space.__all__ +__all__ += sa_nas.__all__ diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py new file mode 100644 index 0000000000000000000000000000000000000000..cfc747b0ec9f977dc4e41d2fb128b29823cfd3a3 --- /dev/null +++ b/paddleslim/nas/sa_nas.py @@ -0,0 +1,145 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket +import logging +import numpy as np +import paddle.fluid as fluid +from ..core import VarWrapper, OpWrapper, GraphWrapper +from ..common import SAController +from ..common import get_logger +from ..analysis import flops + +from ..common import ControllerServer +from ..common import ControllerClient +from .search_space import SearchSpaceFactory + +__all__ = ["SANAS"] + +_logger = get_logger(__name__, level=logging.INFO) + + +class SANAS(object): + def __init__(self, + configs, + max_flops=None, + max_latency=None, + server_addr=("", 0), + init_temperature=100, + reduce_rate=0.85, + max_try_number=300, + max_client_num=10, + search_steps=300, + key="sa_nas", + is_server=True): + """ + Search a group of ratios used to prune program. + Args: + configs(list): A list of search space configuration with format (key, input_size, output_size, block_num). + `key` is the name of search space with data type str. `input_size` and `output_size` are + input size and output size of searched sub-network. `block_num` is the number of blocks in searched network. + max_flops(int): The max flops of searched network. None means no constrains. Default: None. + max_latency(float): The max latency of searched network. None means no constrains. Default: None. + server_addr(tuple): A tuple of server ip and server port for controller server. + init_temperature(float): The init temperature used in simulated annealing search strategy. + reduce_rate(float): The decay rate used in simulated annealing search strategy. + max_try_number(int): The max number of trying to generate legal tokens. + max_client_num(int): The max number of connections of controller server. + search_steps(int): The steps of searching. + key(str): Identity used in communication between controller server and clients. + is_server(bool): Whether current host is controller server. Default: True. + """ + + self._reduce_rate = reduce_rate + self._init_temperature = init_temperature + self._max_try_number = max_try_number + self._is_server = is_server + self._max_flops = max_flops + self._max_latency = max_latency + + self._configs = configs + + factory = SearchSpaceFactory() + self._search_space = factory.get_search_space(configs) + init_tokens = self._search_space.init_tokens() + range_table = self._search_space.range_table() + range_table = (len(range_table) * [0], range_table) + + print range_table + + controller = SAController(range_table, self._reduce_rate, + self._init_temperature, self._max_try_number, + init_tokens, self._constrain_func) + + server_ip, server_port = server_addr + if server_ip == None or server_ip == "": + server_ip = self._get_host_ip() + + self._controller_server = ControllerServer( + controller=controller, + address=(server_ip, server_port), + max_client_num=max_client_num, + search_steps=search_steps, + key=key) + + # create controller server + if self._is_server: + self._controller_server.start() + + self._controller_client = ControllerClient( + self._controller_server.ip(), + self._controller_server.port(), + key=key) + + self._iter = 0 + + def _get_host_ip(self): + return socket.gethostbyname(socket.gethostname()) + + def _constrain_func(self, tokens): + if (self._max_flops is None) and (self._max_latency is None): + return True + archs = self._search_space.token2arch(tokens) + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + i = 0 + for config, arch in zip(self._configs, archs): + input_size = config[1]["input_size"] + input = fluid.data( + name="data_{}".format(i), + shape=[None, 3, input_size, input_size], + dtype="float32") + output = arch(input) + i += 1 + return flops(main_program) < self._max_flops + + def next_archs(self): + """ + Get next network architectures. + Returns: + list: A list of functions that define networks. + """ + self._current_tokens = self._controller_client.next_tokens() + archs = self._search_space.token2arch(self._current_tokens) + return archs + + def reward(self, score): + """ + Return reward of current searched network. + Args: + score(float): The score of current searched network. + """ + self._controller_client.update(self._current_tokens, score) + self._iter += 1 diff --git a/paddleslim/nas/search_space/mobilenetv2.py b/paddleslim/nas/search_space/mobilenetv2.py index 90e0b2a0e704a2f44a8031b05d475419bc534677..28d8a7ea03bc94618b9b5575f837f09879d309c8 100644 --- a/paddleslim/nas/search_space/mobilenetv2.py +++ b/paddleslim/nas/search_space/mobilenetv2.py @@ -52,7 +52,6 @@ class MobileNetV2Space(SearchSpaceBase): self.scale = scale self.class_dim = class_dim - def init_tokens(self): """ The initial token send to controller. @@ -71,10 +70,11 @@ class MobileNetV2Space(SearchSpaceBase): 4, 9, 0, 0] # 6, 320, 1 # yapf: enable - if self.block_num < 5: + if self.block_num < 5: self.token_len = 1 + (self.block_num - 1) * 4 else: - self.token_len = 1 + (self.block_num + 2 * (self.block_num - 5)) * 4 + self.token_len = 1 + (self.block_num + 2 * + (self.block_num - 5)) * 4 return init_token_base[:self.token_len] @@ -92,6 +92,7 @@ class MobileNetV2Space(SearchSpaceBase): 5, 10, 6, 2, 5, 10, 6, 2, 5, 12, 6, 2] + range_table_base = list(np.array(range_table_base) - 1) # yapf: enable return range_table_base[:self.token_len] @@ -107,24 +108,36 @@ class MobileNetV2Space(SearchSpaceBase): tokens = self.init_tokens() bottleneck_params_list = [] - if self.block_num >= 1: bottleneck_params_list.append((1, self.head_num[tokens[0]], 1, 1, 3)) - if self.block_num >= 2: bottleneck_params_list.append((self.multiply[tokens[1]], self.filter_num1[tokens[2]], - self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) - if self.block_num >= 3: bottleneck_params_list.append((self.multiply[tokens[5]], self.filter_num1[tokens[6]], - self.repeat[tokens[7]], 2, self.k_size[tokens[8]])) - if self.block_num >= 4: bottleneck_params_list.append((self.multiply[tokens[9]], self.filter_num2[tokens[10]], - self.repeat[tokens[11]], 2, self.k_size[tokens[12]])) - if self.block_num >= 5: - bottleneck_params_list.append((self.multiply[tokens[13]], self.filter_num3[tokens[14]], - self.repeat[tokens[15]], 2, self.k_size[tokens[16]])) - bottleneck_params_list.append((self.multiply[tokens[17]], self.filter_num3[tokens[18]], - self.repeat[tokens[19]], 1, self.k_size[tokens[20]])) - if self.block_num >= 6: - bottleneck_params_list.append((self.multiply[tokens[21]], self.filter_num5[tokens[22]], - self.repeat[tokens[23]], 2, self.k_size[tokens[24]])) - bottleneck_params_list.append((self.multiply[tokens[25]], self.filter_num6[tokens[26]], - self.repeat[tokens[27]], 1, self.k_size[tokens[28]])) - + if self.block_num >= 1: + bottleneck_params_list.append( + (1, self.head_num[tokens[0]], 1, 1, 3)) + if self.block_num >= 2: + bottleneck_params_list.append( + (self.multiply[tokens[1]], self.filter_num1[tokens[2]], + self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) + if self.block_num >= 3: + bottleneck_params_list.append( + (self.multiply[tokens[5]], self.filter_num1[tokens[6]], + self.repeat[tokens[7]], 2, self.k_size[tokens[8]])) + if self.block_num >= 4: + bottleneck_params_list.append( + (self.multiply[tokens[9]], self.filter_num2[tokens[10]], + self.repeat[tokens[11]], 2, self.k_size[tokens[12]])) + if self.block_num >= 5: + bottleneck_params_list.append( + (self.multiply[tokens[13]], self.filter_num3[tokens[14]], + self.repeat[tokens[15]], 2, self.k_size[tokens[16]])) + bottleneck_params_list.append( + (self.multiply[tokens[17]], self.filter_num3[tokens[18]], + self.repeat[tokens[19]], 1, self.k_size[tokens[20]])) + if self.block_num >= 6: + bottleneck_params_list.append( + (self.multiply[tokens[21]], self.filter_num5[tokens[22]], + self.repeat[tokens[23]], 2, self.k_size[tokens[24]])) + bottleneck_params_list.append( + (self.multiply[tokens[25]], self.filter_num6[tokens[26]], + self.repeat[tokens[27]], 1, self.k_size[tokens[28]])) + def net_arch(input): #conv1 # all padding is 'SAME' in the conv2d, can compute the actual padding automatic. @@ -182,15 +195,15 @@ class MobileNetV2Space(SearchSpaceBase): return fluid.layers.elementwise_add(input, data_residual) def _inverted_residual_unit(self, - input, - num_in_filter, - num_filters, - ifshortcut, - stride, - filter_size, - expansion_factor, - reduction_ratio=4, - name=None): + input, + num_in_filter, + num_filters, + ifshortcut, + stride, + filter_size, + expansion_factor, + reduction_ratio=4, + name=None): """Build inverted residual unit. Args: input(Variable), input. diff --git a/paddleslim/nas/search_space/resnet.py b/paddleslim/nas/search_space/resnet.py index a6ac5817ce89190987d67f1eda644fa2aef79037..7ed404e5e145c9f173aee95823c8d6ac6a47dfdb 100644 --- a/paddleslim/nas/search_space/resnet.py +++ b/paddleslim/nas/search_space/resnet.py @@ -25,19 +25,25 @@ from .search_space_registry import SEARCHSPACE __all__ = ["ResNetSpace"] + @SEARCHSPACE.register class ResNetSpace(SearchSpaceBase): - def __init__(self, input_size, output_size, block_num, scale=1.0, class_dim=1000): + def __init__(self, + input_size, + output_size, + block_num, + scale=1.0, + class_dim=1000): super(ResNetSpace, self).__init__(input_size, output_size, block_num) pass def init_tokens(self): - return [0,0,0,0,0,0] + return [0, 0, 0, 0, 0, 0] def range_table(self): - return [3,3,3,3,3,3] + return [2, 2, 2, 2, 2, 2] - def token2arch(self,tokens=None): + def token2arch(self, tokens=None): if tokens is None: self.init_tokens() @@ -54,5 +60,3 @@ class ResNetSpace(SearchSpaceBase): return input return net_arch - - diff --git a/tests/test_sa_nas.py b/tests/test_sa_nas.py new file mode 100644 index 0000000000000000000000000000000000000000..c1bcd08dadf87e24f31af1a525f67aa9a92bd26e --- /dev/null +++ b/tests/test_sa_nas.py @@ -0,0 +1,59 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle.fluid as fluid +from paddleslim.nas import SANAS +from paddleslim.nas import SearchSpaceFactory +from paddleslim.analysis import flops + + +class TestSANAS(unittest.TestCase): + def test_nas(self): + + factory = SearchSpaceFactory() + config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5} + config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2} + configs = [('MobileNetV2Space', config0), ('ResNetSpace', config1)] + + space = factory.get_search_space([('MobileNetV2Space', config0)]) + origin_arch = space.token2arch()[0] + + main_program = fluid.Program() + s_program = fluid.Program() + with fluid.program_guard(main_program, s_program): + input = fluid.data( + name="input", shape=[None, 3, 224, 224], dtype="float32") + origin_arch(input) + base_flops = flops(main_program) + + search_steps = 3 + sa_nas = SANAS( + configs, max_flops=base_flops, search_steps=search_steps) + + for i in range(search_steps): + archs = sa_nas.next_archs() + main_program = fluid.Program() + s_program = fluid.Program() + with fluid.program_guard(main_program, s_program): + input = fluid.data( + name="input", shape=[None, 3, 224, 224], dtype="float32") + archs[0](input) + sa_nas.reward(1) + self.assertTrue(flops(main_program) < base_flops) + + +if __name__ == '__main__': + unittest.main()