controller_server.py 6.5 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import logging
import socket
18
import time
W
wanghaoshuang 已提交
19 20
from .log_helper import get_logger
from threading import Thread
21
from .lock import lock, unlock
W
wanghaoshuang 已提交
22 23 24 25 26 27 28

__all__ = ['ControllerServer']

_logger = get_logger(__name__, level=logging.INFO)


class ControllerServer(object):
29 30 31 32 33 34 35 36
    """The controller wrapper with a socket server to handle the request of search agent.
    Args:
        controller(slim.searcher.Controller): The controller used to generate tokens.
        address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
                        which means setting ip automatically
        max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
        search_steps(int|None): The total steps of searching. None means never stopping. Default: None 
        key(str|None): Config information. Default: None.
W
wanghaoshuang 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
    """

    def __init__(self,
                 controller=None,
                 address=('', 0),
                 max_client_num=100,
                 search_steps=None,
                 key=None):
        """
        """
        self._controller = controller
        self._address = address
        self._max_client_num = max_client_num
        self._search_steps = search_steps
        self._closed = False
        self._port = address[1]
        self._ip = address[0]
        self._key = key
55 56 57
        self._client_num = 0
        self._client = dict()
        self._compare_time = 172800  ### 48 hours
W
wanghaoshuang 已提交
58 59 60 61 62 63 64 65 66 67

    def start(self):
        self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._socket_server.bind(self._address)
        self._socket_server.listen(self._max_client_num)
        self._port = self._socket_server.getsockname()[1]
        self._ip = self._socket_server.getsockname()[0]
        _logger.info("ControllerServer - listen on: [{}:{}]".format(
            self._ip, self._port))
        thread = Thread(target=self.run)
C
ceci3 已提交
68
        thread.setDaemon(True)
W
wanghaoshuang 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        thread.start()
        return str(thread)

    def close(self):
        """Close the server."""
        self._closed = True
        _logger.info("server closed!")

    def port(self):
        """Get the port."""
        return self._port

    def ip(self):
        """Get the ip."""
        return self._ip

    def run(self):
86 87
        """Start the server.
        """
W
wanghaoshuang 已提交
88 89 90 91 92 93 94
        _logger.info("Controller Server run...")
        try:
            while ((self._search_steps is None) or
                   (self._controller._iter <
                    (self._search_steps))) and not self._closed:
                conn, addr = self._socket_server.accept()
                message = conn.recv(1024).decode()
C
ceci3 已提交
95
                _logger.debug(message)
W
wanghaoshuang 已提交
96 97 98 99
                if message.strip("\n") == "next_tokens":
                    tokens = self._controller.next_tokens()
                    tokens = ",".join([str(token) for token in tokens])
                    conn.send(tokens.encode())
C
ceci3 已提交
100 101 102 103 104 105 106
                elif message.strip("\n") == "current_info":
                    current_info = dict()
                    current_info['best_tokens'] = self._controller.best_tokens
                    current_info['best_reward'] = self._controller.max_reward
                    current_info[
                        'current_tokens'] = self._controller.current_tokens
                    conn.send(str(current_info).encode())
W
wanghaoshuang 已提交
107 108 109 110
                else:
                    _logger.debug("recv message from {}: [{}]".format(addr,
                                                                      message))
                    messages = message.strip('\n').split("\t")
111
                    if (len(messages) < 5) or (messages[0] != self._key):
W
wanghaoshuang 已提交
112 113 114 115 116
                        _logger.debug("recv noise from {}: [{}]".format(
                            addr, message))
                        continue
                    tokens = messages[1]
                    reward = messages[2]
W
wanghaoshuang 已提交
117
                    iter = messages[3]
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
                    client_name = messages[4]

                    one_step_time = -1
                    if client_name in self._client.keys():
                        current_time = time.time() - self._client[client_name]
                        if current_time > one_step_time:
                            one_step_time = current_time
                            self._compare_time = 2 * one_step_time

                    if client_name not in self._client.keys():
                        self._client[client_name] = time.time()
                        self._client_num += 1

                    self._client[client_name] = time.time()

                    for key_client in self._client.keys():
                        ### if a client not request token in double train one tokens' time, we think this client was stoped.
                        if (time.time() - self._client[key_client]
                            ) > self._compare_time and len(self._client.keys(
                            )) > 1:
                            self._client.pop(key_client)
                            self._client_num -= 1
C
ceci3 已提交
140
                    _logger.debug(
141 142 143
                        "client: {}, client_num: {}, compare_time: {}".format(
                            self._client, self._client_num,
                            self._compare_time))
W
wanghaoshuang 已提交
144
                    tokens = [int(token) for token in tokens.split(",")]
145 146 147
                    self._controller.update(tokens,
                                            float(reward),
                                            int(iter), int(self._client_num))
148 149
                    response = "ok"
                    conn.send(response.encode())
W
wanghaoshuang 已提交
150 151 152
                    _logger.debug("send message to {}: [{}]".format(addr,
                                                                    tokens))
                conn.close()
C
ceci3 已提交
153
        except Exception as err:
W
wanghaoshuang 已提交
154
            _logger.error(err)
W
wanghaoshuang 已提交
155 156 157
        finally:
            self._socket_server.close()
            self.close()