diff --git a/demo/nas/darts_nas.py b/demo/nas/sanas_darts_space.py similarity index 100% rename from demo/nas/darts_nas.py rename to demo/nas/sanas_darts_space.py diff --git a/docs/zh_cn/api_cn/nas_api.rst b/docs/zh_cn/api_cn/nas_api.rst index 8cda6f23041b7f084fa362b3cfcd2034a7aec3e2..d906f20d70d572c6cef02046bdcf2fc061321f9e 100644 --- a/docs/zh_cn/api_cn/nas_api.rst +++ b/docs/zh_cn/api_cn/nas_api.rst @@ -179,7 +179,7 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习 - **server_addr(tuple)** - RLNAS中Controller的地址,包括server的ip地址和端口号,如果ip地址为None或者为""的话则默认使用本机ip。默认:("", 8881)。 - **is_server(bool)** - 当前实例是否要启动一个server。默认:True。 - **is_sync(bool)** - 是否使用同步模式更新Controller,该模式仅在多client下有差别。默认:False。 -- **save_controller(str|None)** - 保存Controller的checkpoint的文件目录,如果设置为None的话则不保存checkpoint。默认:None 。 +- **save_controller(str|None|False)** - 保存Controller的checkpoint的文件目录,如果设置为None的话则保存checkpoint到默认路径 ``./.rlnas_controller`` ,如果设置为False的话则不保存checkpoint。默认:None 。 - **load_controller(str|None)** - 加载Controller的checkpoint的文件目录,如果设置为None的话则不加载checkpoint。默认:None。 - **\*\*kwargs** - 附加的参数,由具体强化学习算法决定,`LSTM`和`DDPG`的附加参数请参考note。 diff --git a/docs/zh_cn/tutorials/darts_nas_turorial.ipynb b/docs/zh_cn/tutorials/sanas_darts_space.ipynb similarity index 100% rename from docs/zh_cn/tutorials/darts_nas_turorial.ipynb rename to docs/zh_cn/tutorials/sanas_darts_space.ipynb diff --git a/docs/zh_cn/tutorials/darts_nas_turorial.md b/docs/zh_cn/tutorials/sanas_darts_space.md similarity index 97% rename from docs/zh_cn/tutorials/darts_nas_turorial.md rename to docs/zh_cn/tutorials/sanas_darts_space.md index 2f99a6c1ca3fd4d8e63d25d3c83d1f6b19807f57..d934aa23b4e4801f3cc6f66d0d897c25e0d825da 100644 --- a/docs/zh_cn/tutorials/darts_nas_turorial.md +++ b/docs/zh_cn/tutorials/sanas_darts_space.md @@ -261,14 +261,14 @@ sa_nas.reward(float(valid_top1_list[-1] + valid_top1_list[-2]) / 2) ### 10. 利用demo下的脚本启动搜索 -搜索文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/nas/darts_nas.py),搜索过程中限制模型参数量为不大于3.77M。 +搜索文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/nas/sanas_darts_nas.py),搜索过程中限制模型参数量为不大于3.77M。 ```python cd demo/nas/ python darts_nas.py ``` ### 11. 利用demo下的脚本启动最终实验 -最终实验文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/nas/darts_nas.py),最终实验需要训练600epoch。以下示例输入token为`[5, 5, 0, 5, 5, 10, 7, 7, 5, 7, 7, 11, 10, 12, 10, 0, 5, 3, 10, 8]`。 +最终实验文件位于: [darts_sanas_demo](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/nas/sanas_darts_nas.py),最终实验需要训练600epoch。以下示例输入token为`[5, 5, 0, 5, 5, 10, 7, 7, 5, 7, 7, 11, 10, 12, 10, 0, 5, 3, 10, 8]`。 ```python cd demo/nas/ python darts_nas.py --token 5 5 0 5 5 10 7 7 5 7 7 11 10 12 10 0 5 3 10 8 --retain_epoch 600 diff --git a/paddleslim/common/client.py b/paddleslim/common/client.py index 3d2ea9a75baba4e19fadc91cc4541724a600a6b1..ead7de800c3b43fb648f2506cd7c848bb430414b 100644 --- a/paddleslim/common/client.py +++ b/paddleslim/common/client.py @@ -49,31 +49,38 @@ class Client(object): ConnectMessage.TIMEOUT * 1000) client_address = "{}:{}".format(self._ip, self._port) self._client_socket.connect("tcp://{}".format(client_address)) - self._client_socket.send_multipart( - [ConnectMessage.INIT, self._client_name]) + self._client_socket.send_multipart([ + pickle.dumps(ConnectMessage.INIT), pickle.dumps(self._client_name) + ]) message = self._client_socket.recv_multipart() - if message[0] != ConnectMessage.INIT_DONE: + if pickle.loads(message[0]) != ConnectMessage.INIT_DONE: _logger.error("Client {} init failure, Please start it again". format(self._client_name)) pid = os.getpid() os.kill(pid, signal.SIGTERM) - _logger.info("Client {}: connect to server {}".format( + _logger.info("Client {}: connect to server success!!!".format( + self._client_name)) + _logger.debug("Client {}: connect to server {}".format( self._client_name, client_address)) def _connect_wait_socket(self, port): self._wait_socket = self._ctx.socket(zmq.REQ) wait_address = "{}:{}".format(self._ip, port) self._wait_socket.connect("tcp://{}".format(wait_address)) - self._wait_socket.send_multipart( - [ConnectMessage.WAIT_PARAMS, self._client_name]) + self._wait_socket.send_multipart([ + pickle.dumps(ConnectMessage.WAIT_PARAMS), + pickle.dumps(self._client_name) + ]) message = self._wait_socket.recv_multipart() - return message[0] + return pickle.loads(message[0]) def next_tokens(self, obs, is_inference=False): _logger.debug("Client: requests for weight {}".format( self._client_name)) - self._client_socket.send_multipart( - [ConnectMessage.GET_WEIGHT, self._client_name]) + self._client_socket.send_multipart([ + pickle.dumps(ConnectMessage.GET_WEIGHT), + pickle.dumps(self._client_name) + ]) try: message = self._client_socket.recv_multipart() except zmq.error.Again as e: @@ -95,8 +102,8 @@ class Client(object): params_grad = compute_grad(self._params_dict, current_params_dict) _logger.debug("Client: update weight {}".format(self._client_name)) self._client_socket.send_multipart([ - ConnectMessage.UPDATE_WEIGHT, self._client_name, - pickle.dumps(params_grad) + pickle.dumps(ConnectMessage.UPDATE_WEIGHT), + pickle.dumps(self._client_name), pickle.dumps(params_grad) ]) _logger.debug("Client: update done {}".format(self._client_name)) @@ -108,29 +115,33 @@ class Client(object): format(e)) os._exit(0) - if message[0] == ConnectMessage.WAIT: + if pickle.loads(message[0]) == ConnectMessage.WAIT: _logger.debug("Client: self.init_wait: {}".format(self.init_wait)) if not self.init_wait: wait_port = pickle.loads(message[1]) wait_signal = self._connect_wait_socket(wait_port) self.init_wait = True else: - wait_signal = message[0] + wait_signal = pickle.loads(message[0]) while wait_signal != ConnectMessage.OK: time.sleep(1) - self._wait_socket.send_multipart( - [ConnectMessage.WAIT_PARAMS, self._client_name]) + self._wait_socket.send_multipart([ + pickle.dumps(ConnectMessage.WAIT_PARAMS), + pickle.dumps(self._client_name) + ]) wait_signal = self._wait_socket.recv_multipart() - wait_signal = wait_signal[0] + wait_signal = pickle.loads(wait_signal[0]) _logger.debug("Client: {} {}".format(self._client_name, wait_signal)) - return message[0] + return pickle.loads(message[0]) def __del__(self): try: - self._client_socket.send_multipart( - [ConnectMessage.EXIT, self._client_name]) + self._client_socket.send_multipart([ + pickle.dumps(ConnectMessage.EXIT), + pickle.dumps(self._client_name) + ]) _ = self._client_socket.recv_multipart() except: pass diff --git a/paddleslim/common/controller_server.py b/paddleslim/common/controller_server.py index 008639c0a6a372f4f5d13dffc7f433b919486547..5103dfed8f371938fb151d45d388b23fc9a9f01a 100644 --- a/paddleslim/common/controller_server.py +++ b/paddleslim/common/controller_server.py @@ -62,7 +62,8 @@ class ControllerServer(object): self._socket_server.listen(self._max_client_num) self._port = self._socket_server.getsockname()[1] self._ip = self._socket_server.getsockname()[0] - _logger.info("ControllerServer - listen on: [{}:{}]".format( + _logger.info("ControllerServer Start!!!") + _logger.debug("ControllerServer - listen on: [{}:{}]".format( self._ip, self._port)) thread = Thread(target=self.run) thread.setDaemon(True) diff --git a/paddleslim/common/server.py b/paddleslim/common/server.py index e072f6e549567c47bda345d2fbe2aa56553782fc..47264b3a0154d921b387bfac942e00d7fbf2a4fc 100644 --- a/paddleslim/common/server.py +++ b/paddleslim/common/server.py @@ -65,7 +65,8 @@ class Server(object): server_address = "{}:{}".format(self._ip, self._port) self._server_socket.bind("tcp://{}".format(server_address)) self._server_socket.linger = 0 - _logger.info("ControllerServer - listen on: [{}]".format( + _logger.info("ControllerServer Start!!!") + _logger.debug("ControllerServer - listen on: [{}]".format( server_address)) thread = threading.Thread(target=self.run, args=()) thread.setDaemon(True) @@ -100,19 +101,20 @@ class Server(object): try: while self._server_alive: message = self._wait_socket.recv_multipart() - cmd = message[0] - client_name = message[1] + cmd = pickle.loads(message[0]) + client_name = pickle.loads(message[1]) if cmd == ConnectMessage.WAIT_PARAMS: _logger.debug("Server: wait for params") self._lock.acquire() self._wait_socket.send_multipart([ - ConnectMessage.OK - if self._done else ConnectMessage.WAIT + pickle.dumps(ConnectMessage.OK) + if self._done else pickle.dumps(ConnectMessage.WAIT) ]) if self._done and client_name in self._client: self._client.remove(client_name) if len(self._client) == 0: - self.save_params() + if self._save_controller != False: + self.save_params() self._done = False self._lock.release() else: @@ -127,11 +129,11 @@ class Server(object): try: sum_params_dict = dict() message = self._server_socket.recv_multipart() - cmd = message[0] - client_name = message[1] + cmd = pickle.loads(message[0]) + client_name = pickle.loads(message[1]) if cmd == ConnectMessage.INIT: self._server_socket.send_multipart( - [ConnectMessage.INIT_DONE]) + [pickle.dumps(ConnectMessage.INIT_DONE)]) _logger.debug("Server: init client {}".format( client_name)) self._client_dict[client_name] = 0 @@ -161,7 +163,7 @@ class Server(object): self._done = True self._server_socket.send_multipart([ - ConnectMessage.WAIT, + pickle.dumps(ConnectMessage.WAIT), pickle.dumps(self._wait_port) ]) else: @@ -174,16 +176,17 @@ class Server(object): self._max_update_times = self._client_dict[ client_name] self._lock.release() - self.save_params() + if self._save_controller != False: + self.save_params() self._server_socket.send_multipart( - [ConnectMessage.OK]) + [pickle.dumps(ConnectMessage.OK)]) elif cmd == ConnectMessage.EXIT: self._client_dict.pop(client_name) if client_name in self._client: self._client.remove(client_name) self._server_socket.send_multipart( - [ConnectMessage.EXIT]) + [pickle.dumps(ConnectMessage.EXIT)]) except zmq.error.Again as e: _logger.error(e) self.close() diff --git a/paddleslim/nas/early_stop/median_stop/median_stop.py b/paddleslim/nas/early_stop/median_stop/median_stop.py index 42aa0dff0db00dd875d8f401962c681ba2b6ba9b..1a1cd7cd67846701dac1620a671afdd5dc4ddbae 100644 --- a/paddleslim/nas/early_stop/median_stop/median_stop.py +++ b/paddleslim/nas/early_stop/median_stop/median_stop.py @@ -60,14 +60,14 @@ class MedianStop(EarlyStopBase): 'get_completed_history', callable=return_completed_history) base_manager = BaseManager( address=(self._server_ip, self._server_port), - authkey=PublicAuthKey) + authkey=PublicAuthKey.encode()) base_manager.start() else: BaseManager.register('get_completed_history') base_manager = BaseManager( address=(self._server_ip, self._server_port), - authkey=PublicAuthKey) + authkey=PublicAuthKey.encode()) base_manager.connect() return base_manager diff --git a/paddleslim/nas/rl_nas.py b/paddleslim/nas/rl_nas.py index 5b84382da77f344ee24b4753219ede0e2f610ef1..93c64c5994783229ea32d247a38d6d30c6283495 100644 --- a/paddleslim/nas/rl_nas.py +++ b/paddleslim/nas/rl_nas.py @@ -126,6 +126,10 @@ class RLNAS(object): return archs + @property + def tokens(self): + return self._current_tokens + def reward(self, rewards, **kwargs): """ reward the score and to train controller @@ -143,6 +147,7 @@ class RLNAS(object): """ final_tokens = self._controller_client.next_tokens( batch_obs, is_inference=True) + self._current_tokens = final_tokens _logger.info("Final tokens: {}".format(final_tokens)) archs = [] for token in final_tokens: diff --git a/tests/test_earlystop.py b/tests/test_earlystop.py new file mode 100644 index 0000000000000000000000000000000000000000..6af4de6df5336dc9c3faa97cf48f73df00fc2d5d --- /dev/null +++ b/tests/test_earlystop.py @@ -0,0 +1,42 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import paddle +from paddleslim.nas import SANAS +from paddleslim.nas.early_stop import MedianStop +steps = 5 +epochs = 5 + + +class TestMedianStop(unittest.TestCase): + def test_median_stop(self): + config = [('MobileNetV2Space')] + sanas = SANAS(config, server_addr=("", 8732), save_checkpoint=None) + earlystop = MedianStop(sanas, 2) + avg_loss = 1.0 + for step in range(steps): + status = earlystop.get_status(step, avg_loss, epochs) + self.assertTrue(status, 'GOOD') + + avg_loss = 0.5 + for step in range(steps): + status = earlystop.get_status(step, avg_loss, epochs) + if step < 2: + self.assertTrue(status, 'GOOD') + else: + self.assertTrue(status, 'BAD') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_rl_nas.py b/tests/test_rl_nas.py new file mode 100644 index 0000000000000000000000000000000000000000..cd81eaf15a1f71cf040898da63b6a27fb0b10af8 --- /dev/null +++ b/tests/test_rl_nas.py @@ -0,0 +1,128 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import paddle.fluid as fluid +from paddleslim.nas import RLNAS +from paddleslim.analysis import flops +import numpy as np + + +def compute_op_num(program): + params = {} + ch_list = [] + for block in program.blocks: + for param in block.all_parameters(): + if len(param.shape) == 4: + params[param.name] = param.shape + ch_list.append(int(param.shape[0])) + return params, ch_list + + +class TestRLNAS(unittest.TestCase): + def setUp(self): + self.init_test_case() + port = np.random.randint(8337, 8773) + self.rlnas = RLNAS( + key='lstm', + configs=self.configs, + server_addr=("", port), + is_sync=False, + controller_batch_size=1, + lstm_num_layers=1, + hidden_size=10, + temperature=1.0, + save_controller=False) + + def init_test_case(self): + self.configs = [('MobileNetV2BlockSpace', {'block_mask': [0]})] + self.filter_num = np.array([ + 3, 4, 8, 12, 16, 24, 32, 48, 64, 80, 96, 128, 144, 160, 192, 224, + 256, 320, 384, 512 + ]) + self.k_size = np.array([3, 5]) + self.multiply = np.array([1, 2, 3, 4, 5, 6]) + self.repeat = np.array([1, 2, 3, 4, 5, 6]) + + def check_chnum_convnum(self, program, current_tokens): + channel_exp = self.multiply[current_tokens[0]] + filter_num = self.filter_num[current_tokens[1]] + repeat_num = self.repeat[current_tokens[2]] + + conv_list, ch_pro = compute_op_num(program) + ### assert conv number + self.assertTrue((repeat_num * 3) == len( + conv_list + ), "the number of conv is NOT match, the number compute from token: {}, actual conv number: {}". + format(repeat_num * 3, len(conv_list))) + + ### assert number of channels + ch_token = [] + init_ch_num = 32 + for i in range(repeat_num): + ch_token.append(init_ch_num * channel_exp) + ch_token.append(init_ch_num * channel_exp) + ch_token.append(filter_num) + init_ch_num = filter_num + + self.assertTrue( + str(ch_token) == str(ch_pro), + "channel num is WRONG, channel num from token is {}, channel num come fom program is {}". + format(str(ch_token), str(ch_pro))) + + def test_all_function(self): + ### unittest for next_archs + next_program = fluid.Program() + startup_program = fluid.Program() + token2arch_program = fluid.Program() + + with fluid.program_guard(next_program, startup_program): + inputs = fluid.data( + name='input', shape=[None, 3, 32, 32], dtype='float32') + archs = self.rlnas.next_archs(1)[0] + current_tokens = self.rlnas.tokens + for arch in archs: + output = arch(inputs) + inputs = output + self.check_chnum_convnum(next_program, current_tokens[0]) + + ### unittest for reward + self.assertTrue(self.rlnas.reward(float(1.0)), "reward is False") + + ### uniitest for tokens2arch + with fluid.program_guard(token2arch_program, startup_program): + inputs = fluid.data( + name='input', shape=[None, 3, 32, 32], dtype='float32') + arch = self.rlnas.tokens2arch(self.rlnas.tokens[0]) + for arch in archs: + output = arch(inputs) + inputs = output + self.check_chnum_convnum(token2arch_program, self.rlnas.tokens[0]) + + def test_final_archs(self): + ### unittest for final_archs + final_program = fluid.Program() + final_startup_program = fluid.Program() + with fluid.program_guard(final_program, final_startup_program): + inputs = fluid.data( + name='input', shape=[None, 3, 32, 32], dtype='float32') + archs = self.rlnas.final_archs(1)[0] + current_tokens = self.rlnas.tokens + for arch in archs: + output = arch(inputs) + inputs = output + self.check_chnum_convnum(final_program, current_tokens[0]) + + +if __name__ == '__main__': + unittest.main()