From 7d1ec56646db34787a897d295809f20b2007430e Mon Sep 17 00:00:00 2001 From: ceci3 Date: Wed, 8 Apr 2020 13:35:54 +0800 Subject: [PATCH] add early stop (#106) --- docs/zh_cn/api_cn/early_stop.rst | 73 +++++++ paddleslim/common/sa_controller.py | 2 +- paddleslim/nas/early_stop/__init__.py | 18 ++ paddleslim/nas/early_stop/early_stop.py | 32 +++ .../nas/early_stop/median_stop/__init__.py | 17 ++ .../nas/early_stop/median_stop/median_stop.py | 184 ++++++++++++++++++ paddleslim/nas/sa_nas.py | 12 +- 7 files changed, 331 insertions(+), 7 deletions(-) create mode 100644 docs/zh_cn/api_cn/early_stop.rst create mode 100644 paddleslim/nas/early_stop/__init__.py create mode 100644 paddleslim/nas/early_stop/early_stop.py create mode 100644 paddleslim/nas/early_stop/median_stop/__init__.py create mode 100644 paddleslim/nas/early_stop/median_stop/median_stop.py diff --git a/docs/zh_cn/api_cn/early_stop.rst b/docs/zh_cn/api_cn/early_stop.rst new file mode 100644 index 00000000..57557bcf --- /dev/null +++ b/docs/zh_cn/api_cn/early_stop.rst @@ -0,0 +1,73 @@ +early-stop +======== +早停算法接口在实验中如何使用 + +MedianStop +------ + +.. py:class:: paddleslim.nas.early_stop.MedianStop(strategy, start_epoch, mode) + +`源代码 <>`_ + +MedianStop是利用历史较好实验的中间结果来判断当前实验是否有运行完成的必要,如果当前实验在中间步骤的结果差于历史记录的实验列表中相同步骤的结果的中值,则代表当前实验是较差的实验,可以提前终止。参考 `Google Vizier: A Service for Black-Box Optimization `_. + +**参数:** + +- **strategy** - 搜索策略的实例,例如是SANAS的实例。 +- **start_epoch** - 起始epoch,代表从第几个epoch开始监控实验中间结果。 +- **mode** - 中间结果是越大越好还是越小越好,在'minimize'和'maxmize'之间选择。默认:'maxmize'。 + +**返回:** +一个MedianStop的实例 + +**示例代码:** + +.. code-block:: python + + from paddleslim.nas import SANAS + from paddleslim.nas.early_stop import MedianStop + config = [('MobileNetV2Space')] + sanas = SANAS(config, server_addr=("", 8732), save_checkpoint=None) + earlystop = MedianStop(sanas, start_epoch = 2) + + .. py:method:: get_status(step, result, epochs): + + 获取当前实验当前result的状态。 + + **参数:** + - **step** - 当前实验是当前client中的第几个实验。 + - **result** - 当前实验的中间步骤的result,可以为损失值,也可以为准确率等指标,只要和`mode`对应即可。 + - **epochs** - 在搜索过程中每个实验需要运行的总得epoch数量。 + + **返回:** + 返回当前实验在当前epoch的状态,为`GOOD`或者`BAD`,如果为`BAD`,则代表当前实验可以早停。 + + **示例代码:** + + .. code-block:: python + from paddleslim.nas import SANAS + from paddleslim.nas.early_stop import MedianStop + + steps = 10 + epochs = 7 + + config = [('MobileNetV2Space')] + sanas = SANAS(config, server_addr=("", 8732), save_checkpoint=None) + earlystop = MedianStop(sanas, 2) + + for step in range(steps): + archs = sanas.next_archs()[0] + for epoch in range(epochs): + for data in train_reader(): + loss = archs(data) + + for data in test_reader(): + loss = archs(data) + avg_cost = np.mean(loss) + + status = earlystop.get_status(step, avg_cost, epochs) + if status == 'BAD': + break; + + sanas.reward(avg_cost) + diff --git a/paddleslim/common/sa_controller.py b/paddleslim/common/sa_controller.py index aa150b1b..3f6c1c98 100644 --- a/paddleslim/common/sa_controller.py +++ b/paddleslim/common/sa_controller.py @@ -113,7 +113,7 @@ class SAController(EvolutionaryController): return self._current_tokens - def update(self, tokens, reward, iter, client_num): + def update(self, tokens, reward, iter, client_num=1): """ Update the controller according to latest tokens and reward. diff --git a/paddleslim/nas/early_stop/__init__.py b/paddleslim/nas/early_stop/__init__.py new file mode 100644 index 00000000..748208aa --- /dev/null +++ b/paddleslim/nas/early_stop/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from .early_stop import EarlyStopBase +from .median_stop import MedianStop + +__all__ = ['EarlyStopBase', 'MedianStop'] diff --git a/paddleslim/nas/early_stop/early_stop.py b/paddleslim/nas/early_stop/early_stop.py new file mode 100644 index 00000000..6d05d8d1 --- /dev/null +++ b/paddleslim/nas/early_stop/early_stop.py @@ -0,0 +1,32 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['EarlyStopBase'] + + +class EarlyStopBase(object): + """ Abstract early Stop algorithm. + """ + + def get_status(self, iter, result): + """Get experiment status. + """ + raise NotImplementedError( + 'get_status in Early Stop algorithm NOT implemented.') + + def client_end(self): + """ Stop a client, this function may useful for the client that result is better and better. + """ + raise NotImplementedError( + 'client_end in Early Stop algorithm NOT implemented.') diff --git a/paddleslim/nas/early_stop/median_stop/__init__.py b/paddleslim/nas/early_stop/median_stop/__init__.py new file mode 100644 index 00000000..290ea074 --- /dev/null +++ b/paddleslim/nas/early_stop/median_stop/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from .median_stop import MedianStop + +__all__ = ['MedianStop'] diff --git a/paddleslim/nas/early_stop/median_stop/median_stop.py b/paddleslim/nas/early_stop/median_stop/median_stop.py new file mode 100644 index 00000000..42aa0dff --- /dev/null +++ b/paddleslim/nas/early_stop/median_stop/median_stop.py @@ -0,0 +1,184 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from multiprocessing.managers import BaseManager +from ..early_stop import EarlyStopBase +from ....common.log_helper import get_logger + +PublicAuthKey = u'AbcXyz3' + +__all__ = ['MedianStop'] + +_logger = get_logger(__name__, level=logging.INFO) + +completed_history = dict() + + +def return_completed_history(): + return completed_history + + +class MedianStop(EarlyStopBase): + """ + Median Stop, reference: + Args: + strategy: the stategy of search. + start_epoch: which step to start early stop algorithm. + mode: bigger is better or smaller is better, chooice in ['maxmize', 'minimize']. Default: maxmize. + """ + + def __init__(self, strategy, start_epoch, mode='maxmize'): + self._start_epoch = start_epoch + self._running_history = dict() + self._strategy = strategy + self._mode = mode + self._is_server = self._strategy._is_server + self._manager = self._start_manager() + assert self._mode in [ + 'maxmize', 'minimize' + ], 'mode of MedianStop must be \'maxmize\' or \'minimize\', but received mode is {}'.format( + self._mode) + + def _start_manager(self): + self._server_ip = self._strategy._server_ip + self._server_port = self._strategy._server_port + 1 + + if self._is_server: + BaseManager.register( + 'get_completed_history', callable=return_completed_history) + base_manager = BaseManager( + address=(self._server_ip, self._server_port), + authkey=PublicAuthKey) + + base_manager.start() + else: + BaseManager.register('get_completed_history') + base_manager = BaseManager( + address=(self._server_ip, self._server_port), + authkey=PublicAuthKey) + base_manager.connect() + return base_manager + + def _update_data(self, exp_name, result): + if exp_name not in self._running_history.keys(): + self._running_history[exp_name] = [] + self._running_history[exp_name].append(result) + + def _convert_running2completed(self, exp_name, status): + """ + Convert experiment record from running to complete. + + Args: + exp_name: the name of experiment. + status: the status of this experiment. + """ + _logger.debug('the status of this experiment is {}'.format(status)) + completed_avg_history = dict() + if exp_name in self._running_history: + if status == "GOOD": + count = 0 + history_sum = 0 + result = [] + for res in self._running_history[exp_name]: + count += 1 + history_sum += res + result.append(history_sum / count) + completed_avg_history[exp_name] = result + self._running_history.pop(exp_name) + + if len(completed_avg_history) > 0: + while True: + try: + new_dict = self._manager.get_completed_history() + new_dict.update(completed_avg_history) + break + except Exception as err: + _logger.error("update data error: {}".format(err)) + + def get_status(self, step, result, epochs): + """ + Get current experiment status + + Args: + step: step in this client. + result: the result of this epoch. + epochs: whole epochs. + + Return: + the status of this experiment. + """ + exp_name = self._strategy._client_name + str(step) + self._update_data(exp_name, result) + + _logger.debug("running history after update data: {}".format( + self._running_history)) + + curr_step = len(self._running_history[exp_name]) + status = "GOOD" + if curr_step < self._start_epoch: + return status + + res_same_step = [] + + def list2dict(lists): + res_dict = dict() + for l in lists: + tmp_dict = dict() + tmp_dict[l[0]] = l[1] + res_dict.update(tmp_dict) + return res_dict + + while True: + try: + completed_avg_history = self._manager.get_completed_history() + break + except Exception as err: + _logger.error("get status error: {}".format(err)) + + if len(completed_avg_history.keys()) == 0: + for exp in self._running_history.keys(): + if curr_step <= len(self._running_history[exp]): + res_same_step.append(self._running_history[exp][curr_step - + 1]) + else: + completed_avg_history_dict = list2dict(completed_avg_history.items( + )) + + for exp in completed_avg_history.keys(): + if curr_step <= len(completed_avg_history_dict[exp]): + res_same_step.append(completed_avg_history_dict[exp][ + curr_step - 1]) + + _logger.debug("result of same step in other experiment: {}".format( + res_same_step)) + if res_same_step: + res_same_step.sort() + + if self._mode == 'maxmize' and result < res_same_step[( + len(res_same_step) - 1) // 2]: + status = "BAD" + + if self._mode == 'minimize' and result > res_same_step[len( + res_same_step) // 2]: + status = "BAD" + + if curr_step == epochs: + self._convert_running2completed(exp_name, status) + + return status + + def __del__(self): + if self._is_server: + self._manager.shutdown() diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py index 5f6ae5b4..60472cd6 100644 --- a/paddleslim/nas/sa_nas.py +++ b/paddleslim/nas/sa_nas.py @@ -118,9 +118,9 @@ class SANAS(object): self._key = str(self._configs) self._current_tokens = init_tokens - server_ip, server_port = server_addr - if server_ip == None or server_ip == "": - server_ip = self._get_host_ip() + self._server_ip, self._server_port = server_addr + if self._server_ip == None or self._server_ip == "": + self._server_ip = self._get_host_ip() factory = SearchSpaceFactory() self._search_space = factory.get_search_space(configs) @@ -171,7 +171,7 @@ class SANAS(object): max_client_num = 100 self._controller_server = ControllerServer( controller=self._controller, - address=(server_ip, server_port), + address=(self._server_ip, self._server_port), max_client_num=max_client_num, search_steps=search_steps, key=self._key) @@ -179,8 +179,8 @@ class SANAS(object): server_port = self._controller_server.port() self._controller_client = ControllerClient( - server_ip, - server_port, + self._server_ip, + self._server_port, key=self._key, client_name=self._client_name) -- GitLab