未验证 提交 50f3bd31 编写于 作者: B Bo Zhou 提交者: GitHub

Compatibility (#118)

* fix the vital issue on compatibility

* resolve the warning log

* yapf

* yapf
上级 a13dcce5
...@@ -18,6 +18,7 @@ import threading ...@@ -18,6 +18,7 @@ import threading
import zmq import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.remote import remote_constants from parl.remote import remote_constants
import time
class Client(object): class Client(object):
...@@ -78,7 +79,8 @@ class Client(object): ...@@ -78,7 +79,8 @@ class Client(object):
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
self.submit_job_socket.connect("tcp://{}".format(master_address)) self.submit_job_socket.connect("tcp://{}".format(master_address))
thread = threading.Thread(target=self._reply_heartbeat, daemon=True) thread = threading.Thread(target=self._reply_heartbeat)
thread.setDaemon(True)
thread.start() thread.start()
self.heartbeat_socket_initialized.wait() self.heartbeat_socket_initialized.wait()
...@@ -151,6 +153,8 @@ class Client(object): ...@@ -151,6 +153,8 @@ class Client(object):
# no vacant CPU resources, can not submit a new job # no vacant CPU resources, can not submit a new job
elif tag == remote_constants.CPU_TAG: elif tag == remote_constants.CPU_TAG:
job_address = None job_address = None
# wait 1 second to avoid requesting in a high frequency.
time.sleep(1)
else: else:
raise NotImplementedError raise NotImplementedError
else: else:
......
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['XPARL'] = 'True'
import argparse import argparse
import cloudpickle import cloudpickle
import pickle import pickle
...@@ -27,9 +30,6 @@ from parl.utils.communication import loads_argument, loads_return,\ ...@@ -27,9 +30,6 @@ from parl.utils.communication import loads_argument, loads_return,\
from parl.remote import remote_constants from parl.remote import remote_constants
from parl.utils.exceptions import SerializeError, DeserializeError from parl.utils.exceptions import SerializeError, DeserializeError
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
class Job(object): class Job(object):
"""Base class for the job. """Base class for the job.
...@@ -41,7 +41,6 @@ class Job(object): ...@@ -41,7 +41,6 @@ class Job(object):
def __init__(self, worker_address): def __init__(self, worker_address):
self.job_is_alive = True self.job_is_alive = True
self.heartbeat_socket_initialized = threading.Event()
self.worker_address = worker_address self.worker_address = worker_address
self._create_sockets() self._create_sockets()
...@@ -68,19 +67,9 @@ class Job(object): ...@@ -68,19 +67,9 @@ class Job(object):
reply_thread = threading.Thread( reply_thread = threading.Thread(
target=self._reply_heartbeat, target=self._reply_heartbeat,
args=("worker {}".format(self.worker_address), ), args=("worker {}".format(self.worker_address), ))
daemon=True) reply_thread.setDaemon(True)
reply_thread.start() reply_thread.start()
self.heartbeat_socket_initialized.wait()
# job_socket: sends job_address and heartbeat_address to worker
self.job_socket = self.ctx.socket(zmq.REQ)
self.job_socket.connect("tcp://{}".format(self.worker_address))
self.job_socket.send_multipart([
remote_constants.NORMAL_TAG,
to_byte(self.job_address),
to_byte(self.heartbeat_worker_address)
])
_ = self.job_socket.recv_multipart()
def _reply_heartbeat(self, target): def _reply_heartbeat(self, target):
"""reply heartbeat signals to the target""" """reply heartbeat signals to the target"""
...@@ -90,9 +79,20 @@ class Job(object): ...@@ -90,9 +79,20 @@ class Job(object):
remote_constants.HEARTBEAT_RCVTIMEO_S * 1000) remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
socket.linger = 0 socket.linger = 0
heartbeat_worker_port = socket.bind_to_random_port(addr="tcp://*") heartbeat_worker_port = socket.bind_to_random_port(addr="tcp://*")
self.heartbeat_worker_address = "{}:{}".format(self.job_ip, heartbeat_worker_address = "{}:{}".format(self.job_ip,
heartbeat_worker_port) heartbeat_worker_port)
self.heartbeat_socket_initialized.set()
# job_socket: sends job_address and heartbeat_address to worker
job_socket = self.ctx.socket(zmq.REQ)
job_socket.connect("tcp://{}".format(self.worker_address))
job_socket.send_multipart([
remote_constants.NORMAL_TAG,
to_byte(self.job_address),
to_byte(heartbeat_worker_address),
to_byte(str(os.getpid()))
])
_ = job_socket.recv_multipart()
# a flag to decide when to exit heartbeat loop # a flag to decide when to exit heartbeat loop
self.worker_is_alive = True self.worker_is_alive = True
while self.worker_is_alive and self.job_is_alive: while self.worker_is_alive and self.job_is_alive:
......
...@@ -201,8 +201,9 @@ class Master(object): ...@@ -201,8 +201,9 @@ class Master(object):
self.worker_pool[worker.address] = worker self.worker_pool[worker.address] = worker
self.worker_locks[worker.address] = threading.Lock() self.worker_locks[worker.address] = threading.Lock()
logger.info("A new worker {} is added, ".format(worker.address) + logger.info(
"cluster has {} CPUs.\n".format(len(self.job_pool))) "A new worker {} is added, ".format(worker.address) +
"the cluster has {} CPUs.\n".format(len(self.job_pool)))
# a thread for sending heartbeat signals to `worker.address` # a thread for sending heartbeat signals to `worker.address`
thread = threading.Thread( thread = threading.Thread(
...@@ -210,8 +211,8 @@ class Master(object): ...@@ -210,8 +211,8 @@ class Master(object):
args=( args=(
worker_heartbeat_address, worker_heartbeat_address,
worker.address, worker.address,
), ))
daemon=True) thread.setDaemon(True)
thread.start() thread.start()
self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
...@@ -224,8 +225,8 @@ class Master(object): ...@@ -224,8 +225,8 @@ class Master(object):
thread = threading.Thread( thread = threading.Thread(
target=self._create_client_monitor, target=self._create_client_monitor,
args=(client_heartbeat_address, ), args=(client_heartbeat_address, ))
daemon=True) thread.setDaemon(True)
thread.start() thread.start()
self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
......
...@@ -110,7 +110,6 @@ def remote_class(cls): ...@@ -110,7 +110,6 @@ def remote_class(cls):
_ = self.job_socket.recv_multipart() _ = self.job_socket.recv_multipart()
except zmq.error.Again as e: except zmq.error.Again as e:
logger.error("Job socket failed.") logger.error("Job socket failed.")
logger.info("[connect_job] job_address:{}".format(job_address))
def __del__(self): def __del__(self):
"""Delete the remote class object and release remote resources.""" """Delete the remote class object and release remote resources."""
...@@ -138,7 +137,6 @@ def remote_class(cls): ...@@ -138,7 +137,6 @@ def remote_class(cls):
logger.warning("No vacant cpu resources at present, " logger.warning("No vacant cpu resources at present, "
"will try {} times later.".format(cnt)) "will try {} times later.".format(cnt))
cnt -= 1 cnt -= 1
time.sleep(1)
return None return None
def __getattr__(self, attr): def __getattr__(self, attr):
......
...@@ -14,11 +14,13 @@ ...@@ -14,11 +14,13 @@
import click import click
import locale import locale
import sys
import os import os
import subprocess import subprocess
import threading import threading
import warnings import warnings
from multiprocessing import Process from multiprocessing import Process
from parl.utils import logger
# A flag to mark if parl is started from a command line # A flag to mark if parl is started from a command line
os.environ['XPARL'] = 'True' os.environ['XPARL'] = 'True'
...@@ -27,18 +29,22 @@ os.environ['XPARL'] = 'True' ...@@ -27,18 +29,22 @@ os.environ['XPARL'] = 'True'
# to use ASCII as encoding for the environment` error. # to use ASCII as encoding for the environment` error.
locale.setlocale(locale.LC_ALL, "en_US.UTF-8") locale.setlocale(locale.LC_ALL, "en_US.UTF-8")
warnings.simplefilter("ignore", ResourceWarning) #TODO: this line will cause error in python2/macOS
if sys.version_info.major == 3:
warnings.simplefilter("ignore", ResourceWarning)
def is_port_in_use(port): def is_port_available(port):
""" Check if a port is used. """ Check if a port is used.
True if the port is not available. Otherwise, this port can be used for True if the port is available for connection.
connection.
""" """
port = int(port)
import socket import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
return s.connect_ex(('localhost', int(port))) == 0 available = sock.connect_ex(('localhost', port))
sock.close()
return available
def is_master_started(address): def is_master_started(address):
...@@ -71,22 +77,24 @@ def cli(): ...@@ -71,22 +77,24 @@ def cli():
help="Set number of cpu manually. If not set, it will use all " help="Set number of cpu manually. If not set, it will use all "
"cpus of this machine.") "cpus of this machine.")
def start_master(port, cpu_num): def start_master(port, cpu_num):
if is_port_in_use(port): if not is_port_available(port):
raise Exception( raise Exception(
"The master address localhost:{} already in use.".format(port)) "The master address localhost:{} already in use.".format(port))
cpu_num = str(cpu_num) if cpu_num else '' cpu_num = str(cpu_num) if cpu_num else ''
command = [ start_file = __file__.replace('scripts.pyc', 'start.py')
"python", "{}/start.py".format(__file__[:-11]), "--name", "master", start_file = start_file.replace('scripts.py', 'start.py')
"--port", port command = ["python", start_file, "--name", "master", "--port", port]
]
p = subprocess.Popen(command) p = subprocess.Popen(command)
command = [ command = [
"python", "{}/start.py".format(__file__[:-11]), "--name", "worker", "python", start_file, "--name", "worker", "--address",
"--address", "localhost:" + str(port), "--cpu_num", "localhost:" + str(port), "--cpu_num",
str(cpu_num) str(cpu_num)
] ]
p = subprocess.Popen(command) # Redirect the output to DEVNULL to solve the warning log.
FNULL = open(os.devnull, 'w')
p = subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL.close()
@click.command("connect", short_help="Start a worker node.") @click.command("connect", short_help="Start a worker node.")
......
...@@ -15,15 +15,20 @@ ...@@ -15,15 +15,20 @@
import cloudpickle import cloudpickle
import multiprocessing import multiprocessing
import os import os
import signal
import subprocess import subprocess
import sys import sys
import time import time
import threading import threading
import warnings
import zmq import zmq
from parl.utils import get_ip_address, to_byte, to_str, logger from parl.utils import get_ip_address, to_byte, to_str, logger
from parl.remote import remote_constants from parl.remote import remote_constants
if sys.version_info.major == 3:
warnings.simplefilter("ignore", ResourceWarning)
class WorkerInfo(object): class WorkerInfo(object):
"""A WorkerInfo object records the computation resources of a worker. """A WorkerInfo object records the computation resources of a worker.
...@@ -138,7 +143,7 @@ class Worker(object): ...@@ -138,7 +143,7 @@ class Worker(object):
self.master_is_alive = False self.master_is_alive = False
return return
self._init_jobs() self._init_jobs(job_num=self.cpu_num)
self.request_master_socket.setsockopt( self.request_master_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
...@@ -146,8 +151,8 @@ class Worker(object): ...@@ -146,8 +151,8 @@ class Worker(object):
list(self.job_pid.keys())) list(self.job_pid.keys()))
reply_thread = threading.Thread( reply_thread = threading.Thread(
target=self._reply_heartbeat, target=self._reply_heartbeat,
args=("master {}".format(self.master_address), ), args=("master {}".format(self.master_address), ))
daemon=True) reply_thread.setDaemon(True)
reply_thread.start() reply_thread.start()
self.heartbeat_socket_initialized.wait() self.heartbeat_socket_initialized.wait()
...@@ -158,23 +163,34 @@ class Worker(object): ...@@ -158,23 +163,34 @@ class Worker(object):
]) ])
_ = self.request_master_socket.recv_multipart() _ = self.request_master_socket.recv_multipart()
def _init_job(self): def _init_jobs(self, job_num):
"""Create one job.""" """Create jobs.
Args:
job_num(int): the number of jobs to create.
"""
job_file = __file__.replace('worker.pyc', 'job.py')
job_file = job_file.replace('worker.py', 'job.py')
command = [ command = [
"python", "{}/job.py".format(__file__[:-10]), "--worker_address", "python", job_file, "--worker_address", self.reply_job_address
self.reply_job_address
] ]
with open(os.devnull, "w") as null: # Redirect the output to DEVNULL
pid = subprocess.Popen(command, stdout=null, stderr=null) FNULL = open(os.devnull, 'w')
for _ in range(job_num):
pid = subprocess.Popen(
command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL.close()
self.lock.acquire() new_job_address = []
for _ in range(job_num):
job_message = self.reply_job_socket.recv_multipart() job_message = self.reply_job_socket.recv_multipart()
self.reply_job_socket.send_multipart([remote_constants.NORMAL_TAG]) self.reply_job_socket.send_multipart([remote_constants.NORMAL_TAG])
job_address = to_str(job_message[1]) job_address = to_str(job_message[1])
new_job_address.append(job_address)
heartbeat_job_address = to_str(job_message[2]) heartbeat_job_address = to_str(job_message[2])
self.job_pid[job_address] = pid pid = to_str(job_message[3])
self.lock.release() self.job_pid[job_address] = int(pid)
# a thread for sending heartbeat signals to job # a thread for sending heartbeat signals to job
thread = threading.Thread( thread = threading.Thread(
...@@ -182,31 +198,31 @@ class Worker(object): ...@@ -182,31 +198,31 @@ class Worker(object):
args=( args=(
job_address, job_address,
heartbeat_job_address, heartbeat_job_address,
), ))
daemon=True) thread.setDaemon(True)
thread.start() thread.start()
return job_address assert len(new_job_address) > 0, "init jobs failed"
if len(new_job_address) > 1:
def _init_jobs(self): return new_job_address
"""Create cpu_num jobs when the worker is created.""" else:
job_threads = [] return new_job_address[0]
for _ in range(self.cpu_num):
t = threading.Thread(target=self._init_job, daemon=True)
t.start()
job_threads.append(t)
for th in job_threads:
th.join()
def _kill_job(self, job_address): def _kill_job(self, job_address):
"""kill problematic job process and update worker information""" """kill problematic job process and update worker information"""
if job_address in self.job_pid: if job_address in self.job_pid:
self.job_pid[job_address].kill() self.lock.acquire()
pid = self.job_pid[job_address]
try:
os.kill(pid, signal.SIGTERM)
except OSError:
logger.warn("job:{} has been killed before".format(pid))
self.job_pid.pop(job_address) self.job_pid.pop(job_address)
logger.warning("Worker kills job process {},".format(job_address)) logger.warning("Worker kills job process {},".format(job_address))
self.lock.release()
# When a old job is killed, the worker will create a new job. # When a old job is killed, the worker will create a new job.
if self.master_is_alive: if self.master_is_alive:
new_job_address = self._init_job() new_job_address = self._init_jobs(job_num=1)
self.lock.acquire() self.lock.acquire()
self.request_master_socket.send_multipart([ self.request_master_socket.send_multipart([
......
...@@ -18,7 +18,6 @@ import os ...@@ -18,7 +18,6 @@ import os
import os.path import os.path
import sys import sys
from termcolor import colored from termcolor import colored
import shutil
__all__ = ['set_dir', 'get_dir', 'set_level'] __all__ = ['set_dir', 'get_dir', 'set_level']
...@@ -86,16 +85,10 @@ def _getlogger(): ...@@ -86,16 +85,10 @@ def _getlogger():
def create_file_after_first_call(func_name): def create_file_after_first_call(func_name):
def call(*args, **kwargs): def call(*args, **kwargs):
global _logger global _logger
if LOG_DIR is None: if LOG_DIR is None and hasattr(mod, '__file__'):
basename = os.path.basename(mod.__file__) basename = os.path.basename(mod.__file__)
if basename.rfind('.') == -1: auto_dirname = os.path.join('log_dir',
basename = basename basename[:basename.rfind('.')])
else:
basename = basename[:basename.rfind('.')]
auto_dirname = os.path.join('log_dir', basename)
shutil.rmtree(auto_dirname, ignore_errors=True)
set_dir(auto_dirname) set_dir(auto_dirname)
func = getattr(_logger, func_name) func = getattr(_logger, func_name)
...@@ -165,4 +158,5 @@ def get_dir(): ...@@ -165,4 +158,5 @@ def get_dir():
# Will save log to log_dir/main_file_name/log.log by default # Will save log to log_dir/main_file_name/log.log by default
mod = sys.modules['__main__'] mod = sys.modules['__main__']
_logger.info("Argv: " + ' '.join(sys.argv)) if hasattr(mod, '__file__') and 'XPARL' not in os.environ:
_logger.info("Argv: " + ' '.join(sys.argv))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册