sa_nas.py 4.6 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import socket
import logging
import numpy as np
import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController
from ..common import get_logger
from ..analysis import flops

from ..common import ControllerServer
from ..common import ControllerClient
W
wanghaoshuang 已提交
26
from .search_space import SearchSpaceFactory
W
wanghaoshuang 已提交
27 28 29 30 31 32 33 34 35

__all__ = ["SANAS"]

_logger = get_logger(__name__, level=logging.INFO)


class SANAS(object):
    def __init__(self,
                 configs,
W
wanghaoshuang 已提交
36
                 server_addr=("", 8881),
W
wanghaoshuang 已提交
37 38 39 40
                 init_temperature=100,
                 reduce_rate=0.85,
                 search_steps=300,
                 key="sa_nas",
W
wanghaoshuang 已提交
41
                 is_server=False):
W
wanghaoshuang 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54
        """
        Search a group of ratios used to prune program.
        Args:
            configs(list<tuple>): A list of search space configuration with format (key, input_size, output_size, block_num).
                                  `key` is the name of search space with data type str. `input_size` and `output_size`  are
                                   input size and output size of searched sub-network. `block_num` is the number of blocks in searched network.
            server_addr(tuple): A tuple of server ip and server port for controller server. 
            init_temperature(float): The init temperature used in simulated annealing search strategy.
            reduce_rate(float): The decay rate used in simulated annealing search strategy.
            search_steps(int): The steps of searching.
            key(str): Identity used in communication between controller server and clients.
            is_server(bool): Whether current host is controller server. Default: True.
        """
W
wanghaoshuang 已提交
55 56 57
        if not is_server:
            assert server_addr[
                0] != "", "You should set the IP and port of server when is_server is False."
W
wanghaoshuang 已提交
58 59 60 61 62 63 64 65 66 67 68
        self._reduce_rate = reduce_rate
        self._init_temperature = init_temperature
        self._is_server = is_server
        self._configs = configs

        server_ip, server_port = server_addr
        if server_ip == None or server_ip == "":
            server_ip = self._get_host_ip()

        # create controller server
        if self._is_server:
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
            factory = SearchSpaceFactory()
            self._search_space = factory.get_search_space(configs)
            init_tokens = self._search_space.init_tokens()
            range_table = self._search_space.range_table()
            range_table = (len(range_table) * [0], range_table)
            _logger.info("range table: {}".format(range_table))
            controller = SAController(
                range_table, self._reduce_rate, self._init_temperature,
                self._max_try_number, init_tokens, self._constrain_func)

            max_client_num = 100
            self._controller_server = ControllerServer(
                controller=controller,
                address=(server_ip, server_port),
                max_client_num=max_client_num,
                search_steps=search_steps,
                key=key)
W
wanghaoshuang 已提交
86 87 88
            self._controller_server.start()

        self._controller_client = ControllerClient(
89
            server_ip, server_port, key=key)
W
wanghaoshuang 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110

        self._iter = 0

    def _get_host_ip(self):
        return socket.gethostbyname(socket.gethostname())

    def next_archs(self):
        """
        Get next network architectures.
        Returns:
            list<function>: A list of functions that define networks.
        """
        self._current_tokens = self._controller_client.next_tokens()
        archs = self._search_space.token2arch(self._current_tokens)
        return archs

    def reward(self, score):
        """
        Return reward of current searched network.
        Args:
            score(float): The score of current searched network.
111 112
        Returns:
            bool: True means updating successfully while false means failure.
W
wanghaoshuang 已提交
113 114
        """
        self._iter += 1
W
wanghaoshuang 已提交
115 116
        return self._controller_client.update(self._current_tokens, score,
                                              self._iter)