sa_nas.py 6.0 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

C
ceci3 已提交
15
import os
W
wanghaoshuang 已提交
16 17 18
import socket
import logging
import numpy as np
C
ceci3 已提交
19
import json
20
import hashlib
W
wanghaoshuang 已提交
21 22 23 24 25 26 27 28
import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController
from ..common import get_logger
from ..analysis import flops

from ..common import ControllerServer
from ..common import ControllerClient
W
wanghaoshuang 已提交
29
from .search_space import SearchSpaceFactory
W
wanghaoshuang 已提交
30 31 32 33 34 35 36 37 38

__all__ = ["SANAS"]

_logger = get_logger(__name__, level=logging.INFO)


class SANAS(object):
    def __init__(self,
                 configs,
W
wanghaoshuang 已提交
39
                 server_addr=("", 8881),
W
wanghaoshuang 已提交
40 41 42 43
                 init_temperature=100,
                 reduce_rate=0.85,
                 search_steps=300,
                 key="sa_nas",
C
ceci3 已提交
44 45
                 save_checkpoint=None,
                 load_checkpoint=None,
W
wanghaoshuang 已提交
46
                 is_server=False):
W
wanghaoshuang 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59
        """
        Search a group of ratios used to prune program.
        Args:
            configs(list<tuple>): A list of search space configuration with format (key, input_size, output_size, block_num).
                                  `key` is the name of search space with data type str. `input_size` and `output_size`  are
                                   input size and output size of searched sub-network. `block_num` is the number of blocks in searched network.
            server_addr(tuple): A tuple of server ip and server port for controller server. 
            init_temperature(float): The init temperature used in simulated annealing search strategy.
            reduce_rate(float): The decay rate used in simulated annealing search strategy.
            search_steps(int): The steps of searching.
            key(str): Identity used in communication between controller server and clients.
            is_server(bool): Whether current host is controller server. Default: True.
        """
W
wanghaoshuang 已提交
60 61 62
        if not is_server:
            assert server_addr[
                0] != "", "You should set the IP and port of server when is_server is False."
W
wanghaoshuang 已提交
63 64 65 66
        self._reduce_rate = reduce_rate
        self._init_temperature = init_temperature
        self._is_server = is_server
        self._configs = configs
67
        self._key = hashlib.md5(str(self._configs)).hexdigest()
W
wanghaoshuang 已提交
68 69 70 71 72

        server_ip, server_port = server_addr
        if server_ip == None or server_ip == "":
            server_ip = self._get_host_ip()

W
wanghaoshuang 已提交
73 74 75
        factory = SearchSpaceFactory()
        self._search_space = factory.get_search_space(configs)

W
wanghaoshuang 已提交
76 77
        # create controller server
        if self._is_server:
78 79 80 81
            init_tokens = self._search_space.init_tokens()
            range_table = self._search_space.range_table()
            range_table = (len(range_table) * [0], range_table)
            _logger.info("range table: {}".format(range_table))
C
ceci3 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98

            if load_checkpoint != None:
                assert os.path.exists(load_checkpoint) == True, 'load checkpoint file NOT EXIST!!! Please check the directory of checkpoint!!!'
                checkpoint_path = os.path.join(load_checkpoint, 'sanas.checkpoints')
                scene = json.load(checkpoint_path)
                preinit_tokens = scene['_init_tokens']
                prereward = scene['_reward']
                premax_reward = scene['_max_reward']
                prebest_tokens = scene['_best_tokens']
                preiter = scene['_iter']
            else:
                preinit_tokens = None
                prereward = -1
                premax_reward = -1
                prebest_tokens = init_tokens
                preiter = 0
                      
99
            controller = SAController(
W
wanghaoshuang 已提交
100 101 102 103
                range_table,
                self._reduce_rate,
                self._init_temperature,
                max_try_times=None,
C
ceci3 已提交
104 105 106 107 108 109 110
                init_tokens=preinit_tokens,
                reward = prereward,
                max_reward = premax_reward,
                iters = preiter,
                best_tokens = prebest_tokens,
                constrain_func=None,
                checkpoints=save_checkpoint)
111 112 113 114 115 116 117

            max_client_num = 100
            self._controller_server = ControllerServer(
                controller=controller,
                address=(server_ip, server_port),
                max_client_num=max_client_num,
                search_steps=search_steps,
118
                key=self._key)
W
wanghaoshuang 已提交
119
            self._controller_server.start()
120
            server_port = self._controller_server.port()
W
wanghaoshuang 已提交
121 122

        self._controller_client = ControllerClient(
123
            server_ip, server_port, key=self._key)
W
wanghaoshuang 已提交
124 125 126 127 128 129

        self._iter = 0

    def _get_host_ip(self):
        return socket.gethostbyname(socket.gethostname())

130
    def tokens2arch(self, tokens):
C
ceci3 已提交
131
        return self._search_space.token2arch(tokens)
132

W
wanghaoshuang 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    def next_archs(self):
        """
        Get next network architectures.
        Returns:
            list<function>: A list of functions that define networks.
        """
        self._current_tokens = self._controller_client.next_tokens()
        archs = self._search_space.token2arch(self._current_tokens)
        return archs

    def reward(self, score):
        """
        Return reward of current searched network.
        Args:
            score(float): The score of current searched network.
148 149
        Returns:
            bool: True means updating successfully while false means failure.
W
wanghaoshuang 已提交
150 151
        """
        self._iter += 1
W
wanghaoshuang 已提交
152 153
        return self._controller_client.update(self._current_tokens, score,
                                              self._iter)