diff --git a/demo/sa_nas_mobilenetv2_cifar10.py b/demo/sa_nas_mobilenetv2_cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..3e903960b1c783c38d672238d5a2b3a0c1581c4d --- /dev/null +++ b/demo/sa_nas_mobilenetv2_cifar10.py @@ -0,0 +1,122 @@ +import sys +sys.path.append('..') +import numpy as np +import argparse +import ast +import paddle +import paddle.fluid as fluid +from paddleslim.nas.search_space.search_space_factory import SearchSpaceFactory +from paddleslim.analysis import flops +from paddleslim.nas import SANAS + + +def create_data_loader(): + data = fluid.data(name='data', shape=[-1, 3, 32, 32], dtype='float32') + label = fluid.data(name='label', shape=[-1, 1], dtype='int64') + data_loader = fluid.io.DataLoader.from_generator( + feed_list=[data, label], + capacity=1024, + use_double_buffer=True, + iterable=True) + return data_loader, data, label + + +def init_sa_nas(config): + factory = SearchSpaceFactory() + space = factory.get_search_space(config) + model_arch = space.token2arch()[0] + main_program = fluid.Program() + startup_program = fluid.Program() + + with fluid.program_guard(main_program, startup_program): + data_loader, data, label = create_data_loader() + output = model_arch(data) + cost = fluid.layers.mean( + fluid.layers.softmax_with_cross_entropy( + logits=output, label=label)) + + base_flops = flops(main_program) + search_steps = 10000000 + + ### start a server and a client + sa_nas = SANAS(config, max_flops=base_flops, search_steps=search_steps) + + ### start a client, server_addr is server address + #sa_nas = SANAS(config, max_flops = base_flops, server_addr=("10.255.125.38", 18607), search_steps = search_steps, is_server=False) + + return sa_nas, search_steps + + +def search_mobilenetv2_cifar10(config, args): + sa_nas, search_steps = init_sa_nas(config) + for i in range(search_steps): + print('search step: ', i) + archs = sa_nas.next_archs()[0] + train_program = fluid.Program() + test_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + train_loader, data, label = create_data_loader() + output = archs(data) + cost = fluid.layers.mean( + fluid.layers.softmax_with_cross_entropy( + logits=output, label=label))[0] + test_program = train_program.clone(for_test=True) + + optimizer = fluid.optimizer.Momentum( + learning_rate=0.1, + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + optimizer.minimize(cost) + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_program) + train_reader = paddle.reader.shuffle( + paddle.dataset.cifar.train10(cycle=False), buf_size=1024) + train_loader.set_sample_generator( + train_reader, + batch_size=512, + places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) + + test_loader, _, _ = create_data_loader() + test_reader = paddle.dataset.cifar.test10(cycle=False) + test_loader.set_sample_generator( + test_reader, + batch_size=256, + drop_last=False, + places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) + + for epoch_id in range(10): + for batch_id, data in enumerate(train_loader()): + loss = exe.run(train_program, + feed=data, + fetch_list=[cost.name])[0] + if batch_id % 5 == 0: + print('epoch: {}, batch: {}, loss: {}'.format( + epoch_id, batch_id, loss[0])) + + for data in test_loader(): + reward = exe.run(test_program, feed=data, + fetch_list=[cost.name])[0] + + print('reward:', reward) + sa_nas.reward(float(reward)) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser( + description='SA NAS MobileNetV2 cifar10 argparase') + parser.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=True, + help='Whether to use GPU in train/test model.') + args = parser.parse_args() + print(args) + + config_info = {'input_size': 32, 'output_size': 1, 'block_num': 5} + config = [('MobileNetV2Space', config_info)] + + search_mobilenetv2_cifar10(config, args) diff --git a/paddleslim/analysis/__init__.py b/paddleslim/analysis/__init__.py index 9caa0d24006a3e59f2d39c646d247b7e68480f96..76904c8d548208adb29188f28e9e0c6a0f11f30d 100644 --- a/paddleslim/analysis/__init__.py +++ b/paddleslim/analysis/__init__.py @@ -15,6 +15,9 @@ import flops as flops_module from flops import * import model_size as model_size_module from model_size import * +import sensitive +from sensitive import * __all__ = [] __all__ += flops_module.__all__ __all__ += model_size_module.__all__ +__all__ += sensitive.__all__ diff --git a/paddleslim/analysis/sensitive.py b/paddleslim/analysis/sensitive.py new file mode 100644 index 0000000000000000000000000000000000000000..09dd2a875ae21caf64034cf79421d7cc1661b817 --- /dev/null +++ b/paddleslim/analysis/sensitive.py @@ -0,0 +1,111 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import logging +import pickle +import numpy as np +from ..core import GraphWrapper +from ..common import get_logger +from ..prune import Pruner + +_logger = get_logger(__name__, level=logging.INFO) + +__all__ = ["sensitivity"] + + +def sensitivity(program, + scope, + place, + param_names, + eval_func, + sensitivities_file=None, + step_size=0.2): + + graph = GraphWrapper(program) + sensitivities = _load_sensitivities(sensitivities_file) + + for name in param_names: + if name not in sensitivities: + size = graph.var(name).shape()[0] + sensitivities[name] = { + 'pruned_percent': [], + 'loss': [], + 'size': size + } + baseline = None + for name in sensitivities: + ratio = step_size + while ratio < 1: + ratio = round(ratio, 2) + if ratio in sensitivities[name]['pruned_percent']: + _logger.debug('{}, {} has computed.'.format(name, ratio)) + ratio += step_size + continue + if baseline is None: + baseline = eval_func(graph.program, scope) + + param_backup = {} + pruner = Pruner() + pruned_program = pruner.prune( + program=graph.program, + scope=scope, + params=[name], + ratios=[ratio], + place=place, + lazy=True, + only_graph=False, + param_backup=param_backup) + pruned_metric = eval_func(pruned_program, scope) + loss = (baseline - pruned_metric) / baseline + _logger.info("pruned param: {}; {}; loss={}".format(name, ratio, + loss)) + sensitivities[name]['pruned_percent'].append(ratio) + sensitivities[name]['loss'].append(loss) + _save_sensitivities(sensitivities, sensitivities_file) + + # restore pruned parameters + for param_name in param_backup.keys(): + param_t = scope.find_var(param_name).get_tensor() + param_t.set(param_backup[param_name], place) + ratio += step_size + return sensitivities + + +def _load_sensitivities(sensitivities_file): + """ + Load sensitivities from file. + """ + sensitivities = {} + if sensitivities_file and os.path.exists(sensitivities_file): + with open(sensitivities_file, 'rb') as f: + if sys.version_info < (3, 0): + sensitivities = pickle.load(f) + else: + sensitivities = pickle.load(f, encoding='bytes') + + for param in sensitivities: + sensitivities[param]['pruned_percent'] = [ + round(p, 2) for p in sensitivities[param]['pruned_percent'] + ] + return sensitivities + + +def _save_sensitivities(sensitivities, sensitivities_file): + """ + Save sensitivities into file. + """ + with open(sensitivities_file, 'wb') as f: + pickle.dump(sensitivities, f) diff --git a/paddleslim/search/__init__.py b/paddleslim/common/__init__.py similarity index 56% rename from paddleslim/search/__init__.py rename to paddleslim/common/__init__.py index 4f3182c3058cb33e46777ab1424242b42406a603..98b314ab6d144924bff6b68e3fb176ce73583f5c 100644 --- a/paddleslim/search/__init__.py +++ b/paddleslim/common/__init__.py @@ -11,4 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Controllers and controller server""" +import controller +from controller import * +import sa_controller +from sa_controller import * +import log_helper +from log_helper import * +import controller_server +from controller_server import * +import controller_client +from controller_client import * +import lock_utils +from lock_utils import * + +__all__ = [] +__all__ += controller.__all__ +__all__ += sa_controller.__all__ +__all__ += controller_server.__all__ +__all__ += controller_client.__all__ +__all__ += lock_utils.__all__ diff --git a/paddleslim/common/controller.py b/paddleslim/common/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..8c30f49c3aec27a326417554bac3163789342ff6 --- /dev/null +++ b/paddleslim/common/controller.py @@ -0,0 +1,51 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The controller used to search hyperparameters or neural architecture""" + +import copy +import math +import numpy as np + +__all__ = ['EvolutionaryController'] + + +class EvolutionaryController(object): + """Abstract controller for all evolutionary searching method. + """ + + def __init__(self, *args, **kwargs): + pass + + def update(self, tokens, reward): + """Update the status of controller according current tokens and reward. + Args: + tokens(list): A solution of searching task. + reward(list): The reward of tokens. + """ + raise NotImplementedError('Abstract method.') + + def reset(self, range_table, constrain_func=None): + """Reset the controller. + Args: + range_table(list): It is used to define the searching space of controller. + The tokens[i] generated by controller should be in [0, range_table[i]). + constrain_func(function): It is used to check whether tokens meet the constraint. + None means there is no constraint. Default: None. + """ + raise NotImplementedError('Abstract method.') + + def next_tokens(self): + """Generate new tokens. + """ + raise NotImplementedError('Abstract method.') diff --git a/paddleslim/common/controller_client.py b/paddleslim/common/controller_client.py new file mode 100644 index 0000000000000000000000000000000000000000..ad989dd16014fa8e6fa1495516e81048324fb826 --- /dev/null +++ b/paddleslim/common/controller_client.py @@ -0,0 +1,68 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import socket +from log_helper import get_logger + +__all__ = ['ControllerClient'] + +_logger = get_logger(__name__, level=logging.INFO) + + +class ControllerClient(object): + """ + Controller client. + """ + + def __init__(self, server_ip=None, server_port=None, key=None): + """ + Args: + server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None. + server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0. + key(str): The key used to identify legal agent for controller server. Default: "light-nas" + """ + self.server_ip = server_ip + self.server_port = server_port + self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._key = key + + def update(self, tokens, reward): + """ + Update the controller according to latest tokens and reward. + Args: + tokens(list): The tokens generated in last step. + reward(float): The reward of tokens. + """ + socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_client.connect((self.server_ip, self.server_port)) + tokens = ",".join([str(token) for token in tokens]) + socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) + .encode()) + response = socket_client.recv(1024).decode() + if response.strip('\n').split("\t") == "ok": + return True + else: + return False + + def next_tokens(self): + """ + Get next tokens. + """ + socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_client.connect((self.server_ip, self.server_port)) + socket_client.send("next_tokens".encode()) + tokens = socket_client.recv(1024).decode() + tokens = [int(token) for token in tokens.strip("\n").split(",")] + return tokens diff --git a/paddleslim/common/controller_server.py b/paddleslim/common/controller_server.py new file mode 100644 index 0000000000000000000000000000000000000000..ac24df86030aae8cb286452b6bd6eeb7b5c80741 --- /dev/null +++ b/paddleslim/common/controller_server.py @@ -0,0 +1,129 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import socket +from .log_helper import get_logger +from threading import Thread +from .lock_utils import lock, unlock + +__all__ = ['ControllerServer'] + +_logger = get_logger(__name__, level=logging.INFO) + + +class ControllerServer(object): + """ + The controller wrapper with a socket server to handle the request of search agent. + """ + + def __init__(self, + controller=None, + address=('', 0), + max_client_num=100, + search_steps=None, + key=None): + """ + Args: + controller(slim.searcher.Controller): The controller used to generate tokens. + address(tuple): The address of current server binding with format (ip, port). Default: ('', 0). + which means setting ip automatically + max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100. + search_steps(int): The total steps of searching. None means never stopping. Default: None + """ + self._controller = controller + self._address = address + self._max_client_num = max_client_num + self._search_steps = search_steps + self._closed = False + self._port = address[1] + self._ip = address[0] + self._key = key + self._socket_file = "./controller_server.socket" + + def start(self): + open(self._socket_file, 'a').close() + socket_file = open(self._socket_file, 'r+') + lock(socket_file) + tid = socket_file.readline() + if tid == '': + _logger.info("start controller server...") + tid = self._start() + socket_file.write("tid: {}\nip: {}\nport: {}\n".format( + tid, self._ip, self._port)) + _logger.info("started controller server...") + unlock(socket_file) + socket_file.close() + + def _start(self): + self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket_server.bind(self._address) + self._socket_server.listen(self._max_client_num) + self._port = self._socket_server.getsockname()[1] + self._ip = self._socket_server.getsockname()[0] + _logger.info("ControllerServer - listen on: [{}:{}]".format( + self._ip, self._port)) + thread = Thread(target=self.run) + thread.start() + return str(thread) + + def close(self): + """Close the server.""" + self._closed = True + os.remove(self._socket_file) + _logger.info("server closed!") + + def port(self): + """Get the port.""" + return self._port + + def ip(self): + """Get the ip.""" + return self._ip + + def run(self): + _logger.info("Controller Server run...") + try: + while ((self._search_steps is None) or + (self._controller._iter < + (self._search_steps))) and not self._closed: + conn, addr = self._socket_server.accept() + message = conn.recv(1024).decode() + if message.strip("\n") == "next_tokens": + tokens = self._controller.next_tokens() + tokens = ",".join([str(token) for token in tokens]) + conn.send(tokens.encode()) + else: + _logger.debug("recv message from {}: [{}]".format(addr, + message)) + messages = message.strip('\n').split("\t") + if (len(messages) < 3) or (messages[0] != self._key): + _logger.debug("recv noise from {}: [{}]".format( + addr, message)) + continue + tokens = messages[1] + reward = messages[2] + tokens = [int(token) for token in tokens.split(",")] + self._controller.update(tokens, float(reward)) + #tokens = self._controller.next_tokens() + #tokens = ",".join([str(token) for token in tokens]) + response = "ok" + conn.send(response.encode()) + _logger.debug("send message to {}: [{}]".format(addr, + tokens)) + conn.close() + finally: + self._socket_server.close() + self.close() diff --git a/paddleslim/common/lock_utils.py b/paddleslim/common/lock_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9daf4f3f6e842609a39fd286dfa49eb705c631a7 --- /dev/null +++ b/paddleslim/common/lock_utils.py @@ -0,0 +1,38 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +__all__ = ['lock', 'unlock'] + +if os.name == 'nt': + + def lock(file): + raise NotImplementedError('Windows is not supported.') + + def unlock(file): + raise NotImplementedError('Windows is not supported.') + +elif os.name == 'posix': + from fcntl import flock, LOCK_EX, LOCK_UN + + def lock(file): + """Lock the file in local file system.""" + flock(file.fileno(), LOCK_EX) + + def unlock(file): + """Unlock the file in local file system.""" + flock(file.fileno(), LOCK_UN) +else: + raise RuntimeError("File Locker only support NT and Posix platforms!") diff --git a/paddleslim/common/log_helper.py b/paddleslim/common/log_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..1088761e0284181bc485f5ee1824e1cbd9c7eb81 --- /dev/null +++ b/paddleslim/common/log_helper.py @@ -0,0 +1,48 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import logging + +__all__ = ['get_logger'] + + +def get_logger(name, level, fmt=None): + """ + Get logger from logging with given name, level and format without + setting logging basicConfig. For setting basicConfig in paddle + will disable basicConfig setting after import paddle. + Args: + name (str): The logger name. + level (logging.LEVEL): The base level of the logger + fmt (str): Format of logger output + Returns: + logging.Logger: logging logger with given setttings + Examples: + .. code-block:: python + logger = log_helper.get_logger(__name__, logging.INFO, + fmt='%(asctime)s-%(levelname)s: %(message)s') + """ + + logger = logging.getLogger(name) + logger.setLevel(level) + handler = logging.StreamHandler() + + if fmt: + formatter = logging.Formatter(fmt=fmt) + handler.setFormatter(formatter) + + logger.addHandler(handler) + return logger diff --git a/paddleslim/common/sa_controller.py b/paddleslim/common/sa_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..b619b818a3208d740c1ddb6753cf5931f3d058f5 --- /dev/null +++ b/paddleslim/common/sa_controller.py @@ -0,0 +1,113 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The controller used to search hyperparameters or neural architecture""" + +import copy +import math +import logging +import numpy as np +from .controller import EvolutionaryController +from log_helper import get_logger + +__all__ = ["SAController"] + +_logger = get_logger(__name__, level=logging.INFO) + + +class SAController(EvolutionaryController): + """Simulated annealing controller.""" + + def __init__(self, + range_table=None, + reduce_rate=0.85, + init_temperature=1024, + max_iter_number=300, + init_tokens=None, + constrain_func=None): + """Initialize. + Args: + range_table(list): Range table. + reduce_rate(float): The decay rate of temperature. + init_temperature(float): Init temperature. + max_iter_number(int): max iteration number. + init_tokens(list): The initial tokens. + constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. + """ + super(SAController, self).__init__() + self._range_table = range_table + assert isinstance(self._range_table, tuple) and ( + len(self._range_table) == 2) + self._reduce_rate = reduce_rate + self._init_temperature = init_temperature + self._max_iter_number = max_iter_number + self._reward = -1 + self._tokens = init_tokens + self._constrain_func = constrain_func + self._max_reward = -1 + self._best_tokens = None + self._iter = 0 + + def __getstate__(self): + d = {} + for key in self.__dict__: + if key != "_constrain_func": + d[key] = self.__dict__[key] + return d + + def update(self, tokens, reward): + """ + Update the controller according to latest tokens and reward. + Args: + tokens(list): The tokens generated in last step. + reward(float): The reward of tokens. + """ + self._iter += 1 + temperature = self._init_temperature * self._reduce_rate**self._iter + if (reward > self._reward) or (np.random.random() <= math.exp( + (reward - self._reward) / temperature)): + self._reward = reward + self._tokens = tokens + if reward > self._max_reward: + self._max_reward = reward + self._best_tokens = tokens + _logger.info( + "Controller - iter: {}; current_reward: {}; current tokens: {}". + format(self._iter, self._reward, self._tokens)) + + def next_tokens(self, control_token=None): + """ + Get next tokens. + """ + if control_token: + tokens = control_token[:] + else: + tokens = self._tokens + new_tokens = tokens[:] + index = int(len(self._range_table[0]) * np.random.random()) + new_tokens[index] = np.random.randint(self._range_table[0][index], + self._range_table[1][index] + 1) + _logger.debug("change index[{}] from {} to {}".format(index, tokens[ + index], new_tokens[index])) + if self._constrain_func is None: + return new_tokens + for _ in range(self._max_iter_number): + if not self._constrain_func(new_tokens): + index = int(len(self._range_table[0]) * np.random.random()) + new_tokens = tokens[:] + new_tokens[index] = np.random.randint( + self._range_table[0][index], + self._range_table[1][index] + 1) + else: + break + return new_tokens diff --git a/paddleslim/nas/__init__.py b/paddleslim/nas/__init__.py index 2f5509144f53529ae717b72bbb7252b4b06a0048..f11948f6bcbdd3d52e334bed3b06510e226825bc 100644 --- a/paddleslim/nas/__init__.py +++ b/paddleslim/nas/__init__.py @@ -14,6 +14,9 @@ import search_space from search_space import * +import sa_nas +from sa_nas import * __all__ = [] __all__ += search_space.__all__ +__all__ += sa_nas.__all__ diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py new file mode 100644 index 0000000000000000000000000000000000000000..bbee0d8db641c5b61d520e5a8043721893e86ef5 --- /dev/null +++ b/paddleslim/nas/sa_nas.py @@ -0,0 +1,147 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket +import logging +import numpy as np +import paddle.fluid as fluid +from ..core import VarWrapper, OpWrapper, GraphWrapper +from ..common import SAController +from ..common import get_logger +from ..analysis import flops + +from ..common import ControllerServer +from ..common import ControllerClient +from .search_space import SearchSpaceFactory + +__all__ = ["SANAS"] + +_logger = get_logger(__name__, level=logging.INFO) + + +class SANAS(object): + def __init__(self, + configs, + max_flops=None, + max_latency=None, + server_addr=("", 0), + init_temperature=100, + reduce_rate=0.85, + max_try_number=300, + max_client_num=10, + search_steps=300, + key="sa_nas", + is_server=True): + """ + Search a group of ratios used to prune program. + Args: + configs(list): A list of search space configuration with format (key, input_size, output_size, block_num). + `key` is the name of search space with data type str. `input_size` and `output_size` are + input size and output size of searched sub-network. `block_num` is the number of blocks in searched network. + max_flops(int): The max flops of searched network. None means no constrains. Default: None. + max_latency(float): The max latency of searched network. None means no constrains. Default: None. + server_addr(tuple): A tuple of server ip and server port for controller server. + init_temperature(float): The init temperature used in simulated annealing search strategy. + reduce_rate(float): The decay rate used in simulated annealing search strategy. + max_try_number(int): The max number of trying to generate legal tokens. + max_client_num(int): The max number of connections of controller server. + search_steps(int): The steps of searching. + key(str): Identity used in communication between controller server and clients. + is_server(bool): Whether current host is controller server. Default: True. + """ + + self._reduce_rate = reduce_rate + self._init_temperature = init_temperature + self._max_try_number = max_try_number + self._is_server = is_server + self._max_flops = max_flops + self._max_latency = max_latency + + self._configs = configs + + factory = SearchSpaceFactory() + self._search_space = factory.get_search_space(configs) + init_tokens = self._search_space.init_tokens() + range_table = self._search_space.range_table() + range_table = (len(range_table) * [0], range_table) + + print range_table + + controller = SAController(range_table, self._reduce_rate, + self._init_temperature, self._max_try_number, + init_tokens, self._constrain_func) + + server_ip, server_port = server_addr + if server_ip == None or server_ip == "": + server_ip = self._get_host_ip() + + self._controller_server = ControllerServer( + controller=controller, + address=(server_ip, server_port), + max_client_num=max_client_num, + search_steps=search_steps, + key=key) + + # create controller server + if self._is_server: + self._controller_server.start() + + self._controller_client = ControllerClient( + self._controller_server.ip(), + self._controller_server.port(), + key=key) + + self._iter = 0 + + def _get_host_ip(self): + return socket.gethostbyname(socket.gethostname()) + + def _constrain_func(self, tokens): + if (self._max_flops is None) and (self._max_latency is None): + return True + archs = self._search_space.token2arch(tokens) + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + i = 0 + for config, arch in zip(self._configs, archs): + input_size = config[1]["input_size"] + input = fluid.data( + name="data_{}".format(i), + shape=[None, 3, input_size, input_size], + dtype="float32") + output = arch(input) + i += 1 + return flops(main_program) < self._max_flops + + def next_archs(self): + """ + Get next network architectures. + Returns: + list: A list of functions that define networks. + """ + self._current_tokens = self._controller_client.next_tokens() + archs = self._search_space.token2arch(self._current_tokens) + return archs + + def reward(self, score): + """ + Return reward of current searched network. + Args: + score(float): The score of current searched network. + Returns: + bool: True means updating successfully while false means failure. + """ + self._iter += 1 + return self._controller_client.update(self._current_tokens, score) diff --git a/paddleslim/nas/search_space/__init__.py b/paddleslim/nas/search_space/__init__.py index c8bef8db17e4a4cea110a3ef3fd4f3d7edceeedc..51b433d452b8cd8c3eb32582d9caa43634b700d0 100644 --- a/paddleslim/nas/search_space/__init__.py +++ b/paddleslim/nas/search_space/__init__.py @@ -14,6 +14,8 @@ import mobilenetv2 from .mobilenetv2 import * +import mobilenetv1 +from .mobilenetv1 import * import resnet from .resnet import * import search_space_registry @@ -28,4 +30,3 @@ __all__ += mobilenetv2.__all__ __all__ += search_space_registry.__all__ __all__ += search_space_factory.__all__ __all__ += search_space_base.__all__ - diff --git a/paddleslim/nas/search_space/base_layer.py b/paddleslim/nas/search_space/base_layer.py index 2e769ec6339b639732995849e9f819a08b749c92..b497c92a2ca57b4acab0c39c5dbd69d30083e295 100644 --- a/paddleslim/nas/search_space/base_layer.py +++ b/paddleslim/nas/search_space/base_layer.py @@ -20,7 +20,7 @@ def conv_bn_layer(input, filter_size, num_filters, stride, - padding, + padding='SAME', num_groups=1, act=None, name=None, @@ -51,15 +51,10 @@ def conv_bn_layer(input, param_attr=ParamAttr(name=name + '_weights'), bias_attr=False) bn_name = name + '_bn' - bn = fluid.layers.batch_norm( - input=conv, - param_attr=ParamAttr(name=bn_name + '_scale'), - bias_attr=ParamAttr(name=bn_name + '_offset'), - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance') - if act == 'relu6': - return fluid.layers.relu6(bn) - elif act == 'sigmoid': - return fluid.layers.sigmoid(bn) - else: - return bn + return fluid.layers.batch_norm( + input=conv, + act = act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(name=bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') diff --git a/paddleslim/nas/search_space/combine_search_space.py b/paddleslim/nas/search_space/combine_search_space.py index 371bcf5347ebbc21e0688d1611ed3b298b940eb1..667720a9110aa92e096a4f8fa30bb3e4b3e3cecb 100644 --- a/paddleslim/nas/search_space/combine_search_space.py +++ b/paddleslim/nas/search_space/combine_search_space.py @@ -25,12 +25,14 @@ from .base_layer import conv_bn_layer __all__ = ["CombineSearchSpace"] + class CombineSearchSpace(object): """ Combine Search Space. Args: configs(list): multi config. """ + def __init__(self, config_lists): self.lens = len(config_lists) self.spaces = [] @@ -50,11 +52,10 @@ class CombineSearchSpace(object): """ cls = SEARCHSPACE.get(key) space = cls(config['input_size'], config['output_size'], - config['block_num']) + config['block_num'], config['block_mask']) return space - def init_tokens(self): """ Combine init tokens. @@ -96,4 +97,3 @@ class CombineSearchSpace(object): model_archs.append(space.token2arch(token)) return model_archs - diff --git a/paddleslim/nas/search_space/mobilenetv1.py b/paddleslim/nas/search_space/mobilenetv1.py new file mode 100644 index 0000000000000000000000000000000000000000..8b3277d2cb1b472ccd5e27407e3099b28e64f42b --- /dev/null +++ b/paddleslim/nas/search_space/mobilenetv1.py @@ -0,0 +1,224 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from .search_space_base import SearchSpaceBase +from .base_layer import conv_bn_layer +from .search_space_registry import SEARCHSPACE + +__all__ = ["MobileNetV1Space"] + + +@SEARCHSPACE.register +class MobileNetV1Space(SearchSpaceBase): + def __init__(self, + input_size, + output_size, + block_num, + scale=1.0, + class_dim=1000): + super(MobileNetV1Space, self).__init__(input_size, output_size, + block_num) + self.scale = scale + self.class_dim = class_dim + # self.head_num means the channel of first convolution + self.head_num = np.array([3, 4, 8, 12, 16, 24, 32]) # 7 + # self.filter_num1 ~ self.filtet_num9 means channel of the following convolution + self.filter_num1 = np.array([3, 4, 8, 12, 16, 24, 32, 48]) # 8 + self.filter_num2 = np.array([8, 12, 16, 24, 32, 48, 64, 80]) # 8 + self.filter_num3 = np.array( + [16, 24, 32, 48, 64, 80, 96, 128, 144, 160]) #10 + self.filter_num4 = np.array( + [24, 32, 48, 64, 80, 96, 128, 144, 160, 192]) #10 + self.filter_num5 = np.array( + [32, 48, 64, 80, 96, 128, 144, 160, 192, 224, 256, 320]) #12 + self.filter_num6 = np.array( + [64, 80, 96, 128, 144, 160, 192, 224, 256, 320, 384]) #11 + self.filter_num7 = np.array([ + 64, 80, 96, 128, 144, 160, 192, 224, 256, 320, 384, 512, 1024, 1048 + ]) #14 + self.filter_num8 = np.array( + [128, 144, 160, 192, 224, 256, 320, 384, 512, 576, 640, 704, + 768]) #13 + self.filter_num9 = np.array( + [160, 192, 224, 256, 320, 384, 512, 640, 768, 832, 1024, + 1048]) #12 + # self.k_size means kernel size + self.k_size = np.array([3, 5]) #2 + # self.repeat means repeat_num in forth downsample + self.repeat = np.array([1, 2, 3, 4, 5, 6]) #6 + + assert self.block_num < 6, 'MobileNetV1: block number must less than 6, but receive block number is {}'.format( + self.block_num) + + def init_tokens(self): + """ + The initial token. + The first one is the index of the first layers' channel in self.head_num, + each line in the following represent the index of the [filter_num1, filter_num2, kernel_size] + and depth means repeat times for forth downsample + """ + # yapf: disable + base_init_tokens = [6, # 32 + 6, 6, 0, # 32, 64, 3 + 6, 7, 0, # 64, 128, 3 + 7, 6, 0, # 128, 128, 3 + 6, 10, 0, # 128, 256, 3 + 10, 8, 0, # 256, 256, 3 + 8, 11, 0, # 256, 512, 3 + 4, # depth 5 + 11, 8, 0, # 512, 512, 3 + 8, 10, 0, # 512, 1024, 3 + 10, 10, 0] # 1024, 1024, 3 + # yapf: enable + if self.block_num < 5: + self.token_len = 1 + (self.block_num * 2 - 1) * 3 + else: + self.token_len = 2 + (self.block_num * 2 - 1) * 3 + return base_init_tokens[:self.token_len] + + def range_table(self): + """ + Get range table of current search space, constrains the range of tokens. + """ + # yapf: disable + base_range_table = [len(self.head_num), + len(self.filter_num1), len(self.filter_num2), len(self.k_size), + len(self.filter_num2), len(self.filter_num3), len(self.k_size), + len(self.filter_num3), len(self.filter_num4), len(self.k_size), + len(self.filter_num4), len(self.filter_num5), len(self.k_size), + len(self.filter_num5), len(self.filter_num6), len(self.k_size), + len(self.filter_num6), len(self.filter_num7), len(self.k_size), + len(self.repeat), + len(self.filter_num7), len(self.filter_num8), len(self.k_size), + len(self.filter_num8), len(self.filter_num9), len(self.k_size), + len(self.filter_num9), len(self.filter_num9), len(self.k_size)] + # yapf: enable + return base_range_table[:self.token_len] + + def token2arch(self, tokens=None): + + if tokens is None: + tokens = self.tokens() + + bottleneck_param_list = [] + + if self.block_num >= 1: + # tokens[0] = 32 + # 32, 64 + bottleneck_param_list.append( + (self.filter_num1[tokens[1]], self.filter_num2[tokens[2]], 1, + self.k_size[tokens[3]])) + if self.block_num >= 2: + # 64 128 128 128 + bottleneck_param_list.append( + (self.filter_num2[tokens[4]], self.filter_num3[tokens[5]], 2, + self.k_size[tokens[6]])) + bottleneck_param_list.append( + (self.filter_num3[tokens[7]], self.filter_num4[tokens[8]], 1, + self.k_size[tokens[9]])) + if self.block_num >= 3: + # 128 256 256 256 + bottleneck_param_list.append( + (self.filter_num4[tokens[10]], self.filter_num5[tokens[11]], 2, + self.k_size[tokens[12]])) + bottleneck_param_list.append( + (self.filter_num5[tokens[13]], self.filter_num6[tokens[14]], 1, + self.k_size[tokens[15]])) + if self.block_num >= 4: + # 256 512 (512 512) * 5 + bottleneck_param_list.append( + (self.filter_num6[tokens[16]], self.filter_num7[tokens[17]], 2, + self.k_size[tokens[18]])) + for i in range(self.repeat[tokens[19]]): + bottleneck_param_list.append( + (self.filter_num7[tokens[20]], + self.filter_num8[tokens[21]], 1, self.k_size[tokens[22]])) + if self.block_num >= 5: + # 512 1024 1024 1024 + bottleneck_param_list.append( + (self.filter_num8[tokens[23]], self.filter_num9[tokens[24]], 2, + self.k_size[tokens[25]])) + bottleneck_param_list.append( + (self.filter_num9[tokens[26]], self.filter_num9[tokens[27]], 1, + self.k_size[tokens[28]])) + + def net_arch(input): + input = conv_bn_layer( + input=input, + filter_size=3, + num_filters=self.head_num[tokens[0]], + stride=2, + name='mobilenetv1') + + for i, layer_setting in enumerate(bottleneck_param_list): + filter_num1, filter_num2, stride, kernel_size = layer_setting + input = self._depthwise_separable( + input=input, + num_filters1=filter_num1, + num_filters2=filter_num2, + num_groups=filter_num1, + stride=stride, + scale=self.scale, + kernel_size=kernel_size, + name='mobilenetv1_{}'.format(str(i + 1))) + + if self.output_size == 1: + print('NOTE: if output_size is 1, add fc layer in the end!!!') + input = fluid.layers.fc( + input=input, + size=self.class_dim, + param_attr=ParamAttr(name='mobilenetv2_fc_weights'), + bias_attr=ParamAttr(name='mobilenetv2_fc_offset')) + else: + assert self.output_size == input.shape[2], \ + ("output_size must EQUAL to input_size / (2^block_num)." + "But receive input_size={}, output_size={}, block_num={}".format( + self.input_size, self.output_size, self.block_num)) + + return input + + return net_arch + + def _depthwise_separable(self, + input, + num_filters1, + num_filters2, + num_groups, + stride, + scale, + kernel_size, + name=None): + depthwise_conv = conv_bn_layer( + input=input, + filter_size=kernel_size, + num_filters=int(num_filters1 * scale), + stride=stride, + num_groups=int(num_groups * scale), + use_cudnn=False, + name=name + '_dw') + pointwise_conv = conv_bn_layer( + input=depthwise_conv, + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + name=name + '_sep') + + return pointwise_conv diff --git a/paddleslim/nas/search_space/mobilenetv2.py b/paddleslim/nas/search_space/mobilenetv2.py index 90e0b2a0e704a2f44a8031b05d475419bc534677..e974a676a70546e19aa4649679393031634e7822 100644 --- a/paddleslim/nas/search_space/mobilenetv2.py +++ b/paddleslim/nas/search_space/mobilenetv2.py @@ -32,11 +32,15 @@ class MobileNetV2Space(SearchSpaceBase): input_size, output_size, block_num, + block_mask=None, scale=1.0, class_dim=1000): super(MobileNetV2Space, self).__init__(input_size, output_size, - block_num) + block_num, block_mask) + assert self.block_mask == None, 'MobileNetV2Space will use origin MobileNetV2 as seach space, so use input_size, output_size and block_num to search' + # self.head_num means the first convolution channel self.head_num = np.array([3, 4, 8, 12, 16, 24, 32]) #7 + # self.filter_num1 ~ self.filter_num6 means following convlution channel self.filter_num1 = np.array([3, 4, 8, 12, 16, 24, 32, 48]) #8 self.filter_num2 = np.array([8, 12, 16, 24, 32, 48, 64, 80]) #8 self.filter_num3 = np.array([16, 24, 32, 48, 64, 80, 96, 128]) #8 @@ -46,16 +50,21 @@ class MobileNetV2Space(SearchSpaceBase): [32, 48, 64, 80, 96, 128, 144, 160, 192, 224]) #10 self.filter_num6 = np.array( [64, 80, 96, 128, 144, 160, 192, 224, 256, 320, 384, 512]) #12 + # self.k_size means kernel size self.k_size = np.array([3, 5]) #2 + # self.multiply means expansion_factor of each _inverted_residual_unit self.multiply = np.array([1, 2, 3, 4, 6]) #5 + # self.repeat means repeat_num _inverted_residual_unit in each _invresi_blocks self.repeat = np.array([1, 2, 3, 4, 5, 6]) #6 self.scale = scale self.class_dim = class_dim + assert self.block_num < 7, 'MobileNetV2: block number must less than 7, but receive block number is {}'.format( + self.block_num) def init_tokens(self): """ - The initial token send to controller. + The initial token. The first one is the index of the first layers' channel in self.head_num, each line in the following represent the index of the [expansion_factor, filter_num, repeat_num, kernel_size] """ @@ -71,27 +80,29 @@ class MobileNetV2Space(SearchSpaceBase): 4, 9, 0, 0] # 6, 320, 1 # yapf: enable - if self.block_num < 5: + if self.block_num < 5: self.token_len = 1 + (self.block_num - 1) * 4 else: - self.token_len = 1 + (self.block_num + 2 * (self.block_num - 5)) * 4 + self.token_len = 1 + (self.block_num + 2 * + (self.block_num - 5)) * 4 return init_token_base[:self.token_len] def range_table(self): """ - get range table of current search space + Get range table of current search space, constrains the range of tokens. """ # head_num + 7 * [multiple(expansion_factor), filter_num, repeat, kernel_size] # yapf: disable - range_table_base = [7, - 5, 8, 6, 2, - 5, 8, 6, 2, - 5, 8, 6, 2, - 5, 8, 6, 2, - 5, 10, 6, 2, - 5, 10, 6, 2, - 5, 12, 6, 2] + range_table_base = [len(self.head_num), + len(self.multiply), len(self.filter_num1), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num1), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num2), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num3), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num4), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num5), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num6), len(self.repeat), len(self.k_size)] + range_table_base = list(np.array(range_table_base) - 1) # yapf: enable return range_table_base[:self.token_len] @@ -100,31 +111,41 @@ class MobileNetV2Space(SearchSpaceBase): return net_arch function """ - assert self.block_num < 7, 'block number must less than 7, but receive block number is {}'.format( - self.block_num) - if tokens is None: tokens = self.init_tokens() + print(tokens) bottleneck_params_list = [] - if self.block_num >= 1: bottleneck_params_list.append((1, self.head_num[tokens[0]], 1, 1, 3)) - if self.block_num >= 2: bottleneck_params_list.append((self.multiply[tokens[1]], self.filter_num1[tokens[2]], - self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) - if self.block_num >= 3: bottleneck_params_list.append((self.multiply[tokens[5]], self.filter_num1[tokens[6]], - self.repeat[tokens[7]], 2, self.k_size[tokens[8]])) - if self.block_num >= 4: bottleneck_params_list.append((self.multiply[tokens[9]], self.filter_num2[tokens[10]], - self.repeat[tokens[11]], 2, self.k_size[tokens[12]])) - if self.block_num >= 5: - bottleneck_params_list.append((self.multiply[tokens[13]], self.filter_num3[tokens[14]], - self.repeat[tokens[15]], 2, self.k_size[tokens[16]])) - bottleneck_params_list.append((self.multiply[tokens[17]], self.filter_num3[tokens[18]], - self.repeat[tokens[19]], 1, self.k_size[tokens[20]])) - if self.block_num >= 6: - bottleneck_params_list.append((self.multiply[tokens[21]], self.filter_num5[tokens[22]], - self.repeat[tokens[23]], 2, self.k_size[tokens[24]])) - bottleneck_params_list.append((self.multiply[tokens[25]], self.filter_num6[tokens[26]], - self.repeat[tokens[27]], 1, self.k_size[tokens[28]])) - + if self.block_num >= 1: + bottleneck_params_list.append( + (1, self.head_num[tokens[0]], 1, 1, 3)) + if self.block_num >= 2: + bottleneck_params_list.append( + (self.multiply[tokens[1]], self.filter_num1[tokens[2]], + self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) + if self.block_num >= 3: + bottleneck_params_list.append( + (self.multiply[tokens[5]], self.filter_num1[tokens[6]], + self.repeat[tokens[7]], 2, self.k_size[tokens[8]])) + if self.block_num >= 4: + bottleneck_params_list.append( + (self.multiply[tokens[9]], self.filter_num2[tokens[10]], + self.repeat[tokens[11]], 2, self.k_size[tokens[12]])) + if self.block_num >= 5: + bottleneck_params_list.append( + (self.multiply[tokens[13]], self.filter_num3[tokens[14]], + self.repeat[tokens[15]], 2, self.k_size[tokens[16]])) + bottleneck_params_list.append( + (self.multiply[tokens[17]], self.filter_num4[tokens[18]], + self.repeat[tokens[19]], 1, self.k_size[tokens[20]])) + if self.block_num >= 6: + bottleneck_params_list.append( + (self.multiply[tokens[21]], self.filter_num5[tokens[22]], + self.repeat[tokens[23]], 2, self.k_size[tokens[24]])) + bottleneck_params_list.append( + (self.multiply[tokens[25]], self.filter_num6[tokens[26]], + self.repeat[tokens[27]], 1, self.k_size[tokens[28]])) + def net_arch(input): #conv1 # all padding is 'SAME' in the conv2d, can compute the actual padding automatic. @@ -182,15 +203,15 @@ class MobileNetV2Space(SearchSpaceBase): return fluid.layers.elementwise_add(input, data_residual) def _inverted_residual_unit(self, - input, - num_in_filter, - num_filters, - ifshortcut, - stride, - filter_size, - expansion_factor, - reduction_ratio=4, - name=None): + input, + num_in_filter, + num_filters, + ifshortcut, + stride, + filter_size, + expansion_factor, + reduction_ratio=4, + name=None): """Build inverted residual unit. Args: input(Variable), input. diff --git a/paddleslim/nas/search_space/resnet.py b/paddleslim/nas/search_space/resnet.py index a6ac5817ce89190987d67f1eda644fa2aef79037..fd761d417575988e8ba8bd99da25372613c5912f 100644 --- a/paddleslim/nas/search_space/resnet.py +++ b/paddleslim/nas/search_space/resnet.py @@ -25,34 +25,151 @@ from .search_space_registry import SEARCHSPACE __all__ = ["ResNetSpace"] + @SEARCHSPACE.register class ResNetSpace(SearchSpaceBase): - def __init__(self, input_size, output_size, block_num, scale=1.0, class_dim=1000): - super(ResNetSpace, self).__init__(input_size, output_size, block_num) - pass + def __init__(self, + input_size, + output_size, + block_num, + block_mask=None, + extract_feature=False, + class_dim=1000): + super(ResNetSpace, self).__init__(input_size, output_size, block_num, + block_mask) + assert self.block_mask == None, 'ResNetSpace will use origin ResNet as seach space, so use input_size, output_size and block_num to search' + # self.filter_num1 ~ self.filter_num4 means convolution channel + self.filter_num1 = np.array([48, 64, 96, 128, 160, 192, 224]) #7 + self.filter_num2 = np.array([64, 96, 128, 160, 192, 256, 320]) #7 + self.filter_num3 = np.array([128, 160, 192, 256, 320, 384]) #6 + self.filter_num4 = np.array([192, 256, 384, 512, 640]) #5 + # self.repeat1 ~ self.repeat4 means depth of network + self.repeat1 = [2, 3, 4, 5, 6] #5 + self.repeat2 = [2, 3, 4, 5, 6, 7] #6 + self.repeat3 = [2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24] #13 + self.repeat4 = [2, 3, 4, 5, 6, 7] #6 + self.class_dim = class_dim + self.extract_feature = extract_feature + assert self.block_num < 5, 'ResNet: block number must less than 5, but receive block number is {}'.format( + self.block_num) def init_tokens(self): - return [0,0,0,0,0,0] + """ + The initial token. + return 2 * self.block_num, 2 means depth and num_filter + """ + init_token_base = [0, 0, 0, 0, 0, 0, 0, 0] + self.token_len = self.block_num * 2 + return init_token_base[:self.token_len] def range_table(self): - return [3,3,3,3,3,3] + """ + Get range table of current search space, constrains the range of tokens. + """ + #2 * self.block_num, 2 means depth and num_filter + range_table_base = [ + len(self.filter_num1), len(self.repeat1), len(self.filter_num2), + len(self.repeat2), len(self.filter_num3), len(self.repeat3), + len(self.filter_num4), len(self.repeat4) + ] + return range_table_base[:self.token_len] - def token2arch(self,tokens=None): + def token2arch(self, tokens=None): + """ + return net_arch function + """ if tokens is None: - self.init_tokens() + tokens = self.init_tokens() + + depth = [] + num_filters = [] + if self.block_num >= 1: + filter1 = self.filter_num1[tokens[0]] + repeat1 = self.repeat1[tokens[1]] + num_filters.append(filter1) + depth.append(repeat1) + if self.block_num >= 2: + filter2 = self.filter_num2[tokens[2]] + repeat2 = self.repeat2[tokens[3]] + num_filters.append(filter2) + depth.append(repeat2) + if self.block_num >= 3: + filter3 = self.filter_num3[tokens[4]] + repeat3 = self.repeat3[tokens[5]] + num_filters.append(filter3) + depth.append(repeat3) + if self.block_num >= 4: + filter4 = self.filter_num4[tokens[6]] + repeat4 = self.repeat4[tokens[7]] + num_filters.append(filter4) + depth.append(repeat4) def net_arch(input): - input = conv_bn_layer( - input, - num_filters=32, - filter_size=3, + conv = conv_bn_layer( + input=input, + filter_size=5, + num_filters=filter1, stride=2, - padding='SAME', - act='sigmoid', - name='resnet_conv1_1') + act='relu', + name='resnet_conv0') + for block in range(len(depth)): + for i in range(depth[block]): + conv = self._bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + name='resnet_depth{}_block{}'.format(i, block)) - return input + if self.output_size == 1: + conv = fluid.layers.fc( + input=conv, + size=self.class_dim, + act=None, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.NormalInitializer(0.0, + 0.01)), + bias_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.ConstantInitializer(0))) + + return conv return net_arch + def _shortcut(self, input, ch_out, stride, name=None): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + return conv_bn_layer( + input=input, + filter_size=1, + num_filters=ch_out, + stride=stride, + name=name + '_conv') + else: + return input + + def _bottleneck_block(self, input, num_filters, stride, name=None): + conv0 = conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + '_bottleneck_conv0') + conv1 = conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu', + name=name + '_bottleneck_conv1') + conv2 = conv_bn_layer( + input=conv1, + num_filters=num_filters * 4, + filter_size=1, + act=None, + name=name + '_bottleneck_conv2') + + short = self._shortcut( + input, num_filters * 4, stride, name=name + '_shortcut') + return fluid.layers.elementwise_add( + x=short, y=conv2, act='relu', name=name + '_bottleneck_add') diff --git a/paddleslim/nas/search_space/search_space_base.py b/paddleslim/nas/search_space/search_space_base.py index bb1ce0f8a4bbd0b18d36fa9199a6ff814ab13236..6a83f86005a5fb2408f7f85f40dff8a9e5cba819 100644 --- a/paddleslim/nas/search_space/search_space_base.py +++ b/paddleslim/nas/search_space/search_space_base.py @@ -19,10 +19,11 @@ class SearchSpaceBase(object): """Controller for Neural Architecture Search. """ - def __init__(self, input_size, output_size, block_num, *argss): + def __init__(self, input_size, output_size, block_num, block_mask, *argss): self.input_size = input_size self.output_size = output_size self.block_num = block_num + self.block_mask = block_mask def init_tokens(self): """Get init tokens in search space. diff --git a/paddleslim/prune/__init__.py b/paddleslim/prune/__init__.py index 926586c67d9e0b73ecd66f107ef897b389c5844f..bb615b9dfca03ed2b289f902f6d75c73543f6fb2 100644 --- a/paddleslim/prune/__init__.py +++ b/paddleslim/prune/__init__.py @@ -11,4 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pruner import Pruner +import pruner +from pruner import * +import auto_pruner +from auto_pruner import * +import controller_server +from controller_server import * +import controller_client +from controller_client import * + +__all__ = [] +__all__ += pruner.__all__ +__all__ += auto_pruner.__all__ +__all__ += controller_server.__all__ +__all__ += controller_client.__all__ diff --git a/paddleslim/prune/auto_pruner.py b/paddleslim/prune/auto_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..b144251a0a9a294094f7101f30958486abcf0543 --- /dev/null +++ b/paddleslim/prune/auto_pruner.py @@ -0,0 +1,226 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket +import logging +import numpy as np +import paddle.fluid as fluid +from .pruner import Pruner +from ..core import VarWrapper, OpWrapper, GraphWrapper +from ..common import SAController +from ..common import get_logger +from ..analysis import flops + +from ..common import ControllerServer +from ..common import ControllerClient + +__all__ = ["AutoPruner"] + +_logger = get_logger(__name__, level=logging.INFO) + + +class AutoPruner(object): + def __init__(self, + program, + scope, + place, + params=[], + init_ratios=None, + pruned_flops=0.5, + pruned_latency=None, + server_addr=("", 0), + init_temperature=100, + reduce_rate=0.85, + max_try_number=300, + max_client_num=10, + search_steps=300, + max_ratios=[0.9], + min_ratios=[0], + key="auto_pruner", + is_server=True): + """ + Search a group of ratios used to prune program. + Args: + program(Program): The program to be pruned. + scope(Scope): The scope to be pruned. + place(fluid.Place): The device place of parameters. + params(list): The names of parameters to be pruned. + init_ratios(list|float): Init ratios used to pruned parameters in `params`. + List means ratios used for pruning each parameter in `params`. + The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. + If it is a scalar, all the parameters in `params` will be pruned by uniform ratio. + None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None. + pruned_flops(float): The percent of FLOPS to be pruned. Default: None. + pruned_latency(float): The percent of latency to be pruned. Default: None. + server_addr(tuple): A tuple of server ip and server port for controller server. + init_temperature(float): The init temperature used in simulated annealing search strategy. + reduce_rate(float): The decay rate used in simulated annealing search strategy. + max_try_number(int): The max number of trying to generate legal tokens. + max_client_num(int): The max number of connections of controller server. + search_steps(int): The steps of searching. + max_ratios(float|list): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`. + The length of `max_ratios` should be equal to length of params when `max_ratios` is a list. + If it is a scalar, it will used for all the parameters in `params`. + min_ratios(float|list): Min ratios used to pruned parameters in `params`. List means min ratios for each parameter in `params`. + The length of `min_ratios` should be equal to length of params when `min_ratios` is a list. + If it is a scalar, it will used for all the parameters in `params`. + key(str): Identity used in communication between controller server and clients. + is_server(bool): Whether current host is controller server. Default: True. + """ + + self._program = program + self._scope = scope + self._place = place + self._params = params + self._init_ratios = init_ratios + self._pruned_flops = pruned_flops + self._pruned_latency = pruned_latency + self._reduce_rate = reduce_rate + self._init_temperature = init_temperature + self._max_try_number = max_try_number + self._is_server = is_server + + self._range_table = self._get_range_table(min_ratios, max_ratios) + + self._pruner = Pruner() + if self._pruned_flops: + self._base_flops = flops(program) + _logger.info("AutoPruner - base flops: {};".format( + self._base_flops)) + if self._pruned_latency: + self._base_latency = latency(program) + + if self._init_ratios is None: + self._init_ratios = self._get_init_ratios( + self, _program, self._params, self._pruned_flops, + self._pruned_latency) + init_tokens = self._ratios2tokens(self._init_ratios) + _logger.info("range table: {}".format(self._range_table)) + controller = SAController(self._range_table, self._reduce_rate, + self._init_temperature, self._max_try_number, + init_tokens, self._constrain_func) + + server_ip, server_port = server_addr + if server_ip == None or server_ip == "": + server_ip = self._get_host_ip() + + self._controller_server = ControllerServer( + controller=controller, + address=(server_ip, server_port), + max_client_num=max_client_num, + search_steps=search_steps, + key=key) + + # create controller server + if self._is_server: + self._controller_server.start() + + self._controller_client = ControllerClient( + self._controller_server.ip(), + self._controller_server.port(), + key=key) + + self._iter = 0 + self._param_backup = {} + + def _get_host_ip(self): + return socket.gethostbyname(socket.gethostname()) + + def _get_init_ratios(self, program, params, pruned_flops, pruned_latency): + pass + + def _get_range_table(self, min_ratios, max_ratios): + assert isinstance(min_ratios, list) or isinstance(min_ratios, float) + assert isinstance(max_ratios, list) or isinstance(max_ratios, float) + min_ratios = min_ratios if isinstance( + min_ratios, list) else [min_ratios] * len(self._params) + max_ratios = max_ratios if isinstance( + max_ratios, list) else [max_ratios] * len(self._params) + min_tokens = self._ratios2tokens(min_ratios) + max_tokens = self._ratios2tokens(max_ratios) + return (min_tokens, max_tokens) + + def _constrain_func(self, tokens): + ratios = self._tokens2ratios(tokens) + pruned_program = self._pruner.prune( + self._program, + self._scope, + self._params, + ratios, + place=self._place, + only_graph=True) + return flops(pruned_program) < self._base_flops * ( + 1 - self._pruned_flops) + + def prune(self, program, eval_program=None): + """ + Prune program with latest tokens generated by controller. + Args: + program(fluid.Program): The program to be pruned. + Returns: + Program: The pruned program. + """ + self._current_ratios = self._next_ratios() + pruned_program = self._pruner.prune( + program, + self._scope, + self._params, + self._current_ratios, + place=self._place, + only_graph=False, + param_backup=self._param_backup) + pruned_val_program = None + if eval_program is not None: + pruned_val_program = self._pruner.prune( + program, + self._scope, + self._params, + self._current_ratios, + place=self._place, + only_graph=True) + + _logger.info("AutoPruner - pruned ratios: {}".format( + self._current_ratios)) + return pruned_program, pruned_val_program + + def reward(self, score): + """ + Return reward of current pruned program. + Args: + score(float): The score of pruned program. + """ + self._restore(self._scope) + self._param_backup = {} + tokens = self._ratios2tokens(self._current_ratios) + self._controller_client.update(tokens, score) + self._iter += 1 + + def _restore(self, scope): + for param_name in self._param_backup.keys(): + param_t = scope.find_var(param_name).get_tensor() + param_t.set(self._param_backup[param_name], self._place) + + def _next_ratios(self): + tokens = self._controller_client.next_tokens() + return self._tokens2ratios(tokens) + + def _ratios2tokens(self, ratios): + """Convert pruned ratios to tokens. + """ + return [int(ratio / 0.01) for ratio in ratios] + + def _tokens2ratios(self, tokens): + """Convert tokens to pruned ratios. + """ + return [token * 0.01 for token in tokens] diff --git a/paddleslim/prune/controller_client.py b/paddleslim/prune/controller_client.py new file mode 100644 index 0000000000000000000000000000000000000000..f133e8b28f823bba89024fe1473630feb509a616 --- /dev/null +++ b/paddleslim/prune/controller_client.py @@ -0,0 +1,66 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import socket +from ..common import get_logger + +__all__ = ['ControllerClient'] + +_logger = get_logger(__name__, level=logging.INFO) + + +class ControllerClient(object): + """ + Controller client. + """ + + def __init__(self, server_ip=None, server_port=None, key=None): + """ + Args: + server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None. + server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0. + key(str): The key used to identify legal agent for controller server. Default: "light-nas" + """ + self.server_ip = server_ip + self.server_port = server_port + self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._key = key + + def update(self, tokens, reward): + """ + Update the controller according to latest tokens and reward. + Args: + tokens(list): The tokens generated in last step. + reward(float): The reward of tokens. + """ + socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_client.connect((self.server_ip, self.server_port)) + tokens = ",".join([str(token) for token in tokens]) + socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) + .encode()) + tokens = socket_client.recv(1024).decode() + tokens = [int(token) for token in tokens.strip("\n").split(",")] + return tokens + + def next_tokens(self): + """ + Get next tokens. + """ + socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_client.connect((self.server_ip, self.server_port)) + socket_client.send("next_tokens".encode()) + tokens = socket_client.recv(1024).decode() + tokens = [int(token) for token in tokens.strip("\n").split(",")] + return tokens diff --git a/paddleslim/prune/controller_server.py b/paddleslim/prune/controller_server.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc978444656d2650904eedfd37453b6b5e22207 --- /dev/null +++ b/paddleslim/prune/controller_server.py @@ -0,0 +1,128 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import socket +from ..common import get_logger +from threading import Thread +from .lock import lock, unlock + +__all__ = ['ControllerServer'] + +_logger = get_logger(__name__, level=logging.INFO) + + +class ControllerServer(object): + """ + The controller wrapper with a socket server to handle the request of search agent. + """ + + def __init__(self, + controller=None, + address=('', 0), + max_client_num=100, + search_steps=None, + key=None): + """ + Args: + controller(slim.searcher.Controller): The controller used to generate tokens. + address(tuple): The address of current server binding with format (ip, port). Default: ('', 0). + which means setting ip automatically + max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100. + search_steps(int): The total steps of searching. None means never stopping. Default: None + """ + self._controller = controller + self._address = address + self._max_client_num = max_client_num + self._search_steps = search_steps + self._closed = False + self._port = address[1] + self._ip = address[0] + self._key = key + self._socket_file = "./controller_server.socket" + + def start(self): + open(self._socket_file, 'a').close() + socket_file = open(self._socket_file, 'r+') + lock(socket_file) + tid = socket_file.readline() + if tid == '': + _logger.info("start controller server...") + tid = self._start() + socket_file.write("tid: {}\nip: {}\nport: {}\n".format( + tid, self._ip, self._port)) + _logger.info("started controller server...") + unlock(socket_file) + socket_file.close() + + def _start(self): + self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket_server.bind(self._address) + self._socket_server.listen(self._max_client_num) + self._port = self._socket_server.getsockname()[1] + self._ip = self._socket_server.getsockname()[0] + _logger.info("ControllerServer - listen on: [{}:{}]".format( + self._ip, self._port)) + thread = Thread(target=self.run) + thread.start() + return str(thread) + + def close(self): + """Close the server.""" + self._closed = True + os.remove(self._socket_file) + _logger.info("server closed!") + + def port(self): + """Get the port.""" + return self._port + + def ip(self): + """Get the ip.""" + return self._ip + + def run(self): + _logger.info("Controller Server run...") + try: + while ((self._search_steps is None) or + (self._controller._iter < + (self._search_steps))) and not self._closed: + conn, addr = self._socket_server.accept() + message = conn.recv(1024).decode() + if message.strip("\n") == "next_tokens": + tokens = self._controller.next_tokens() + tokens = ",".join([str(token) for token in tokens]) + conn.send(tokens.encode()) + else: + _logger.debug("recv message from {}: [{}]".format(addr, + message)) + messages = message.strip('\n').split("\t") + if (len(messages) < 3) or (messages[0] != self._key): + _logger.debug("recv noise from {}: [{}]".format( + addr, message)) + continue + tokens = messages[1] + reward = messages[2] + tokens = [int(token) for token in tokens.split(",")] + self._controller.update(tokens, float(reward)) + tokens = self._controller.next_tokens() + tokens = ",".join([str(token) for token in tokens]) + conn.send(tokens.encode()) + _logger.debug("send message to {}: [{}]".format(addr, + tokens)) + conn.close() + finally: + self._socket_server.close() + self.close() diff --git a/paddleslim/prune/lock.py b/paddleslim/prune/lock.py new file mode 100644 index 0000000000000000000000000000000000000000..5edcd317304f941c2e7c15ad56e95525dea85398 --- /dev/null +++ b/paddleslim/prune/lock.py @@ -0,0 +1,36 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +__All__ = ['lock', 'unlock'] +if os.name == 'nt': + + def lock(file): + raise NotImplementedError('Windows is not supported.') + + def unlock(file): + raise NotImplementedError('Windows is not supported.') + +elif os.name == 'posix': + from fcntl import flock, LOCK_EX, LOCK_UN + + def lock(file): + """Lock the file in local file system.""" + flock(file.fileno(), LOCK_EX) + + def unlock(file): + """Unlock the file in local file system.""" + flock(file.fileno(), LOCK_UN) +else: + raise RuntimeError("File Locker only support NT and Posix platforms!") diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 30341f63407aa1b0cc52ec5b43eadead27aec2ab..c7cc9c9e814789d1863251017ecdb19beb41ae42 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -12,11 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import numpy as np import paddle.fluid as fluid -from core import VarWrapper, OpWrapper, GraphWrapper +import copy +from ..core import VarWrapper, OpWrapper, GraphWrapper +from ..common import get_logger -__all__ = ["prune"] +__all__ = ["Pruner"] + +_logger = get_logger(__name__, level=logging.INFO) class Pruner(): @@ -64,10 +69,14 @@ class Pruner(): params, ratios, place, - lazy=False, - only_graph=False, - param_backup=None, - param_shape_backup=None) + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + for op in graph.ops(): + if op.type() == 'depthwise_conv2d' or op.type( + ) == 'depthwise_conv2d_grad': + op.set_attr('groups', op.inputs('Filter')[0].shape()[0]) return graph.program def _prune_filters_by_ratio(self, @@ -93,27 +102,49 @@ class Pruner(): """ if params[0].name() in self.pruned_list[0]: return - param_t = scope.find_var(params[0].name()).get_tensor() - pruned_idx = self._cal_pruned_idx( - params[0].name(), np.array(param_t), ratio, axis=0) - for param in params: - assert isinstance(param, VarWrapper) - param_t = scope.find_var(param.name()).get_tensor() - if param_backup is not None and (param.name() not in param_backup): - param_backup[param.name()] = copy.deepcopy(np.array(param_t)) - pruned_param = self._prune_tensor( - np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy) - if not only_graph: + + if only_graph: + pruned_num = int(round(params[0].shape()[0] * ratio)) + for param in params: + ori_shape = param.shape() + if param_backup is not None and ( + param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy(ori_shape) + new_shape = list(ori_shape) + new_shape[0] -= pruned_num + param.set_shape(new_shape) + _logger.info("prune [{}] from {} to {}".format(param.name( + ), ori_shape, new_shape)) + self.pruned_list[0].append(param.name()) + return range(pruned_num) + + else: + + param_t = scope.find_var(params[0].name()).get_tensor() + pruned_idx = self._cal_pruned_idx( + params[0].name(), np.array(param_t), ratio, axis=0) + for param in params: + assert isinstance(param, VarWrapper) + param_t = scope.find_var(param.name()).get_tensor() + if param_backup is not None and ( + param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy( + np.array(param_t)) + pruned_param = self._prune_tensor( + np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy) param_t.set(pruned_param, place) - ori_shape = param.shape() - if param_shape_backup is not None and ( - param.name() not in param_shape_backup): - param_shape_backup[param.name()] = copy.deepcopy(param.shape()) - new_shape = list(param.shape()) - new_shape[0] = pruned_param.shape[0] - param.set_shape(new_shape) - self.pruned_list[0].append(param.name()) - return pruned_idx + ori_shape = param.shape() + if param_shape_backup is not None and ( + param.name() not in param_shape_backup): + param_shape_backup[param.name()] = copy.deepcopy( + param.shape()) + new_shape = list(param.shape()) + new_shape[0] = pruned_param.shape[0] + param.set_shape(new_shape) + _logger.info("prune [{}] from {} to {}".format(param.name( + ), ori_shape, new_shape)) + self.pruned_list[0].append(param.name()) + return pruned_idx def _prune_parameter_by_idx(self, scope, @@ -140,24 +171,44 @@ class Pruner(): """ if params[0].name() in self.pruned_list[pruned_axis]: return - for param in params: - assert isinstance(param, VarWrapper) - param_t = scope.find_var(param.name()).get_tensor() - if param_backup is not None and (param.name() not in param_backup): - param_backup[param.name()] = copy.deepcopy(np.array(param_t)) - pruned_param = self._prune_tensor( - np.array(param_t), pruned_idx, pruned_axis, lazy=lazy) - if not only_graph: + + if only_graph: + pruned_num = len(pruned_idx) + for param in params: + ori_shape = param.shape() + if param_backup is not None and ( + param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy(ori_shape) + new_shape = list(ori_shape) + new_shape[pruned_axis] -= pruned_num + param.set_shape(new_shape) + _logger.info("prune [{}] from {} to {}".format(param.name( + ), ori_shape, new_shape)) + self.pruned_list[pruned_axis].append(param.name()) + + else: + for param in params: + assert isinstance(param, VarWrapper) + param_t = scope.find_var(param.name()).get_tensor() + if param_backup is not None and ( + param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy( + np.array(param_t)) + pruned_param = self._prune_tensor( + np.array(param_t), pruned_idx, pruned_axis, lazy=lazy) param_t.set(pruned_param, place) - ori_shape = param.shape() + ori_shape = param.shape() - if param_shape_backup is not None and ( - param.name() not in param_shape_backup): - param_shape_backup[param.name()] = copy.deepcopy(param.shape()) - new_shape = list(param.shape()) - new_shape[pruned_axis] = pruned_param.shape[pruned_axis] - param.set_shape(new_shape) - self.pruned_list[pruned_axis].append(param.name()) + if param_shape_backup is not None and ( + param.name() not in param_shape_backup): + param_shape_backup[param.name()] = copy.deepcopy( + param.shape()) + new_shape = list(param.shape()) + new_shape[pruned_axis] = pruned_param.shape[pruned_axis] + param.set_shape(new_shape) + _logger.info("prune [{}] from {} to {}".format(param.name( + ), ori_shape, new_shape)) + self.pruned_list[pruned_axis].append(param.name()) def _forward_search_related_op(self, graph, param): """ @@ -487,14 +538,16 @@ class Pruner(): visited.append(op.idx()) while len(stack) > 0: top_op = stack.pop() - for parent in graph.pre_ops(top_op): - if parent.idx() not in visited and (not parent.is_bwd_op()): - if ((parent.type() == 'conv2d') or - (parent.type() == 'fc')): - brothers.append(parent) - else: - stack.append(parent) - visited.append(parent.idx()) + if top_op.type().startswith("elementwise_"): + for parent in graph.pre_ops(top_op): + if parent.idx() not in visited and ( + not parent.is_bwd_op()): + if ((parent.type() == 'conv2d') or + (parent.type() == 'fc')): + brothers.append(parent) + else: + stack.append(parent) + visited.append(parent.idx()) for child in graph.next_ops(top_op): if (child.type() != 'conv2d') and (child.type() != 'fc') and ( diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py old mode 100644 new mode 100755 index 0db22772d712951ed895f2d2e897142d6ce3c377..8ea9fbe32ee3f8617d9f00a1ce097b715957163e --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -20,11 +20,19 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass from paddle.fluid.contrib.slim.quantization import TransformForMobilePass +from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid import core -WEIGHT_QUANTIZATION_TYPES=['abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'] -ACTIVATION_QUANTIZATION_TYPES=['abs_max','range_abs_max', 'moving_average_abs_max'] +WEIGHT_QUANTIZATION_TYPES = [ + 'abs_max', 'channel_wise_abs_max', 'range_abs_max', + 'moving_average_abs_max' +] +ACTIVATION_QUANTIZATION_TYPES = [ + 'abs_max', 'range_abs_max', 'moving_average_abs_max' +] VALID_DTYPES = ['int8'] +TRANSFORM_PASS_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul'] +QUANT_DEQUANT_PASS_OP_TYPES = ['elementwise_add', 'pool2d'] _quant_config_default = { # weight quantize type, default is 'abs_max' @@ -38,7 +46,8 @@ _quant_config_default = { # ops of name_scope in not_quant_pattern list, will not be quantized 'not_quant_pattern': ['skip_quant'], # ops of type in quantize_op_types, will be quantized - 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], + 'quantize_op_types': + ['conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'], # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' 'dtype': 'int8', # window size for 'range_abs_max' quantization. defaulf is 10000 @@ -88,6 +97,12 @@ def _parse_configs(user_config): assert isinstance(configs['quantize_op_types'], list), \ "quantize_op_types must be a list" + for op_type in configs['quantize_op_types']: + assert (op_type in QUANT_DEQUANT_PASS_OP_TYPES) or ( + op_type in TRANSFORM_PASS_OP_TYPES), "{} is not support, \ + now support op types are {}".format( + op_type, TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES) + assert isinstance(configs['dtype'], str), \ "dtype must be a str." @@ -132,19 +147,37 @@ def quant_aware(program, place, config, scope=None, for_test=False): config = _parse_configs(config) main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) - transform_pass = QuantizationTransformPass( - scope=scope, - place=place, - weight_bits=config['weight_bits'], - activation_bits=config['activation_bits'], - activation_quantize_type=config['activation_quantize_type'], - weight_quantize_type=config['weight_quantize_type'], - window_size=config['window_size'], - moving_rate=config['moving_rate'], - quantizable_op_type=config['quantize_op_types'], - skip_pattern=config['not_quant_pattern']) - - transform_pass.apply(main_graph) + transform_pass_ops = [] + quant_dequant_ops = [] + for op_type in config['quantize_op_types']: + if op_type in TRANSFORM_PASS_OP_TYPES: + transform_pass_ops.append(op_type) + elif op_type in QUANT_DEQUANT_PASS_OP_TYPES: + quant_dequant_ops.append(op_type) + if len(transform_pass_ops) > 0: + transform_pass = QuantizationTransformPass( + scope=scope, + place=place, + weight_bits=config['weight_bits'], + activation_bits=config['activation_bits'], + activation_quantize_type=config['activation_quantize_type'], + weight_quantize_type=config['weight_quantize_type'], + window_size=config['window_size'], + moving_rate=config['moving_rate'], + quantizable_op_type=transform_pass_ops, + skip_pattern=config['not_quant_pattern']) + + transform_pass.apply(main_graph) + + if len(quant_dequant_ops) > 0: + quant_dequant_pass = AddQuantDequantPass( + scope=scope, + place=place, + moving_rate=config['moving_rate'], + quant_bits=config['activation_bits'], + skip_pattern=config['not_quant_pattern'], + quantizable_op_type=quant_dequant_ops) + quant_dequant_pass.apply(main_graph) if for_test: quant_program = main_graph.to_program() @@ -168,7 +201,7 @@ def quant_post(program, place, config, scope=None): pass -def convert(program, scope, place, config, save_int8=False): +def convert(program, place, config, scope=None, save_int8=False): """ add quantization ops in program. the program returned is not trainable. Args: @@ -183,7 +216,7 @@ def convert(program, scope, place, config, save_int8=False): fluid.Program: freezed int8 program which can be used for inference. if save_int8 is False, this value is None. """ - + scope = fluid.global_scope() if not scope else scope test_graph = IrGraph(core.Graph(program.desc), for_test=True) # Freeze the graph after training by adjusting the quantize diff --git a/paddleslim/tests/test_nas_search_space.py b/paddleslim/tests/test_nas_search_space.py deleted file mode 100644 index c2f2af5b8b9fac38a6d8f1273e853aefc6983bff..0000000000000000000000000000000000000000 --- a/paddleslim/tests/test_nas_search_space.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -sys.path.append('..') -import unittest -import paddle.fluid as fluid -from nas.search_space_factory import SearchSpaceFactory - -class TestSearchSpace(unittest.TestCase): - def test_searchspace(self): - # if output_size is 1, the model will add fc layer in the end. - config = {'input_size': 224, 'output_size': 7, 'block_num': 5} - space = SearchSpaceFactory() - - my_space = space.get_search_space('MobileNetV2Space', config) - model_arch = my_space.token2arch() - - train_prog = fluid.Program() - startup_prog = fluid.Program() - with fluid.program_guard(train_prog, startup_prog): - input_size= config['input_size'] - model_input = fluid.layers.data(name='model_in', shape=[1, 3, input_size, input_size], dtype='float32', append_batch_size=False) - predict = model_arch(model_input) - self.assertTrue(predict.shape[2] == config['output_size']) - - - #for op in train_prog.global_block().ops: - # print(op.type) - -if __name__ == '__main__': - unittest.main() diff --git a/setup.py b/setup.py index 86421878ed5493c3ab5f8b446f6b62a3b0135975..5ff0a92fdd48668c9447d8625f122d93a168444c 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,8 @@ packages = [ 'paddleslim.nas', 'paddleslim.analysis', 'paddleslim.quant', + 'paddleslim.core', + 'paddleslim.common', ] setup( diff --git a/paddleslim/tests/layers.py b/tests/layers.py similarity index 100% rename from paddleslim/tests/layers.py rename to tests/layers.py diff --git a/tests/test_auto_prune.py b/tests/test_auto_prune.py new file mode 100644 index 0000000000000000000000000000000000000000..c9cdc72c33ce683f2dc3ecbfdf406740ef6e69a8 --- /dev/null +++ b/tests/test_auto_prune.py @@ -0,0 +1,84 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle.fluid as fluid +from paddleslim.prune import AutoPruner +from paddleslim.analysis import flops +from layers import conv_bn_layer + + +class TestPrune(unittest.TestCase): + def test_prune(self): + main_program = fluid.Program() + startup_program = fluid.Program() + # X X O X O + # conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6 + # | ^ | ^ + # |____________| |____________________| + # + # X: prune output channels + # O: prune input channels + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[None, 3, 16, 16]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + conv6 = conv_bn_layer(conv5, 8, 3, "conv6") + + shapes = {} + for param in main_program.global_block().all_parameters(): + shapes[param.name] = param.shape + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.Scope() + exe.run(startup_program, scope=scope) + + pruned_flops = 0.5 + pruner = AutoPruner( + main_program, + scope, + place, + params=["conv4_weights"], + init_ratios=[0.5], + pruned_flops=0.5, + pruned_latency=None, + server_addr=("", 0), + init_temperature=100, + reduce_rate=0.85, + max_try_number=300, + max_client_num=10, + search_steps=2, + max_ratios=[0.9], + min_ratios=[0], + key="auto_pruner") + + base_flops = flops(main_program) + program = pruner.prune(main_program) + self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops)) + pruner.reward(1) + + program = pruner.prune(main_program) + self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops)) + pruner.reward(1) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_searchspace.py b/tests/test_nas_search_space.py similarity index 80% rename from tests/test_searchspace.py rename to tests/test_nas_search_space.py index c751bdc3d31d1822051436b71035b64ee6963fac..ad373cf146fecb1cf9ea2b3681eaf73e9e65dd3d 100644 --- a/tests/test_searchspace.py +++ b/tests/test_nas_search_space.py @@ -11,15 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import sys -sys.path.append("../") +sys.path.append('..') import unittest import paddle.fluid as fluid -from paddleslim.nas import SearchSpaceFactory +from nas.search_space_factory import SearchSpaceFactory -class TestSearchSpaceFactory(unittest.TestCase): - def test_factory(self): +class TestSearchSpace(unittest.TestCase): + def test_searchspace(self): # if output_size is 1, the model will add fc layer in the end. config = {'input_size': 224, 'output_size': 7, 'block_num': 5} space = SearchSpaceFactory() @@ -39,23 +40,30 @@ class TestSearchSpaceFactory(unittest.TestCase): predict = model_arch[0](model_input) self.assertTrue(predict.shape[2] == config['output_size']) + class TestMultiSearchSpace(unittest.TestCase): space = SearchSpaceFactory() - + config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5} config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2} - my_space = space.get_search_space([('MobileNetV2Space', config0), ('ResNetSpace', config1)]) + my_space = space.get_search_space( + [('MobileNetV2Space', config0), ('ResNetSpace', config1)]) model_archs = my_space.token2arch() - + train_prog = fluid.Program() startup_prog = fluid.Program() with fluid.program_guard(train_prog, startup_prog): - input_size= config0['input_size'] - model_input = fluid.layers.data(name='model_in', shape=[1, 3, input_size, input_size], dtype='float32', append_batch_size=False) + input_size = config0['input_size'] + model_input = fluid.layers.data( + name='model_in', + shape=[1, 3, input_size, input_size], + dtype='float32', + append_batch_size=False) for model_arch in model_archs: predict = model_arch(model_input) model_input = predict print(predict) + if __name__ == '__main__': unittest.main() diff --git a/paddleslim/tests/test_prune.py b/tests/test_prune.py similarity index 100% rename from paddleslim/tests/test_prune.py rename to tests/test_prune.py diff --git a/tests/test_sa_nas.py b/tests/test_sa_nas.py new file mode 100644 index 0000000000000000000000000000000000000000..c1bcd08dadf87e24f31af1a525f67aa9a92bd26e --- /dev/null +++ b/tests/test_sa_nas.py @@ -0,0 +1,59 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle.fluid as fluid +from paddleslim.nas import SANAS +from paddleslim.nas import SearchSpaceFactory +from paddleslim.analysis import flops + + +class TestSANAS(unittest.TestCase): + def test_nas(self): + + factory = SearchSpaceFactory() + config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5} + config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2} + configs = [('MobileNetV2Space', config0), ('ResNetSpace', config1)] + + space = factory.get_search_space([('MobileNetV2Space', config0)]) + origin_arch = space.token2arch()[0] + + main_program = fluid.Program() + s_program = fluid.Program() + with fluid.program_guard(main_program, s_program): + input = fluid.data( + name="input", shape=[None, 3, 224, 224], dtype="float32") + origin_arch(input) + base_flops = flops(main_program) + + search_steps = 3 + sa_nas = SANAS( + configs, max_flops=base_flops, search_steps=search_steps) + + for i in range(search_steps): + archs = sa_nas.next_archs() + main_program = fluid.Program() + s_program = fluid.Program() + with fluid.program_guard(main_program, s_program): + input = fluid.data( + name="input", shape=[None, 3, 224, 224], dtype="float32") + archs[0](input) + sa_nas.reward(1) + self.assertTrue(flops(main_program) < base_flops) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_sensitivity.py b/tests/test_sensitivity.py new file mode 100644 index 0000000000000000000000000000000000000000..e2cfa01d889db2891fd7507b2d4d9aec018a1163 --- /dev/null +++ b/tests/test_sensitivity.py @@ -0,0 +1,69 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import numpy +import paddle +import paddle.fluid as fluid +from paddleslim.analysis import sensitivity +from layers import conv_bn_layer + + +class TestSensitivity(unittest.TestCase): + def test_sensitivity(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[None, 1, 28, 28]) + label = fluid.data(name="label", shape=[None, 1], dtype="int64") + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + conv6 = conv_bn_layer(conv5, 8, 3, "conv6") + out = fluid.layers.fc(conv6, size=10, act='softmax') + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + eval_program = main_program.clone(for_test=True) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(startup_program) + + val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) + + def eval_func(program, scope): + feeder = fluid.DataFeeder( + feed_list=['image', 'label'], place=place, program=program) + acc_set = [] + for data in val_reader(): + acc_np = exe.run(program=program, + scope=scope, + feed=feeder.feed(data), + fetch_list=[acc_top1]) + acc_set.append(float(acc_np[0])) + acc_val_mean = numpy.array(acc_set).mean() + print("acc_val_mean: {}".format(acc_val_mean)) + return acc_val_mean + + sensitivity(eval_program, + fluid.global_scope(), place, ["conv4_weights"], eval_func, + "./sensitivities_file") + + +if __name__ == '__main__': + unittest.main()