sa_nas.py 4.8 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import socket
import logging
import numpy as np
18
import hashlib
W
wanghaoshuang 已提交
19 20 21 22 23 24 25 26
import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController
from ..common import get_logger
from ..analysis import flops

from ..common import ControllerServer
from ..common import ControllerClient
W
wanghaoshuang 已提交
27
from .search_space import SearchSpaceFactory
W
wanghaoshuang 已提交
28 29 30 31 32 33 34 35 36

__all__ = ["SANAS"]

_logger = get_logger(__name__, level=logging.INFO)


class SANAS(object):
    def __init__(self,
                 configs,
W
wanghaoshuang 已提交
37
                 server_addr=("", 8881),
W
wanghaoshuang 已提交
38 39 40 41
                 init_temperature=100,
                 reduce_rate=0.85,
                 search_steps=300,
                 key="sa_nas",
W
wanghaoshuang 已提交
42
                 is_server=False):
W
wanghaoshuang 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55
        """
        Search a group of ratios used to prune program.
        Args:
            configs(list<tuple>): A list of search space configuration with format (key, input_size, output_size, block_num).
                                  `key` is the name of search space with data type str. `input_size` and `output_size`  are
                                   input size and output size of searched sub-network. `block_num` is the number of blocks in searched network.
            server_addr(tuple): A tuple of server ip and server port for controller server. 
            init_temperature(float): The init temperature used in simulated annealing search strategy.
            reduce_rate(float): The decay rate used in simulated annealing search strategy.
            search_steps(int): The steps of searching.
            key(str): Identity used in communication between controller server and clients.
            is_server(bool): Whether current host is controller server. Default: True.
        """
W
wanghaoshuang 已提交
56 57 58
        if not is_server:
            assert server_addr[
                0] != "", "You should set the IP and port of server when is_server is False."
W
wanghaoshuang 已提交
59 60 61 62
        self._reduce_rate = reduce_rate
        self._init_temperature = init_temperature
        self._is_server = is_server
        self._configs = configs
63
        self._key = hashlib.md5(str(self._configs)).hexdigest()
W
wanghaoshuang 已提交
64 65 66 67 68 69 70

        server_ip, server_port = server_addr
        if server_ip == None or server_ip == "":
            server_ip = self._get_host_ip()

        # create controller server
        if self._is_server:
71 72 73 74 75 76 77
            factory = SearchSpaceFactory()
            self._search_space = factory.get_search_space(configs)
            init_tokens = self._search_space.init_tokens()
            range_table = self._search_space.range_table()
            range_table = (len(range_table) * [0], range_table)
            _logger.info("range table: {}".format(range_table))
            controller = SAController(
W
wanghaoshuang 已提交
78 79 80 81 82 83
                range_table,
                self._reduce_rate,
                self._init_temperature,
                max_try_times=None,
                init_tokens=init_tokens,
                constrain_func=None)
84 85 86 87 88 89 90

            max_client_num = 100
            self._controller_server = ControllerServer(
                controller=controller,
                address=(server_ip, server_port),
                max_client_num=max_client_num,
                search_steps=search_steps,
91
                key=self._key)
W
wanghaoshuang 已提交
92
            self._controller_server.start()
93
            server_port = self._controller_server.port()
W
wanghaoshuang 已提交
94 95

        self._controller_client = ControllerClient(
96
            server_ip, server_port, key=self._key)
W
wanghaoshuang 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117

        self._iter = 0

    def _get_host_ip(self):
        return socket.gethostbyname(socket.gethostname())

    def next_archs(self):
        """
        Get next network architectures.
        Returns:
            list<function>: A list of functions that define networks.
        """
        self._current_tokens = self._controller_client.next_tokens()
        archs = self._search_space.token2arch(self._current_tokens)
        return archs

    def reward(self, score):
        """
        Return reward of current searched network.
        Args:
            score(float): The score of current searched network.
118 119
        Returns:
            bool: True means updating successfully while false means failure.
W
wanghaoshuang 已提交
120 121
        """
        self._iter += 1
W
wanghaoshuang 已提交
122 123
        return self._controller_client.update(self._current_tokens, score,
                                              self._iter)