sa_nas.py 7.7 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

C
ceci3 已提交
15
import os
W
wanghaoshuang 已提交
16 17 18
import socket
import logging
import numpy as np
C
ceci3 已提交
19
import json
20
import hashlib
21
import time
W
wanghaoshuang 已提交
22 23 24 25 26 27 28 29
import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController
from ..common import get_logger
from ..analysis import flops

from ..common import ControllerServer
from ..common import ControllerClient
W
wanghaoshuang 已提交
30
from .search_space import SearchSpaceFactory
W
wanghaoshuang 已提交
31 32 33 34 35 36 37 38 39

__all__ = ["SANAS"]

_logger = get_logger(__name__, level=logging.INFO)


class SANAS(object):
    def __init__(self,
                 configs,
W
wanghaoshuang 已提交
40
                 server_addr=("", 8881),
41
                 init_temperature=None,
W
wanghaoshuang 已提交
42 43
                 reduce_rate=0.85,
                 search_steps=300,
44
                 init_tokens=None,
C
update  
ceci3 已提交
45
                 save_checkpoint='nas_checkpoint',
C
ceci3 已提交
46
                 load_checkpoint=None,
47
                 is_server=True):
W
wanghaoshuang 已提交
48 49 50
        """
        Search a group of ratios used to prune program.
        Args:
C
update  
ceci3 已提交
51
            configs(list<tuple>): A list of search space configuration with format [(key, {input_size, output_size, block_num, block_mask})].
W
wanghaoshuang 已提交
52
                                  `key` is the name of search space with data type str. `input_size` and `output_size`  are
C
update  
ceci3 已提交
53
                                   input size and output size of searched sub-network. `block_num` is the number of blocks in searched network, `block_mask` is a list consists by 0 and 1, 0 means normal block, 1 means reduction block.
W
wanghaoshuang 已提交
54
            server_addr(tuple): A tuple of server ip and server port for controller server. 
55 56 57 58
            init_temperature(float|None): The init temperature used in simulated annealing search strategy. Default: None.
            reduce_rate(float): The decay rate used in simulated annealing search strategy. Default: None.
            search_steps(int): The steps of searching. Default: 300.
            init_token(list): Init tokens user can set by yourself. Default: None.
C
update  
ceci3 已提交
59 60
            save_checkpoint(string|None): The directory of checkpoint to save, if set to None, not save checkpoint. Default: 'nas_checkpoint'.
            load_checkpoint(string|None): The directory of checkpoint to load, if set to None, not load checkpoint. Default: None.
W
wanghaoshuang 已提交
61 62
            is_server(bool): Whether current host is controller server. Default: True.
        """
W
wanghaoshuang 已提交
63 64 65
        if not is_server:
            assert server_addr[
                0] != "", "You should set the IP and port of server when is_server is False."
W
wanghaoshuang 已提交
66 67 68 69
        self._reduce_rate = reduce_rate
        self._init_temperature = init_temperature
        self._is_server = is_server
        self._configs = configs
70 71 72 73 74 75
        self._init_tokens = init_tokens
        self._client_name = hashlib.md5(
            str(time.time() + np.random.randint(1, 10000)).encode(
                "utf-8")).hexdigest()
        self._key = str(self._configs)
        self._current_tokens = init_tokens
W
wanghaoshuang 已提交
76 77 78 79 80

        server_ip, server_port = server_addr
        if server_ip == None or server_ip == "":
            server_ip = self._get_host_ip()

W
wanghaoshuang 已提交
81 82 83
        factory = SearchSpaceFactory()
        self._search_space = factory.get_search_space(configs)

W
wanghaoshuang 已提交
84 85
        # create controller server
        if self._is_server:
86
            init_tokens = self._search_space.init_tokens(self._init_tokens)
87 88 89
            range_table = self._search_space.range_table()
            range_table = (len(range_table) * [0], range_table)
            _logger.info("range table: {}".format(range_table))
C
ceci3 已提交
90 91

            if load_checkpoint != None:
