未验证 提交 a521a961 编写于 作者: C ceci3 提交者: GitHub

Fix qa (#247)

* fix

* fix some trouble

* update
上级 7b77db07
...@@ -179,7 +179,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习 ...@@ -179,7 +179,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习
- **server_addr(tuple)** - RLNAS中Controller的地址,包括server的ip地址和端口号,如果ip地址为None或者为""的话则默认使用本机ip。默认:("", 8881)。 - **server_addr(tuple)** - RLNAS中Controller的地址,包括server的ip地址和端口号,如果ip地址为None或者为""的话则默认使用本机ip。默认:("", 8881)。
- **is_server(bool)** - 当前实例是否要启动一个server。默认:True。 - **is_server(bool)** - 当前实例是否要启动一个server。默认:True。
- **is_sync(bool)** - 是否使用同步模式更新Controller,该模式仅在多client下有差别。默认:False。 - **is_sync(bool)** - 是否使用同步模式更新Controller,该模式仅在多client下有差别。默认:False。
- **save_controller(str|None)** - 保存Controller的checkpoint的文件目录,如果设置为None的话则不保存checkpoint。默认:None 。 - **save_controller(str|None|False)** - 保存Controller的checkpoint的文件目录,如果设置为None的话则保存checkpoint到默认路径 ``./.rlnas_controller`` ,如果设置为False的话则不保存checkpoint。默认:None 。
- **load_controller(str|None)** - 加载Controller的checkpoint的文件目录,如果设置为None的话则不加载checkpoint。默认:None。 - **load_controller(str|None)** - 加载Controller的checkpoint的文件目录,如果设置为None的话则不加载checkpoint。默认:None。
- **\*\*kwargs** - 附加的参数,由具体强化学习算法决定,`LSTM`和`DDPG`的附加参数请参考note。 - **\*\*kwargs** - 附加的参数,由具体强化学习算法决定,`LSTM`和`DDPG`的附加参数请参考note。
......
...@@ -261,14 +261,14 @@ sa_nas.reward(float(valid_top1_list[-1] + valid_top1_list[-2]) / 2) ...@@ -261,14 +261,14 @@ sa_nas.reward(float(valid_top1_list[-1] + valid_top1_list[-2]) / 2)
### 10. 利用demo下的脚本启动搜索 ### 10. 利用demo下的脚本启动搜索
搜索文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/nas/darts_nas.py),搜索过程中限制模型参数量为不大于3.77M。 搜索文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/nas/sanas_darts_nas.py),搜索过程中限制模型参数量为不大于3.77M。
```python ```python
cd demo/nas/ cd demo/nas/
python darts_nas.py python darts_nas.py
``` ```
### 11. 利用demo下的脚本启动最终实验 ### 11. 利用demo下的脚本启动最终实验
最终实验文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/nas/darts_nas.py),最终实验需要训练600epoch。以下示例输入token为`[5, 5, 0, 5, 5, 10, 7, 7, 5, 7, 7, 11, 10, 12, 10, 0, 5, 3, 10, 8]` 最终实验文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/nas/sanas_darts_nas.py),最终实验需要训练600epoch。以下示例输入token为`[5, 5, 0, 5, 5, 10, 7, 7, 5, 7, 7, 11, 10, 12, 10, 0, 5, 3, 10, 8]`
```python ```python
cd demo/nas/ cd demo/nas/
python darts_nas.py --token 5 5 0 5 5 10 7 7 5 7 7 11 10 12 10 0 5 3 10 8 --retain_epoch 600 python darts_nas.py --token 5 5 0 5 5 10 7 7 5 7 7 11 10 12 10 0 5 3 10 8 --retain_epoch 600
......
...@@ -49,31 +49,38 @@ class Client(object): ...@@ -49,31 +49,38 @@ class Client(object):
ConnectMessage.TIMEOUT * 1000) ConnectMessage.TIMEOUT * 1000)
client_address = "{}:{}".format(self._ip, self._port) client_address = "{}:{}".format(self._ip, self._port)
self._client_socket.connect("tcp://{}".format(client_address)) self._client_socket.connect("tcp://{}".format(client_address))
self._client_socket.send_multipart( self._client_socket.send_multipart([
[ConnectMessage.INIT, self._client_name]) pickle.dumps(ConnectMessage.INIT), pickle.dumps(self._client_name)
])
message = self._client_socket.recv_multipart() message = self._client_socket.recv_multipart()
if message[0] != ConnectMessage.INIT_DONE: if pickle.loads(message[0]) != ConnectMessage.INIT_DONE:
_logger.error("Client {} init failure, Please start it again". _logger.error("Client {} init failure, Please start it again".
format(self._client_name)) format(self._client_name))
pid = os.getpid() pid = os.getpid()
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
_logger.info("Client {}: connect to server {}".format( _logger.info("Client {}: connect to server success!!!".format(
self._client_name))
_logger.debug("Client {}: connect to server {}".format(
self._client_name, client_address)) self._client_name, client_address))
def _connect_wait_socket(self, port): def _connect_wait_socket(self, port):
self._wait_socket = self._ctx.socket(zmq.REQ) self._wait_socket = self._ctx.socket(zmq.REQ)
wait_address = "{}:{}".format(self._ip, port) wait_address = "{}:{}".format(self._ip, port)
self._wait_socket.connect("tcp://{}".format(wait_address)) self._wait_socket.connect("tcp://{}".format(wait_address))
self._wait_socket.send_multipart( self._wait_socket.send_multipart([
[ConnectMessage.WAIT_PARAMS, self._client_name]) pickle.dumps(ConnectMessage.WAIT_PARAMS),
pickle.dumps(self._client_name)
])
message = self._wait_socket.recv_multipart() message = self._wait_socket.recv_multipart()
return message[0] return pickle.loads(message[0])
def next_tokens(self, obs, is_inference=False): def next_tokens(self, obs, is_inference=False):
_logger.debug("Client: requests for weight {}".format( _logger.debug("Client: requests for weight {}".format(
self._client_name)) self._client_name))
self._client_socket.send_multipart( self._client_socket.send_multipart([
[ConnectMessage.GET_WEIGHT, self._client_name]) pickle.dumps(ConnectMessage.GET_WEIGHT),
pickle.dumps(self._client_name)
])
try: try:
message = self._client_socket.recv_multipart() message = self._client_socket.recv_multipart()
except zmq.error.Again as e: except zmq.error.Again as e:
...@@ -95,8 +102,8 @@ class Client(object): ...@@ -95,8 +102,8 @@ class Client(object):
params_grad = compute_grad(self._params_dict, current_params_dict) params_grad = compute_grad(self._params_dict, current_params_dict)
_logger.debug("Client: update weight {}".format(self._client_name)) _logger.debug("Client: update weight {}".format(self._client_name))
self._client_socket.send_multipart([ self._client_socket.send_multipart([
ConnectMessage.UPDATE_WEIGHT, self._client_name, pickle.dumps(ConnectMessage.UPDATE_WEIGHT),
pickle.dumps(params_grad) pickle.dumps(self._client_name), pickle.dumps(params_grad)
]) ])
_logger.debug("Client: update done {}".format(self._client_name)) _logger.debug("Client: update done {}".format(self._client_name))
...@@ -108,29 +115,33 @@ class Client(object): ...@@ -108,29 +115,33 @@ class Client(object):
format(e)) format(e))
os._exit(0) os._exit(0)
if message[0] == ConnectMessage.WAIT: if pickle.loads(message[0]) == ConnectMessage.WAIT:
_logger.debug("Client: self.init_wait: {}".format(self.init_wait)) _logger.debug("Client: self.init_wait: {}".format(self.init_wait))
if not self.init_wait: if not self.init_wait:
wait_port = pickle.loads(message[1]) wait_port = pickle.loads(message[1])
wait_signal = self._connect_wait_socket(wait_port) wait_signal = self._connect_wait_socket(wait_port)
self.init_wait = True self.init_wait = True
else: else:
wait_signal = message[0] wait_signal = pickle.loads(message[0])
while wait_signal != ConnectMessage.OK: while wait_signal != ConnectMessage.OK:
time.sleep(1) time.sleep(1)
self._wait_socket.send_multipart( self._wait_socket.send_multipart([
[ConnectMessage.WAIT_PARAMS, self._client_name]) pickle.dumps(ConnectMessage.WAIT_PARAMS),
pickle.dumps(self._client_name)
])
wait_signal = self._wait_socket.recv_multipart() wait_signal = self._wait_socket.recv_multipart()
wait_signal = wait_signal[0] wait_signal = pickle.loads(wait_signal[0])
_logger.debug("Client: {} {}".format(self._client_name, _logger.debug("Client: {} {}".format(self._client_name,
wait_signal)) wait_signal))
return message[0] return pickle.loads(message[0])
def __del__(self): def __del__(self):
try: try:
self._client_socket.send_multipart( self._client_socket.send_multipart([
[ConnectMessage.EXIT, self._client_name]) pickle.dumps(ConnectMessage.EXIT),
pickle.dumps(self._client_name)
])
_ = self._client_socket.recv_multipart() _ = self._client_socket.recv_multipart()
except: except:
pass pass
......
...@@ -62,7 +62,8 @@ class ControllerServer(object): ...@@ -62,7 +62,8 @@ class ControllerServer(object):
self._socket_server.listen(self._max_client_num) self._socket_server.listen(self._max_client_num)
self._port = self._socket_server.getsockname()[1] self._port = self._socket_server.getsockname()[1]
self._ip = self._socket_server.getsockname()[0] self._ip = self._socket_server.getsockname()[0]
_logger.info("ControllerServer - listen on: [{}:{}]".format( _logger.info("ControllerServer Start!!!")
_logger.debug("ControllerServer - listen on: [{}:{}]".format(
self._ip, self._port)) self._ip, self._port))
thread = Thread(target=self.run) thread = Thread(target=self.run)
thread.setDaemon(True) thread.setDaemon(True)
......
...@@ -65,7 +65,8 @@ class Server(object): ...@@ -65,7 +65,8 @@ class Server(object):
server_address = "{}:{}".format(self._ip, self._port) server_address = "{}:{}".format(self._ip, self._port)
self._server_socket.bind("tcp://{}".format(server_address)) self._server_socket.bind("tcp://{}".format(server_address))
self._server_socket.linger = 0 self._server_socket.linger = 0
_logger.info("ControllerServer - listen on: [{}]".format( _logger.info("ControllerServer Start!!!")
_logger.debug("ControllerServer - listen on: [{}]".format(
server_address)) server_address))
thread = threading.Thread(target=self.run, args=()) thread = threading.Thread(target=self.run, args=())
thread.setDaemon(True) thread.setDaemon(True)
...@@ -100,19 +101,20 @@ class Server(object): ...@@ -100,19 +101,20 @@ class Server(object):
try: try:
while self._server_alive: while self._server_alive:
message = self._wait_socket.recv_multipart() message = self._wait_socket.recv_multipart()
cmd = message[0] cmd = pickle.loads(message[0])
client_name = message[1] client_name = pickle.loads(message[1])
if cmd == ConnectMessage.WAIT_PARAMS: if cmd == ConnectMessage.WAIT_PARAMS:
_logger.debug("Server: wait for params") _logger.debug("Server: wait for params")
self._lock.acquire() self._lock.acquire()
self._wait_socket.send_multipart([ self._wait_socket.send_multipart([
ConnectMessage.OK pickle.dumps(ConnectMessage.OK)
if self._done else ConnectMessage.WAIT if self._done else pickle.dumps(ConnectMessage.WAIT)
]) ])
if self._done and client_name in self._client: if self._done and client_name in self._client:
self._client.remove(client_name) self._client.remove(client_name)
if len(self._client) == 0: if len(self._client) == 0:
self.save_params() if self._save_controller != False:
self.save_params()
self._done = False self._done = False
self._lock.release() self._lock.release()
else: else:
...@@ -127,11 +129,11 @@ class Server(object): ...@@ -127,11 +129,11 @@ class Server(object):
try: try:
sum_params_dict = dict() sum_params_dict = dict()
message = self._server_socket.recv_multipart() message = self._server_socket.recv_multipart()
cmd = message[0] cmd = pickle.loads(message[0])
client_name = message[1] client_name = pickle.loads(message[1])
if cmd == ConnectMessage.INIT: if cmd == ConnectMessage.INIT:
self._server_socket.send_multipart( self._server_socket.send_multipart(
[ConnectMessage.INIT_DONE]) [pickle.dumps(ConnectMessage.INIT_DONE)])
_logger.debug("Server: init client {}".format( _logger.debug("Server: init client {}".format(
client_name)) client_name))
self._client_dict[client_name] = 0 self._client_dict[client_name] = 0
...@@ -161,7 +163,7 @@ class Server(object): ...@@ -161,7 +163,7 @@ class Server(object):
self._done = True self._done = True
self._server_socket.send_multipart([ self._server_socket.send_multipart([
ConnectMessage.WAIT, pickle.dumps(ConnectMessage.WAIT),
pickle.dumps(self._wait_port) pickle.dumps(self._wait_port)
]) ])
else: else:
...@@ -174,16 +176,17 @@ class Server(object): ...@@ -174,16 +176,17 @@ class Server(object):
self._max_update_times = self._client_dict[ self._max_update_times = self._client_dict[
client_name] client_name]
self._lock.release() self._lock.release()
self.save_params() if self._save_controller != False:
self.save_params()
self._server_socket.send_multipart( self._server_socket.send_multipart(
[ConnectMessage.OK]) [pickle.dumps(ConnectMessage.OK)])
elif cmd == ConnectMessage.EXIT: elif cmd == ConnectMessage.EXIT:
self._client_dict.pop(client_name) self._client_dict.pop(client_name)
if client_name in self._client: if client_name in self._client:
self._client.remove(client_name) self._client.remove(client_name)
self._server_socket.send_multipart( self._server_socket.send_multipart(
[ConnectMessage.EXIT]) [pickle.dumps(ConnectMessage.EXIT)])
except zmq.error.Again as e: except zmq.error.Again as e:
_logger.error(e) _logger.error(e)
self.close() self.close()
......
...@@ -60,14 +60,14 @@ class MedianStop(EarlyStopBase): ...@@ -60,14 +60,14 @@ class MedianStop(EarlyStopBase):
'get_completed_history', callable=return_completed_history) 'get_completed_history', callable=return_completed_history)
base_manager = BaseManager( base_manager = BaseManager(
address=(self._server_ip, self._server_port), address=(self._server_ip, self._server_port),
authkey=PublicAuthKey) authkey=PublicAuthKey.encode())
base_manager.start() base_manager.start()
else: else:
BaseManager.register('get_completed_history') BaseManager.register('get_completed_history')
base_manager = BaseManager( base_manager = BaseManager(
address=(self._server_ip, self._server_port), address=(self._server_ip, self._server_port),
authkey=PublicAuthKey) authkey=PublicAuthKey.encode())
base_manager.connect() base_manager.connect()
return base_manager return base_manager
......
...@@ -126,6 +126,10 @@ class RLNAS(object): ...@@ -126,6 +126,10 @@ class RLNAS(object):
return archs return archs
@property
def tokens(self):
return self._current_tokens
def reward(self, rewards, **kwargs): def reward(self, rewards, **kwargs):
""" """
reward the score and to train controller reward the score and to train controller
...@@ -143,6 +147,7 @@ class RLNAS(object): ...@@ -143,6 +147,7 @@ class RLNAS(object):
""" """
final_tokens = self._controller_client.next_tokens( final_tokens = self._controller_client.next_tokens(
batch_obs, is_inference=True) batch_obs, is_inference=True)
self._current_tokens = final_tokens
_logger.info("Final tokens: {}".format(final_tokens)) _logger.info("Final tokens: {}".format(final_tokens))
archs = [] archs = []
for token in final_tokens: for token in final_tokens:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddleslim.nas import SANAS
from paddleslim.nas.early_stop import MedianStop
steps = 5
epochs = 5
class TestMedianStop(unittest.TestCase):
def test_median_stop(self):
config = [('MobileNetV2Space')]
sanas = SANAS(config, server_addr=("", 8732), save_checkpoint=None)
earlystop = MedianStop(sanas, 2)
avg_loss = 1.0
for step in range(steps):
status = earlystop.get_status(step, avg_loss, epochs)
self.assertTrue(status, 'GOOD')
avg_loss = 0.5
for step in range(steps):
status = earlystop.get_status(step, avg_loss, epochs)
if step < 2:
self.assertTrue(status, 'GOOD')
else:
self.assertTrue(status, 'BAD')
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
from paddleslim.nas import RLNAS
from paddleslim.analysis import flops
import numpy as np
def compute_op_num(program):
params = {}
ch_list = []
for block in program.blocks:
for param in block.all_parameters():
if len(param.shape) == 4:
params[param.name] = param.shape
ch_list.append(int(param.shape[0]))
return params, ch_list
class TestRLNAS(unittest.TestCase):
def setUp(self):
self.init_test_case()
port = np.random.randint(8337, 8773)
self.rlnas = RLNAS(
key='lstm',
configs=self.configs,
server_addr=("", port),
is_sync=False,
controller_batch_size=1,
lstm_num_layers=1,
hidden_size=10,
temperature=1.0,
save_controller=False)
def init_test_case(self):
self.configs = [('MobileNetV2BlockSpace', {'block_mask': [0]})]
self.filter_num = np.array([
3, 4, 8, 12, 16, 24, 32, 48, 64, 80, 96, 128, 144, 160, 192, 224,
256, 320, 384, 512
])
self.k_size = np.array([3, 5])
self.multiply = np.array([1, 2, 3, 4, 5, 6])
self.repeat = np.array([1, 2, 3, 4, 5, 6])
def check_chnum_convnum(self, program, current_tokens):
channel_exp = self.multiply[current_tokens[0]]
filter_num = self.filter_num[current_tokens[1]]
repeat_num = self.repeat[current_tokens[2]]
conv_list, ch_pro = compute_op_num(program)
### assert conv number
self.assertTrue((repeat_num * 3) == len(
conv_list
), "the number of conv is NOT match, the number compute from token: {}, actual conv number: {}".
format(repeat_num * 3, len(conv_list)))
### assert number of channels
ch_token = []
init_ch_num = 32
for i in range(repeat_num):
ch_token.append(init_ch_num * channel_exp)
ch_token.append(init_ch_num * channel_exp)
ch_token.append(filter_num)
init_ch_num = filter_num
self.assertTrue(
str(ch_token) == str(ch_pro),
"channel num is WRONG, channel num from token is {}, channel num come fom program is {}".
format(str(ch_token), str(ch_pro)))
def test_all_function(self):
### unittest for next_archs
next_program = fluid.Program()
startup_program = fluid.Program()
token2arch_program = fluid.Program()
with fluid.program_guard(next_program, startup_program):
inputs = fluid.data(
name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = self.rlnas.next_archs(1)[0]
current_tokens = self.rlnas.tokens
for arch in archs:
output = arch(inputs)
inputs = output
self.check_chnum_convnum(next_program, current_tokens[0])
### unittest for reward
self.assertTrue(self.rlnas.reward(float(1.0)), "reward is False")
### uniitest for tokens2arch
with fluid.program_guard(token2arch_program, startup_program):
inputs = fluid.data(
name='input', shape=[None, 3, 32, 32], dtype='float32')
arch = self.rlnas.tokens2arch(self.rlnas.tokens[0])
for arch in archs:
output = arch(inputs)
inputs = output
self.check_chnum_convnum(token2arch_program, self.rlnas.tokens[0])
def test_final_archs(self):
### unittest for final_archs
final_program = fluid.Program()
final_startup_program = fluid.Program()
with fluid.program_guard(final_program, final_startup_program):
inputs = fluid.data(
name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = self.rlnas.final_archs(1)[0]
current_tokens = self.rlnas.tokens
for arch in archs:
output = arch(inputs)
inputs = output
self.check_chnum_convnum(final_program, current_tokens[0])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册