sa_nas.py 10.9 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

C
ceci3 已提交
15
import os
W
wanghaoshuang 已提交
16 17 18
import socket
import logging
import numpy as np
C
ceci3 已提交
19
import json
20
import hashlib
21
import time
W
wanghaoshuang 已提交
22 23 24 25 26 27 28
import paddle.fluid as fluid
from ..common import SAController
from ..common import get_logger
from ..analysis import flops

from ..common import ControllerServer
from ..common import ControllerClient
W
wanghaoshuang 已提交
29
from .search_space import SearchSpaceFactory
W
wanghaoshuang 已提交
30 31 32 33 34 35 36

__all__ = ["SANAS"]

_logger = get_logger(__name__, level=logging.INFO)


class SANAS(object):
C
ceci3 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
    """
    SANAS(Simulated Annealing Neural Architecture Search) is a neural architecture search algorithm 
    based on simulated annealing, used in discrete search task generally.

    Args:
        configs(list<tuple>): A list of search space configuration with format [(key, {input_size, 
                              output_size, block_num, block_mask})]. `key` is the name of search space 
                              with data type str. `input_size` and `output_size`  are input size and 
                              output size of searched sub-network. `block_num` is the number of blocks 
                              in searched network, `block_mask` is a list consists by 0 and 1, 0 means 
                              normal block, 1 means reduction block.
        server_addr(tuple): Server address, including ip and port of server. If ip is None or "", will 
                            use host ip if is_server = True. Default: ("", 8881).
        init_temperature(float): Initial temperature in SANAS. If init_temperature and init_tokens are None, 
                                 default initial temperature is 10.0, if init_temperature is None and 
                                 init_tokens is not None, default initial temperature is 1.0. The detail 
                                 configuration about the init_temperature please reference Note. Default: None.
        reduce_rate(float): Reduce rate in SANAS. The detail configuration about the reduce_rate please 
                            reference Note. Default: 0.85.
        search_steps(int): The steps of searching. Default: 300.
        init_tokens(list|None): Initial token. If init_tokens is None, SANAS will random generate initial 
                                tokens. Default: None.
        save_checkpoint(string|None): The directory of checkpoint to save, if set to None, not save checkpoint.
                                      Default: 'nas_checkpoint'.
        load_checkpoint(string|None): The directory of checkpoint to load, if set to None, not load checkpoint. 
                                      Default: None.
        is_server(bool): Whether current host is controller server. Default: True.

    .. note::
        - Why need to set initial temperature and reduce rate:

          - SA algorithm preserve a base token(initial token is the first base token, can be set by 
            yourself or random generate) and base score(initial score is -1), next token will be 
            generated based on base token. During the search, if the score which is obtained by the 
            model corresponding to the token is greater than the score which is saved in SA corresponding to 
            base token, current token saved as base token certainly; if score which is obtained by the model 
            corresponding to the token is less than the score which is saved in SA correspinding to base token, 
            current token saved as base token with a certain probability.
          - For initial temperature, higher is more unstable, it means that SA has a strong possibility to save 
            current token as base token if current score is smaller than base score saved in SA.
          - For initial temperature, lower is more stable, it means that SA has a small possibility to save 
            current token as base token if current score is smaller than base score saved in SA.
          - For reduce rate, higher means SA algorithm has slower convergence.
          - For reduce rate, lower means SA algorithm has faster convergence.

        - How to set initial temperature and reduce rate:

          - If there is a better initial token, and want to search based on this token, we suggest start search 
            experiment in the steady state of the SA algorithm, initial temperature can be set to a small value, 
            such as 1.0, and reduce rate can be set to a large value, such as 0.85. If you want to start search 
            experiment based on the better token with greedy algorithm, which only saved current token as base 
            token if current score higher than base score saved in SA algorithm, reduce rate can be set to a 
            extremely small value, such as 0.85 ** 10.

          - If initial token is generated randomly, it means initial token is a worse token, we suggest start 
            search experiment in the unstable state of the SA algorithm, explore all random tokens as much as 
            possible, and get a better token. Initial temperature can be set a higher value, such as 1000.0, 
            and reduce rate can be set to a small value.
    """

W
wanghaoshuang 已提交
97 98
    def __init__(self,
                 configs,
W
wanghaoshuang 已提交
99
                 server_addr=("", 8881),
100
                 init_temperature=None,
W
wanghaoshuang 已提交
101 102
                 reduce_rate=0.85,
                 search_steps=300,
103
                 init_tokens=None,
C
update  
ceci3 已提交
104
                 save_checkpoint='nas_checkpoint',
C
ceci3 已提交
105
                 load_checkpoint=None,
106
                 is_server=True):
W
wanghaoshuang 已提交
107 108 109
        if not is_server:
            assert server_addr[
                0] != "", "You should set the IP and port of server when is_server is False."
W
wanghaoshuang 已提交
110 111 112 113
        self._reduce_rate = reduce_rate
        self._init_temperature = init_temperature
        self._is_server = is_server
        self._configs = configs
114 115 116 117 118 119
        self._init_tokens = init_tokens
        self._client_name = hashlib.md5(
            str(time.time() + np.random.randint(1, 10000)).encode(
                "utf-8")).hexdigest()
        self._key = str(self._configs)
        self._current_tokens = init_tokens
