未验证 提交 50f3bd31 编写于 作者: B Bo Zhou 提交者: GitHub

Compatibility (#118)

* fix the vital issue on compatibility

* resolve the warning log

* yapf

* yapf
上级 a13dcce5
......@@ -18,6 +18,7 @@ import threading
import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.remote import remote_constants
import time
class Client(object):
......@@ -78,7 +79,8 @@ class Client(object):
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
self.submit_job_socket.connect("tcp://{}".format(master_address))
thread = threading.Thread(target=self._reply_heartbeat, daemon=True)
thread = threading.Thread(target=self._reply_heartbeat)
thread.setDaemon(True)
thread.start()
self.heartbeat_socket_initialized.wait()
......@@ -127,7 +129,7 @@ class Client(object):
When a `@parl.remote_class` object is created, the global client
sends a job to the master node. Then the master node will allocate
a vacant job from its job pool to the remote object.
a vacant job from its job pool to the remote object.
Returns:
IP address of the job.
......@@ -151,6 +153,8 @@ class Client(object):
# no vacant CPU resources, can not submit a new job
elif tag == remote_constants.CPU_TAG:
job_address = None
# wait 1 second to avoid requesting in a high frequency.
time.sleep(1)
else:
raise NotImplementedError
else:
......
......@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['XPARL'] = 'True'
import argparse
import cloudpickle
import pickle
......@@ -27,9 +30,6 @@ from parl.utils.communication import loads_argument, loads_return,\
from parl.remote import remote_constants
from parl.utils.exceptions import SerializeError, DeserializeError
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
class Job(object):
"""Base class for the job.
......@@ -41,7 +41,6 @@ class Job(object):
def __init__(self, worker_address):
self.job_is_alive = True
self.heartbeat_socket_initialized = threading.Event()
self.worker_address = worker_address
self._create_sockets()
......@@ -68,19 +67,9 @@ class Job(object):
reply_thread = threading.Thread(
target=self._reply_heartbeat,
args=("worker {}".format(self.worker_address), ),
daemon=True)
args=("worker {}".format(self.worker_address), ))
reply_thread.setDaemon(True)
reply_thread.start()
self.heartbeat_socket_initialized.wait()
# job_socket: sends job_address and heartbeat_address to worker
self.job_socket = self.ctx.socket(zmq.REQ)
self.job_socket.connect("tcp://{}".format(self.worker_address))
self.job_socket.send_multipart([
remote_constants.NORMAL_TAG,
to_byte(self.job_address),
to_byte(self.heartbeat_worker_address)
])
_ = self.job_socket.recv_multipart()
def _reply_heartbeat(self, target):
"""reply heartbeat signals to the target"""
......@@ -90,9 +79,20 @@ class Job(object):
remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
socket.linger = 0
heartbeat_worker_port = socket.bind_to_random_port(addr="tcp://*")
self.heartbeat_worker_address = "{}:{}".format(self.job_ip,
heartbeat_worker_port)
self.heartbeat_socket_initialized.set()
heartbeat_worker_address = "{}:{}".format(self.job_ip,
heartbeat_worker_port)
# job_socket: sends job_address and heartbeat_address to worker
job_socket = self.ctx.socket(zmq.REQ)
job_socket.connect("tcp://{}".format(self.worker_address))
job_socket.send_multipart([
remote_constants.NORMAL_TAG,
to_byte(self.job_address),
to_byte(heartbeat_worker_address),
to_byte(str(os.getpid()))
])
_ = job_socket.recv_multipart()
# a flag to decide when to exit heartbeat loop
self.worker_is_alive = True
while self.worker_is_alive and self.job_is_alive:
......
......@@ -41,7 +41,7 @@ class Master(object):
Attributes:
worker_pool (dict): A dict to store connected workers.
job_pool (list): A list to store the job address of vacant cpu, when
job_pool (list): A list to store the job address of vacant cpu, when
this number is 0, the master will refuse to create
new remote object.
client_job_dict (dict): A dict of list to record the job submitted by
......@@ -201,8 +201,9 @@ class Master(object):
self.worker_pool[worker.address] = worker
self.worker_locks[worker.address] = threading.Lock()
logger.info("A new worker {} is added, ".format(worker.address) +
"cluster has {} CPUs.\n".format(len(self.job_pool)))
logger.info(
"A new worker {} is added, ".format(worker.address) +
"the cluster has {} CPUs.\n".format(len(self.job_pool)))
# a thread for sending heartbeat signals to `worker.address`
thread = threading.Thread(
......@@ -210,8 +211,8 @@ class Master(object):
args=(
worker_heartbeat_address,
worker.address,
),
daemon=True)
))
thread.setDaemon(True)
thread.start()
self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
......@@ -224,8 +225,8 @@ class Master(object):
thread = threading.Thread(
target=self._create_client_monitor,
args=(client_heartbeat_address, ),
daemon=True)
args=(client_heartbeat_address, ))
thread.setDaemon(True)
thread.start()
self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
......
......@@ -110,7 +110,6 @@ def remote_class(cls):
_ = self.job_socket.recv_multipart()
except zmq.error.Again as e:
logger.error("Job socket failed.")
logger.info("[connect_job] job_address:{}".format(job_address))
def __del__(self):
"""Delete the remote class object and release remote resources."""
......@@ -138,7 +137,6 @@ def remote_class(cls):
logger.warning("No vacant cpu resources at present, "
"will try {} times later.".format(cnt))
cnt -= 1
time.sleep(1)
return None
def __getattr__(self, attr):
......
......@@ -14,11 +14,13 @@
import click
import locale
import sys
import os
import subprocess
import threading
import warnings
from multiprocessing import Process
from parl.utils import logger
# A flag to mark if parl is started from a command line
os.environ['XPARL'] = 'True'
......@@ -27,18 +29,22 @@ os.environ['XPARL'] = 'True'
# to use ASCII as encoding for the environment` error.
locale.setlocale(locale.LC_ALL, "en_US.UTF-8")
warnings.simplefilter("ignore", ResourceWarning)
#TODO: this line will cause error in python2/macOS
if sys.version_info.major == 3:
warnings.simplefilter("ignore", ResourceWarning)
def is_port_in_use(port):
def is_port_available(port):
""" Check if a port is used.
True if the port is not available. Otherwise, this port can be used for
connection.
True if the port is available for connection.
"""
port = int(port)
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', int(port))) == 0
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
available = sock.connect_ex(('localhost', port))
sock.close()
return available
def is_master_started(address):
......@@ -71,22 +77,24 @@ def cli():
help="Set number of cpu manually. If not set, it will use all "
"cpus of this machine.")
def start_master(port, cpu_num):
if is_port_in_use(port):
if not is_port_available(port):
raise Exception(
"The master address localhost:{} already in use.".format(port))
cpu_num = str(cpu_num) if cpu_num else ''
command = [
"python", "{}/start.py".format(__file__[:-11]), "--name", "master",
"--port", port
]
start_file = __file__.replace('scripts.pyc', 'start.py')
start_file = start_file.replace('scripts.py', 'start.py')
command = ["python", start_file, "--name", "master", "--port", port]
p = subprocess.Popen(command)
command = [
"python", "{}/start.py".format(__file__[:-11]), "--name", "worker",
"--address", "localhost:" + str(port), "--cpu_num",
"python", start_file, "--name", "worker", "--address",
"localhost:" + str(port), "--cpu_num",
str(cpu_num)
]
p = subprocess.Popen(command)
# Redirect the output to DEVNULL to solve the warning log.
FNULL = open(os.devnull, 'w')
p = subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL.close()
@click.command("connect", short_help="Start a worker node.")
......
......@@ -15,15 +15,20 @@
import cloudpickle
import multiprocessing
import os
import signal
import subprocess
import sys
import time
import threading
import warnings
import zmq
from parl.utils import get_ip_address, to_byte, to_str, logger
from parl.remote import remote_constants
if sys.version_info.major == 3:
warnings.simplefilter("ignore", ResourceWarning)
class WorkerInfo(object):
"""A WorkerInfo object records the computation resources of a worker.
......@@ -138,7 +143,7 @@ class Worker(object):
self.master_is_alive = False
return
self._init_jobs()
self._init_jobs(job_num=self.cpu_num)
self.request_master_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
......@@ -146,8 +151,8 @@ class Worker(object):
list(self.job_pid.keys()))
reply_thread = threading.Thread(
target=self._reply_heartbeat,
args=("master {}".format(self.master_address), ),
daemon=True)
args=("master {}".format(self.master_address), ))
reply_thread.setDaemon(True)
reply_thread.start()
self.heartbeat_socket_initialized.wait()
......@@ -158,55 +163,66 @@ class Worker(object):
])
_ = self.request_master_socket.recv_multipart()
def _init_job(self):
"""Create one job."""
def _init_jobs(self, job_num):
"""Create jobs.
Args:
job_num(int): the number of jobs to create.
"""
job_file = __file__.replace('worker.pyc', 'job.py')
job_file = job_file.replace('worker.py', 'job.py')
command = [
"python", "{}/job.py".format(__file__[:-10]), "--worker_address",
self.reply_job_address
"python", job_file, "--worker_address", self.reply_job_address
]
with open(os.devnull, "w") as null:
pid = subprocess.Popen(command, stdout=null, stderr=null)
self.lock.acquire()
job_message = self.reply_job_socket.recv_multipart()
self.reply_job_socket.send_multipart([remote_constants.NORMAL_TAG])
job_address = to_str(job_message[1])
heartbeat_job_address = to_str(job_message[2])
self.job_pid[job_address] = pid
self.lock.release()
# a thread for sending heartbeat signals to job
thread = threading.Thread(
target=self._create_job_monitor,
args=(
job_address,
heartbeat_job_address,
),
daemon=True)
thread.start()
return job_address
def _init_jobs(self):
"""Create cpu_num jobs when the worker is created."""
job_threads = []
for _ in range(self.cpu_num):
t = threading.Thread(target=self._init_job, daemon=True)
t.start()
job_threads.append(t)
for th in job_threads:
th.join()
# Redirect the output to DEVNULL
FNULL = open(os.devnull, 'w')
for _ in range(job_num):
pid = subprocess.Popen(
command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL.close()
new_job_address = []
for _ in range(job_num):
job_message = self.reply_job_socket.recv_multipart()
self.reply_job_socket.send_multipart([remote_constants.NORMAL_TAG])
job_address = to_str(job_message[1])
new_job_address.append(job_address)
heartbeat_job_address = to_str(job_message[2])
pid = to_str(job_message[3])
self.job_pid[job_address] = int(pid)
# a thread for sending heartbeat signals to job
thread = threading.Thread(
target=self._create_job_monitor,
args=(
job_address,
heartbeat_job_address,
))
thread.setDaemon(True)
thread.start()
assert len(new_job_address) > 0, "init jobs failed"
if len(new_job_address) > 1:
return new_job_address
else:
return new_job_address[0]
def _kill_job(self, job_address):
"""kill problematic job process and update worker information"""
if job_address in self.job_pid:
self.job_pid[job_address].kill()
self.lock.acquire()
pid = self.job_pid[job_address]
try:
os.kill(pid, signal.SIGTERM)
except OSError:
logger.warn("job:{} has been killed before".format(pid))
self.job_pid.pop(job_address)
logger.warning("Worker kills job process {},".format(job_address))
self.lock.release()
# When a old job is killed, the worker will create a new job.
if self.master_is_alive:
new_job_address = self._init_job()
new_job_address = self._init_jobs(job_num=1)
self.lock.acquire()
self.request_master_socket.send_multipart([
......
......@@ -18,7 +18,6 @@ import os
import os.path
import sys
from termcolor import colored
import shutil
__all__ = ['set_dir', 'get_dir', 'set_level']
......@@ -86,16 +85,10 @@ def _getlogger():
def create_file_after_first_call(func_name):
def call(*args, **kwargs):
global _logger
if LOG_DIR is None:
if LOG_DIR is None and hasattr(mod, '__file__'):
basename = os.path.basename(mod.__file__)
if basename.rfind('.') == -1:
basename = basename
else:
basename = basename[:basename.rfind('.')]
auto_dirname = os.path.join('log_dir', basename)
shutil.rmtree(auto_dirname, ignore_errors=True)
auto_dirname = os.path.join('log_dir',
basename[:basename.rfind('.')])
set_dir(auto_dirname)
func = getattr(_logger, func_name)
......@@ -165,4 +158,5 @@ def get_dir():
# Will save log to log_dir/main_file_name/log.log by default
mod = sys.modules['__main__']
_logger.info("Argv: " + ' '.join(sys.argv))
if hasattr(mod, '__file__') and 'XPARL' not in os.environ:
_logger.info("Argv: " + ' '.join(sys.argv))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册