提交 dcb16294 编写于 作者: F fuyw 提交者: Bo Zhou

Fuyw (#131)

* Add parl monitor.

* Add clustermonitor.py

* Fix thread safety bug.

* Fix js bugs.

* Fix End of Files.

* fix cluster_test bug

* yapf error

* fix bugs in tests/job_center_test.py

* fix bugs in tests/job_center_test.py

* add test_cluster_monitor.py

* add worker.exit in test_cluster_monitor.py
上级 b29c45a6
include parl/remote/static/logo.png
recursive-include parl/remote/templates *.html
recursive-include parl/remote/static/css *.css
recursive-include parl/remote/static/js *.js
...@@ -13,7 +13,10 @@ ...@@ -13,7 +13,10 @@
# limitations under the License. # limitations under the License.
import cloudpickle import cloudpickle
import datetime
import os import os
import socket
import sys
import threading import threading
import zmq import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger from parl.utils import to_str, to_byte, get_ip_address, logger
...@@ -33,6 +36,8 @@ class Client(object): ...@@ -33,6 +36,8 @@ class Client(object):
the master node. the master node.
pyfiles (bytes): A serialized dictionary containing the code of python pyfiles (bytes): A serialized dictionary containing the code of python
files in local working directory. files in local working directory.
executable_path (str): File path of the executable python script.
start_time (time): A timestamp to record the start time of the program.
""" """
...@@ -47,9 +52,23 @@ class Client(object): ...@@ -47,9 +52,23 @@ class Client(object):
self.master_is_alive = True self.master_is_alive = True
self.client_is_alive = True self.client_is_alive = True
self.executable_path = self.get_executable_path()
self.actor_num = 0
self._create_sockets(master_address) self._create_sockets(master_address)
self.pyfiles = self.read_local_files() self.pyfiles = self.read_local_files()
def get_executable_path(self):
"""Return current executable path."""
mod = sys.modules['__main__']
if hasattr(mod, '__file__'):
executable_path = os.path.abspath(mod.__file__)
else:
executable_path = os.getcwd()
executable_path = executable_path[:executable_path.rfind('/')]
return executable_path
def read_local_files(self): def read_local_files(self):
"""Read local python code and store them in a dictionary, which will """Read local python code and store them in a dictionary, which will
then be sent to the job. then be sent to the job.
...@@ -78,7 +97,7 @@ class Client(object): ...@@ -78,7 +97,7 @@ class Client(object):
self.submit_job_socket.setsockopt( self.submit_job_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
self.submit_job_socket.connect("tcp://{}".format(master_address)) self.submit_job_socket.connect("tcp://{}".format(master_address))
self.start_time = time.time()
thread = threading.Thread(target=self._reply_heartbeat) thread = threading.Thread(target=self._reply_heartbeat)
thread.setDaemon(True) thread.setDaemon(True)
thread.start() thread.start()
...@@ -88,7 +107,8 @@ class Client(object): ...@@ -88,7 +107,8 @@ class Client(object):
try: try:
self.submit_job_socket.send_multipart([ self.submit_job_socket.send_multipart([
remote_constants.CLIENT_CONNECT_TAG, remote_constants.CLIENT_CONNECT_TAG,
to_byte(self.heartbeat_master_address) to_byte(self.heartbeat_master_address),
to_byte(socket.gethostname())
]) ])
_ = self.submit_job_socket.recv_multipart() _ = self.submit_job_socket.recv_multipart()
except zmq.error.Again as e: except zmq.error.Again as e:
...@@ -115,7 +135,14 @@ class Client(object): ...@@ -115,7 +135,14 @@ class Client(object):
while self.client_is_alive and self.master_is_alive: while self.client_is_alive and self.master_is_alive:
try: try:
message = socket.recv_multipart() message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG]) elapsed_time = datetime.timedelta(
seconds=int(time.time() - self.start_time))
socket.send_multipart([
remote_constants.HEARTBEAT_TAG,
to_byte(self.executable_path),
to_byte(str(self.actor_num)),
to_byte(str(elapsed_time))
])
except zmq.error.Again as e: except zmq.error.Again as e:
logger.warning("[Client] Cannot connect to the master." logger.warning("[Client] Cannot connect to the master."
...@@ -169,6 +196,7 @@ class Client(object): ...@@ -169,6 +196,7 @@ class Client(object):
except zmq.error.Again as e: except zmq.error.Again as e:
job_is_alive = False job_is_alive = False
self.actor_num -= 1
except zmq.error.ZMQError as e: except zmq.error.ZMQError as e:
break break
...@@ -207,6 +235,7 @@ class Client(object): ...@@ -207,6 +235,7 @@ class Client(object):
check_result = self._check_and_monitor_job( check_result = self._check_and_monitor_job(
job_heartbeat_address, ping_heartbeat_address) job_heartbeat_address, ping_heartbeat_address)
if check_result: if check_result:
self.actor_num += 1
return job_address return job_address
# no vacant CPU resources, cannot submit a new job # no vacant CPU resources, cannot submit a new job
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cloudpickle
import threading
from collections import defaultdict, deque
from parl.utils import to_str
class ClusterMonitor(object):
"""The client monitor watches the cluster status.
Attributes:
status (dict): A dict to store workers status and clients status.
"""
def __init__(self):
self.status = {
'workers': defaultdict(dict),
'clients': defaultdict(dict)
}
self.lock = threading.Lock()
def add_worker_status(self, worker_address, hostname):
"""Record worker status when it is connected to the cluster.
Args:
worker_address (str): worker ip address
hostname (str): worker hostname
"""
self.lock.acquire()
worker_status = self.status['workers'][worker_address]
worker_status['load_value'] = deque(maxlen=10)
worker_status['load_time'] = deque(maxlen=10)
worker_status['hostname'] = hostname
self.lock.release()
def update_client_status(self, client_status, client_address,
client_hostname):
"""Update client status with message send from client heartbeat.
Args:
client_status (tuple): client status information
(file_path, actor_num, elapsed_time).
client_address (str): client ip address.
client_hostname (str): client hostname.
"""
self.lock.acquire()
self.status['clients'][client_address] = {
'client_address': client_hostname,
'file_path': to_str(client_status[1]),
'actor_num': int(to_str(client_status[2])),
'time': to_str(client_status[3])
}
self.lock.release()
def update_worker_status(self, update_status, worker_address, vacant_cpus,
total_cpus):
"""Update a worker status.
Args:
update_status (tuple): master status information (vacant_memory, used_memory, load_time, load_value).
worker_address (str): worker ip address.
vacant_cpus (int): vacant cpu number.
total_cpus (int): total cpu number.
"""
self.lock.acquire()
worker_status = self.status['workers'][worker_address]
worker_status['vacant_memory'] = float(to_str(update_status[1]))
worker_status['used_memory'] = float(to_str(update_status[2]))
worker_status['load_time'].append(to_str(update_status[3]))
worker_status['load_value'].append(float(update_status[4]))
worker_status['vacant_cpus'] = vacant_cpus
worker_status['used_cpus'] = total_cpus - vacant_cpus
self.lock.release()
def drop_worker_status(self, worker_address):
"""Drop worker status when it exits.
Args:
worker_address (str): IP address of the exited worker.
"""
self.lock.acquire()
self.status['workers'].pop(worker_address)
self.lock.release()
def drop_cluster_status(self, client_address):
"""Drop cluster status when it exits.
Args:
cluster_address (str): IP address of the exited client.
"""
self.lock.acquire()
self.status['clients'].pop(client_address)
self.lock.release()
def get_status(self):
"""Return a cloudpickled status."""
self.lock.acquire()
status = cloudpickle.dumps(self.status)
self.lock.release()
return status
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import threading import threading
from collections import defaultdict
class JobCenter(object): class JobCenter(object):
...@@ -20,12 +21,19 @@ class JobCenter(object): ...@@ -20,12 +21,19 @@ class JobCenter(object):
Attributes: Attributes:
job_pool (set): A set to store the job address of vacant cpu. job_pool (set): A set to store the job address of vacant cpu.
worker_dict (dict): A dict to store connected workers. worker_dict (dict): A dict to store connected workers.
worker_hostname (dict): A dict to record worker hostname.
worker_vacant_jobs (dict): Record how many vacant jobs does each
worker has.
master_ip (str): IP address of the master node.
""" """
def __init__(self): def __init__(self, master_ip):
self.job_pool = dict() self.job_pool = dict()
self.worker_dict = {} self.worker_dict = {}
self.worker_hostname = defaultdict(int)
self.worker_vacant_jobs = {}
self.lock = threading.Lock() self.lock = threading.Lock()
self.master_ip = master_ip
@property @property
def cpu_num(self): def cpu_num(self):
...@@ -38,7 +46,7 @@ class JobCenter(object): ...@@ -38,7 +46,7 @@ class JobCenter(object):
return len(self.worker_dict) return len(self.worker_dict)
def add_worker(self, worker): def add_worker(self, worker):
"""A new worker connects. """When a new worker connects, add its hostname to worker_hostname.
Args: Args:
worker (InitializedWorker): New worker with initialized jobs. worker (InitializedWorker): New worker with initialized jobs.
...@@ -47,13 +55,26 @@ class JobCenter(object): ...@@ -47,13 +55,26 @@ class JobCenter(object):
self.worker_dict[worker.worker_address] = worker self.worker_dict[worker.worker_address] = worker
for job in worker.initialized_jobs: for job in worker.initialized_jobs:
self.job_pool[job.job_address] = job self.job_pool[job.job_address] = job
self.worker_vacant_jobs[worker.worker_address] = len(
worker.initialized_jobs)
if self.master_ip and worker.worker_address.split(
':')[0] == self.master_ip:
self.worker_hostname[worker.worker_address] = "Master"
self.master_ip = None
else:
self.worker_hostname[worker.hostname] += 1
self.worker_hostname[worker.worker_address] = "{}:{}".format(
worker.hostname, self.worker_hostname[worker.hostname])
self.lock.release() self.lock.release()
def drop_worker(self, worker_address): def drop_worker(self, worker_address):
"""Remove jobs from job_pool when a worker dies. """Remove jobs from job_pool when a worker dies.
Args: Args:
worker (start): Old worker to be removed from the cluster. worker_address (str): the worker_address of a worker to be
removed from the job center.
""" """
self.lock.acquire() self.lock.acquire()
worker = self.worker_dict[worker_address] worker = self.worker_dict[worker_address]
...@@ -61,11 +82,12 @@ class JobCenter(object): ...@@ -61,11 +82,12 @@ class JobCenter(object):
if job.job_address in self.job_pool: if job.job_address in self.job_pool:
self.job_pool.pop(job.job_address) self.job_pool.pop(job.job_address)
self.worker_dict.pop(worker_address) self.worker_dict.pop(worker_address)
self.worker_vacant_jobs.pop(worker_address)
self.lock.release() self.lock.release()
def request_job(self): def request_job(self):
"""Return a job_address when the client submits a job. """Return a job_address when the client submits a job.
If there is no vacant CPU in the cluster, this will return None. If there is no vacant CPU in the cluster, this will return None.
Return: Return:
...@@ -75,6 +97,8 @@ class JobCenter(object): ...@@ -75,6 +97,8 @@ class JobCenter(object):
job = None job = None
if len(self.job_pool): if len(self.job_pool):
job_address, job = self.job_pool.popitem() job_address, job = self.job_pool.popitem()
self.worker_vacant_jobs[job.worker_address] -= 1
assert self.worker_vacant_jobs[job.worker_address] >= 0
self.lock.release() self.lock.release()
return job return job
...@@ -101,12 +125,30 @@ class JobCenter(object): ...@@ -101,12 +125,30 @@ class JobCenter(object):
if killed_job_address in self.job_pool: if killed_job_address in self.job_pool:
self.job_pool.pop(killed_job_address) self.job_pool.pop(killed_job_address)
to_del_idx = None to_del_idx = None
for i, job in enumerate( for i, job in enumerate(
self.worker_dict[worker_address].initialized_jobs): self.worker_dict[worker_address].initialized_jobs):
if job.job_address == killed_job_address: if job.job_address == killed_job_address:
to_del_idx = i to_del_idx = i
break break
del self.worker_dict[worker_address].initialized_jobs[to_del_idx] del self.worker_dict[worker_address].initialized_jobs[to_del_idx]
self.worker_dict[worker_address].initialized_jobs.append(new_job) self.worker_dict[worker_address].initialized_jobs.append(new_job)
if killed_job_address not in self.job_pool:
self.worker_vacant_jobs[worker_address] += 1
self.lock.release() self.lock.release()
def get_vacant_cpu(self, worker_address):
"""Return vacant cpu number of a worker."""
return self.worker_vacant_jobs[worker_address]
def get_total_cpu(self, worker_address):
"""Return total cpu number of a worker."""
return len(self.worker_dict[worker_address].initialized_jobs)
def get_hostname(self, worker_address):
"""Return the hostname of a worker."""
return self.worker_hostname[worker_address]
...@@ -17,10 +17,11 @@ import pickle ...@@ -17,10 +17,11 @@ import pickle
import threading import threading
import time import time
import zmq import zmq
from collections import deque, defaultdict
from parl.utils import to_str, to_byte, logger from parl.utils import to_str, to_byte, logger, get_ip_address
from parl.remote import remote_constants from parl.remote import remote_constants
from parl.remote.job_center import JobCenter from parl.remote.job_center import JobCenter
from parl.remote.cluster_monitor import ClusterMonitor
import cloudpickle import cloudpickle
import time import time
...@@ -46,8 +47,11 @@ class Master(object): ...@@ -46,8 +47,11 @@ class Master(object):
client_socket (zmq.Context.socket): A socket that receives submitted client_socket (zmq.Context.socket): A socket that receives submitted
job from the client, and later sends job from the client, and later sends
job_address back to the client. job_address back to the client.
master_ip(str): The ip address of the master node.
cpu_num(int): The number of available CPUs in the cluster. cpu_num(int): The number of available CPUs in the cluster.
worker_num(int): The number of workers connected to this cluster. worker_num(int): The number of workers connected to this cluster.
cluster_monitor(dict): A dict to record worker status and client status.
client_hostname(dict): A dict to store hostname for each client address.
Args: Args:
port: The ip port that the master node binds to. port: The ip port that the master node binds to.
...@@ -57,16 +61,21 @@ class Master(object): ...@@ -57,16 +61,21 @@ class Master(object):
logger.set_dir(os.path.expanduser('~/.parl_data/master/')) logger.set_dir(os.path.expanduser('~/.parl_data/master/'))
self.ctx = zmq.Context() self.ctx = zmq.Context()
self.master_ip = get_ip_address()
self.client_socket = self.ctx.socket(zmq.REP) self.client_socket = self.ctx.socket(zmq.REP)
self.client_socket.bind("tcp://*:{}".format(port)) self.client_socket.bind("tcp://*:{}".format(port))
self.client_socket.linger = 0 self.client_socket.linger = 0
self.port = port self.port = port
self.job_center = JobCenter() self.job_center = JobCenter(self.master_ip)
self.cluster_monitor = ClusterMonitor()
self.master_is_alive = True self.master_is_alive = True
self.client_hostname = defaultdict(int)
def _get_status(self):
return self.cluster_monitor.get_status()
def _create_worker_monitor(self, worker_heartbeat_address, worker_address): def _create_worker_monitor(self, worker_address):
"""When a new worker connects to the master, a socket is created to """When a new worker connects to the master, a socket is created to
send heartbeat signals to the worker. send heartbeat signals to the worker.
""" """
...@@ -74,17 +83,22 @@ class Master(object): ...@@ -74,17 +83,22 @@ class Master(object):
worker_heartbeat_socket.linger = 0 worker_heartbeat_socket.linger = 0
worker_heartbeat_socket.setsockopt( worker_heartbeat_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
worker_heartbeat_socket.connect("tcp://" + worker_heartbeat_address) worker_heartbeat_socket.connect("tcp://" + worker_address)
connected = True connected = True
while connected and self.master_is_alive: while connected and self.master_is_alive:
try: try:
worker_heartbeat_socket.send_multipart( worker_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG]) [remote_constants.HEARTBEAT_TAG])
_ = worker_heartbeat_socket.recv_multipart() worker_status = worker_heartbeat_socket.recv_multipart()
vacant_cpus = self.job_center.get_vacant_cpu(worker_address)
total_cpus = self.job_center.get_total_cpu(worker_address)
self.cluster_monitor.update_worker_status(
worker_status, worker_address, vacant_cpus, total_cpus)
time.sleep(remote_constants.HEARTBEAT_INTERVAL_S) time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
except zmq.error.Again as e: except zmq.error.Again as e:
self.job_center.drop_worker(worker_address) self.job_center.drop_worker(worker_address)
self.cluster_monitor.drop_worker_status(worker_address)
logger.warning("\n[Master] Cannot connect to the worker " + logger.warning("\n[Master] Cannot connect to the worker " +
"{}. ".format(worker_address) + "{}. ".format(worker_address) +
"Worker_pool will drop this worker.") "Worker_pool will drop this worker.")
...@@ -112,9 +126,16 @@ class Master(object): ...@@ -112,9 +126,16 @@ class Master(object):
try: try:
client_heartbeat_socket.send_multipart( client_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG]) [remote_constants.HEARTBEAT_TAG])
_ = client_heartbeat_socket.recv_multipart() client_status = client_heartbeat_socket.recv_multipart()
self.cluster_monitor.update_client_status(
client_status, client_heartbeat_address,
self.client_hostname[client_heartbeat_address])
except zmq.error.Again as e: except zmq.error.Again as e:
client_is_alive = False client_is_alive = False
self.cluster_monitor.drop_cluster_status(
client_heartbeat_address)
logger.warning("[Master] cannot connect to the client " + logger.warning("[Master] cannot connect to the client " +
"{}. ".format(client_heartbeat_address) + "{}. ".format(client_heartbeat_address) +
"Please check if it is still alive.") "Please check if it is still alive.")
...@@ -152,19 +173,25 @@ class Master(object): ...@@ -152,19 +173,25 @@ class Master(object):
if tag == remote_constants.WORKER_CONNECT_TAG: if tag == remote_constants.WORKER_CONNECT_TAG:
self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
elif tag == remote_constants.MONITOR_TAG:
status = self._get_status()
self.client_socket.send_multipart(
[remote_constants.NORMAL_TAG, status])
elif tag == remote_constants.WORKER_INITIALIZED_TAG: elif tag == remote_constants.WORKER_INITIALIZED_TAG:
initialized_worker = cloudpickle.loads(message[1]) initialized_worker = cloudpickle.loads(message[1])
worker_address = initialized_worker.worker_address
self.job_center.add_worker(initialized_worker) self.job_center.add_worker(initialized_worker)
logger.info("A new worker {} is added, ".format(initialized_worker. hostname = self.job_center.get_hostname(worker_address)
worker_address) + self.cluster_monitor.add_worker_status(worker_address, hostname)
logger.info("A new worker {} is added, ".format(worker_address) +
"the cluster has {} CPUs.\n".format(self.cpu_num)) "the cluster has {} CPUs.\n".format(self.cpu_num))
# a thread for sending heartbeat signals to `worker.address` # a thread for sending heartbeat signals to `worker.address`
thread = threading.Thread( thread = threading.Thread(
target=self._create_worker_monitor, target=self._create_worker_monitor,
args=(initialized_worker.master_heartbeat_address, args=(initialized_worker.worker_address, ))
initialized_worker.worker_address))
thread.start() thread.start()
self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
...@@ -172,6 +199,8 @@ class Master(object): ...@@ -172,6 +199,8 @@ class Master(object):
# a client connects to the master # a client connects to the master
elif tag == remote_constants.CLIENT_CONNECT_TAG: elif tag == remote_constants.CLIENT_CONNECT_TAG:
client_heartbeat_address = to_str(message[1]) client_heartbeat_address = to_str(message[1])
client_hostname = to_str(message[2])
self.client_hostname[client_heartbeat_address] = client_hostname
logger.info( logger.info(
"Client {} is connected.".format(client_heartbeat_address)) "Client {} is connected.".format(client_heartbeat_address))
...@@ -192,7 +221,7 @@ class Master(object): ...@@ -192,7 +221,7 @@ class Master(object):
remote_constants.NORMAL_TAG, remote_constants.NORMAL_TAG,
to_byte(job.job_address), to_byte(job.job_address),
to_byte(job.client_heartbeat_address), to_byte(job.client_heartbeat_address),
to_byte(job.ping_heartbeat_address) to_byte(job.ping_heartbeat_address),
]) ])
self._print_workers() self._print_workers()
else: else:
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
class InitializedJob(object): class InitializedJob(object):
def __init__(self, job_address, worker_heartbeat_address, def __init__(self, job_address, worker_heartbeat_address,
client_heartbeat_address, ping_heartbeat_address, client_heartbeat_address, ping_heartbeat_address,
...@@ -20,7 +22,7 @@ class InitializedJob(object): ...@@ -20,7 +22,7 @@ class InitializedJob(object):
job_address(str): Job address to which the new task connect. job_address(str): Job address to which the new task connect.
worker_heartbeat_address(str): Optional. The address to which the worker sends heartbeat signals. worker_heartbeat_address(str): Optional. The address to which the worker sends heartbeat signals.
client_heartbeat_address(str): Address to which the client sends heartbeat signals. client_heartbeat_address(str): Address to which the client sends heartbeat signals.
ping_heartbeat_address(str): the server address to which the client sends ping signals. ping_heartbeat_address(str): the server address to which the client sends ping signals.
The signal is used to check if the job is alive. The signal is used to check if the job is alive.
worker_address(str): Worker's server address that receive command from the master. worker_address(str): Worker's server address that receive command from the master.
pid(int): Optional. Process id of the job. pid(int): Optional. Process id of the job.
...@@ -36,8 +38,8 @@ class InitializedJob(object): ...@@ -36,8 +38,8 @@ class InitializedJob(object):
class InitializedWorker(object): class InitializedWorker(object):
def __init__(self, worker_address, master_heartbeat_address, def __init__(self, master_heartbeat_address, initialized_jobs, cpu_num,
initialized_jobs, cpu_num): hostname):
""" """
Args: Args:
worker_address(str): Worker server address that receives commands from the master. worker_address(str): Worker server address that receives commands from the master.
...@@ -45,7 +47,7 @@ class InitializedWorker(object): ...@@ -45,7 +47,7 @@ class InitializedWorker(object):
initialized_jobs(list): A list of ``InitializedJob`` containing the information for initialized jobs. initialized_jobs(list): A list of ``InitializedJob`` containing the information for initialized jobs.
cpu_num(int): The number of CPUs used in this worker. cpu_num(int): The number of CPUs used in this worker.
""" """
self.worker_address = worker_address self.worker_address = master_heartbeat_address
self.master_heartbeat_address = master_heartbeat_address
self.initialized_jobs = initialized_jobs self.initialized_jobs = initialized_jobs
self.cpu_num = cpu_num self.cpu_num = cpu_num
self.hostname = hostname
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import pickle
import random
import time
import zmq
import threading
from flask import Flask, render_template, jsonify
app = Flask(__name__)
@app.route('/')
@app.route('/workers')
def worker():
return render_template('workers.html')
@app.route('/clients')
def clients():
return render_template('clients.html')
class ClusterMonitor(object):
"""A monitor which requests the cluster status every 10 seconds.
"""
def __init__(self, master_address):
ctx = zmq.Context()
self.socket = ctx.socket(zmq.REQ)
self.socket.setsockopt(zmq.RCVTIMEO, 10000)
self.socket.connect('tcp://{}'.format(master_address))
self.data = None
thread = threading.Thread(target=self.run)
thread.setDaemon(True)
thread.start()
def run(self):
master_is_alive = True
while master_is_alive:
try:
self.socket.send_multipart([b'[MONITOR]'])
msg = self.socket.recv_multipart()
status = pickle.loads(msg[1])
data = {'workers': [], 'clients': []}
master_idx = None
for idx, worker in enumerate(status['workers'].values()):
worker['load_time'] = list(worker['load_time'])
worker['load_value'] = list(worker['load_value'])
if worker['hostname'] == 'Master':
master_idx = idx
data['workers'].append(worker)
if master_idx != 0 and master_idx is not None:
master_worker = data['workers'].pop(master_idx)
data['workers'] = [master_worker] + data['workers']
data['clients'] = list(status['clients'].values())
self.data = data
time.sleep(10)
except zmq.error.Again as e:
master_is_alive = False
self.socket.close(0)
def get_data(self):
assert self.data is not None
return self.data
@app.route('/cluster')
def cluster():
data = CLUSTER_MONITOR.get_data()
return jsonify(data)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--monitor_port', default=1234, type=int)
parser.add_argument('--address', default='localhost:1234', type=str)
args = parser.parse_args()
CLUSTER_MONITOR = ClusterMonitor(args.address)
app.run(host="0.0.0.0", port=args.monitor_port)
...@@ -16,6 +16,7 @@ CPU_TAG = b'[CPU]' ...@@ -16,6 +16,7 @@ CPU_TAG = b'[CPU]'
CONNECT_TAG = b'[CONNECT]' CONNECT_TAG = b'[CONNECT]'
HEARTBEAT_TAG = b'[HEARTBEAT]' HEARTBEAT_TAG = b'[HEARTBEAT]'
KILLJOB_TAG = b'[KILLJOB]' KILLJOB_TAG = b'[KILLJOB]'
MONITOR_TAG = b'[MONITOR]'
WORKER_CONNECT_TAG = b'[WORKER_CONNECT]' WORKER_CONNECT_TAG = b'[WORKER_CONNECT]'
WORKER_INITIALIZED_TAG = b'[WORKER_INITIALIZED]' WORKER_INITIALIZED_TAG = b'[WORKER_INITIALIZED]'
......
...@@ -13,14 +13,18 @@ ...@@ -13,14 +13,18 @@
# limitations under the License. # limitations under the License.
import click import click
import socket
import locale import locale
import sys import sys
import random
import os import os
import multiprocessing
import subprocess import subprocess
import threading import threading
import warnings import warnings
import zmq
from multiprocessing import Process from multiprocessing import Process
from parl.utils import logger from parl.utils import get_ip_address
# A flag to mark if parl is started from a command line # A flag to mark if parl is started from a command line
os.environ['XPARL'] = 'True' os.environ['XPARL'] = 'True'
...@@ -34,13 +38,20 @@ if sys.version_info.major == 3: ...@@ -34,13 +38,20 @@ if sys.version_info.major == 3:
warnings.simplefilter("ignore", ResourceWarning) warnings.simplefilter("ignore", ResourceWarning)
def get_free_tcp_port():
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tcp.bind(('', 0))
addr, port = tcp.getsockname()
tcp.close()
return port
def is_port_available(port): def is_port_available(port):
""" Check if a port is used. """ Check if a port is used.
True if the port is available for connection. True if the port is available for connection.
""" """
port = int(port) port = int(port)
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
available = sock.connect_ex(('localhost', port)) available = sock.connect_ex(('localhost', port))
sock.close() sock.close()
...@@ -48,7 +59,6 @@ def is_port_available(port): ...@@ -48,7 +59,6 @@ def is_port_available(port):
def is_master_started(address): def is_master_started(address):
import zmq
ctx = zmq.Context() ctx = zmq.Context()
socket = ctx.socket(zmq.REQ) socket = ctx.socket(zmq.REQ)
socket.linger = 0 socket.linger = 0
...@@ -80,12 +90,12 @@ def start_master(port, cpu_num): ...@@ -80,12 +90,12 @@ def start_master(port, cpu_num):
if not is_port_available(port): if not is_port_available(port):
raise Exception( raise Exception(
"The master address localhost:{} already in use.".format(port)) "The master address localhost:{} already in use.".format(port))
cpu_num = str(cpu_num) if cpu_num else '' cpu_num = cpu_num if cpu_num else multiprocessing.cpu_count()
start_file = __file__.replace('scripts.pyc', 'start.py') start_file = __file__.replace('scripts.pyc', 'start.py')
start_file = start_file.replace('scripts.py', 'start.py') start_file = start_file.replace('scripts.py', 'start.py')
command = [sys.executable, start_file, "--name", "master", "--port", port] command = [sys.executable, start_file, "--name", "master", "--port", port]
p = subprocess.Popen(command)
p = subprocess.Popen(command)
command = [ command = [
sys.executable, start_file, "--name", "worker", "--address", sys.executable, start_file, "--name", "worker", "--address",
"localhost:" + str(port), "--cpu_num", "localhost:" + str(port), "--cpu_num",
...@@ -94,8 +104,37 @@ def start_master(port, cpu_num): ...@@ -94,8 +104,37 @@ def start_master(port, cpu_num):
# Redirect the output to DEVNULL to solve the warning log. # Redirect the output to DEVNULL to solve the warning log.
FNULL = open(os.devnull, 'w') FNULL = open(os.devnull, 'w')
p = subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT) p = subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT)
monitor_port = get_free_tcp_port()
command = [
sys.executable, '{}/monitor.py'.format(__file__[:__file__.rfind('/')]),
"--monitor_port",
str(monitor_port), "--address", "localhost:" + str(port)
]
p = subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL.close() FNULL.close()
cluster_info = """
# The Parl cluster is started at localhost:{}.
# A local worker with {} CPUs is connected to the cluster.
## If you want to check cluster status, visit:
http://{}:{}.
## If you want to add more CPU resources, call:
xparl connect --address localhost:{}
## If you want to shutdown the cluster, call:
xparl stop""".format(port, cpu_num, get_ip_address(), monitor_port,
port)
click.echo(cluster_info)
@click.command("connect", short_help="Start a worker node.") @click.command("connect", short_help="Start a worker node.")
@click.option( @click.option(
...@@ -125,6 +164,8 @@ def stop(): ...@@ -125,6 +164,8 @@ def stop():
subprocess.call([command], shell=True) subprocess.call([command], shell=True)
command = ("pkill -f remote/job.py") command = ("pkill -f remote/job.py")
subprocess.call([command], shell=True) subprocess.call([command], shell=True)
command = ("pkill -f remote/monitor.py")
subprocess.call([command], shell=True)
cli.add_command(start_worker) cli.add_command(start_worker)
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
function createDivs(res, divs) {
var elem = document.getElementById("main");
var worker_num = res.workers.length; // 8
var curr_num = elem.children.length; // 0
divs = (divs < worker_num) ? divs : worker_num;
if (curr_num < divs) {
for (var i = curr_num; i < divs; i++) {
var workerDiv = document.createElement("div");
workerDiv.id = `w${i}`;
if (i === 0) {
workerDiv.innerHTML = `<p class="card-header" id="${i}">Master</p>`;
} else {
workerDiv.innerHTML = `<p class="card-header" id="${i}">Worker ${res.workers[i].hostname}</p>`;
}
var cardDiv = document.createElement("div");
cardDiv.className = "row";
var card = '';
for (var j = 0; j < 3; j++)
card += `<div class="col-lg-4"><div class="card mb-1"><div id="w${i}c${j}" class="card-body" style="height: 110px;"></div></div></div>`;
cardDiv.innerHTML = card;
workerDiv.appendChild(cardDiv);
elem.appendChild(workerDiv);
for (var j = 0; j < 3; j++)
imgHandle[`w${i}c${j}`] = echarts.init(document.getElementById(`w${i}c${j}`));
};
} else if (curr_num > worker_num) {
for (var i = curr_num - 1; i >= worker_num; i--) {
delete imgHandle[`w${i}c0`];
delete imgHandle[`w${i}c1`];
delete imgHandle[`w${i}c2`];
var workerDiv = document.getElementById(`w${i}`);
elem.removeChild(workerDiv);
}
}
}
function addPlots(res, record, imgHandle, begin, end) {
var worker_num = res.workers.length;
var record_num = Object.keys(record).length;
end = (end < worker_num) ? end : worker_num;
for (var i = begin; i < end; i++) {
var worker = res.workers[i];
var cpuOption = {
color: ["#7B68EE", "#6495ED"],
legend: {
orient: 'vertical',
x: 'left',
data: ['Used CPU', 'Vacant CPU'],
textStyle: {
fontSize: 8,
}
},
series: [
{
type: "pie",
radius: "80%",
label: {
normal: {
formatter: "{c}",
show: true,
position: "inner",
fontSize: 16,
}
},
data: [
{ value: worker.used_cpus, name: "Used CPU" },
{ value: worker.vacant_cpus, name: "Vacant CPU" }
]
}
]
};
var memoryOption = {
color: ["#FF8C00", "#FF4500"],
legend: {
orient: "vertical",
x: "left",
data: ["Used Memory", "Vacant Memory"],
textStyle: {
fontSize: 8,
}
},
series: [
{
name: "Memory",
type: "pie",
radius: "80%",
label: {
normal: {
formatter: "{c}",
show: true,
position: "inner",
fontSize: 12,
}
},
data: [
{ value: worker.used_memory, name: "Used Memory" },
{ value: worker.vacant_memory, name: "Vacant Memory" }
]
}
]
};
var loadOption = {
grid:{
x:30,
y:25,
x2:20,
y2:20,
borderWidth:1
},
xAxis: {
type: "category",
data: worker.load_time,
},
yAxis: {
type: "value",
name: "Average CPU load (%)",
splitNumber:3,
nameTextStyle:{
padding: [0, 0, 0, 60],
fontSize: 10,
}
},
series: [
{
data: worker.load_value,
type: "line"
}
]
};
if (i < record_num && worker.worker_address === record[i].worker_address) {
if (worker.cpu_num !== record[i].cpu_num) {
imgHandle[`w${i}c0`].setOption(cpuOption);
}
if (worker.used_memory !== record[i].used_memory) {
imgHandle[`w${i}c1`].setOption(memoryOption);
}
imgHandle[`w${i}c2`].setOption(loadOption);
} else {
var workerTitle = document.getElementById(`${i}`);
workerTitle.innerText = i===0 ? "Master" : `Worker ${worker.hostname}`;
imgHandle[`w${i}c0`].setOption(cpuOption);
imgHandle[`w${i}c1`].setOption(memoryOption);
imgHandle[`w${i}c2`].setOption(loadOption);
}
record[i] = {
worker_address: worker.worker_address,
used_cpu: worker.used_cpu,
vacant_cpu: worker.vacant_cpu,
used_memory: worker.used_cpu,
vacant_memory: worker.vacant_memory
};
}
if (end < record_num) {
for (var i = end; i < record_num; i++)
delete record[i]
}
};
function autoTable(res) {
var table = document.getElementById("table");
table.innerHTML = "";
var rows = res.clients.length;
for(var i=0; i< rows; i++){
var tr = document.createElement('tr');
var s1 = `<th scope="row">${i+1}</th>`;
var s2 = `<td>${res.clients[i].file_path}</td>`;
var s3 = `<td>${res.clients[i].client_address}</td>`;
var s4 = `<td>${res.clients[i].actor_num}</td>`;
var s5 = `<td>${res.clients[i].time}</td>`;
tr.innerHTML = s1 + s2 + s3 + s4 + s5;
table.appendChild(tr);
}
};
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Parl Cluster</title>
<script type="text/javascript" src="../static/js/jquery.min.js"></script>
<script src="../static/js/echarts.min.js"></script>
<script src="../static/js/parl.js"></script>
<link rel="stylesheet" href="../static/css/bootstrap-parl.min.css">
</head>
<body>
<nav class="navbar navbar-expand-lg navbar-light bg-dark fixed-top">
<div class="container">
<a class="navbar-brand">
<img src="../static/logo.png" style="height: 30px">
</a>
<div class="collapse navbar-collapse" id="navbarSupportedContent">
<ul class="navbar-nav">
<li class="nav-item" id="worker_nav">
<a class="btn text-white" href="workers">Worker</a>
</li>
<li class="nav-item" id="client_nav">
<a class="btn text-white" href="clients">Client</a>
</li>
</ul>
</div>
</div>
</nav>
<div class="container" id="client_container">
<h5 class="font-weight-light text-center text-lg-left mt-4 mb-4">
Client status
</h5>
<table class="table table-striped">
<thead>
<tr>
<th scope="col">#</th>
<th scope="col">Path</th>
<th scope="col">Client ID</th>
<th scope="col">Actor Num</th>
<th scope="col">Time (min)</th>
</tr>
</thead>
<tbody id='table'>
<th colspan="5">Loading Data...</th>
</tbody>
</table>
</div>
<script>
var res = {};
$(document).ready(function () {
$.get('cluster', function (data, status) {
res = data;
autoTable(res);
});
setInterval(function () {
$.get('cluster', function (data, status) {
res = data;
autoTable(res);
});
}, 10000);
});
</script>
</body>
</html>
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>Parl Cluster</title>
<script type="text/javascript" src="../static/js/jquery.min.js"></script>
<script src="../static/js/echarts.min.js"></script>
<script src="../static/js/parl.js"></script>
<link rel="stylesheet" href="../static/css/bootstrap-parl.min.css" />
</head>
<body>
<nav class="navbar navbar-expand-lg navbar-light bg-dark fixed-top">
<div class="container">
<a class="navbar-brand">
<img src="../static/logo.png" style="height: 30px" />
</a>
<div class="collapse navbar-collapse" id="navbarSupportedContent">
<ul class="navbar-nav">
<li class="nav-item" id="worker_nav">
<a class="btn text-white" href="workers">Worker</a>
</li>
<li class="nav-item" id="client_nav">
<a class="btn text-white" href="clients">Client</a>
</li>
</ul>
</div>
</div>
</nav>
<div class="container" id="worker_container">
<h5 class="font-weight-light text-center text-lg-left mt-4 mb-4">
Worker CPU usage, Memory usage and average loads.
</h5>
<div id="main"></div>
</div>
<script>
var record = {};
var imgHandle = {};
var res = {};
var div_num = 5;
var start_num = 5;
var delta = 3;
$(document).ready(function() {
console.log('After ready.');
$.get("cluster", function(data, status) {
res = data;
console.log('Get first data.', res);
createDivs(res, start_num);
addPlots(res, record, imgHandle, 0, start_num);
});
setInterval(function() {
$.get("cluster", function(data, status) {
res = data;
console.log('Interval', res);
createDivs(res, start_num);
addPlots(res, record, imgHandle, 0, start_num);
});
}, 10000);
$(window).on("scroll", function() {
if (monitor === 1) {
var scrollTop = $(document).scrollTop();
var windowHeight = $(window).height();
var bodyHeight = $(document).height() - windowHeight;
var scrollPercentage = scrollTop / bodyHeight;
if (scrollPercentage > 0.25) {
if (div_num < res.workers.length) {
div_num += delta;
createDivs(res, div_num);
addPlots(res, record, imgHandle, div_num - delta, div_num);
}
}
}
});
});
</script>
</body>
</html>
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.monitor import ClusterMonitor
import time
import threading
from parl.remote.client import disconnect
from parl.remote import exceptions
import timeout_decorator
import subprocess
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def get_arg1(self):
return self.arg1
def get_arg2(self):
return self.arg2
def set_arg1(self, value):
self.arg1 = value
def set_arg2(self, value):
self.arg2 = value
def get_unable_serialize_object(self):
return UnableSerializeObject()
def add_one(self, value):
value += 1
return value
def add(self, x, y):
time.sleep(3)
return x + y
def will_raise_exception_func(self):
x = 1 / 0
class TestClusterMonitor(unittest.TestCase):
def tearDown(self):
disconnect()
def test_one_worker(self):
port = 1439
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
cluster_monitor = ClusterMonitor('localhost:{}'.format(port))
time.sleep(1)
self.assertEqual(1, len(cluster_monitor.data['workers']))
worker.exit()
time.sleep(40)
self.assertEqual(0, len(cluster_monitor.data['workers']))
master.exit()
def test_twenty_worker(self):
port = 1440
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
workers = []
for _ in range(20):
worker = Worker('localhost:{}'.format(port), 1)
workers.append(worker)
cluster_monitor = ClusterMonitor('localhost:{}'.format(port))
time.sleep(1)
self.assertEqual(20, len(cluster_monitor.data['workers']))
for i in range(10):
workers[i].exit()
time.sleep(40)
self.assertEqual(10, len(cluster_monitor.data['workers']))
for i in range(10, 20):
workers[i].exit()
time.sleep(40)
self.assertEqual(0, len(cluster_monitor.data['workers']))
master.exit()
def test_add_actor(self):
port = 1441
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
cluster_monitor = ClusterMonitor('localhost:{}'.format(port))
time.sleep(1)
self.assertEqual(0, len(cluster_monitor.data['clients']))
parl.connect('localhost:{}'.format(port))
time.sleep(10)
self.assertEqual(1, len(cluster_monitor.data['clients']))
self.assertEqual(1, cluster_monitor.data['workers'][0]['vacant_cpus'])
actor = Actor()
time.sleep(20)
self.assertEqual(0, cluster_monitor.data['workers'][0]['vacant_cpus'])
self.assertEqual(1, cluster_monitor.data['workers'][0]['used_cpus'])
self.assertEqual(1, cluster_monitor.data['clients'][0]['actor_num'])
del actor
time.sleep(40)
self.assertEqual(0, cluster_monitor.data['clients'][0]['actor_num'])
self.assertEqual(1, cluster_monitor.data['workers'][0]['vacant_cpus'])
worker.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
import socket
from parl.remote.job_center import JobCenter from parl.remote.job_center import JobCenter
from parl.remote.message import InitializedWorker, InitializedJob from parl.remote.message import InitializedWorker, InitializedJob
...@@ -22,11 +23,13 @@ class InitializedWorker(object): ...@@ -22,11 +23,13 @@ class InitializedWorker(object):
worker_address, worker_address,
master_heartbeat_address='localhost:8010', master_heartbeat_address='localhost:8010',
initialized_jobs=[], initialized_jobs=[],
cpu_num=4): cpu_num=4,
hostname=None):
self.worker_address = worker_address self.worker_address = worker_address
self.master_heartbeat_address = master_heartbeat_address self.master_heartbeat_address = master_heartbeat_address
self.initialized_jobs = initialized_jobs self.initialized_jobs = initialized_jobs
self.cpu_num = cpu_num self.cpu_num = cpu_num
self.hostname = hostname
class ImportTest(unittest.TestCase): class ImportTest(unittest.TestCase):
...@@ -38,12 +41,14 @@ class ImportTest(unittest.TestCase): ...@@ -38,12 +41,14 @@ class ImportTest(unittest.TestCase):
worker_heartbeat_address='172.18.182.39:48724', worker_heartbeat_address='172.18.182.39:48724',
client_heartbeat_address='172.18.182.39:48725', client_heartbeat_address='172.18.182.39:48725',
ping_heartbeat_address='172.18.182.39:48726', ping_heartbeat_address='172.18.182.39:48726',
worker_address='172.18.182.39:478727', worker_address='172.18.182.39:8001',
pid=1234) pid=1234)
jobs.append(job) jobs.append(job)
self.worker1 = InitializedWorker( self.worker1 = InitializedWorker(
worker_address='172.18.182.39:8001', initialized_jobs=jobs) worker_address='172.18.182.39:8001',
initialized_jobs=jobs,
hostname=socket.gethostname())
jobs = [] jobs = []
for i in range(5): for i in range(5):
...@@ -52,16 +57,18 @@ class ImportTest(unittest.TestCase): ...@@ -52,16 +57,18 @@ class ImportTest(unittest.TestCase):
worker_heartbeat_address='172.18.182.39:48724', worker_heartbeat_address='172.18.182.39:48724',
client_heartbeat_address='172.18.182.39:48725', client_heartbeat_address='172.18.182.39:48725',
ping_heartbeat_address='172.18.182.39:48726', ping_heartbeat_address='172.18.182.39:48726',
worker_address='172.18.182.39:478727', worker_address='172.18.182.39:8002',
pid=1234) pid=1234)
jobs.append(job) jobs.append(job)
self.worker2 = InitializedWorker( self.worker2 = InitializedWorker(
worker_address='172.18.182.39:8002', initialized_jobs=jobs) worker_address='172.18.182.39:8002',
initialized_jobs=jobs,
hostname=socket.gethostname())
def test_add_worker(self): def test_add_worker(self):
job_center = JobCenter() job_center = JobCenter('localhost')
job_center.add_worker(self.worker1) job_center.add_worker(self.worker1)
job_center.add_worker(self.worker2) job_center.add_worker(self.worker2)
...@@ -70,7 +77,7 @@ class ImportTest(unittest.TestCase): ...@@ -70,7 +77,7 @@ class ImportTest(unittest.TestCase):
self.worker1) self.worker1)
def test_drop_worker(self): def test_drop_worker(self):
job_center = JobCenter() job_center = JobCenter('localhost')
job_center.add_worker(self.worker1) job_center.add_worker(self.worker1)
job_center.add_worker(self.worker2) job_center.add_worker(self.worker2)
job_center.drop_worker(self.worker2.worker_address) job_center.drop_worker(self.worker2.worker_address)
...@@ -81,7 +88,7 @@ class ImportTest(unittest.TestCase): ...@@ -81,7 +88,7 @@ class ImportTest(unittest.TestCase):
self.assertEqual(len(job_center.worker_dict), 1) self.assertEqual(len(job_center.worker_dict), 1)
def test_request_job(self): def test_request_job(self):
job_center = JobCenter() job_center = JobCenter('localhost')
job_address1 = job_center.request_job() job_address1 = job_center.request_job()
self.assertTrue(job_address1 is None) self.assertTrue(job_address1 is None)
...@@ -91,7 +98,7 @@ class ImportTest(unittest.TestCase): ...@@ -91,7 +98,7 @@ class ImportTest(unittest.TestCase):
self.assertEqual(len(job_center.job_pool), 4) self.assertEqual(len(job_center.job_pool), 4)
def test_reset_job(self): def test_reset_job(self):
job_center = JobCenter() job_center = JobCenter('localhost')
job_center.add_worker(self.worker1) job_center.add_worker(self.worker1)
job_address = job_center.request_job() job_address = job_center.request_job()
...@@ -103,7 +110,7 @@ class ImportTest(unittest.TestCase): ...@@ -103,7 +110,7 @@ class ImportTest(unittest.TestCase):
def test_update_job(self): def test_update_job(self):
job_center = JobCenter() job_center = JobCenter('localhost')
job_center.add_worker(self.worker1) job_center.add_worker(self.worker1)
job_center.add_worker(self.worker2) job_center.add_worker(self.worker2)
...@@ -142,7 +149,7 @@ class ImportTest(unittest.TestCase): ...@@ -142,7 +149,7 @@ class ImportTest(unittest.TestCase):
self.assertEqual(5, len(self.worker1.initialized_jobs)) self.assertEqual(5, len(self.worker1.initialized_jobs))
def test_cpu_num(self): def test_cpu_num(self):
job_center = JobCenter() job_center = JobCenter('localhost')
job_center.add_worker(self.worker1) job_center.add_worker(self.worker1)
self.assertEqual(job_center.cpu_num, 5) self.assertEqual(job_center.cpu_num, 5)
job_center.add_worker(self.worker2) job_center.add_worker(self.worker2)
...@@ -151,7 +158,7 @@ class ImportTest(unittest.TestCase): ...@@ -151,7 +158,7 @@ class ImportTest(unittest.TestCase):
self.assertEqual(job_center.cpu_num, 9) self.assertEqual(job_center.cpu_num, 9)
def test_worker_num(self): def test_worker_num(self):
job_center = JobCenter() job_center = JobCenter('localhost')
job_center.add_worker(self.worker1) job_center.add_worker(self.worker1)
self.assertEqual(job_center.worker_num, 1) self.assertEqual(job_center.worker_num, 1)
job_center.add_worker(self.worker2) job_center.add_worker(self.worker2)
......
...@@ -15,13 +15,16 @@ ...@@ -15,13 +15,16 @@
import cloudpickle import cloudpickle
import multiprocessing import multiprocessing
import os import os
import psutil
import signal import signal
import socket
import subprocess import subprocess
import sys import sys
import time import time
import threading import threading
import warnings import warnings
import zmq import zmq
from datetime import datetime
from parl.utils import get_ip_address, to_byte, to_str, logger from parl.utils import get_ip_address, to_byte, to_str, logger
from parl.remote import remote_constants from parl.remote import remote_constants
...@@ -29,9 +32,6 @@ from parl.remote.message import InitializedWorker ...@@ -29,9 +32,6 @@ from parl.remote.message import InitializedWorker
from parl.remote.status import WorkerStatus from parl.remote.status import WorkerStatus
from six.moves import queue from six.moves import queue
if sys.version_info.major == 3:
warnings.simplefilter("ignore", ResourceWarning)
class Worker(object): class Worker(object):
"""Worker provides the cpu computation resources for the cluster. """Worker provides the cpu computation resources for the cluster.
...@@ -53,9 +53,6 @@ class Worker(object): ...@@ -53,9 +53,6 @@ class Worker(object):
master_address (str): Master's ip address. master_address (str): Master's ip address.
request_master_socket (zmq.Context.socket): A socket which sends job request_master_socket (zmq.Context.socket): A socket which sends job
address to the master node. address to the master node.
reply_master_socket (zmq.Context.socket): A socket which accepts
submitted job from master
node.
reply_job_socket (zmq.Context.socket): A socket which receives reply_job_socket (zmq.Context.socket): A socket which receives
job_address from the job. job_address from the job.
kill_job_socket (zmq.Context.socket): A socket that receives commands to kill the job from jobs. kill_job_socket (zmq.Context.socket): A socket that receives commands to kill the job from jobs.
...@@ -101,17 +98,17 @@ class Worker(object): ...@@ -101,17 +98,17 @@ class Worker(object):
self.cpu_num = multiprocessing.cpu_count() self.cpu_num = multiprocessing.cpu_count()
def _create_sockets(self): def _create_sockets(self):
""" Each worker has four sockets at start: """ Each worker has three sockets at start:
(1) request_master_socket: sends job address to master node. (1) request_master_socket: sends job address to master node.
(2) reply_master_socket: accepts submitted job from master node. (2) reply_job_socket: receives job_address from subprocess.
(3) reply_job_socket: receives job_address from subprocess. (3) kill_job_socket : receives commands to kill the job from jobs.
(4) kill_job_socket : receives commands to kill the job from jobs.
When a job starts, a new heartbeat socket is created to receive When a job starts, a new heartbeat socket is created to receive
heartbeat signals from the job. heartbeat signals from the job.
""" """
self.worker_ip = get_ip_address()
# request_master_socket: sends job address to master # request_master_socket: sends job address to master
self.request_master_socket = self.ctx.socket(zmq.REQ) self.request_master_socket = self.ctx.socket(zmq.REQ)
...@@ -121,17 +118,6 @@ class Worker(object): ...@@ -121,17 +118,6 @@ class Worker(object):
self.request_master_socket.setsockopt(zmq.RCVTIMEO, 500) self.request_master_socket.setsockopt(zmq.RCVTIMEO, 500)
self.request_master_socket.connect("tcp://" + self.master_address) self.request_master_socket.connect("tcp://" + self.master_address)
# reply_master_socket: receives submitted job from master
self.reply_master_socket = self.ctx.socket(zmq.REP)
self.reply_master_socket.linger = 0
self.worker_ip = get_ip_address()
reply_master_port = self.reply_master_socket.bind_to_random_port(
"tcp://*")
self.reply_master_address = "{}:{}".format(self.worker_ip,
reply_master_port)
logger.set_dir(
os.path.expanduser('~/.parl_data/worker/{}'.format(
self.reply_master_address)))
# reply_job_socket: receives job_address from subprocess # reply_job_socket: receives job_address from subprocess
self.reply_job_socket = self.ctx.socket(zmq.REP) self.reply_job_socket = self.ctx.socket(zmq.REP)
self.reply_job_socket.linger = 0 self.reply_job_socket.linger = 0
...@@ -165,16 +151,20 @@ class Worker(object): ...@@ -165,16 +151,20 @@ class Worker(object):
args=("master {}".format(self.master_address), )) args=("master {}".format(self.master_address), ))
self.reply_master_hearbeat_thread.start() self.reply_master_hearbeat_thread.start()
self.heartbeat_socket_initialized.wait() self.heartbeat_socket_initialized.wait()
initialized_worker = InitializedWorker(self.reply_master_address,
self.master_heartbeat_address,
initialized_jobs, self.cpu_num)
for job in initialized_jobs:
job.worker_address = self.master_heartbeat_address
initialized_worker = InitializedWorker(self.master_heartbeat_address,
initialized_jobs, self.cpu_num,
socket.gethostname())
self.request_master_socket.send_multipart([ self.request_master_socket.send_multipart([
remote_constants.WORKER_INITIALIZED_TAG, remote_constants.WORKER_INITIALIZED_TAG,
cloudpickle.dumps(initialized_worker) cloudpickle.dumps(initialized_worker)
]) ])
_ = self.request_master_socket.recv_multipart() _ = self.request_master_socket.recv_multipart()
self.worker_status = WorkerStatus(self.reply_master_address, self.worker_status = WorkerStatus(self.master_heartbeat_address,
initialized_jobs, self.cpu_num) initialized_jobs, self.cpu_num)
def _fill_job_buffer(self): def _fill_job_buffer(self):
...@@ -204,6 +194,9 @@ class Worker(object): ...@@ -204,6 +194,9 @@ class Worker(object):
self.reply_job_address self.reply_job_address
] ]
if sys.version_info.major == 3:
warnings.simplefilter("ignore", ResourceWarning)
# avoid that many jobs are killed and restarted at the same time. # avoid that many jobs are killed and restarted at the same time.
self.lock.acquire() self.lock.acquire()
...@@ -220,7 +213,6 @@ class Worker(object): ...@@ -220,7 +213,6 @@ class Worker(object):
[remote_constants.NORMAL_TAG, [remote_constants.NORMAL_TAG,
to_byte(self.kill_job_address)]) to_byte(self.kill_job_address)])
initialized_job = cloudpickle.loads(job_message[1]) initialized_job = cloudpickle.loads(job_message[1])
initialized_job.worker_address = self.reply_master_address
new_jobs.append(initialized_job) new_jobs.append(initialized_job)
# a thread for sending heartbeat signals to job # a thread for sending heartbeat signals to job
...@@ -237,6 +229,7 @@ class Worker(object): ...@@ -237,6 +229,7 @@ class Worker(object):
if success: if success:
while True: while True:
initialized_job = self.job_buffer.get() initialized_job = self.job_buffer.get()
initialized_job.worker_address = self.master_heartbeat_address
if initialized_job.is_alive: if initialized_job.is_alive:
self.worker_status.add_job(initialized_job) self.worker_status.add_job(initialized_job)
if not initialized_job.is_alive: # make sure that the job is still alive. if not initialized_job.is_alive: # make sure that the job is still alive.
...@@ -307,6 +300,15 @@ class Worker(object): ...@@ -307,6 +300,15 @@ class Worker(object):
#detect whether `self.worker_is_alive` is True periodically #detect whether `self.worker_is_alive` is True periodically
pass pass
def _get_worker_status(self):
now = datetime.strftime(datetime.now(), '%H:%M:%S')
virtual_memory = psutil.virtual_memory()
total_memory = round(virtual_memory[0] / (1024**3), 2)
used_memory = round(virtual_memory[3] / (1024**3), 2)
vacant_memory = round(total_memory - used_memory, 2)
load_average = round(psutil.getloadavg()[0], 2)
return (vacant_memory, used_memory, now, load_average)
def _reply_heartbeat(self, target): def _reply_heartbeat(self, target):
"""Worker will kill its jobs when it lost connection with the master. """Worker will kill its jobs when it lost connection with the master.
""" """
...@@ -319,13 +321,25 @@ class Worker(object): ...@@ -319,13 +321,25 @@ class Worker(object):
socket.bind_to_random_port("tcp://*") socket.bind_to_random_port("tcp://*")
self.master_heartbeat_address = "{}:{}".format(self.worker_ip, self.master_heartbeat_address = "{}:{}".format(self.worker_ip,
heartbeat_master_port) heartbeat_master_port)
logger.set_dir(
os.path.expanduser('~/.parl_data/worker/{}'.format(
self.master_heartbeat_address)))
self.heartbeat_socket_initialized.set() self.heartbeat_socket_initialized.set()
logger.info("[Worker] Connect to the master node successfully. " logger.info("[Worker] Connect to the master node successfully. "
"({} CPUs)".format(self.cpu_num)) "({} CPUs)".format(self.cpu_num))
while self.master_is_alive and self.worker_is_alive: while self.master_is_alive and self.worker_is_alive:
try: try:
message = socket.recv_multipart() message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG]) worker_status = self._get_worker_status()
socket.send_multipart([
remote_constants.HEARTBEAT_TAG,
to_byte(str(worker_status[0])),
to_byte(str(worker_status[1])),
to_byte(worker_status[2]),
to_byte(str(worker_status[3]))
])
except zmq.error.Again as e: except zmq.error.Again as e:
self.master_is_alive = False self.master_is_alive = False
except zmq.error.ContextTerminated as e: except zmq.error.ContextTerminated as e:
......
...@@ -62,6 +62,7 @@ setup( ...@@ -62,6 +62,7 @@ setup(
long_description_content_type='text/markdown', long_description_content_type='text/markdown',
url='https://github.com/PaddlePaddle/PARL', url='https://github.com/PaddlePaddle/PARL',
packages=_find_packages(), packages=_find_packages(),
include_package_data=True,
package_data={'': ['*.so']}, package_data={'': ['*.so']},
install_requires=[ install_requires=[
"termcolor>=1.1.0", "termcolor>=1.1.0",
...@@ -72,6 +73,8 @@ setup( ...@@ -72,6 +73,8 @@ setup(
"tensorboardX==1.8", "tensorboardX==1.8",
"tb-nightly==1.15.0a20190801", "tb-nightly==1.15.0a20190801",
"click", "click",
"flask",
"psutil",
], ],
classifiers=[ classifiers=[
'Intended Audience :: Developers', 'Intended Audience :: Developers',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册