W
wanghaoshuang 已提交
120 121 122 123 124

        server_ip, server_port = server_addr
        if server_ip == None or server_ip == "":
            server_ip = self._get_host_ip()

W
wanghaoshuang 已提交
125 126 127
        factory = SearchSpaceFactory()
        self._search_space = factory.get_search_space(configs)

W
wanghaoshuang 已提交
128 129
        # create controller server
        if self._is_server:
130
            init_tokens = self._search_space.init_tokens(self._init_tokens)
131 132 133
            range_table = self._search_space.range_table()
            range_table = (len(range_table) * [0], range_table)
            _logger.info("range table: {}".format(range_table))
C
ceci3 已提交
134 135

            if load_checkpoint != None:
C
fix bug  
ceci3 已提交
136 137 138 139 140 141 142 143
                assert os.path.exists(
                    load_checkpoint
                ) == True, 'load checkpoint file NOT EXIST!!! Please check the directory of checkpoint!!!'
                checkpoint_path = os.path.join(load_checkpoint,
                                               'sanas.checkpoints')
                with open(checkpoint_path, 'r') as f:
                    scene = json.load(f)
                preinit_tokens = scene['_tokens']
C
ceci3 已提交
144 145 146 147
                prereward = scene['_reward']
                premax_reward = scene['_max_reward']
                prebest_tokens = scene['_best_tokens']
                preiter = scene['_iter']
C
ceci3 已提交
148
                psearched = scene['_searched']
C
ceci3 已提交
149
            else:
C
fix  
ceci3 已提交
150
                preinit_tokens = init_tokens
C
ceci3 已提交
151 152
                prereward = -1
                premax_reward = -1
C
fix  
ceci3 已提交
153
                prebest_tokens = None
C
ceci3 已提交
154
                preiter = 0
C
ceci3 已提交
155
                psearched = None
C
fix bug  
ceci3 已提交
156

C
ceci3 已提交
157
            self._controller = SAController(
W
wanghaoshuang 已提交
158 159 160
                range_table,
                self._reduce_rate,
                self._init_temperature,
161
                max_try_times=50000,
C
ceci3 已提交
162
                init_tokens=preinit_tokens,
C
fix bug  
ceci3 已提交
163 164 165 166
                reward=prereward,
                max_reward=premax_reward,
                iters=preiter,
                best_tokens=prebest_tokens,
C
ceci3 已提交
167
                constrain_func=None,
C
ceci3 已提交
168
                checkpoints=save_checkpoint,
C
ceci3 已提交
169
                searched=psearched)
170 171 172

            max_client_num = 100
            self._controller_server = ControllerServer(
C
ceci3 已提交
173
                controller=self._controller,
174 175 176
                address=(server_ip, server_port),
                max_client_num=max_client_num,
                search_steps=search_steps,
177
                key=self._key)
W
wanghaoshuang 已提交
178
            self._controller_server.start()
179
            server_port = self._controller_server.port()
W
wanghaoshuang 已提交
180 181

        self._controller_client = ControllerClient(
182 183 184 185
            server_ip,
            server_port,
            key=self._key,
            client_name=self._client_name)
W
wanghaoshuang 已提交
186

C
update  
ceci3 已提交
187 188 189 190
        if is_server and load_checkpoint != None:
            self._iter = scene['_iter']
        else:
            self._iter = 0
W
wanghaoshuang 已提交
191 192

    def _get_host_ip(self):
C
ceci3 已提交
193 194 195 196
        if os.name == 'posix':
            return socket.gethostbyname('localhost')
        else:
            return socket.gethostbyname(socket.gethostname())
W
wanghaoshuang 已提交
197

198
    def tokens2arch(self, tokens):
199
        """
C
ceci3 已提交
200 201 202
        Convert tokens to model architectures.
        Args
            tokens<list>: A list of token. The length and range based on search space.:
203
        Returns:
C
ceci3 已提交
204
            list<function>: A model architecture instance according to tokens.
205
        """
C
ceci3 已提交
206
        return self._search_space.token2arch(tokens)
207

C
ceci3 已提交
208 209 210 211 212 213
    def current_info(self):
        """
        Get current information, including best tokens, best reward in all the search, and current token.
        Returns:
            dict<name, value>: a dictionary include best tokens, best reward and current reward.
        """
C
ceci3 已提交
214
        current_dict = self._controller_client.request_current_info()
C
ceci3 已提交
215 216
        return current_dict

W
wanghaoshuang 已提交
217 218
    def next_archs(self):
        """
C
ceci3 已提交
219
        Get next model architectures.
W
wanghaoshuang 已提交
220
        Returns:
C
ceci3 已提交
221
            list<function>: A list of instance of model architecture.
W
wanghaoshuang 已提交
222 223
        """
        self._current_tokens = self._controller_client.next_tokens()
224
        _logger.info("current tokens: {}".format(self._current_tokens))
W
wanghaoshuang 已提交
225 226 227 228 229 230 231
        archs = self._search_space.token2arch(self._current_tokens)
        return archs

    def reward(self, score):
        """
        Return reward of current searched network.
        Args:
C
ceci3 已提交
232
            score(float): The score of current searched network, bigger is better.
233 234
        Returns:
            bool: True means updating successfully while false means failure.
W
wanghaoshuang 已提交
235 236
        """
        self._iter += 1
W
wanghaoshuang 已提交
237 238
        return self._controller_client.update(self._current_tokens, score,
                                              self._iter)