未验证 提交 a521a961 编写于 作者: C ceci3 提交者: GitHub

Fix qa (#247)

* fix

* fix some trouble

* update
上级 7b77db07
......@@ -179,7 +179,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习
- **server_addr(tuple)** - RLNAS中Controller的地址,包括server的ip地址和端口号,如果ip地址为None或者为""的话则默认使用本机ip。默认:("", 8881)。
- **is_server(bool)** - 当前实例是否要启动一个server。默认:True。
- **is_sync(bool)** - 是否使用同步模式更新Controller,该模式仅在多client下有差别。默认:False。
- **save_controller(str|None)** - 保存Controller的checkpoint的文件目录,如果设置为None的话则不保存checkpoint。默认:None 。
- **save_controller(str|None|False)** - 保存Controller的checkpoint的文件目录,如果设置为None的话则保存checkpoint到默认路径 ``./.rlnas_controller`` ,如果设置为False的话则不保存checkpoint。默认:None 。
- **load_controller(str|None)** - 加载Controller的checkpoint的文件目录,如果设置为None的话则不加载checkpoint。默认:None。
- **\*\*kwargs** - 附加的参数,由具体强化学习算法决定,`LSTM`和`DDPG`的附加参数请参考note。
......
......@@ -261,14 +261,14 @@ sa_nas.reward(float(valid_top1_list[-1] + valid_top1_list[-2]) / 2)
### 10. 利用demo下的脚本启动搜索
搜索文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/nas/darts_nas.py),搜索过程中限制模型参数量为不大于3.77M。
搜索文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/nas/sanas_darts_nas.py),搜索过程中限制模型参数量为不大于3.77M。
```python
cd demo/nas/
python darts_nas.py
```
### 11. 利用demo下的脚本启动最终实验
最终实验文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/nas/darts_nas.py),最终实验需要训练600epoch。以下示例输入token为`[5, 5, 0, 5, 5, 10, 7, 7, 5, 7, 7, 11, 10, 12, 10, 0, 5, 3, 10, 8]`
最终实验文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/nas/sanas_darts_nas.py),最终实验需要训练600epoch。以下示例输入token为`[5, 5, 0, 5, 5, 10, 7, 7, 5, 7, 7, 11, 10, 12, 10, 0, 5, 3, 10, 8]`
```python
cd demo/nas/
python darts_nas.py --token 5 5 0 5 5 10 7 7 5 7 7 11 10 12 10 0 5 3 10 8 --retain_epoch 600
......
......@@ -49,31 +49,38 @@ class Client(object):
ConnectMessage.TIMEOUT * 1000)
client_address = "{}:{}".format(self._ip, self._port)
self._client_socket.connect("tcp://{}".format(client_address))
self._client_socket.send_multipart(
[ConnectMessage.INIT, self._client_name])
self._client_socket.send_multipart([
pickle.dumps(ConnectMessage.INIT), pickle.dumps(self._client_name)
])
message = self._client_socket.recv_multipart()
if message[0] != ConnectMessage.INIT_DONE:
if pickle.loads(message[0]) != ConnectMessage.INIT_DONE:
_logger.error("Client {} init failure, Please start it again".
format(self._client_name))
pid = os.getpid()
os.kill(pid, signal.SIGTERM)
_logger.info("Client {}: connect to server {}".format(
_logger.info("Client {}: connect to server success!!!".format(
self._client_name))
_logger.debug("Client {}: connect to server {}".format(
self._client_name, client_address))
def _connect_wait_socket(self, port):
self._wait_socket = self._ctx.socket(zmq.REQ)
wait_address = "{}:{}".format(self._ip, port)
self._wait_socket.connect("tcp://{}".format(wait_address))
self._wait_socket.send_multipart(
[ConnectMessage.WAIT_PARAMS, self._client_name])
self._wait_socket.send_multipart([
pickle.dumps(ConnectMessage.WAIT_PARAMS),
pickle.dumps(self._client_name)
])
message = self._wait_socket.recv_multipart()
return message[0]
return pickle.loads(message[0])
def next_tokens(self, obs, is_inference=False):
_logger.debug("Client: requests for weight {}".format(
self._client_name))
self._client_socket.send_multipart(
[ConnectMessage.GET_WEIGHT, self._client_name])
self._client_socket.send_multipart([
pickle.dumps(ConnectMessage.GET_WEIGHT),
pickle.dumps(self._client_name)
])
try:
message = self._client_socket.recv_multipart()
except zmq.error.Again as e:
......@@ -95,8 +102,8 @@ class Client(object):
params_grad = compute_grad(self._params_dict, current_params_dict)
_logger.debug("Client: update weight {}".format(self._client_name))
self._client_socket.send_multipart([
ConnectMessage.UPDATE_WEIGHT, self._client_name,
pickle.dumps(params_grad)
pickle.dumps(ConnectMessage.UPDATE_WEIGHT),
pickle.dumps(self._client_name), pickle.dumps(params_grad)
])
_logger.debug("Client: update done {}".format(self._client_name))
......@@ -108,29 +115,33 @@ class Client(object):
format(e))
os._exit(0)
if message[0] == ConnectMessage.WAIT:
if pickle.loads(message[0]) == ConnectMessage.WAIT:
_logger.debug("Client: self.init_wait: {}".format(self.init_wait))
if not self.init_wait:
wait_port = pickle.loads(message[1])
wait_signal = self._connect_wait_socket(wait_port)
self.init_wait = True
else:
wait_signal = message[0]
wait_signal = pickle.loads(message[0])
while wait_signal != ConnectMessage.OK:
time.sleep(1)
self._wait_socket.send_multipart(
[ConnectMessage.WAIT_PARAMS, self._client_name])
self._wait_socket.send_multipart([
pickle.dumps(ConnectMessage.WAIT_PARAMS),
pickle.dumps(self._client_name)
])
wait_signal = self._wait_socket.recv_multipart()
wait_signal = wait_signal[0]
wait_signal = pickle.loads(wait_signal[0])
_logger.debug("Client: {} {}".format(self._client_name,
wait_signal))
return message[0]
return pickle.loads(message[0])
def __del__(self):
try:
self._client_socket.send_multipart(
[ConnectMessage.EXIT, self._client_name])
self._client_socket.send_multipart([
pickle.dumps(ConnectMessage.EXIT),
pickle.dumps(self._client_name)
])
_ = self._client_socket.recv_multipart()
except:
pass
......
......@@ -62,7 +62,8 @@ class ControllerServer(object):
self._socket_server.listen(self._max_client_num)
self._port = self._socket_server.getsockname()[1]
self._ip = self._socket_server.getsockname()[0]
_logger.info("ControllerServer - listen on: [{}:{}]".format(
_logger.info("ControllerServer Start!!!")
_logger.debug("ControllerServer - listen on: [{}:{}]".format(
self._ip, self._port))
thread = Thread(target=self.run)
thread.setDaemon(True)
......
......@@ -65,7 +65,8 @@ class Server(object):
server_address = "{}:{}".format(self._ip, self._port)
self._server_socket.bind("tcp://{}".format(server_address))
self._server_socket.linger = 0
_logger.info("ControllerServer - listen on: [{}]".format(
_logger.info("ControllerServer Start!!!")
_logger.debug("ControllerServer - listen on: [{}]".format(
server_address))
thread = threading.Thread(target=self.run, args=())
thread.setDaemon(True)
......@@ -100,19 +101,20 @@ class Server(object):
try:
while self._server_alive:
message = self._wait_socket.recv_multipart()
cmd = message[0]
client_name = message[1]
cmd = pickle.loads(message[0])
client_name = pickle.loads(message[1])
if cmd == ConnectMessage.WAIT_PARAMS:
_logger.debug("Server: wait for params")
self._lock.acquire()
self._wait_socket.send_multipart([
ConnectMessage.OK
if self._done else ConnectMessage.WAIT
pickle.dumps(ConnectMessage.OK)
if self._done else pickle.dumps(ConnectMessage.WAIT)
])
if self._done and client_name in self._client:
self._client.remove(client_name)
if len(self._client) == 0:
self.save_params()
if self._save_controller != False:
self.save_params()
self._done = False
self._lock.release()
else:
......@@ -127,11 +129,11 @@ class Server(object):
try:
sum_params_dict = dict()
message = self._server_socket.recv_multipart()
cmd = message[0]
client_name = message[1]
cmd = pickle.loads(message[0])
client_name = pickle.loads(message[1])
if cmd == ConnectMessage.INIT:
self._server_socket.send_multipart(
[ConnectMessage.INIT_DONE])
[pickle.dumps(ConnectMessage.INIT_DONE)])
_logger.debug("Server: init client {}".format(
client_name))
self._client_dict[client_name] = 0
......@@ -161,7 +163,7 @@ class Server(object):
self._done = True
self._server_socket.send_multipart([
ConnectMessage.WAIT,
pickle.dumps(ConnectMessage.WAIT),
pickle.dumps(self._wait_port)
])
else:
......@@ -174,16 +176,17 @@ class Server(object):
self._max_update_times = self._client_dict[
client_name]
self._lock.release()
self.save_params()
if self._save_controller != False:
self.save_params()
self._server_socket.send_multipart(
[ConnectMessage.OK])
[pickle.dumps(ConnectMessage.OK)])
elif cmd == ConnectMessage.EXIT:
self._client_dict.pop(client_name)
if client_name in self._client:
self._client.remove(client_name)
self._server_socket.send_multipart(
[ConnectMessage.EXIT])
[pickle.dumps(ConnectMessage.EXIT)])
except zmq.error.Again as e:
_logger.error(e)
self.close()
......
......@@ -60,14 +60,14 @@ class MedianStop(EarlyStopBase):
'get_completed_history', callable=return_completed_history)
base_manager = BaseManager(
address=(self._server_ip, self._server_port),
authkey=PublicAuthKey)
authkey=PublicAuthKey.encode())
base_manager.start()
else:
BaseManager.register('get_completed_history')
base_manager = BaseManager(
address=(self._server_ip, self._server_port),
authkey=PublicAuthKey)
authkey=PublicAuthKey.encode())
base_manager.connect()
return base_manager
......
......@@ -126,6 +126,10 @@ class RLNAS(object):
return archs
@property
def tokens(self):
return self._current_tokens
def reward(self, rewards, **kwargs):
"""
reward the score and to train controller
......@@ -143,6 +147,7 @@ class RLNAS(object):
"""
final_tokens = self._controller_client.next_tokens(
batch_obs, is_inference=True)
self._current_tokens = final_tokens
_logger.info("Final tokens: {}".format(final_tokens))
archs = []
for token in final_tokens:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddleslim.nas import SANAS
from paddleslim.nas.early_stop import MedianStop
steps = 5
epochs = 5
class TestMedianStop(unittest.TestCase):
def test_median_stop(self):
config = [('MobileNetV2Space')]
sanas = SANAS(config, server_addr=("", 8732), save_checkpoint=None)
earlystop = MedianStop(sanas, 2)
avg_loss = 1.0
for step in range(steps):
status = earlystop.get_status(step, avg_loss, epochs)
self.assertTrue(status, 'GOOD')
avg_loss = 0.5
for step in range(steps):
status = earlystop.get_status(step, avg_loss, epochs)
if step < 2:
self.assertTrue(status, 'GOOD')
else:
self.assertTrue(status, 'BAD')
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
from paddleslim.nas import RLNAS
from paddleslim.analysis import flops
import numpy as np
def compute_op_num(program):
params = {}
ch_list = []
for block in program.blocks:
for param in block.all_parameters():
if len(param.shape) == 4:
params[param.name] = param.shape
ch_list.append(int(param.shape[0]))
return params, ch_list
class TestRLNAS(unittest.TestCase):
def setUp(self):
self.init_test_case()
port = np.random.randint(8337, 8773)
self.rlnas = RLNAS(
key='lstm',
configs=self.configs,
server_addr=("", port),
is_sync=False,
controller_batch_size=1,
lstm_num_layers=1,
hidden_size=10,
temperature=1.0,
save_controller=False)
def init_test_case(self):
self.configs = [('MobileNetV2BlockSpace', {'block_mask': [0]})]
self.filter_num = np.array([
3, 4, 8, 12, 16, 24, 32, 48, 64, 80, 96, 128, 144, 160, 192, 224,
256, 320, 384, 512
])
self.k_size = np.array([3, 5])
self.multiply = np.array([1, 2, 3, 4, 5, 6])
self.repeat = np.array([1, 2, 3, 4, 5, 6])
def check_chnum_convnum(self, program, current_tokens):
channel_exp = self.multiply[current_tokens[0]]
filter_num = self.filter_num[current_tokens[1]]
repeat_num = self.repeat[current_tokens[2]]
conv_list, ch_pro = compute_op_num(program)
### assert conv number
self.assertTrue((repeat_num * 3) == len(
conv_list
), "the number of conv is NOT match, the number compute from token: {}, actual conv number: {}".
format(repeat_num * 3, len(conv_list)))
### assert number of channels
ch_token = []
init_ch_num = 32
for i in range(repeat_num):
ch_token.append(init_ch_num * channel_exp)
ch_token.append(init_ch_num * channel_exp)
ch_token.append(filter_num)
init_ch_num = filter_num
self.assertTrue(
str(ch_token) == str(ch_pro),
"channel num is WRONG, channel num from token is {}, channel num come fom program is {}".
format(str(ch_token), str(ch_pro)))
def test_all_function(self):
### unittest for next_archs
next_program = fluid.Program()
startup_program = fluid.Program()
token2arch_program = fluid.Program()
with fluid.program_guard(next_program, startup_program):
inputs = fluid.data(
name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = self.rlnas.next_archs(1)[0]
current_tokens = self.rlnas.tokens
for arch in archs:
output = arch(inputs)
inputs = output
self.check_chnum_convnum(next_program, current_tokens[0])
### unittest for reward
self.assertTrue(self.rlnas.reward(float(1.0)), "reward is False")
### uniitest for tokens2arch
with fluid.program_guard(token2arch_program, startup_program):
inputs = fluid.data(
name='input', shape=[None, 3, 32, 32], dtype='float32')
arch = self.rlnas.tokens2arch(self.rlnas.tokens[0])
for arch in archs:
output = arch(inputs)
inputs = output
self.check_chnum_convnum(token2arch_program, self.rlnas.tokens[0])
def test_final_archs(self):
### unittest for final_archs
final_program = fluid.Program()
final_startup_program = fluid.Program()
with fluid.program_guard(final_program, final_startup_program):
inputs = fluid.data(
name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = self.rlnas.final_archs(1)[0]
current_tokens = self.rlnas.tokens
for arch in archs:
output = arch(inputs)
inputs = output
self.check_chnum_convnum(final_program, current_tokens[0])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册