C
fix bug  
ceci3 已提交
92 93 94 95 96 97 98 99
                assert os.path.exists(
                    load_checkpoint
                ) == True, 'load checkpoint file NOT EXIST!!! Please check the directory of checkpoint!!!'
                checkpoint_path = os.path.join(load_checkpoint,
                                               'sanas.checkpoints')
                with open(checkpoint_path, 'r') as f:
                    scene = json.load(f)
                preinit_tokens = scene['_tokens']
C
ceci3 已提交
100 101 102 103
                prereward = scene['_reward']
                premax_reward = scene['_max_reward']
                prebest_tokens = scene['_best_tokens']
                preiter = scene['_iter']
C
ceci3 已提交
104
                psearched = scene['_searched']
C
ceci3 已提交
105
            else:
C
fix  
ceci3 已提交
106
                preinit_tokens = init_tokens
C
ceci3 已提交
107 108
                prereward = -1
                premax_reward = -1
C
fix  
ceci3 已提交
109
                prebest_tokens = None
C
ceci3 已提交
110
                preiter = 0
C
ceci3 已提交
111
                psearched = None
C
fix bug  
ceci3 已提交
112

C
ceci3 已提交
113
            self._controller = SAController(
W
wanghaoshuang 已提交
114 115 116
                range_table,
                self._reduce_rate,
                self._init_temperature,
C
ceci3 已提交
117
                max_try_times=500,
C
ceci3 已提交
118
                init_tokens=preinit_tokens,
C
fix bug  
ceci3 已提交
119 120 121 122
                reward=prereward,
                max_reward=premax_reward,
                iters=preiter,
                best_tokens=prebest_tokens,
C
ceci3 已提交
123
                constrain_func=None,
C
ceci3 已提交
124
                checkpoints=save_checkpoint,
C
ceci3 已提交
125
                searched=psearched)
126 127 128

            max_client_num = 100
            self._controller_server = ControllerServer(
C
ceci3 已提交
129
                controller=self._controller,
130 131 132
                address=(server_ip, server_port),
                max_client_num=max_client_num,
                search_steps=search_steps,
133
                key=self._key)
W
wanghaoshuang 已提交
134
            self._controller_server.start()
135
            server_port = self._controller_server.port()
W
wanghaoshuang 已提交
136 137

        self._controller_client = ControllerClient(
138 139 140 141
            server_ip,
            server_port,
            key=self._key,
            client_name=self._client_name)
W
wanghaoshuang 已提交
142

C
update  
ceci3 已提交
143 144 145 146
        if is_server and load_checkpoint != None:
            self._iter = scene['_iter']
        else:
            self._iter = 0
W
wanghaoshuang 已提交
147 148 149 150

    def _get_host_ip(self):
        return socket.gethostbyname(socket.gethostname())

151
    def tokens2arch(self, tokens):
152 153 154 155 156
        """
        Convert tokens to network architectures.
        Returns:
            list<function>: A list of functions that define networks.
        """
C
ceci3 已提交
157
        return self._search_space.token2arch(tokens)
158

C
ceci3 已提交
159 160 161 162 163 164 165 166 167 168 169 170
    def current_info(self):
        """
        Get current information, including best tokens, best reward in all the search, and current token.
        Returns:
            dict<name, value>: a dictionary include best tokens, best reward and current reward.
        """
        current_dict = dict()
        current_dict['best_tokens'] = self._controller.best_tokens
        current_dict['best_reward'] = self._controller.max_reward
        current_dict['current_tokens'] = self._controller.current_tokens
        return current_dict

W
wanghaoshuang 已提交
171 172 173 174 175 176 177
    def next_archs(self):
        """
        Get next network architectures.
        Returns:
            list<function>: A list of functions that define networks.
        """
        self._current_tokens = self._controller_client.next_tokens()
178
        _logger.info("current tokens: {}".format(self._current_tokens))
W
wanghaoshuang 已提交
179 180 181 182 183 184 185 186
        archs = self._search_space.token2arch(self._current_tokens)
        return archs

    def reward(self, score):
        """
        Return reward of current searched network.
        Args:
            score(float): The score of current searched network.
187 188
        Returns:
            bool: True means updating successfully while false means failure.
W
wanghaoshuang 已提交
189 190
        """
        self._iter += 1
W
wanghaoshuang 已提交
191 192
        return self._controller_client.update(self._current_tokens, score,
                                              self._iter)