未验证 提交 87d75b98 编写于 作者: C ceci3 提交者: GitHub

support set init tokens by user for sanas (#31)


* add set init_token for nas
上级 fad49014
......@@ -26,17 +26,22 @@ class ControllerClient(object):
Controller client.
"""
def __init__(self, server_ip=None, server_port=None, key=None):
def __init__(self,
server_ip=None,
server_port=None,
key=None,
client_name=None):
"""
Args:
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
client_name(str): Current client name, random generate for counting client number. Default: None.
"""
self.server_ip = server_ip
self.server_port = server_port
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key
self._client_name = client_name
def update(self, tokens, reward, iter):
"""
......@@ -48,8 +53,8 @@ class ControllerClient(object):
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}\t{}".format(self._key, tokens, reward,
iter).encode())
socket_client.send("{}\t{}\t{}\t{}\t{}".format(
self._key, tokens, reward, iter, self._client_name).encode())
response = socket_client.recv(1024).decode()
if response.strip('\n').split("\t") == "ok":
return True
......
......@@ -15,6 +15,7 @@
import os
import logging
import socket
import time
from .log_helper import get_logger
from threading import Thread
from .lock_utils import lock, unlock
......@@ -41,7 +42,8 @@ class ControllerServer(object):
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
which means setting ip automatically
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
search_steps(int): The total steps of searching. None means never stopping. Default: None
search_steps(int|None): The total steps of searching. None means never stopping. Default: None
key(str|None): Config information. Default: None.
"""
self._controller = controller
self._address = address
......@@ -51,6 +53,9 @@ class ControllerServer(object):
self._port = address[1]
self._ip = address[0]
self._key = key
self._client_num = 0
self._client = dict()
self._compare_time = 172800 ### 48 hours
def start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
......@@ -93,15 +98,43 @@ class ControllerServer(object):
_logger.debug("recv message from {}: [{}]".format(addr,
message))
messages = message.strip('\n').split("\t")
if (len(messages) < 4) or (messages[0] != self._key):
if (len(messages) < 5) or (messages[0] != self._key):
_logger.debug("recv noise from {}: [{}]".format(
addr, message))
continue
tokens = messages[1]
reward = messages[2]
iter = messages[3]
client_name = messages[4]
one_step_time = -1
if client_name in self._client.keys():
current_time = time.time() - self._client[client_name]
if current_time > one_step_time:
one_step_time = current_time
self._compare_time = 2 * one_step_time
if client_name not in self._client.keys():
self._client[client_name] = time.time()
self._client_num += 1
self._client[client_name] = time.time()
for key_client in self._client.keys():
### if a client not request token in double train one tokens' time, we think this client was stoped.
if (time.time() - self._client[key_client]
) > self._compare_time and len(self._client.keys(
)) > 1:
self._client.pop(key_client)
self._client_num -= 1
_logger.info(
"client: {}, client_num: {}, compare_time: {}".format(
self._client, self._client_num,
self._compare_time))
tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward), int(iter))
self._controller.update(tokens,
float(reward),
int(iter), int(self._client_num))
response = "ok"
conn.send(response.encode())
_logger.debug("send message to {}: [{}]".format(addr,
......
......@@ -34,7 +34,7 @@ class SAController(EvolutionaryController):
def __init__(self,
range_table=None,
reduce_rate=0.85,
init_temperature=1024,
init_temperature=None,
max_try_times=300,
init_tokens=None,
reward=-1,
......@@ -68,12 +68,20 @@ class SAController(EvolutionaryController):
self._max_try_times = max_try_times
self._reward = reward
self._tokens = init_tokens
if init_temperature == None:
if init_tokens == None:
self._init_temperature = 10.0
else:
self._init_temperature = 1.0
self._constrain_func = constrain_func
self._max_reward = max_reward
self._best_tokens = best_tokens
self._iter = iters
self._checkpoints = checkpoints
self._searched = searched if searched != None else dict()
self._current_token = init_tokens
def __getstate__(self):
d = {}
......@@ -92,9 +100,9 @@ class SAController(EvolutionaryController):
@property
def current_tokens(self):
return self._tokens
return self._current_tokens
def update(self, tokens, reward, iter):
def update(self, tokens, reward, iter, client_num):
"""
Update the controller according to latest tokens and reward.
Args:
......@@ -105,7 +113,9 @@ class SAController(EvolutionaryController):
if iter > self._iter:
self._iter = iter
self._searched[str(tokens)] = reward
temperature = self._init_temperature * self._reduce_rate**self._iter
temperature = self._init_temperature * self._reduce_rate**(client_num *
self._iter)
self._current_tokens = tokens
if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)):
self._reward = reward
......@@ -117,6 +127,9 @@ class SAController(EvolutionaryController):
"Controller - iter: {}; best_reward: {}, best tokens: {}, current_reward: {}; current tokens: {}".
format(self._iter, self._max_reward, self._best_tokens, reward,
tokens))
_logger.debug(
'Controller - iter: {}, controller current tokens: {}, controller current reward: {}'.
format(self._iter, self._tokens, self._reward))
if self._checkpoints != None:
self._save_checkpoint(self._checkpoints)
......
......@@ -18,6 +18,7 @@ import logging
import numpy as np
import json
import hashlib
import time
import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController
......@@ -37,12 +38,13 @@ class SANAS(object):
def __init__(self,
configs,
server_addr=("", 8881),
init_temperature=100,
init_temperature=None,
reduce_rate=0.85,
search_steps=300,
init_tokens=None,
save_checkpoint='nas_checkpoint',
load_checkpoint=None,
is_server=False):
is_server=True):
"""
Search a group of ratios used to prune program.
Args:
......@@ -50,9 +52,10 @@ class SANAS(object):
`key` is the name of search space with data type str. `input_size` and `output_size` are
input size and output size of searched sub-network. `block_num` is the number of blocks in searched network, `block_mask` is a list consists by 0 and 1, 0 means normal block, 1 means reduction block.
server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy.
search_steps(int): The steps of searching.
init_temperature(float|None): The init temperature used in simulated annealing search strategy. Default: None.
reduce_rate(float): The decay rate used in simulated annealing search strategy. Default: None.
search_steps(int): The steps of searching. Default: 300.
init_token(list): Init tokens user can set by yourself. Default: None.
save_checkpoint(string|None): The directory of checkpoint to save, if set to None, not save checkpoint. Default: 'nas_checkpoint'.
load_checkpoint(string|None): The directory of checkpoint to load, if set to None, not load checkpoint. Default: None.
is_server(bool): Whether current host is controller server. Default: True.
......@@ -64,7 +67,12 @@ class SANAS(object):
self._init_temperature = init_temperature
self._is_server = is_server
self._configs = configs
self._key = hashlib.md5(str(self._configs).encode("utf-8")).hexdigest()
self._init_tokens = init_tokens
self._client_name = hashlib.md5(
str(time.time() + np.random.randint(1, 10000)).encode(
"utf-8")).hexdigest()
self._key = str(self._configs)
self._current_tokens = init_tokens
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
......@@ -75,7 +83,7 @@ class SANAS(object):
# create controller server
if self._is_server:
init_tokens = self._search_space.init_tokens()
init_tokens = self._search_space.init_tokens(self._init_tokens)
range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table)
_logger.info("range table: {}".format(range_table))
......@@ -127,7 +135,10 @@ class SANAS(object):
server_port = self._controller_server.port()
self._controller_client = ControllerClient(
server_ip, server_port, key=self._key)
server_ip,
server_port,
key=self._key,
client_name=self._client_name)
if is_server and load_checkpoint != None:
self._iter = scene['_iter']
......@@ -164,6 +175,7 @@ class SANAS(object):
list<function>: A list of functions that define networks.
"""
self._current_tokens = self._controller_client.next_tokens()
_logger.info("current tokens: {}".format(self._current_tokens))
archs = self._search_space.token2arch(self._current_tokens)
return archs
......
......@@ -97,16 +97,19 @@ class CombineSearchSpace(object):
space = cls(input_size, output_size, block_num, block_mask=block_mask)
return space
def init_tokens(self):
def init_tokens(self, tokens=None):
"""
Combine init tokens.
"""
tokens = []
self.single_token_num = []
for space in self.spaces:
tokens.extend(space.init_tokens())
self.single_token_num.append(len(space.init_tokens()))
return tokens
if tokens is None:
tokens = []
self.single_token_num = []
for space in self.spaces:
tokens.extend(space.init_tokens())
self.single_token_num.append(len(space.init_tokens()))
return tokens
else:
return tokens
def range_table(self):
"""
......
......@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase
from .base_layer import conv_bn_layer
from .search_space_registry import SEARCHSPACE
from .utils import compute_downsample_num, check_points
from .utils import compute_downsample_num, check_points, get_random_tokens
__all__ = ["InceptionABlockSpace", "InceptionCBlockSpace"]
### TODO add asymmetric kernel of conv when paddle-lite support
......@@ -58,10 +58,7 @@ class InceptionABlockSpace(SearchSpaceBase):
"""
The initial token.
"""
if self.block_mask != None:
return [0] * (len(self.block_mask) * 9)
else:
return [0] * (self.block_num * 9)
return get_random_tokens(self.range_table)
def range_table(self):
"""
......@@ -290,10 +287,7 @@ class InceptionCBlockSpace(SearchSpaceBase):
"""
The initial token.
"""
if self.block_mask != None:
return [0] * (len(self.block_mask) * 11)
else:
return [0] * (self.block_num * 11)
return get_random_tokens(self.range_table)
def range_table(self):
"""
......
......@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase
from .base_layer import conv_bn_layer
from .search_space_registry import SEARCHSPACE
from .utils import compute_downsample_num, check_points
from .utils import compute_downsample_num, check_points, get_random_tokens
__all__ = ["MobileNetV1BlockSpace", "MobileNetV2BlockSpace"]
......@@ -60,10 +60,7 @@ class MobileNetV2BlockSpace(SearchSpaceBase):
self.scale = scale
def init_tokens(self):
if self.block_mask != None:
return [0] * (len(self.block_mask) * 4)
else:
return [0] * (self.block_num * 4)
return get_random_tokens(self.range_table)
def range_table(self):
range_table_base = []
......@@ -308,10 +305,7 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
self.scale = scale
def init_tokens(self):
if self.block_mask != None:
return [0] * (len(self.block_mask) * 3)
else:
return [0] * (self.block_num * 3)
return get_random_tokens(self.range_table)
def range_table(self):
range_table_base = []
......
......@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase
from .base_layer import conv_bn_layer
from .search_space_registry import SEARCHSPACE
from .utils import check_points
from .utils import check_points, get_random_tokens
__all__ = ["ResNetSpace"]
......@@ -47,8 +47,7 @@ class ResNetSpace(SearchSpaceBase):
"""
The initial token.
"""
init_token_base = [0, 0, 0, 0, 0, 0, 0, 0]
return init_token_base
return [1, 1, 2, 2, 3, 4, 3, 1]
def range_table(self):
"""
......
......@@ -22,7 +22,7 @@ from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase
from .base_layer import conv_bn_layer
from .search_space_registry import SEARCHSPACE
from .utils import compute_downsample_num, check_points
from .utils import compute_downsample_num, check_points, get_random_tokens
__all__ = ["ResNetBlockSpace"]
......@@ -40,14 +40,11 @@ class ResNetBlockSpace(SearchSpaceBase):
self.downsample_num, self.block_num)
self.filter_num = np.array(
[48, 64, 96, 128, 160, 192, 224, 256, 320, 384, 512, 640])
self.repeat = np.array([0, 1, 2])
self.repeat = np.array([0, 1, 2, 3, 4, 6, 7, 8, 10, 12, 14, 16])
self.k_size = np.array([3, 5])
def init_tokens(self):
if self.block_mask != None:
return [0] * (len(self.block_mask) * 6)
else:
return [0] * (self.block_num * 6)
return get_random_tokens(self.range_table)
def range_table(self):
range_table_base = []
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import math
import numpy as np
def compute_downsample_num(input_size, output_size):
......@@ -36,3 +37,11 @@ def check_points(count, points):
return (True if count in points else False)
else:
return (True if count == points else False)
def get_random_tokens(range_table):
tokens = []
for idx, max_value in enumerate(range_table):
tokens_idx = int(np.floor(range_table[idx] * np.random.rand(1)))
tokens.append(tokens_idx)
return tokens
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册