未验证 提交 7d1ec566 编写于 作者: C ceci3 提交者: GitHub

add early stop (#106)

上级 abb30ef2
early-stop
========
早停算法接口在实验中如何使用
MedianStop
------
.. py:class:: paddleslim.nas.early_stop.MedianStop(strategy, start_epoch, mode)
`源代码 <>`_
MedianStop是利用历史较好实验的中间结果来判断当前实验是否有运行完成的必要,如果当前实验在中间步骤的结果差于历史记录的实验列表中相同步骤的结果的中值,则代表当前实验是较差的实验,可以提前终止。参考 `Google Vizier: A Service for Black-Box Optimization <https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/46180.pdf>`_.
**参数:**
- **strategy<class instance>** - 搜索策略的实例,例如是SANAS的实例。
- **start_epoch<int>** - 起始epoch,代表从第几个epoch开始监控实验中间结果。
- **mode<str>** - 中间结果是越大越好还是越小越好,在'minimize'和'maxmize'之间选择。默认:'maxmize'。
**返回:**
一个MedianStop的实例
**示例代码:**
.. code-block:: python
from paddleslim.nas import SANAS
from paddleslim.nas.early_stop import MedianStop
config = [('MobileNetV2Space')]
sanas = SANAS(config, server_addr=("", 8732), save_checkpoint=None)
earlystop = MedianStop(sanas, start_epoch = 2)
.. py:method:: get_status(step, result, epochs):
获取当前实验当前result的状态。
**参数:**
- **step<int>** - 当前实验是当前client中的第几个实验。
- **result<float>** - 当前实验的中间步骤的result,可以为损失值,也可以为准确率等指标,只要和`mode`对应即可。
- **epochs<int>** - 在搜索过程中每个实验需要运行的总得epoch数量。
**返回:**
返回当前实验在当前epoch的状态,为`GOOD`或者`BAD`,如果为`BAD`,则代表当前实验可以早停。
**示例代码:**
.. code-block:: python
from paddleslim.nas import SANAS
from paddleslim.nas.early_stop import MedianStop
steps = 10
epochs = 7
config = [('MobileNetV2Space')]
sanas = SANAS(config, server_addr=("", 8732), save_checkpoint=None)
earlystop = MedianStop(sanas, 2)
for step in range(steps):
archs = sanas.next_archs()[0]
for epoch in range(epochs):
for data in train_reader():
loss = archs(data)
for data in test_reader():
loss = archs(data)
avg_cost = np.mean(loss)
status = earlystop.get_status(step, avg_cost, epochs)
if status == 'BAD':
break;
sanas.reward(avg_cost)
......@@ -113,7 +113,7 @@ class SAController(EvolutionaryController):
return self._current_tokens
def update(self, tokens, reward, iter, client_num):
def update(self, tokens, reward, iter, client_num=1):
"""
Update the controller according to latest tokens and reward.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .early_stop import EarlyStopBase
from .median_stop import MedianStop
__all__ = ['EarlyStopBase', 'MedianStop']
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['EarlyStopBase']
class EarlyStopBase(object):
""" Abstract early Stop algorithm.
"""
def get_status(self, iter, result):
"""Get experiment status.
"""
raise NotImplementedError(
'get_status in Early Stop algorithm NOT implemented.')
def client_end(self):
""" Stop a client, this function may useful for the client that result is better and better.
"""
raise NotImplementedError(
'client_end in Early Stop algorithm NOT implemented.')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .median_stop import MedianStop
__all__ = ['MedianStop']
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from multiprocessing.managers import BaseManager
from ..early_stop import EarlyStopBase
from ....common.log_helper import get_logger
PublicAuthKey = u'AbcXyz3'
__all__ = ['MedianStop']
_logger = get_logger(__name__, level=logging.INFO)
completed_history = dict()
def return_completed_history():
return completed_history
class MedianStop(EarlyStopBase):
"""
Median Stop, reference:
Args:
strategy<class instance>: the stategy of search.
start_epoch<int>: which step to start early stop algorithm.
mode<str>: bigger is better or smaller is better, chooice in ['maxmize', 'minimize']. Default: maxmize.
"""
def __init__(self, strategy, start_epoch, mode='maxmize'):
self._start_epoch = start_epoch
self._running_history = dict()
self._strategy = strategy
self._mode = mode
self._is_server = self._strategy._is_server
self._manager = self._start_manager()
assert self._mode in [
'maxmize', 'minimize'
], 'mode of MedianStop must be \'maxmize\' or \'minimize\', but received mode is {}'.format(
self._mode)
def _start_manager(self):
self._server_ip = self._strategy._server_ip
self._server_port = self._strategy._server_port + 1
if self._is_server:
BaseManager.register(
'get_completed_history', callable=return_completed_history)
base_manager = BaseManager(
address=(self._server_ip, self._server_port),
authkey=PublicAuthKey)
base_manager.start()
else:
BaseManager.register('get_completed_history')
base_manager = BaseManager(
address=(self._server_ip, self._server_port),
authkey=PublicAuthKey)
base_manager.connect()
return base_manager
def _update_data(self, exp_name, result):
if exp_name not in self._running_history.keys():
self._running_history[exp_name] = []
self._running_history[exp_name].append(result)
def _convert_running2completed(self, exp_name, status):
"""
Convert experiment record from running to complete.
Args:
exp_name<str>: the name of experiment.
status<str>: the status of this experiment.
"""
_logger.debug('the status of this experiment is {}'.format(status))
completed_avg_history = dict()
if exp_name in self._running_history:
if status == "GOOD":
count = 0
history_sum = 0
result = []
for res in self._running_history[exp_name]:
count += 1
history_sum += res
result.append(history_sum / count)
completed_avg_history[exp_name] = result
self._running_history.pop(exp_name)
if len(completed_avg_history) > 0:
while True:
try:
new_dict = self._manager.get_completed_history()
new_dict.update(completed_avg_history)
break
except Exception as err:
_logger.error("update data error: {}".format(err))
def get_status(self, step, result, epochs):
"""
Get current experiment status
Args:
step: step in this client.
result: the result of this epoch.
epochs: whole epochs.
Return:
the status of this experiment.
"""
exp_name = self._strategy._client_name + str(step)
self._update_data(exp_name, result)
_logger.debug("running history after update data: {}".format(
self._running_history))
curr_step = len(self._running_history[exp_name])
status = "GOOD"
if curr_step < self._start_epoch:
return status
res_same_step = []
def list2dict(lists):
res_dict = dict()
for l in lists:
tmp_dict = dict()
tmp_dict[l[0]] = l[1]
res_dict.update(tmp_dict)
return res_dict
while True:
try:
completed_avg_history = self._manager.get_completed_history()
break
except Exception as err:
_logger.error("get status error: {}".format(err))
if len(completed_avg_history.keys()) == 0:
for exp in self._running_history.keys():
if curr_step <= len(self._running_history[exp]):
res_same_step.append(self._running_history[exp][curr_step -
1])
else:
completed_avg_history_dict = list2dict(completed_avg_history.items(
))
for exp in completed_avg_history.keys():
if curr_step <= len(completed_avg_history_dict[exp]):
res_same_step.append(completed_avg_history_dict[exp][
curr_step - 1])
_logger.debug("result of same step in other experiment: {}".format(
res_same_step))
if res_same_step:
res_same_step.sort()
if self._mode == 'maxmize' and result < res_same_step[(
len(res_same_step) - 1) // 2]:
status = "BAD"
if self._mode == 'minimize' and result > res_same_step[len(
res_same_step) // 2]:
status = "BAD"
if curr_step == epochs:
self._convert_running2completed(exp_name, status)
return status
def __del__(self):
if self._is_server:
self._manager.shutdown()
......@@ -118,9 +118,9 @@ class SANAS(object):
self._key = str(self._configs)
self._current_tokens = init_tokens
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
server_ip = self._get_host_ip()
self._server_ip, self._server_port = server_addr
if self._server_ip == None or self._server_ip == "":
self._server_ip = self._get_host_ip()
factory = SearchSpaceFactory()
self._search_space = factory.get_search_space(configs)
......@@ -171,7 +171,7 @@ class SANAS(object):
max_client_num = 100
self._controller_server = ControllerServer(
controller=self._controller,
address=(server_ip, server_port),
address=(self._server_ip, self._server_port),
max_client_num=max_client_num,
search_steps=search_steps,
key=self._key)
......@@ -179,8 +179,8 @@ class SANAS(object):
server_port = self._controller_server.port()
self._controller_client = ControllerClient(
server_ip,
server_port,
self._server_ip,
self._server_port,
key=self._key,
client_name=self._client_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册