提交 97f1776b 编写于 作者: W wanghaoshuang

Add light-nas

上级 a0d17e44
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import controller
from controller import *
import sa_controller
from sa_controller import *
import log_helper
from log_helper import *
import controller_server
from controller_server import *
import controller_client
from controller_client import *
__all__ = []
__all__ += controller.__all__
__all__ += sa_controller.__all__
__all__ += controller_server.__all__
__all__ += controller_client.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The controller used to search hyperparameters or neural architecture"""
import copy
import math
import numpy as np
__all__ = ['EvolutionaryController']
class EvolutionaryController(object):
"""Abstract controller for all evolutionary searching method.
"""
def __init__(self, *args, **kwargs):
pass
def update(self, tokens, reward):
"""Update the status of controller according current tokens and reward.
Args:
tokens(list<int>): A solution of searching task.
reward(list<int>): The reward of tokens.
"""
raise NotImplementedError('Abstract method.')
def reset(self, range_table, constrain_func=None):
"""Reset the controller.
Args:
range_table(list<int>): It is used to define the searching space of controller.
The tokens[i] generated by controller should be in [0, range_table[i]).
constrain_func(function): It is used to check whether tokens meet the constraint.
None means there is no constraint. Default: None.
"""
raise NotImplementedError('Abstract method.')
def next_tokens(self):
"""Generate new tokens.
"""
raise NotImplementedError('Abstract method.')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import socket
from log_helper import get_logger
__all__ = ['ControllerClient']
_logger = get_logger(__name__, level=logging.INFO)
class ControllerClient(object):
"""
Controller client.
"""
def __init__(self, server_ip=None, server_port=None, key=None):
"""
Args:
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
"""
self.server_ip = server_ip
self.server_port = server_port
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key
def update(self, tokens, reward):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward)
.encode())
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens
def next_tokens(self):
"""
Get next tokens.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
socket_client.send("next_tokens".encode())
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import socket
from .log_helper import get_logger
from threading import Thread
from .lock import lock, unlock
__all__ = ['ControllerServer']
_logger = get_logger(__name__, level=logging.INFO)
class ControllerServer(object):
"""
The controller wrapper with a socket server to handle the request of search agent.
"""
def __init__(self,
controller=None,
address=('', 0),
max_client_num=100,
search_steps=None,
key=None):
"""
Args:
controller(slim.searcher.Controller): The controller used to generate tokens.
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
which means setting ip automatically
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
search_steps(int): The total steps of searching. None means never stopping. Default: None
"""
self._controller = controller
self._address = address
self._max_client_num = max_client_num
self._search_steps = search_steps
self._closed = False
self._port = address[1]
self._ip = address[0]
self._key = key
self._socket_file = "./controller_server.socket"
def start(self):
open(self._socket_file, 'a').close()
socket_file = open(self._socket_file, 'r+')
lock(socket_file)
tid = socket_file.readline()
if tid == '':
_logger.info("start controller server...")
tid = self._start()
socket_file.write("tid: {}\nip: {}\nport: {}\n".format(
tid, self._ip, self._port))
_logger.info("started controller server...")
unlock(socket_file)
socket_file.close()
def _start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket_server.bind(self._address)
self._socket_server.listen(self._max_client_num)
self._port = self._socket_server.getsockname()[1]
self._ip = self._socket_server.getsockname()[0]
_logger.info("ControllerServer - listen on: [{}:{}]".format(
self._ip, self._port))
thread = Thread(target=self.run)
thread.start()
return str(thread)
def close(self):
"""Close the server."""
self._closed = True
os.remove(self._socket_file)
_logger.info("server closed!")
def port(self):
"""Get the port."""
return self._port
def ip(self):
"""Get the ip."""
return self._ip
def run(self):
_logger.info("Controller Server run...")
try:
while ((self._search_steps is None) or
(self._controller._iter <
(self._search_steps))) and not self._closed:
conn, addr = self._socket_server.accept()
message = conn.recv(1024).decode()
if message.strip("\n") == "next_tokens":
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
else:
_logger.debug("recv message from {}: [{}]".format(addr,
message))
messages = message.strip('\n').split("\t")
if (len(messages) < 3) or (messages[0] != self._key):
_logger.debug("recv noise from {}: [{}]".format(
addr, message))
continue
tokens = messages[1]
reward = messages[2]
tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward))
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
_logger.debug("send message to {}: [{}]".format(addr,
tokens))
conn.close()
finally:
self._socket_server.close()
self.close()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import logging
__all__ = ['get_logger']
def get_logger(name, level, fmt=None):
"""
Get logger from logging with given name, level and format without
setting logging basicConfig. For setting basicConfig in paddle
will disable basicConfig setting after import paddle.
Args:
name (str): The logger name.
level (logging.LEVEL): The base level of the logger
fmt (str): Format of logger output
Returns:
logging.Logger: logging logger with given setttings
Examples:
.. code-block:: python
logger = log_helper.get_logger(__name__, logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s')
"""
logger = logging.getLogger(name)
logger.setLevel(level)
handler = logging.StreamHandler()
if fmt:
formatter = logging.Formatter(fmt=fmt)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The controller used to search hyperparameters or neural architecture"""
import copy
import math
import logging
import numpy as np
from .controller import EvolutionaryController
from log_helper import get_logger
__all__ = ["SAController"]
_logger = get_logger(__name__, level=logging.INFO)
class SAController(EvolutionaryController):
"""Simulated annealing controller."""
def __init__(self,
range_table=None,
reduce_rate=0.85,
init_temperature=1024,
max_iter_number=300,
init_tokens=None,
constrain_func=None):
"""Initialize.
Args:
range_table(list<int>): Range table.
reduce_rate(float): The decay rate of temperature.
init_temperature(float): Init temperature.
max_iter_number(int): max iteration number.
init_tokens(list<int>): The initial tokens.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
"""
super(SAController, self).__init__()
self._range_table = range_table
assert isinstance(self._range_table, tuple) and (
len(self._range_table) == 2)
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_iter_number = max_iter_number
self._reward = -1
self._tokens = init_tokens
self._constrain_func = constrain_func
self._max_reward = -1
self._best_tokens = None
self._iter = 0
def __getstate__(self):
d = {}
for key in self.__dict__:
if key != "_constrain_func":
d[key] = self.__dict__[key]
return d
def update(self, tokens, reward):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
"""
self._iter += 1
temperature = self._init_temperature * self._reduce_rate**self._iter
if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)):
self._reward = reward
self._tokens = tokens
if reward > self._max_reward:
self._max_reward = reward
self._best_tokens = tokens
_logger.info(
"Controller - iter: {}; current_reward: {}; current tokens: {}".
format(self._iter, self._reward, self._tokens))
def next_tokens(self, control_token=None):
"""
Get next tokens.
"""
if control_token:
tokens = control_token[:]
else:
tokens = self._tokens
new_tokens = tokens[:]
index = int(len(self._range_table[0]) * np.random.random())
new_tokens[index] = np.random.randint(self._range_table[0][index],
self._range_table[1][index] + 1)
_logger.debug("change index[{}] from {} to {}".format(index, tokens[
index], new_tokens[index]))
if self._constrain_func is None:
return new_tokens
for _ in range(self._max_iter_number):
if not self._constrain_func(new_tokens):
index = int(len(self._range_table[0]) * np.random.random())
new_tokens = tokens[:]
new_tokens[index] = np.random.randint(
self._range_table[0][index],
self._range_table[1][index] + 1)
else:
break
return new_tokens
......@@ -14,6 +14,9 @@
import search_space
from search_space import *
import sa_nas
from sa_nas import *
__all__ = []
__all__ += search_space.__all__
__all__ += sa_nas.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
import logging
import numpy as np
import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController
from ..common import get_logger
from ..analysis import flops
from ..common import ControllerServer
from ..common import ControllerClient
__all__ = ["SANAS"]
_logger = get_logger(__name__, level=logging.INFO)
class SANAS(object):
def __init__(self,
configs,
flops=None,
latency=None,
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_number=300,
max_client_num=10,
search_steps=300,
key="sa_nas",
is_server=True):
"""
Search a group of ratios used to prune program.
Args:
configs(list<tuple>): A list of search space configuration with format (key, input_size, output_size, block_num).
`key` is the name of search space with data type str. `input_size` and `output_size` are
input size and output size of searched sub-network. `block_num` is the number of blocks in searched network.
max_flops(int): The max flops of searched network. None means no constrains. Default: None.
max_latency(float): The max latency of searched network. None means no constrains. Default: None.
server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching.
key(str): Identity used in communication between controller server and clients.
is_server(bool): Whether current host is controller server. Default: True.
"""
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_try_number = max_try_number
self._is_server = is_server
self._max_flops = max_flops
self._max_latency = max_latency
self._configs = configs
factory = SearchSpaceFactory()
self._search_space = factory.get_search_space(configs)
init_tokens = self._search_space.init_tokens()
range_table = self._search_space.range_table()
controller = SAController(range_table, self._reduce_rate,
self._init_temperature, self._max_try_number,
init_tokens, self._constrain_func)
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
server_ip = self._get_host_ip()
self._controller_server = ControllerServer(
controller=controller,
address=(server_ip, server_port),
max_client_num=max_client_num,
search_steps=search_steps,
key=key)
# create controller server
if self._is_server:
self._controller_server.start()
self._controller_client = ControllerClient(
self._controller_server.ip(),
self._controller_server.port(),
key=key)
self._iter = 0
def _get_host_ip(self):
return socket.gethostbyname(socket.gethostname())
def _constrain_func(self, tokens):
if (self._max_flops is None) and (self._max_latency is None):
return True
archs = self._search_space.token2arch(tokens)
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
i = 0
for config, arch in zip(self._configs, archs):
input = fluid.data(
name="data_{}".format(i),
shape=[None, 3, input_size, input_size],
dtype="float32")
output = arch(input)
i += 1
return flops(main_program) < self._max_flops
def next_archs(self):
"""
Get next network architectures.
Returns:
list<function>: A list of functions that define networks.
"""
self._current_tokens = self._controller_client.next_tokens()
archs = self._search_space.token2arch(self._current_tokens)
return archs
def reward(self, score):
"""
Return reward of current searched network.
Args:
score(float): The score of current searched network.
"""
self._controller_client.update(self._current_tokens, score)
self._iter += 1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.nas import SANAS
class TestSANAS(unittest.TestCase):
def test_nas(self):
factory = SearchSpaceFactory()
config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5}
config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2}
configs = [('MobileNetV2Space', config0), ('ResNetSpace', config1)]
space = factory.get_search_space([('MobileNetV2Space', config0)])
origin_arch = space.token2arch()
main_program = fluid.Program()
s_program = fluid.Program()
with fluid.program_guard(main_program, s_program):
input = fluid.data(
name="input", shape=[3, 224, 224], dtype="float32")
origin_arch(input)
base_flops = flops(main_program)
serch_steps = 3
sa_nas = SANAS(
configs, max_flops=base_flops, search_steps=search_steps)
for i in range(serch_steps):
archs = sa_nas.next_archs()
main_program = fluid.Program()
s_program = fluid.Program()
with fluid.program_guard(main_program, s_program):
input = fluid.data(
name="input", shape=[3, 224, 224], dtype="float32")
archs[0](input)
sa_nas.reward(1)
self.assertTrue(flops(main_program) < base_flops)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册