提交 513e6961 编写于 作者: W wanghaoshuang

Add sa controller

Add controller server and client
Refine auto pruner
上级 a78c4b8a
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import controller
from controller import *
import sa_controller
from sa_controller import *
import log_helper
from log_helper import *
__all__ = []
__all__ += controller.__all__
__all__ += sa_controller.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The controller used to search hyperparameters or neural architecture"""
import copy
import math
import numpy as np
__all__ = ['EvolutionaryController']
class EvolutionaryController(object):
"""Abstract controller for all evolutionary searching method.
"""
def __init__(self, *args, **kwargs):
pass
def update(self, tokens, reward):
"""Update the status of controller according current tokens and reward.
Args:
tokens(list<int>): A solution of searching task.
reward(list<int>): The reward of tokens.
"""
raise NotImplementedError('Abstract method.')
def reset(self, range_table, constrain_func=None):
"""Reset the controller.
Args:
range_table(list<int>): It is used to define the searching space of controller.
The tokens[i] generated by controller should be in [0, range_table[i]).
constrain_func(function): It is used to check whether tokens meet the constraint.
None means there is no constraint. Default: None.
"""
raise NotImplementedError('Abstract method.')
def next_tokens(self):
"""Generate new tokens.
"""
raise NotImplementedError('Abstract method.')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import logging
__all__ = ['get_logger']
def get_logger(name, level, fmt=None):
"""
Get logger from logging with given name, level and format without
setting logging basicConfig. For setting basicConfig in paddle
will disable basicConfig setting after import paddle.
Args:
name (str): The logger name.
level (logging.LEVEL): The base level of the logger
fmt (str): Format of logger output
Returns:
logging.Logger: logging logger with given setttings
Examples:
.. code-block:: python
logger = log_helper.get_logger(__name__, logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s')
"""
logger = logging.getLogger(name)
logger.setLevel(level)
handler = logging.StreamHandler()
if fmt:
formatter = logging.Formatter(fmt=fmt)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The controller used to search hyperparameters or neural architecture"""
import copy
import math
import logging
import numpy as np
from .controller import EvolutionaryController
from log_helper import get_logger
__all__ = ["SAController"]
_logger = get_logger(__name__, level=logging.INFO)
class SAController(EvolutionaryController):
"""Simulated annealing controller."""
def __init__(self,
range_table=None,
reduce_rate=0.85,
init_temperature=1024,
max_iter_number=300,
init_tokens=None,
constrain_func=None):
"""Initialize.
Args:
range_table(list<int>): Range table.
reduce_rate(float): The decay rate of temperature.
init_temperature(float): Init temperature.
max_iter_number(int): max iteration number.
init_tokens(list<int>): The initial tokens.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
"""
super(SAController, self).__init__()
self._range_table = range_table
assert isinstance(self._range_table, tuple) and (len(self._range_table) == 2)
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_iter_number = max_iter_number
self._reward = -1
self._tokens = init_tokens
self._constrain_func = constrain_func
self._max_reward = -1
self._best_tokens = None
self._iter = 0
def __getstate__(self):
d = {}
for key in self.__dict__:
if key != "_constrain_func":
d[key] = self.__dict__[key]
return d
def update(self, tokens, reward):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
"""
self._iter += 1
temperature = self._init_temperature * self._reduce_rate**self._iter
if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)):
self._reward = reward
self._tokens = tokens
if reward > self._max_reward:
self._max_reward = reward
self._best_tokens = tokens
_logger.info("iter: {}; max_reward: {}; best_tokens: {}".format(
self._iter, self._max_reward, self._best_tokens))
_logger.info("current_reward: {}; current tokens: {}".format(
self._reward, self._tokens))
def next_tokens(self, control_token=None):
"""
Get next tokens.
"""
if control_token:
tokens = control_token[:]
else:
tokens = self._tokens
new_tokens = tokens[:]
index = int(len(self._range_table[0]) * np.random.random())
new_tokens[index] = np.random.randint(self._range_table[0][index], self._range_table[1][index]+1)
_logger.info("change index[{}] from {} to {}".format(index, tokens[
index], new_tokens[index]))
if self._constrain_func is None:
return new_tokens
for _ in range(self._max_iter_number):
if not self._constrain_func(new_tokens):
index = int(len(self._range_table) * np.random.random())
new_tokens = tokens[:]
new_tokens[index] = np.random.randint(self._range_table[index])
else:
break
return new_tokens
......@@ -15,6 +15,13 @@ import pruner
from pruner import *
import auto_pruner
from auto_pruner import *
import controller_server
from controller_server import *
import controller_client
from controller_client import *
__all__ = []
__all__ += pruner.__all__
__all__ += auto_pruner.__all__
__all__ += controller_server.__all__
__all__ += controller_client.__all__
......@@ -24,12 +24,21 @@ __all__ = ["AutoPruner"]
class AutoPruner(object):
def __init__(self,
program,
params=[],
init_ratios=None,
pruned_flops=0.5,
pruned_latency=None,
server_addr=("", ""),
search_strategy="sa"):
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_iter_number=300,
max_client_num=10,
search_steps=300,
max_ratios=[0.9],
min_ratios=[0],
key="auto_pruner"
):
"""
Search a group of ratios used to prune program.
Args:
......@@ -45,71 +54,86 @@ class AutoPruner(object):
search_strategy(str): The search strategy. Default: 'sa'.
"""
# step1: Create controller server. And start server if current host match server_ip.
self._controller_server = ControllerServer(
addr=(server_ip, server_port), search_strategy="sa")
self._program = program
self._params = params
self._init_ratios = init_ratios
self._pruned_flops = pruned_flops
self._pruned_latency = pruned_latency
self._pruner = Pruner()
self._controller_agent = None
self._base_flops = None
self._base_latency = None
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_try_number = max_try_number
def prune(self, program, scope, place):
assert isinstance(self._max_ratios, float) or isinstance(self._max_ratios)
self._range_table = self._get_range_table(min_ratios, max_ratios)
if self._controller_agent is None:
self._controller_agent = PrunerAgent(
addr=self._controller_server.addr, self._range_table)
self._pruner = Pruner()
if self._pruned_flops:
self._base_flops = flops(program)
if self._pruned_latency:
self._base_latency = latency(program)
if self._init_ratios is None:
self._init_ratios = self._get_init_ratios(
program, self._params, self._pruned_flops,
self,_program, self._params, self._pruned_flops,
self._pruned_latency)
self._current_ratios = self._init_ratios
else:
self._current_ratios = self._controller_agent.next_ratios()
init_tokens = self._ratios2tokens(self._init_ratios)
if self._base_flops == None:
self._base_flops = flops(program)
for i in range(self._max_try_num):
controller = SAController(self._range_table,
self._reduce_rate,
self._init_temperature,
self._max_try_number,
init_tokens,
self._constrain_func)
self._controller_server = ControllerServer(
controller=controller,
addr=server_addr,
max_client_num,
search_steps,
key=key)
self._controller_client = ControllerClient(server_addr, key=key)
self._iter = 0
def _get_init_ratios(self, program, params, pruned_flops, pruned_latency):
pass
def _get_range_table(self, min_ratios, max_ratios):
assert isinstance(min_ratios, list) or isinstance(min_ratios, float)
assert isinstance(max_ratios, list) or isinstance(max_ratios, float)
min_ratios = min_ratios if isinstance(min_ratios, list) else [min_ratios]
max_ratios = max_ratios if isinstance(max_ratios, list) else [max_ratios]
min_tokens = self._ratios2tokens(min_ratios)
max_tokens = self._ratios2tokens(max_ratios)
return (min_tokens, max_tokens)
def _constrain_func(self, tokens):
ratios = self._tokens2ratios(tokens)
pruned_program = self._pruner.prune(
program,
scope,
self._params,
self._current_ratios,
only_graph=True)
if flops(pruned_program) < self._base_flops * (
1 - self._pruned_flops):
break
self._current_ratios = self._controller_agent.illegal_ratios(
self._current_ratios)
return flops(pruned_program) < self._base_flops
def prune(self, program, scope, place):
self._current_ratios = self._next_ratios()
pruned_program = self._pruner.prune(program, scope, self._params,
self._current_ratios)
return pruned_program
def reward(self, score):
self._controller_agent.reward(self._current_ratios, score)
tokens = self.ratios2tokens(self._current_ratios)
self._controller_client.reward(tokens, score)
self._iter += 1
class PrunerAgent(object):
"""
The agent used to talk with controller server.
"""
def __init__(self, server_attr=("", ""), range_table):
self._range_table = range_table
self._controller_client = ControllerClient(server_attr)
self._controller_client.send_range_table(range_table)
def next_ratios(self):
def _next_ratios(self):
tokens = self._controller_client.next_tokens()
self._tokens2ratios(tokens)
def illegal_ratios(self, ratios):
tokens = self._ratios2tokens(ratios)
tokens = self._controller_client.illegal_tokens(tokens)
return self._tokens2ratios(tokens)
def _ratios2tokens(self, ratios):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
__all__ = ['ControllerClient']
class ControllerClient(object):
"""
Controller client.
"""
def __init__(self, server_ip=None, server_port=None, key=None):
"""
Args:
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
"""
self.server_ip = server_ip
self.server_port = server_port
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key
def update(self, tokens, reward):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward)
.encode())
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens
def next_tokens(self):
"""
Get next tokens.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
socket_client.send("next_tokens".encode())
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
from threading import Thread
__all__ = ['ControllerServer']
class ControllerServer(object):
"""
The controller wrapper with a socket server to handle the request of search agent.
"""
def __init__(self,
controller=None,
address=('', 0),
max_client_num=100,
search_steps=None,
key=None):
"""
Args:
controller(slim.searcher.Controller): The controller used to generate tokens.
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
which means setting ip automatically
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
search_steps(int): The total steps of searching. None means never stopping. Default: None
"""
self._controller = controller
self._address = address
self._max_client_num = max_client_num
self._search_steps = search_steps
self._closed = False
self._port = address[1]
self._ip = address[0]
self._key = key
def start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket_server.bind(self._address)
self._socket_server.listen(self._max_client_num)
self._port = self._socket_server.getsockname()[1]
self._ip = self._socket_server.getsockname()[0]
_logger.info("listen on: [{}:{}]".format(self._ip, self._port))
thread = Thread(target=self.run)
thread.start()
return str(thread)
def close(self):
"""Close the server."""
self._closed = True
def port(self):
"""Get the port."""
return self._port
def ip(self):
"""Get the ip."""
return self._ip
def run(self):
_logger.info("Controller Server run...")
while ((self._search_steps is None) or
(self._controller._iter <
(self._search_steps))) and not self._closed:
conn, addr = self._socket_server.accept()
message = conn.recv(1024).decode()
if message.strip("\n") == "next_tokens":
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
else:
_logger.info("recv message from {}: [{}]".format(addr, message))
messages = message.strip('\n').split("\t")
if (len(messages) < 3) or (messages[0] != self._key):
_logger.info("recv noise from {}: [{}]".format(addr,
message))
continue
tokens = messages[1]
reward = messages[2]
tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward))
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
_logger.info("send message to {}: [{}]".format(addr, tokens))
conn.close()
self._socket_server.close()
_logger.info("server closed!")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册