# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import cloudpickle import datetime import os import socket import sys import threading import zmq from parl.utils import to_str, to_byte, get_ip_address, logger from parl.remote import remote_constants import time class Client(object): """Base class for the remote client. For each training task, there is a global client in the cluster which submits jobs to the master node. Different `@parl.remote_class` objects connect to the same global client in a training task. Attributes: submit_job_socket (zmq.Context.socket): A socket which submits job to the master node. pyfiles (bytes): A serialized dictionary containing the code of python files in local working directory. executable_path (str): File path of the executable python script. start_time (time): A timestamp to record the start time of the program. """ def __init__(self, master_address): """ Args: master_addr (str): ip address of the master node. """ self.ctx = zmq.Context() self.lock = threading.Lock() self.heartbeat_socket_initialized = threading.Event() self.master_is_alive = True self.client_is_alive = True self.executable_path = self.get_executable_path() self.actor_num = 0 self._create_sockets(master_address) self.pyfiles = self.read_local_files() def get_executable_path(self): """Return current executable path.""" mod = sys.modules['__main__'] if hasattr(mod, '__file__'): executable_path = os.path.abspath(mod.__file__) else: executable_path = os.getcwd() executable_path = executable_path[:executable_path.rfind('/')] return executable_path def read_local_files(self): """Read local python code and store them in a dictionary, which will then be sent to the job. Returns: A cloudpickled dictionary containing the python code in current working directory. """ pyfiles = dict() for file in os.listdir('./'): if file.endswith('.py'): with open(file, 'rb') as code_file: code = code_file.read() pyfiles[file] = code return cloudpickle.dumps(pyfiles) def _create_sockets(self, master_address): """ Each client has 1 sockets as start: (1) submit_job_socket: submits jobs to master node. """ # submit_job_socket: submits job to master self.submit_job_socket = self.ctx.socket(zmq.REQ) self.submit_job_socket.linger = 0 self.submit_job_socket.setsockopt( zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) self.submit_job_socket.connect("tcp://{}".format(master_address)) self.start_time = time.time() thread = threading.Thread(target=self._reply_heartbeat) thread.setDaemon(True) thread.start() self.heartbeat_socket_initialized.wait() # check if the master is connected properly try: self.submit_job_socket.send_multipart([ remote_constants.CLIENT_CONNECT_TAG, to_byte(self.heartbeat_master_address), to_byte(socket.gethostname()) ]) _ = self.submit_job_socket.recv_multipart() except zmq.error.Again as e: logger.warning("[Client] Can not connect to the master, please " "check if master is started and ensure the input " "address {} is correct.".format(master_address)) self.master_is_alive = False raise Exception("Client can not connect to the master, please " "check if master is started and ensure the input " "address {} is correct.".format(master_address)) def _reply_heartbeat(self): """Reply heartbeat signals to the specific node.""" socket = self.ctx.socket(zmq.REP) socket.linger = 0 socket.setsockopt(zmq.RCVTIMEO, remote_constants.HEARTBEAT_RCVTIMEO_S * 1000) heartbeat_master_port =\ socket.bind_to_random_port(addr="tcp://*") self.heartbeat_master_address = "{}:{}".format(get_ip_address(), heartbeat_master_port) self.heartbeat_socket_initialized.set() while self.client_is_alive and self.master_is_alive: try: message = socket.recv_multipart() elapsed_time = datetime.timedelta( seconds=int(time.time() - self.start_time)) socket.send_multipart([ remote_constants.HEARTBEAT_TAG, to_byte(self.executable_path), to_byte(str(self.actor_num)), to_byte(str(elapsed_time)) ]) except zmq.error.Again as e: logger.warning("[Client] Cannot connect to the master." "Please check if it is still alive.") self.master_is_alive = False socket.close(0) logger.warning("Client exit replying heartbeat for master.") def _check_and_monitor_job(self, job_heartbeat_address, ping_heartbeat_address): """ Sometimes the client may receive a job that is dead, thus we have to check if this job is still alive before sending it to the actor. """ # job_heartbeat_socket: sends heartbeat signal to job job_heartbeat_socket = self.ctx.socket(zmq.REQ) job_heartbeat_socket.linger = 0 job_heartbeat_socket.setsockopt(zmq.RCVTIMEO, int(0.9 * 1000)) job_heartbeat_socket.connect("tcp://" + ping_heartbeat_address) try: job_heartbeat_socket.send_multipart( [remote_constants.HEARTBEAT_TAG]) job_heartbeat_socket.recv_multipart() except zmq.error.Again: job_heartbeat_socket.close(0) logger.error( "[Client] connects to a finished job, will try again, ping_heartbeat_address:{}" .format(ping_heartbeat_address)) return False job_heartbeat_socket.disconnect("tcp://" + ping_heartbeat_address) job_heartbeat_socket.connect("tcp://" + job_heartbeat_address) job_heartbeat_socket.setsockopt( zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) # a thread for sending heartbeat signals to job thread = threading.Thread( target=self._create_job_monitor, args=(job_heartbeat_socket, )) thread.setDaemon(True) thread.start() return True def _create_job_monitor(self, job_heartbeat_socket): """Send heartbeat signals to check target's status""" job_is_alive = True while job_is_alive and self.client_is_alive: try: job_heartbeat_socket.send_multipart( [remote_constants.HEARTBEAT_TAG]) _ = job_heartbeat_socket.recv_multipart() time.sleep(remote_constants.HEARTBEAT_INTERVAL_S) except zmq.error.Again as e: job_is_alive = False self.actor_num -= 1 except zmq.error.ZMQError as e: break job_heartbeat_socket.close(0) def submit_job(self): """Send a job to the Master node. When a `@parl.remote_class` object is created, the global client sends a job to the master node. Then the master node will allocate a vacant job from its job pool to the remote object. Returns: job_address(str): IP address of the job. None if there is no available CPU in the cluster. """ if self.master_is_alive: while True: # A lock to prevent multiple actors from submitting job at the same time. self.lock.acquire() self.submit_job_socket.send_multipart([ remote_constants.CLIENT_SUBMIT_TAG, to_byte(self.heartbeat_master_address) ]) message = self.submit_job_socket.recv_multipart() self.lock.release() tag = message[0] if tag == remote_constants.NORMAL_TAG: job_address = to_str(message[1]) job_heartbeat_address = to_str(message[2]) ping_heartbeat_address = to_str(message[3]) check_result = self._check_and_monitor_job( job_heartbeat_address, ping_heartbeat_address) if check_result: self.actor_num += 1 return job_address # no vacant CPU resources, cannot submit a new job elif tag == remote_constants.CPU_TAG: job_address = None # wait 1 second to avoid requesting in a high frequency. time.sleep(1) return job_address else: raise NotImplementedError else: raise Exception("Client can not submit job to the master, " "please check if master is connected.") return None GLOBAL_CLIENT = None def connect(master_address): """Create a global client which connects to the master node. .. code-block:: python parl.connect(master_address='localhost:1234') Args: master_address (str): The address of the Master node to connect to. Raises: Exception: An exception is raised if the master node is not started. """ assert len(master_address.split(":")) == 2, "please input address in " +\ "{ip}:{port} format" global GLOBAL_CLIENT if GLOBAL_CLIENT is None: GLOBAL_CLIENT = Client(master_address) def get_global_client(): """Get the global client. Returns: The global client. """ global GLOBAL_CLIENT assert GLOBAL_CLIENT is not None, "Cannot get the client to submit the" +\ " job, have you connected to the cluster by calling " +\ "parl.connect(master_ip, master_port)?" return GLOBAL_CLIENT def disconnect(): """Disconnect the global client from the master node.""" global GLOBAL_CLIENT if GLOBAL_CLIENT is not None: GLOBAL_CLIENT.client_is_alive = False GLOBAL_CLIENT = None else: logger.info( "No client to be released. Please make sure that you have call `parl.connect`" )