未验证 提交 87d75b98 编写于 作者: C ceci3 提交者: GitHub

support set init tokens by user for sanas (#31)


* add set init_token for nas
上级 fad49014
...@@ -26,17 +26,22 @@ class ControllerClient(object): ...@@ -26,17 +26,22 @@ class ControllerClient(object):
Controller client. Controller client.
""" """
def __init__(self, server_ip=None, server_port=None, key=None): def __init__(self,
server_ip=None,
server_port=None,
key=None,
client_name=None):
""" """
Args: Args:
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None. server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0. server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
key(str): The key used to identify legal agent for controller server. Default: "light-nas" key(str): The key used to identify legal agent for controller server. Default: "light-nas"
client_name(str): Current client name, random generate for counting client number. Default: None.
""" """
self.server_ip = server_ip self.server_ip = server_ip
self.server_port = server_port self.server_port = server_port
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key self._key = key
self._client_name = client_name
def update(self, tokens, reward, iter): def update(self, tokens, reward, iter):
""" """
...@@ -48,8 +53,8 @@ class ControllerClient(object): ...@@ -48,8 +53,8 @@ class ControllerClient(object):
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port)) socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens]) tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}\t{}".format(self._key, tokens, reward, socket_client.send("{}\t{}\t{}\t{}\t{}".format(
iter).encode()) self._key, tokens, reward, iter, self._client_name).encode())
response = socket_client.recv(1024).decode() response = socket_client.recv(1024).decode()
if response.strip('\n').split("\t") == "ok": if response.strip('\n').split("\t") == "ok":
return True return True
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import logging import logging
import socket import socket
import time
from .log_helper import get_logger from .log_helper import get_logger
from threading import Thread from threading import Thread
from .lock_utils import lock, unlock from .lock_utils import lock, unlock
...@@ -41,7 +42,8 @@ class ControllerServer(object): ...@@ -41,7 +42,8 @@ class ControllerServer(object):
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0). address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
which means setting ip automatically which means setting ip automatically
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100. max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
search_steps(int): The total steps of searching. None means never stopping. Default: None search_steps(int|None): The total steps of searching. None means never stopping. Default: None
key(str|None): Config information. Default: None.
""" """
self._controller = controller self._controller = controller
self._address = address self._address = address
...@@ -51,6 +53,9 @@ class ControllerServer(object): ...@@ -51,6 +53,9 @@ class ControllerServer(object):
self._port = address[1] self._port = address[1]
self._ip = address[0] self._ip = address[0]
self._key = key self._key = key
self._client_num = 0
self._client = dict()
self._compare_time = 172800 ### 48 hours
def start(self): def start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
...@@ -93,15 +98,43 @@ class ControllerServer(object): ...@@ -93,15 +98,43 @@ class ControllerServer(object):
_logger.debug("recv message from {}: [{}]".format(addr, _logger.debug("recv message from {}: [{}]".format(addr,
message)) message))
messages = message.strip('\n').split("\t") messages = message.strip('\n').split("\t")
if (len(messages) < 4) or (messages[0] != self._key): if (len(messages) < 5) or (messages[0] != self._key):
_logger.debug("recv noise from {}: [{}]".format( _logger.debug("recv noise from {}: [{}]".format(
addr, message)) addr, message))
continue continue
tokens = messages[1] tokens = messages[1]
reward = messages[2] reward = messages[2]
iter = messages[3] iter = messages[3]
client_name = messages[4]
one_step_time = -1
if client_name in self._client.keys():
current_time = time.time() - self._client[client_name]
if current_time > one_step_time:
one_step_time = current_time
self._compare_time = 2 * one_step_time
if client_name not in self._client.keys():
self._client[client_name] = time.time()
self._client_num += 1
self._client[client_name] = time.time()
for key_client in self._client.keys():
### if a client not request token in double train one tokens' time, we think this client was stoped.
if (time.time() - self._client[key_client]
) > self._compare_time and len(self._client.keys(
)) > 1:
self._client.pop(key_client)
self._client_num -= 1
_logger.info(
"client: {}, client_num: {}, compare_time: {}".format(
self._client, self._client_num,
self._compare_time))
tokens = [int(token) for token in tokens.split(",")] tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward), int(iter)) self._controller.update(tokens,
float(reward),
int(iter), int(self._client_num))
response = "ok" response = "ok"
conn.send(response.encode()) conn.send(response.encode())
_logger.debug("send message to {}: [{}]".format(addr, _logger.debug("send message to {}: [{}]".format(addr,
......
...@@ -34,7 +34,7 @@ class SAController(EvolutionaryController): ...@@ -34,7 +34,7 @@ class SAController(EvolutionaryController):
def __init__(self, def __init__(self,
range_table=None, range_table=None,
reduce_rate=0.85, reduce_rate=0.85,
init_temperature=1024, init_temperature=None,
max_try_times=300, max_try_times=300,
init_tokens=None, init_tokens=None,
reward=-1, reward=-1,
...@@ -68,12 +68,20 @@ class SAController(EvolutionaryController): ...@@ -68,12 +68,20 @@ class SAController(EvolutionaryController):
self._max_try_times = max_try_times self._max_try_times = max_try_times
self._reward = reward self._reward = reward
self._tokens = init_tokens self._tokens = init_tokens
if init_temperature == None:
if init_tokens == None:
self._init_temperature = 10.0
else:
self._init_temperature = 1.0
self._constrain_func = constrain_func self._constrain_func = constrain_func
self._max_reward = max_reward self._max_reward = max_reward
self._best_tokens = best_tokens self._best_tokens = best_tokens
self._iter = iters self._iter = iters
self._checkpoints = checkpoints self._checkpoints = checkpoints
self._searched = searched if searched != None else dict() self._searched = searched if searched != None else dict()
self._current_token = init_tokens
def __getstate__(self): def __getstate__(self):
d = {} d = {}
...@@ -92,9 +100,9 @@ class SAController(EvolutionaryController): ...@@ -92,9 +100,9 @@ class SAController(EvolutionaryController):
@property @property
def current_tokens(self): def current_tokens(self):
return self._tokens return self._current_tokens
def update(self, tokens, reward, iter): def update(self, tokens, reward, iter, client_num):
""" """
Update the controller according to latest tokens and reward. Update the controller according to latest tokens and reward.
Args: Args:
...@@ -105,7 +113,9 @@ class SAController(EvolutionaryController): ...@@ -105,7 +113,9 @@ class SAController(EvolutionaryController):
if iter > self._iter: if iter > self._iter:
self._iter = iter self._iter = iter
self._searched[str(tokens)] = reward self._searched[str(tokens)] = reward
temperature = self._init_temperature * self._reduce_rate**self._iter temperature = self._init_temperature * self._reduce_rate**(client_num *
self._iter)
self._current_tokens = tokens
if (reward > self._reward) or (np.random.random() <= math.exp( if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)): (reward - self._reward) / temperature)):
self._reward = reward self._reward = reward
...@@ -117,6 +127,9 @@ class SAController(EvolutionaryController): ...@@ -117,6 +127,9 @@ class SAController(EvolutionaryController):
"Controller - iter: {}; best_reward: {}, best tokens: {}, current_reward: {}; current tokens: {}". "Controller - iter: {}; best_reward: {}, best tokens: {}, current_reward: {}; current tokens: {}".
format(self._iter, self._max_reward, self._best_tokens, reward, format(self._iter, self._max_reward, self._best_tokens, reward,
tokens)) tokens))
_logger.debug(
'Controller - iter: {}, controller current tokens: {}, controller current reward: {}'.
format(self._iter, self._tokens, self._reward))
if self._checkpoints != None: if self._checkpoints != None:
self._save_checkpoint(self._checkpoints) self._save_checkpoint(self._checkpoints)
......
...@@ -18,6 +18,7 @@ import logging ...@@ -18,6 +18,7 @@ import logging
import numpy as np import numpy as np
import json import json
import hashlib import hashlib
import time
import paddle.fluid as fluid import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController from ..common import SAController
...@@ -37,12 +38,13 @@ class SANAS(object): ...@@ -37,12 +38,13 @@ class SANAS(object):
def __init__(self, def __init__(self,
configs, configs,
server_addr=("", 8881), server_addr=("", 8881),
init_temperature=100, init_temperature=None,
reduce_rate=0.85, reduce_rate=0.85,
search_steps=300, search_steps=300,
init_tokens=None,
save_checkpoint='nas_checkpoint', save_checkpoint='nas_checkpoint',
load_checkpoint=None, load_checkpoint=None,
is_server=False): is_server=True):
""" """
Search a group of ratios used to prune program. Search a group of ratios used to prune program.
Args: Args:
...@@ -50,9 +52,10 @@ class SANAS(object): ...@@ -50,9 +52,10 @@ class SANAS(object):
`key` is the name of search space with data type str. `input_size` and `output_size` are `key` is the name of search space with data type str. `input_size` and `output_size` are
input size and output size of searched sub-network. `block_num` is the number of blocks in searched network, `block_mask` is a list consists by 0 and 1, 0 means normal block, 1 means reduction block. input size and output size of searched sub-network. `block_num` is the number of blocks in searched network, `block_mask` is a list consists by 0 and 1, 0 means normal block, 1 means reduction block.
server_addr(tuple): A tuple of server ip and server port for controller server. server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy. init_temperature(float|None): The init temperature used in simulated annealing search strategy. Default: None.
reduce_rate(float): The decay rate used in simulated annealing search strategy. reduce_rate(float): The decay rate used in simulated annealing search strategy. Default: None.
search_steps(int): The steps of searching. search_steps(int): The steps of searching. Default: 300.
init_token(list): Init tokens user can set by yourself. Default: None.
save_checkpoint(string|None): The directory of checkpoint to save, if set to None, not save checkpoint. Default: 'nas_checkpoint'. save_checkpoint(string|None): The directory of checkpoint to save, if set to None, not save checkpoint. Default: 'nas_checkpoint'.
load_checkpoint(string|None): The directory of checkpoint to load, if set to None, not load checkpoint. Default: None. load_checkpoint(string|None): The directory of checkpoint to load, if set to None, not load checkpoint. Default: None.
is_server(bool): Whether current host is controller server. Default: True. is_server(bool): Whether current host is controller server. Default: True.
...@@ -64,7 +67,12 @@ class SANAS(object): ...@@ -64,7 +67,12 @@ class SANAS(object):
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._is_server = is_server self._is_server = is_server
self._configs = configs self._configs = configs
self._key = hashlib.md5(str(self._configs).encode("utf-8")).hexdigest() self._init_tokens = init_tokens
self._client_name = hashlib.md5(
str(time.time() + np.random.randint(1, 10000)).encode(
"utf-8")).hexdigest()
self._key = str(self._configs)
self._current_tokens = init_tokens
server_ip, server_port = server_addr server_ip, server_port = server_addr
if server_ip == None or server_ip == "": if server_ip == None or server_ip == "":
...@@ -75,7 +83,7 @@ class SANAS(object): ...@@ -75,7 +83,7 @@ class SANAS(object):
# create controller server # create controller server
if self._is_server: if self._is_server:
init_tokens = self._search_space.init_tokens() init_tokens = self._search_space.init_tokens(self._init_tokens)
range_table = self._search_space.range_table() range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table) range_table = (len(range_table) * [0], range_table)
_logger.info("range table: {}".format(range_table)) _logger.info("range table: {}".format(range_table))
...@@ -127,7 +135,10 @@ class SANAS(object): ...@@ -127,7 +135,10 @@ class SANAS(object):
server_port = self._controller_server.port() server_port = self._controller_server.port()
self._controller_client = ControllerClient( self._controller_client = ControllerClient(
server_ip, server_port, key=self._key) server_ip,
server_port,
key=self._key,
client_name=self._client_name)
if is_server and load_checkpoint != None: if is_server and load_checkpoint != None:
self._iter = scene['_iter'] self._iter = scene['_iter']
...@@ -164,6 +175,7 @@ class SANAS(object): ...@@ -164,6 +175,7 @@ class SANAS(object):
list<function>: A list of functions that define networks. list<function>: A list of functions that define networks.
""" """
self._current_tokens = self._controller_client.next_tokens() self._current_tokens = self._controller_client.next_tokens()
_logger.info("current tokens: {}".format(self._current_tokens))
archs = self._search_space.token2arch(self._current_tokens) archs = self._search_space.token2arch(self._current_tokens)
return archs return archs
......
...@@ -97,16 +97,19 @@ class CombineSearchSpace(object): ...@@ -97,16 +97,19 @@ class CombineSearchSpace(object):
space = cls(input_size, output_size, block_num, block_mask=block_mask) space = cls(input_size, output_size, block_num, block_mask=block_mask)
return space return space
def init_tokens(self): def init_tokens(self, tokens=None):
""" """
Combine init tokens. Combine init tokens.
""" """
tokens = [] if tokens is None:
self.single_token_num = [] tokens = []
for space in self.spaces: self.single_token_num = []
tokens.extend(space.init_tokens()) for space in self.spaces:
self.single_token_num.append(len(space.init_tokens())) tokens.extend(space.init_tokens())
return tokens self.single_token_num.append(len(space.init_tokens()))
return tokens
else:
return tokens
def range_table(self): def range_table(self):
""" """
......
...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr ...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase from .search_space_base import SearchSpaceBase
from .base_layer import conv_bn_layer from .base_layer import conv_bn_layer
from .search_space_registry import SEARCHSPACE from .search_space_registry import SEARCHSPACE
from .utils import compute_downsample_num, check_points from .utils import compute_downsample_num, check_points, get_random_tokens
__all__ = ["InceptionABlockSpace", "InceptionCBlockSpace"] __all__ = ["InceptionABlockSpace", "InceptionCBlockSpace"]
### TODO add asymmetric kernel of conv when paddle-lite support ### TODO add asymmetric kernel of conv when paddle-lite support
...@@ -58,10 +58,7 @@ class InceptionABlockSpace(SearchSpaceBase): ...@@ -58,10 +58,7 @@ class InceptionABlockSpace(SearchSpaceBase):
""" """
The initial token. The initial token.
""" """
if self.block_mask != None: return get_random_tokens(self.range_table)
return [0] * (len(self.block_mask) * 9)
else:
return [0] * (self.block_num * 9)
def range_table(self): def range_table(self):
""" """
...@@ -290,10 +287,7 @@ class InceptionCBlockSpace(SearchSpaceBase): ...@@ -290,10 +287,7 @@ class InceptionCBlockSpace(SearchSpaceBase):
""" """
The initial token. The initial token.
""" """
if self.block_mask != None: return get_random_tokens(self.range_table)
return [0] * (len(self.block_mask) * 11)
else:
return [0] * (self.block_num * 11)
def range_table(self): def range_table(self):
""" """
......
...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr ...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase from .search_space_base import SearchSpaceBase
from .base_layer import conv_bn_layer from .base_layer import conv_bn_layer
from .search_space_registry import SEARCHSPACE from .search_space_registry import SEARCHSPACE
from .utils import compute_downsample_num, check_points from .utils import compute_downsample_num, check_points, get_random_tokens
__all__ = ["MobileNetV1BlockSpace", "MobileNetV2BlockSpace"] __all__ = ["MobileNetV1BlockSpace", "MobileNetV2BlockSpace"]
...@@ -60,10 +60,7 @@ class MobileNetV2BlockSpace(SearchSpaceBase): ...@@ -60,10 +60,7 @@ class MobileNetV2BlockSpace(SearchSpaceBase):
self.scale = scale self.scale = scale
def init_tokens(self): def init_tokens(self):
if self.block_mask != None: return get_random_tokens(self.range_table)
return [0] * (len(self.block_mask) * 4)
else:
return [0] * (self.block_num * 4)
def range_table(self): def range_table(self):
range_table_base = [] range_table_base = []
...@@ -308,10 +305,7 @@ class MobileNetV1BlockSpace(SearchSpaceBase): ...@@ -308,10 +305,7 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
self.scale = scale self.scale = scale
def init_tokens(self): def init_tokens(self):
if self.block_mask != None: return get_random_tokens(self.range_table)
return [0] * (len(self.block_mask) * 3)
else:
return [0] * (self.block_num * 3)
def range_table(self): def range_table(self):
range_table_base = [] range_table_base = []
......
...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr ...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase from .search_space_base import SearchSpaceBase
from .base_layer import conv_bn_layer from .base_layer import conv_bn_layer
from .search_space_registry import SEARCHSPACE from .search_space_registry import SEARCHSPACE
from .utils import check_points from .utils import check_points, get_random_tokens
__all__ = ["ResNetSpace"] __all__ = ["ResNetSpace"]
...@@ -47,8 +47,7 @@ class ResNetSpace(SearchSpaceBase): ...@@ -47,8 +47,7 @@ class ResNetSpace(SearchSpaceBase):
""" """
The initial token. The initial token.
""" """
init_token_base = [0, 0, 0, 0, 0, 0, 0, 0] return [1, 1, 2, 2, 3, 4, 3, 1]
return init_token_base
def range_table(self): def range_table(self):
""" """
......
...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr ...@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase from .search_space_base import SearchSpaceBase
from .base_layer import conv_bn_layer from .base_layer import conv_bn_layer
from .search_space_registry import SEARCHSPACE from .search_space_registry import SEARCHSPACE
from .utils import compute_downsample_num, check_points from .utils import compute_downsample_num, check_points, get_random_tokens
__all__ = ["ResNetBlockSpace"] __all__ = ["ResNetBlockSpace"]
...@@ -40,14 +40,11 @@ class ResNetBlockSpace(SearchSpaceBase): ...@@ -40,14 +40,11 @@ class ResNetBlockSpace(SearchSpaceBase):
self.downsample_num, self.block_num) self.downsample_num, self.block_num)
self.filter_num = np.array( self.filter_num = np.array(
[48, 64, 96, 128, 160, 192, 224, 256, 320, 384, 512, 640]) [48, 64, 96, 128, 160, 192, 224, 256, 320, 384, 512, 640])
self.repeat = np.array([0, 1, 2]) self.repeat = np.array([0, 1, 2, 3, 4, 6, 7, 8, 10, 12, 14, 16])
self.k_size = np.array([3, 5]) self.k_size = np.array([3, 5])
def init_tokens(self): def init_tokens(self):
if self.block_mask != None: return get_random_tokens(self.range_table)
return [0] * (len(self.block_mask) * 6)
else:
return [0] * (self.block_num * 6)
def range_table(self): def range_table(self):
range_table_base = [] range_table_base = []
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
import numpy as np
def compute_downsample_num(input_size, output_size): def compute_downsample_num(input_size, output_size):
...@@ -36,3 +37,11 @@ def check_points(count, points): ...@@ -36,3 +37,11 @@ def check_points(count, points):
return (True if count in points else False) return (True if count in points else False)
else: else:
return (True if count == points else False) return (True if count == points else False)
def get_random_tokens(range_table):
tokens = []
for idx, max_value in enumerate(range_table):
tokens_idx = int(np.floor(range_table[idx] * np.random.rand(1)))
tokens.append(tokens_idx)
return tokens
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册