client.py 7.2 KB
Newer Older
F
fuyw 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cloudpickle
import os
import threading
import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.remote import remote_constants


class Client(object):
    """Base class for the remote client.

    For each training task, there is a global client in the cluster which
    submits jobs to the master node. Different `@parl.remote_class` objects
    connect to the same global client in a training task.

    Attributes:
        submit_job_socket (zmq.Context.socket): A socket which submits job to
                                                the master node.
        pyfiles (bytes): A serialized dictionary containing the code of python
                         files in local working directory.

    """

    def __init__(self, master_address):
        """
        Args:
            master_addr (str): ip address of the master node.
        """
        self.ctx = zmq.Context()
        self.lock = threading.Lock()
        self.heartbeat_socket_initialized = threading.Event()
        self.master_is_alive = True
        self.client_is_alive = True

        self._create_sockets(master_address)
        self.pyfiles = self.read_local_files()

    def read_local_files(self):
        """Read local python code and store them in a dictionary, which will
        then be sent to the job.

        Returns:
            A cloudpickled dictionary containing the python code in current
            working directory.
        """
        pyfiles = dict()
        for file in os.listdir('./'):
            if file.endswith('.py'):
                with open(file, 'rb') as code_file:
                    code = code_file.read()
                    pyfiles[file] = code
        return cloudpickle.dumps(pyfiles)

    def _create_sockets(self, master_address):
        """ Each client has 1 sockets as start:

        (1) submit_job_socket: submits jobs to master node.
        """

        # submit_job_socket: submits job to master
        self.submit_job_socket = self.ctx.socket(zmq.REQ)
        self.submit_job_socket.linger = 0
        self.submit_job_socket.setsockopt(
            zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
        self.submit_job_socket.connect("tcp://{}".format(master_address))

        thread = threading.Thread(target=self._reply_heartbeat, daemon=True)
        thread.start()
        self.heartbeat_socket_initialized.wait()

        # check if the master is connected properly
        try:
            self.submit_job_socket.send_multipart([
                remote_constants.CLIENT_CONNECT_TAG,
                to_byte(self.heartbeat_master_address)
            ])
            _ = self.submit_job_socket.recv_multipart()
        except zmq.error.Again as e:
            logger.warning("[Client] Can not connect to the master, please "
                           "check if master is started and ensure the input "
                           "address {} is correct.".format(master_address))
            self.master_is_alive = False
            raise Exception("Client can not connect to the master, please "
                            "check if master is started and ensure the input "
                            "address {} is correct.".format(master_address))

    def _reply_heartbeat(self):
        """Reply heartbeat signals to the Master node."""

        socket = self.ctx.socket(zmq.REP)
        socket.linger = 0
        socket.setsockopt(zmq.RCVTIMEO,
                          remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
        heartbeat_master_port =\
            socket.bind_to_random_port(addr="tcp://*")
        self.heartbeat_master_address = "{}:{}".format(get_ip_address(),
                                                       heartbeat_master_port)
        self.heartbeat_socket_initialized.set()
        while self.client_is_alive and self.master_is_alive:
            try:
                message = socket.recv_multipart()
                socket.send_multipart([remote_constants.HEARTBEAT_TAG])

            except zmq.error.Again as e:
                logger.warning("[Client] Cannot connect to the master."
                               "Please check if it is still alive.")
                self.master_is_alive = False
        socket.close(0)
        logger.warning("Client exit replying heartbeat for master.")

    def submit_job(self):
        """Send a job to the Master node.

        When a `@parl.remote_class` object is created, the global client
        sends a job to the master node. Then the master node will allocate
        a vacant job from its job pool to the remote object. 

        Returns:
            IP address of the job.
        """
        if self.master_is_alive:

            # A lock to prevent multiple actor submit job at the same time.
            self.lock.acquire()
            self.submit_job_socket.send_multipart([
                remote_constants.CLIENT_SUBMIT_TAG,
                to_byte(self.heartbeat_master_address)
            ])
            message = self.submit_job_socket.recv_multipart()
            self.lock.release()

            tag = message[0]

            if tag == remote_constants.NORMAL_TAG:
                job_address = to_str(message[1])

            # no vacant CPU resources, can not submit a new job
            elif tag == remote_constants.CPU_TAG:
                job_address = None
            else:
                raise NotImplementedError
        else:
            raise Exception("Client can not submit job to the master, "
                            "please check if master is connected.")
        return job_address


GLOBAL_CLIENT = None


def connect(master_address):
    """Create a global client which connects to the master node.

    .. code-block:: python

        parl.connect(master_address='localhost:1234')

    Args:
        master_address (str): The address of the Master node to connect to.

    Raises:
        Exception: An exception is raised if the master node is not started.
    """

    assert len(master_address.split(":")) == 2, "please input address in " +\
        "{ip}:{port} format"
    global GLOBAL_CLIENT
    if GLOBAL_CLIENT is None:
        GLOBAL_CLIENT = Client(master_address)


def get_global_client():
    """Get the global client.

    Returns:
        The global client.
    """
    global GLOBAL_CLIENT
    assert GLOBAL_CLIENT is not None, "Cannot get the client to submit the" +\
        " job, have you connected to the cluster by calling " +\
        "parl.connect(master_ip, master_port)?"
    return GLOBAL_CLIENT


def disconnect():
    """Disconnect the global client from the master node."""
    global GLOBAL_CLIENT
    GLOBAL_CLIENT.client_is_alive = False
    GLOBAL_CLIENT = None