From 50f3bd312379806c99623013d082d73c0d194fe8 Mon Sep 17 00:00:00 2001 From: Bo Zhou <2466956298@qq.com> Date: Sat, 3 Aug 2019 16:09:57 +0800 Subject: [PATCH] Compatibility (#118) * fix the vital issue on compatibility * resolve the warning log * yapf * yapf --- parl/remote/client.py | 8 ++- parl/remote/job.py | 38 ++++++------- parl/remote/master.py | 15 +++--- parl/remote/remote_decorator.py | 2 - parl/remote/scripts.py | 36 ++++++++----- parl/remote/worker.py | 96 +++++++++++++++++++-------------- parl/utils/logger.py | 16 ++---- 7 files changed, 116 insertions(+), 95 deletions(-) diff --git a/parl/remote/client.py b/parl/remote/client.py index fb28a2d..d80c0b7 100644 --- a/parl/remote/client.py +++ b/parl/remote/client.py @@ -18,6 +18,7 @@ import threading import zmq from parl.utils import to_str, to_byte, get_ip_address, logger from parl.remote import remote_constants +import time class Client(object): @@ -78,7 +79,8 @@ class Client(object): zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) self.submit_job_socket.connect("tcp://{}".format(master_address)) - thread = threading.Thread(target=self._reply_heartbeat, daemon=True) + thread = threading.Thread(target=self._reply_heartbeat) + thread.setDaemon(True) thread.start() self.heartbeat_socket_initialized.wait() @@ -127,7 +129,7 @@ class Client(object): When a `@parl.remote_class` object is created, the global client sends a job to the master node. Then the master node will allocate - a vacant job from its job pool to the remote object. + a vacant job from its job pool to the remote object. Returns: IP address of the job. @@ -151,6 +153,8 @@ class Client(object): # no vacant CPU resources, can not submit a new job elif tag == remote_constants.CPU_TAG: job_address = None + # wait 1 second to avoid requesting in a high frequency. + time.sleep(1) else: raise NotImplementedError else: diff --git a/parl/remote/job.py b/parl/remote/job.py index 1b05893..5096e17 100644 --- a/parl/remote/job.py +++ b/parl/remote/job.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '' +os.environ['XPARL'] = 'True' import argparse import cloudpickle import pickle @@ -27,9 +30,6 @@ from parl.utils.communication import loads_argument, loads_return,\ from parl.remote import remote_constants from parl.utils.exceptions import SerializeError, DeserializeError -import os -os.environ['CUDA_VISIBLE_DEVICES'] = '' - class Job(object): """Base class for the job. @@ -41,7 +41,6 @@ class Job(object): def __init__(self, worker_address): self.job_is_alive = True - self.heartbeat_socket_initialized = threading.Event() self.worker_address = worker_address self._create_sockets() @@ -68,19 +67,9 @@ class Job(object): reply_thread = threading.Thread( target=self._reply_heartbeat, - args=("worker {}".format(self.worker_address), ), - daemon=True) + args=("worker {}".format(self.worker_address), )) + reply_thread.setDaemon(True) reply_thread.start() - self.heartbeat_socket_initialized.wait() - # job_socket: sends job_address and heartbeat_address to worker - self.job_socket = self.ctx.socket(zmq.REQ) - self.job_socket.connect("tcp://{}".format(self.worker_address)) - self.job_socket.send_multipart([ - remote_constants.NORMAL_TAG, - to_byte(self.job_address), - to_byte(self.heartbeat_worker_address) - ]) - _ = self.job_socket.recv_multipart() def _reply_heartbeat(self, target): """reply heartbeat signals to the target""" @@ -90,9 +79,20 @@ class Job(object): remote_constants.HEARTBEAT_RCVTIMEO_S * 1000) socket.linger = 0 heartbeat_worker_port = socket.bind_to_random_port(addr="tcp://*") - self.heartbeat_worker_address = "{}:{}".format(self.job_ip, - heartbeat_worker_port) - self.heartbeat_socket_initialized.set() + heartbeat_worker_address = "{}:{}".format(self.job_ip, + heartbeat_worker_port) + + # job_socket: sends job_address and heartbeat_address to worker + job_socket = self.ctx.socket(zmq.REQ) + job_socket.connect("tcp://{}".format(self.worker_address)) + job_socket.send_multipart([ + remote_constants.NORMAL_TAG, + to_byte(self.job_address), + to_byte(heartbeat_worker_address), + to_byte(str(os.getpid())) + ]) + _ = job_socket.recv_multipart() + # a flag to decide when to exit heartbeat loop self.worker_is_alive = True while self.worker_is_alive and self.job_is_alive: diff --git a/parl/remote/master.py b/parl/remote/master.py index 42adc58..384e0aa 100644 --- a/parl/remote/master.py +++ b/parl/remote/master.py @@ -41,7 +41,7 @@ class Master(object): Attributes: worker_pool (dict): A dict to store connected workers. - job_pool (list): A list to store the job address of vacant cpu, when + job_pool (list): A list to store the job address of vacant cpu, when this number is 0, the master will refuse to create new remote object. client_job_dict (dict): A dict of list to record the job submitted by @@ -201,8 +201,9 @@ class Master(object): self.worker_pool[worker.address] = worker self.worker_locks[worker.address] = threading.Lock() - logger.info("A new worker {} is added, ".format(worker.address) + - "cluster has {} CPUs.\n".format(len(self.job_pool))) + logger.info( + "A new worker {} is added, ".format(worker.address) + + "the cluster has {} CPUs.\n".format(len(self.job_pool))) # a thread for sending heartbeat signals to `worker.address` thread = threading.Thread( @@ -210,8 +211,8 @@ class Master(object): args=( worker_heartbeat_address, worker.address, - ), - daemon=True) + )) + thread.setDaemon(True) thread.start() self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) @@ -224,8 +225,8 @@ class Master(object): thread = threading.Thread( target=self._create_client_monitor, - args=(client_heartbeat_address, ), - daemon=True) + args=(client_heartbeat_address, )) + thread.setDaemon(True) thread.start() self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) diff --git a/parl/remote/remote_decorator.py b/parl/remote/remote_decorator.py index 785c2d5..07f9330 100644 --- a/parl/remote/remote_decorator.py +++ b/parl/remote/remote_decorator.py @@ -110,7 +110,6 @@ def remote_class(cls): _ = self.job_socket.recv_multipart() except zmq.error.Again as e: logger.error("Job socket failed.") - logger.info("[connect_job] job_address:{}".format(job_address)) def __del__(self): """Delete the remote class object and release remote resources.""" @@ -138,7 +137,6 @@ def remote_class(cls): logger.warning("No vacant cpu resources at present, " "will try {} times later.".format(cnt)) cnt -= 1 - time.sleep(1) return None def __getattr__(self, attr): diff --git a/parl/remote/scripts.py b/parl/remote/scripts.py index b361d83..d348b4a 100644 --- a/parl/remote/scripts.py +++ b/parl/remote/scripts.py @@ -14,11 +14,13 @@ import click import locale +import sys import os import subprocess import threading import warnings from multiprocessing import Process +from parl.utils import logger # A flag to mark if parl is started from a command line os.environ['XPARL'] = 'True' @@ -27,18 +29,22 @@ os.environ['XPARL'] = 'True' # to use ASCII as encoding for the environment` error. locale.setlocale(locale.LC_ALL, "en_US.UTF-8") -warnings.simplefilter("ignore", ResourceWarning) +#TODO: this line will cause error in python2/macOS +if sys.version_info.major == 3: + warnings.simplefilter("ignore", ResourceWarning) -def is_port_in_use(port): +def is_port_available(port): """ Check if a port is used. - True if the port is not available. Otherwise, this port can be used for - connection. + True if the port is available for connection. """ + port = int(port) import socket - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', int(port))) == 0 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + available = sock.connect_ex(('localhost', port)) + sock.close() + return available def is_master_started(address): @@ -71,22 +77,24 @@ def cli(): help="Set number of cpu manually. If not set, it will use all " "cpus of this machine.") def start_master(port, cpu_num): - if is_port_in_use(port): + if not is_port_available(port): raise Exception( "The master address localhost:{} already in use.".format(port)) cpu_num = str(cpu_num) if cpu_num else '' - command = [ - "python", "{}/start.py".format(__file__[:-11]), "--name", "master", - "--port", port - ] + start_file = __file__.replace('scripts.pyc', 'start.py') + start_file = start_file.replace('scripts.py', 'start.py') + command = ["python", start_file, "--name", "master", "--port", port] p = subprocess.Popen(command) command = [ - "python", "{}/start.py".format(__file__[:-11]), "--name", "worker", - "--address", "localhost:" + str(port), "--cpu_num", + "python", start_file, "--name", "worker", "--address", + "localhost:" + str(port), "--cpu_num", str(cpu_num) ] - p = subprocess.Popen(command) + # Redirect the output to DEVNULL to solve the warning log. + FNULL = open(os.devnull, 'w') + p = subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT) + FNULL.close() @click.command("connect", short_help="Start a worker node.") diff --git a/parl/remote/worker.py b/parl/remote/worker.py index 2001db2..b60ec3a 100644 --- a/parl/remote/worker.py +++ b/parl/remote/worker.py @@ -15,15 +15,20 @@ import cloudpickle import multiprocessing import os +import signal import subprocess import sys import time import threading +import warnings import zmq from parl.utils import get_ip_address, to_byte, to_str, logger from parl.remote import remote_constants +if sys.version_info.major == 3: + warnings.simplefilter("ignore", ResourceWarning) + class WorkerInfo(object): """A WorkerInfo object records the computation resources of a worker. @@ -138,7 +143,7 @@ class Worker(object): self.master_is_alive = False return - self._init_jobs() + self._init_jobs(job_num=self.cpu_num) self.request_master_socket.setsockopt( zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) @@ -146,8 +151,8 @@ class Worker(object): list(self.job_pid.keys())) reply_thread = threading.Thread( target=self._reply_heartbeat, - args=("master {}".format(self.master_address), ), - daemon=True) + args=("master {}".format(self.master_address), )) + reply_thread.setDaemon(True) reply_thread.start() self.heartbeat_socket_initialized.wait() @@ -158,55 +163,66 @@ class Worker(object): ]) _ = self.request_master_socket.recv_multipart() - def _init_job(self): - """Create one job.""" + def _init_jobs(self, job_num): + """Create jobs. + + Args: + job_num(int): the number of jobs to create. + """ + job_file = __file__.replace('worker.pyc', 'job.py') + job_file = job_file.replace('worker.py', 'job.py') command = [ - "python", "{}/job.py".format(__file__[:-10]), "--worker_address", - self.reply_job_address + "python", job_file, "--worker_address", self.reply_job_address ] - with open(os.devnull, "w") as null: - pid = subprocess.Popen(command, stdout=null, stderr=null) - - self.lock.acquire() - job_message = self.reply_job_socket.recv_multipart() - self.reply_job_socket.send_multipart([remote_constants.NORMAL_TAG]) - job_address = to_str(job_message[1]) - heartbeat_job_address = to_str(job_message[2]) - self.job_pid[job_address] = pid - self.lock.release() - - # a thread for sending heartbeat signals to job - thread = threading.Thread( - target=self._create_job_monitor, - args=( - job_address, - heartbeat_job_address, - ), - daemon=True) - thread.start() - return job_address - - def _init_jobs(self): - """Create cpu_num jobs when the worker is created.""" - job_threads = [] - for _ in range(self.cpu_num): - t = threading.Thread(target=self._init_job, daemon=True) - t.start() - job_threads.append(t) - for th in job_threads: - th.join() + # Redirect the output to DEVNULL + FNULL = open(os.devnull, 'w') + for _ in range(job_num): + pid = subprocess.Popen( + command, stdout=FNULL, stderr=subprocess.STDOUT) + FNULL.close() + + new_job_address = [] + for _ in range(job_num): + job_message = self.reply_job_socket.recv_multipart() + self.reply_job_socket.send_multipart([remote_constants.NORMAL_TAG]) + job_address = to_str(job_message[1]) + new_job_address.append(job_address) + heartbeat_job_address = to_str(job_message[2]) + pid = to_str(job_message[3]) + self.job_pid[job_address] = int(pid) + + # a thread for sending heartbeat signals to job + thread = threading.Thread( + target=self._create_job_monitor, + args=( + job_address, + heartbeat_job_address, + )) + thread.setDaemon(True) + thread.start() + assert len(new_job_address) > 0, "init jobs failed" + if len(new_job_address) > 1: + return new_job_address + else: + return new_job_address[0] def _kill_job(self, job_address): """kill problematic job process and update worker information""" if job_address in self.job_pid: - self.job_pid[job_address].kill() + self.lock.acquire() + pid = self.job_pid[job_address] + try: + os.kill(pid, signal.SIGTERM) + except OSError: + logger.warn("job:{} has been killed before".format(pid)) self.job_pid.pop(job_address) logger.warning("Worker kills job process {},".format(job_address)) + self.lock.release() # When a old job is killed, the worker will create a new job. if self.master_is_alive: - new_job_address = self._init_job() + new_job_address = self._init_jobs(job_num=1) self.lock.acquire() self.request_master_socket.send_multipart([ diff --git a/parl/utils/logger.py b/parl/utils/logger.py index 3e84cef..2e837ae 100644 --- a/parl/utils/logger.py +++ b/parl/utils/logger.py @@ -18,7 +18,6 @@ import os import os.path import sys from termcolor import colored -import shutil __all__ = ['set_dir', 'get_dir', 'set_level'] @@ -86,16 +85,10 @@ def _getlogger(): def create_file_after_first_call(func_name): def call(*args, **kwargs): global _logger - if LOG_DIR is None: - + if LOG_DIR is None and hasattr(mod, '__file__'): basename = os.path.basename(mod.__file__) - if basename.rfind('.') == -1: - basename = basename - else: - basename = basename[:basename.rfind('.')] - auto_dirname = os.path.join('log_dir', basename) - - shutil.rmtree(auto_dirname, ignore_errors=True) + auto_dirname = os.path.join('log_dir', + basename[:basename.rfind('.')]) set_dir(auto_dirname) func = getattr(_logger, func_name) @@ -165,4 +158,5 @@ def get_dir(): # Will save log to log_dir/main_file_name/log.log by default mod = sys.modules['__main__'] -_logger.info("Argv: " + ' '.join(sys.argv)) +if hasattr(mod, '__file__') and 'XPARL' not in os.environ: + _logger.info("Argv: " + ' '.join(sys.argv)) -- GitLab