client.py 5.9 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import signal
import zmq
import socket
import logging
import time
import threading
C
ceci3 已提交
22 23 24 25 26
import six
if six.PY2:
    import cPickle as pickle
else:
    import pickle
C
ceci3 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
from .log_helper import get_logger
from .RL_controller.utils import compute_grad, ConnectMessage

_logger = get_logger(__name__, level=logging.INFO)


class Client(object):
    def __init__(self, controller, address, client_name):
        self._controller = controller
        self._address = address
        self._ip = self._address[0]
        self._port = self._address[1]
        self._client_name = client_name
        self._params_dict = None
        self.init_wait = False
        self._connect_server()

    def _connect_server(self):
        self._ctx = zmq.Context()
        self._client_socket = self._ctx.socket(zmq.REQ)
        ### NOTE: change the method to exit client when server is dead if there are better solutions
        self._client_socket.setsockopt(zmq.RCVTIMEO,
                                       ConnectMessage.TIMEOUT * 1000)
        client_address = "{}:{}".format(self._ip, self._port)
        self._client_socket.connect("tcp://{}".format(client_address))
C
ceci3 已提交
52 53 54
        self._client_socket.send_multipart([
            pickle.dumps(ConnectMessage.INIT), pickle.dumps(self._client_name)
        ])
C
ceci3 已提交
55
        message = self._client_socket.recv_multipart()
C
ceci3 已提交
56
        if pickle.loads(message[0]) != ConnectMessage.INIT_DONE:
C
ceci3 已提交
57 58 59 60
            _logger.error("Client {} init failure, Please start it again".
                          format(self._client_name))
            pid = os.getpid()
            os.kill(pid, signal.SIGTERM)
C
ceci3 已提交
61 62 63
        _logger.info("Client {}: connect to server success!!!".format(
            self._client_name))
        _logger.debug("Client {}: connect to server {}".format(
C
ceci3 已提交
64 65 66 67 68 69
            self._client_name, client_address))

    def _connect_wait_socket(self, port):
        self._wait_socket = self._ctx.socket(zmq.REQ)
        wait_address = "{}:{}".format(self._ip, port)
        self._wait_socket.connect("tcp://{}".format(wait_address))
C
ceci3 已提交
70 71 72 73
        self._wait_socket.send_multipart([
            pickle.dumps(ConnectMessage.WAIT_PARAMS),
            pickle.dumps(self._client_name)
        ])
C
ceci3 已提交
74
        message = self._wait_socket.recv_multipart()
C
ceci3 已提交
75
        return pickle.loads(message[0])
C
ceci3 已提交
76 77 78 79

    def next_tokens(self, obs, is_inference=False):
        _logger.debug("Client: requests for weight {}".format(
            self._client_name))
C
ceci3 已提交
80 81 82 83
        self._client_socket.send_multipart([
            pickle.dumps(ConnectMessage.GET_WEIGHT),
            pickle.dumps(self._client_name)
        ])
C
ceci3 已提交
84 85 86 87 88 89 90
        try:
            message = self._client_socket.recv_multipart()
        except zmq.error.Again as e:
            _logger.error(
                "CANNOT recv params from server in next_archs, Please check whether the server is alive!!! {}".
                format(e))
            os._exit(0)
C
ceci3 已提交
91
        self._params_dict = pickle.loads(message[0])
C
ceci3 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104
        tokens = self._controller.next_tokens(
            obs, params_dict=self._params_dict, is_inference=is_inference)
        _logger.debug("Client: client_name is {}, current token is {}".format(
            self._client_name, tokens))
        return tokens

    def update(self, rewards, **kwargs):
        assert self._params_dict != None, "Please call next_token to get token first, then call update"
        current_params_dict = self._controller.update(
            rewards, self._params_dict, **kwargs)
        params_grad = compute_grad(self._params_dict, current_params_dict)
        _logger.debug("Client: update weight {}".format(self._client_name))
        self._client_socket.send_multipart([
C
ceci3 已提交
105 106
            pickle.dumps(ConnectMessage.UPDATE_WEIGHT),
            pickle.dumps(self._client_name), pickle.dumps(params_grad)
C
ceci3 已提交
107 108 109 110 111 112 113 114 115 116 117
        ])
        _logger.debug("Client: update done {}".format(self._client_name))

        try:
            message = self._client_socket.recv_multipart()
        except zmq.error.Again as e:
            _logger.error(
                "CANNOT recv params from server in rewards, Please check whether the server is alive!!! {}".
                format(e))
            os._exit(0)

C
ceci3 已提交
118
        if pickle.loads(message[0]) == ConnectMessage.WAIT:
C
ceci3 已提交
119 120
            _logger.debug("Client: self.init_wait: {}".format(self.init_wait))
            if not self.init_wait:
C
ceci3 已提交
121
                wait_port = pickle.loads(message[1])
C
ceci3 已提交
122 123 124
                wait_signal = self._connect_wait_socket(wait_port)
                self.init_wait = True
            else:
C
ceci3 已提交
125
                wait_signal = pickle.loads(message[0])
C
ceci3 已提交
126 127
            while wait_signal != ConnectMessage.OK:
                time.sleep(1)
C
ceci3 已提交
128 129 130 131
                self._wait_socket.send_multipart([
                    pickle.dumps(ConnectMessage.WAIT_PARAMS),
                    pickle.dumps(self._client_name)
                ])
C
ceci3 已提交
132
                wait_signal = self._wait_socket.recv_multipart()
C
ceci3 已提交
133
                wait_signal = pickle.loads(wait_signal[0])
C
ceci3 已提交
134 135 136
                _logger.debug("Client: {} {}".format(self._client_name,
                                                     wait_signal))

C
ceci3 已提交
137
        return pickle.loads(message[0])
C
ceci3 已提交
138 139 140

    def __del__(self):
        try:
C
ceci3 已提交
141 142 143 144
            self._client_socket.send_multipart([
                pickle.dumps(ConnectMessage.EXIT),
                pickle.dumps(self._client_name)
            ])
C
ceci3 已提交
145 146 147 148
            _ = self._client_socket.recv_multipart()
        except:
            pass
        self._client_socket.close()