未验证 提交 500b189f 编写于 作者: B Bo Zhou 提交者: GitHub

stop using picklecloud to transmit the remote class (#281)

* stop using pickle cloud to transmit the remote class
上级 7d6e8719
...@@ -22,6 +22,3 @@ COPY ./requirements.txt /root/ ...@@ -22,6 +22,3 @@ COPY ./requirements.txt /root/
RUN apt-get install -y libgflags-dev libgoogle-glog-dev libomp-dev unzip RUN apt-get install -y libgflags-dev libgoogle-glog-dev libomp-dev unzip
RUN apt-get install -y libgtest-dev && cd /usr/src/gtest && mkdir build \ RUN apt-get install -y libgtest-dev && cd /usr/src/gtest && mkdir build \
&& cd build && cmake .. && make && cp libgtest*.a /usr/local/lib && cd build && cmake .. && make && cp libgtest*.a /usr/local/lib
RUN apt-get update
RUN apt-get install -y iputils-ping
...@@ -69,7 +69,7 @@ function run_test_with_gpu() { ...@@ -69,7 +69,7 @@ function run_test_with_gpu() {
Running unit tests with GPU... Running unit tests with GPU...
======================================== ========================================
EOF EOF
ctest --output-on-failure -j10 ctest --output-on-failure -j20 --verbose
cd ${REPO_ROOT} cd ${REPO_ROOT}
rm -rf ${REPO_ROOT}/build rm -rf ${REPO_ROOT}/build
} }
...@@ -90,7 +90,7 @@ function run_test_with_cpu() { ...@@ -90,7 +90,7 @@ function run_test_with_cpu() {
===================================================== =====================================================
EOF EOF
if [ $# -eq 1 ];then if [ $# -eq 1 ];then
ctest --output-on-failure -j10 ctest --output-on-failure -j20 --verbose
else else
ctest --output-on-failure ctest --output-on-failure
fi fi
...@@ -145,7 +145,8 @@ function main() { ...@@ -145,7 +145,8 @@ function main() {
;; ;;
test) test)
# test code compability in environments with various python versions # test code compability in environments with various python versions
declare -a envs=("py36_torch" "py37_torch" "py27" "py36" "py37") #declare -a envs=("py36_torch" "py37_torch" "py27" "py36" "py37")
declare -a envs=("py27" "py36")
for env in "${envs[@]}";do for env in "${envs[@]}";do
cd /work cd /work
source ~/.bashrc source ~/.bashrc
...@@ -169,6 +170,10 @@ function main() { ...@@ -169,6 +170,10 @@ function main() {
pip install -r .teamcity/requirements_torch.txt pip install -r .teamcity/requirements_torch.txt
run_test_with_cpu $env "DIS_TESTING_TORCH" run_test_with_cpu $env "DIS_TESTING_TORCH"
fi fi
# clean env
export LC_ALL=C.UTF-8
export LANG=C.UTF-8
xparl stop
done done
run_test_with_gpu run_test_with_gpu
......
...@@ -33,6 +33,7 @@ function(py_test TARGET_NAME) ...@@ -33,6 +33,7 @@ function(py_test TARGET_NAME)
add_test(NAME ${TARGET_NAME} add_test(NAME ${TARGET_NAME}
COMMAND python -u ${py_test_SRCS} ${py_test_ARGS} COMMAND python -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 300)
endfunction() endfunction()
function(import_test TARGET_NAME) function(import_test TARGET_NAME)
......
...@@ -20,7 +20,6 @@ import sys ...@@ -20,7 +20,6 @@ import sys
import threading import threading
import zmq import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.utils.communication import ping
from parl.remote import remote_constants from parl.remote import remote_constants
import time import time
...@@ -109,6 +108,13 @@ class Client(object): ...@@ -109,6 +108,13 @@ class Client(object):
with open(file, 'rb') as f: with open(file, 'rb') as f:
content = f.read() content = f.read()
pyfiles['other_files'][file] = content pyfiles['other_files'][file] = content
# append entry file to code list
main_file = sys.argv[0]
with open(main_file, 'rb') as code_file:
code = code_file.read()
# parl/remote/remote_decorator.py -> remote_decorator.py
file_name = main_file.split(os.sep)[-1]
pyfiles['python_files'][file_name] = code
except AssertionError as e: except AssertionError as e:
raise Exception( raise Exception(
'Failed to create the client, the file {} does not exist.'. 'Failed to create the client, the file {} does not exist.'.
...@@ -162,6 +168,7 @@ class Client(object): ...@@ -162,6 +168,7 @@ class Client(object):
self.heartbeat_master_address = "{}:{}".format(get_ip_address(), self.heartbeat_master_address = "{}:{}".format(get_ip_address(),
heartbeat_master_port) heartbeat_master_port)
self.heartbeat_socket_initialized.set() self.heartbeat_socket_initialized.set()
connected = False
while self.client_is_alive and self.master_is_alive: while self.client_is_alive and self.master_is_alive:
try: try:
message = socket.recv_multipart() message = socket.recv_multipart()
...@@ -173,9 +180,16 @@ class Client(object): ...@@ -173,9 +180,16 @@ class Client(object):
to_byte(str(self.actor_num)), to_byte(str(self.actor_num)),
to_byte(str(elapsed_time)) to_byte(str(elapsed_time))
]) ])
connected = True
except zmq.error.Again as e: except zmq.error.Again as e:
logger.warning("[Client] Cannot connect to the master." if connected:
"Please check if it is still alive.") logger.warning("[Client] Cannot connect to the master."
"Please check if it is still alive.")
else:
logger.warning(
"[Client] Cannot connect to the master."
"Please check the firewall between client and master.(e.g., ping the master IP)"
)
self.master_is_alive = False self.master_is_alive = False
socket.close(0) socket.close(0)
logger.warning("Client exit replying heartbeat for master.") logger.warning("Client exit replying heartbeat for master.")
...@@ -331,10 +345,6 @@ def connect(master_address, distributed_files=[]): ...@@ -331,10 +345,6 @@ def connect(master_address, distributed_files=[]):
"{ip}:{port} format" "{ip}:{port} format"
global GLOBAL_CLIENT global GLOBAL_CLIENT
addr = master_address.split(":")[0] addr = master_address.split(":")[0]
assert ping(
addr
) == 0, "Error occurs in connection with {}. PARL failed to ping this IP.".format(
master_address)
cur_process_id = os.getpid() cur_process_id = os.getpid()
if GLOBAL_CLIENT is None: if GLOBAL_CLIENT is None:
GLOBAL_CLIENT = Client(master_address, cur_process_id, GLOBAL_CLIENT = Client(master_address, cur_process_id,
......
...@@ -69,7 +69,7 @@ class Job(object): ...@@ -69,7 +69,7 @@ class Job(object):
self.run_job_process.start() self.run_job_process.start()
""" """
NOTE: NOTE:
In Windows, it will raise errors when creating threading.Lock before starting multiprocess.Process. In Windows, it will raise errors when creating threading.Lock before starting multiprocess.Process.
""" """
self.lock = threading.Lock() self.lock = threading.Lock()
self._create_sockets() self._create_sockets()
...@@ -87,7 +87,7 @@ class Job(object): ...@@ -87,7 +87,7 @@ class Job(object):
_ = self.kill_job_socket.recv_multipart() _ = self.kill_job_socket.recv_multipart()
except zmq.error.Again as e: except zmq.error.Again as e:
pass pass
os._exit(1) os._exit(0)
def _create_sockets(self): def _create_sockets(self):
"""Create five sockets for each job in main process. """Create five sockets for each job in main process.
...@@ -243,7 +243,7 @@ class Job(object): ...@@ -243,7 +243,7 @@ class Job(object):
the python files to the job. Later, the job will save these files to a the python files to the job. Later, the job will save these files to a
temporary directory and add the temporary diretory to Python's working temporary directory and add the temporary diretory to Python's working
directory. directory.
Args: Args:
reply_socket (sockert): main socket to accept commands of remote object. reply_socket (sockert): main socket to accept commands of remote object.
job_address (String): address of reply_socket. job_address (String): address of reply_socket.
...@@ -301,8 +301,12 @@ class Job(object): ...@@ -301,8 +301,12 @@ class Job(object):
if tag == remote_constants.INIT_OBJECT_TAG: if tag == remote_constants.INIT_OBJECT_TAG:
try: try:
cls = cloudpickle.loads(message[1]) file_name, class_name = cloudpickle.loads(message[1])
#/home/nlp-ol/Firework/baidu/nlp/evokit/python_api/es_agent -> es_agent
file_name = file_name.split(os.sep)[-1]
args, kwargs = cloudpickle.loads(message[2]) args, kwargs = cloudpickle.loads(message[2])
mod = __import__(file_name)
cls = getattr(mod, class_name)._original
obj = cls(*args, **kwargs) obj = cls(*args, **kwargs)
except Exception as e: except Exception as e:
traceback_str = str(traceback.format_exc()) traceback_str = str(traceback.format_exc())
......
...@@ -18,6 +18,7 @@ import threading ...@@ -18,6 +18,7 @@ import threading
import time import time
import zmq import zmq
import numpy as np import numpy as np
import inspect
from parl.utils import get_ip_address, logger, to_str, to_byte from parl.utils import get_ip_address, logger, to_str, to_byte
from parl.utils.communication import loads_argument, loads_return,\ from parl.utils.communication import loads_argument, loads_return,\
...@@ -55,7 +56,7 @@ def remote_class(*args, **kwargs): ...@@ -55,7 +56,7 @@ def remote_class(*args, **kwargs):
actor = Actor() actor = Actor()
actor.step() actor.step()
# Set maximum memory usage to 300 MB for each object. # Set maximum memory usage to 300 MB for each object.
@parl.remote_class(max_memory=300) @parl.remote_class(max_memory=300)
class LimitedActor(object): class LimitedActor(object):
... ...
...@@ -113,10 +114,11 @@ def remote_class(*args, **kwargs): ...@@ -113,10 +114,11 @@ def remote_class(*args, **kwargs):
self.job_shutdown = False self.job_shutdown = False
self.send_file(self.job_socket) self.send_file(self.job_socket)
file_name = inspect.getfile(cls)[:-3]
class_name = cls.__name__
self.job_socket.send_multipart([ self.job_socket.send_multipart([
remote_constants.INIT_OBJECT_TAG, remote_constants.INIT_OBJECT_TAG,
cloudpickle.dumps(cls), cloudpickle.dumps([file_name, class_name]),
cloudpickle.dumps([args, kwargs]), cloudpickle.dumps([args, kwargs]),
]) ])
message = self.job_socket.recv_multipart() message = self.job_socket.recv_multipart()
...@@ -128,6 +130,7 @@ def remote_class(*args, **kwargs): ...@@ -128,6 +130,7 @@ def remote_class(*args, **kwargs):
def __del__(self): def __del__(self):
"""Delete the remote class object and release remote resources.""" """Delete the remote class object and release remote resources."""
self.job_socket.setsockopt(zmq.RCVTIMEO, 1 * 1000)
if not self.job_shutdown: if not self.job_shutdown:
try: try:
self.job_socket.send_multipart( self.job_socket.send_multipart(
...@@ -212,6 +215,7 @@ def remote_class(*args, **kwargs): ...@@ -212,6 +215,7 @@ def remote_class(*args, **kwargs):
return wrapper return wrapper
RemoteWrapper._original = cls
return RemoteWrapper return RemoteWrapper
max_memory = kwargs.get('max_memory') max_memory = kwargs.get('max_memory')
......
...@@ -22,6 +22,7 @@ import threading ...@@ -22,6 +22,7 @@ import threading
from parl.remote.client import disconnect from parl.remote.client import disconnect
from parl.remote import exceptions from parl.remote import exceptions
import subprocess import subprocess
from parl.utils import logger
@parl.remote_class @parl.remote_class
...@@ -62,20 +63,24 @@ class TestCluster(unittest.TestCase): ...@@ -62,20 +63,24 @@ class TestCluster(unittest.TestCase):
disconnect() disconnect()
def test_actor_exception(self): def test_actor_exception(self):
master = Master(port=1235) logger.info("running:test_actor_exception")
master = Master(port=8235)
th = threading.Thread(target=master.run) th = threading.Thread(target=master.run)
th.start() th.start()
time.sleep(3) time.sleep(3)
worker1 = Worker('localhost:1235', 1) worker1 = Worker('localhost:8235', 1)
for _ in range(3): for _ in range(3):
if master.cpu_num == 1: if master.cpu_num == 1:
break break
time.sleep(10) time.sleep(10)
self.assertEqual(1, master.cpu_num) self.assertEqual(1, master.cpu_num)
parl.connect('localhost:1235') logger.info("running:test_actor_exception: 0")
parl.connect('localhost:8235')
logger.info("running:test_actor_exception: 1")
with self.assertRaises(exceptions.RemoteError): with self.assertRaises(exceptions.RemoteError):
actor = Actor(abcd='a bug') actor = Actor(abcd='a bug')
logger.info("running:test_actor_exception: 2")
actor2 = Actor() actor2 = Actor()
for _ in range(3): for _ in range(3):
...@@ -88,14 +93,15 @@ class TestCluster(unittest.TestCase): ...@@ -88,14 +93,15 @@ class TestCluster(unittest.TestCase):
master.exit() master.exit()
worker1.exit() worker1.exit()
def test_actor_exception(self): def test_actor_exception_2(self):
master = Master(port=1236) logger.info("running: test_actor_exception_2")
master = Master(port=8236)
th = threading.Thread(target=master.run) th = threading.Thread(target=master.run)
th.start() th.start()
time.sleep(3) time.sleep(3)
worker1 = Worker('localhost:1236', 1) worker1 = Worker('localhost:8236', 1)
self.assertEqual(1, master.cpu_num) self.assertEqual(1, master.cpu_num)
parl.connect('localhost:1236') parl.connect('localhost:8236')
actor = Actor() actor = Actor()
try: try:
actor.will_raise_exception_func() actor.will_raise_exception_func()
...@@ -114,14 +120,15 @@ class TestCluster(unittest.TestCase): ...@@ -114,14 +120,15 @@ class TestCluster(unittest.TestCase):
master.exit() master.exit()
def test_reset_actor(self): def test_reset_actor(self):
logger.info("running: test_reset_actor")
# start the master # start the master
master = Master(port=1237) master = Master(port=8237)
th = threading.Thread(target=master.run) th = threading.Thread(target=master.run)
th.start() th.start()
time.sleep(3) time.sleep(3)
worker1 = Worker('localhost:1237', 4) worker1 = Worker('localhost:8237', 4)
parl.connect('localhost:1237') parl.connect('localhost:8237')
for _ in range(10): for _ in range(10):
actor = Actor() actor = Actor()
ret = actor.add_one(1) ret = actor.add_one(1)
...@@ -138,19 +145,20 @@ class TestCluster(unittest.TestCase): ...@@ -138,19 +145,20 @@ class TestCluster(unittest.TestCase):
master.exit() master.exit()
def test_add_worker(self): def test_add_worker(self):
master = Master(port=1234) logger.info("running: test_add_worker")
master = Master(port=8234)
th = threading.Thread(target=master.run) th = threading.Thread(target=master.run)
th.start() th.start()
time.sleep(1) time.sleep(1)
worker1 = Worker('localhost:1234', 4) worker1 = Worker('localhost:8234', 4)
for _ in range(3): for _ in range(3):
if master.cpu_num == 4: if master.cpu_num == 4:
break break
time.sleep(10) time.sleep(10)
self.assertEqual(master.cpu_num, 4) self.assertEqual(master.cpu_num, 4)
worker2 = Worker('localhost:1234', 4) worker2 = Worker('localhost:8234', 4)
for _ in range(3): for _ in range(3):
if master.cpu_num == 8: if master.cpu_num == 8:
break break
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
import unittest
from parl.remote.client import disconnect
class TestPingMaster(unittest.TestCase):
def tearDown(self):
disconnect()
def test_throw_exception(self):
with self.assertRaises(AssertionError):
parl.connect("176.2.3.4:8080")
if __name__ == '__main__':
unittest.main()
...@@ -78,6 +78,7 @@ class Worker(object): ...@@ -78,6 +78,7 @@ class Worker(object):
# create a thread that waits commands from the job to kill the job. # create a thread that waits commands from the job to kill the job.
self.kill_job_thread = threading.Thread(target=self._reply_kill_job) self.kill_job_thread = threading.Thread(target=self._reply_kill_job)
self.kill_job_thread.setDaemon(True)
self.kill_job_thread.start() self.kill_job_thread.start()
self._create_jobs() self._create_jobs()
...@@ -169,6 +170,7 @@ class Worker(object): ...@@ -169,6 +170,7 @@ class Worker(object):
def _fill_job_buffer(self): def _fill_job_buffer(self):
"""An endless loop that adds initialized job into the job buffer""" """An endless loop that adds initialized job into the job buffer"""
initialized_jobs = []
while self.worker_is_alive: while self.worker_is_alive:
if self.job_buffer.full() is False: if self.job_buffer.full() is False:
job_num = self.cpu_num - self.job_buffer.qsize() job_num = self.cpu_num - self.job_buffer.qsize()
...@@ -178,13 +180,7 @@ class Worker(object): ...@@ -178,13 +180,7 @@ class Worker(object):
self.job_buffer.put(job) self.job_buffer.put(job)
time.sleep(0.02) time.sleep(0.02)
self.exit()
# release jobs if the worker is not alive
for job in initialized_jobs:
try:
os.kill(job.pid, signal.SIGTERM)
except OSError:
pass
def _init_jobs(self, job_num): def _init_jobs(self, job_num):
"""Create jobs. """Create jobs.
...@@ -223,6 +219,7 @@ class Worker(object): ...@@ -223,6 +219,7 @@ class Worker(object):
# a thread for sending heartbeat signals to job # a thread for sending heartbeat signals to job
thread = threading.Thread( thread = threading.Thread(
target=self._create_job_monitor, args=(initialized_job, )) target=self._create_job_monitor, args=(initialized_job, ))
thread.setDaemon(True)
thread.start() thread.start()
self.lock.release() self.lock.release()
assert len(new_jobs) > 0, "init jobs failed" assert len(new_jobs) > 0, "init jobs failed"
...@@ -354,15 +351,19 @@ class Worker(object): ...@@ -354,15 +351,19 @@ class Worker(object):
break break
socket.close(0) socket.close(0)
logger.warning( logger.warning(
"[Worker] lost connection with the master, will exit replying heartbeat for master." "[Worker] lost connection with the master, will exit reply heartbeat for master."
) )
self.worker_status.clear() self.worker_status.clear()
# exit the worker # exit the worker
self.worker_is_alive = False self.worker_is_alive = False
self.exit()
def exit(self): def exit(self):
"""close the worker""" """close the worker"""
self.worker_is_alive = False self.worker_is_alive = False
command = ('ps aux | grep "remote/job.py.*{}"'.format(
self.reply_job_address) + " | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
def run(self): def run(self):
"""Keep running until it lost connection with the master. """Keep running until it lost connection with the master.
......
...@@ -16,12 +16,9 @@ import cloudpickle ...@@ -16,12 +16,9 @@ import cloudpickle
import pyarrow import pyarrow
import subprocess import subprocess
import os import os
from parl.utils import _IS_WINDOWS
from parl.utils import SerializeError, DeserializeError from parl.utils import SerializeError, DeserializeError
__all__ = [ __all__ = ['dumps_argument', 'loads_argument', 'dumps_return', 'loads_return']
'dumps_argument', 'loads_argument', 'dumps_return', 'loads_return', 'ping'
]
# Reference: https://github.com/apache/arrow/blob/f88474c84e7f02e226eb4cc32afef5e2bbc6e5b4/python/pyarrow/tests/test_serialization.py#L658-L682 # Reference: https://github.com/apache/arrow/blob/f88474c84e7f02e226eb4cc32afef5e2bbc6e5b4/python/pyarrow/tests/test_serialization.py#L658-L682
...@@ -120,23 +117,3 @@ def loads_return(data): ...@@ -120,23 +117,3 @@ def loads_return(data):
raise DeserializeError(e) raise DeserializeError(e)
return ret return ret
#Reference: https://stackoverflow.com/questions/2953462/pinging-servers-in-python
def ping(host):
"""
Returns True if host (str) responds to a ping request.
Remember that a host may not respond to a ping (ICMP) request even if the host name is valid.
"""
# Option for the number of packets as a function of
param = '-n' if _IS_WINDOWS else '-c'
# Building the command. Ex: "ping -c 1 google.com"
command = ['ping', param, '1', host]
FNULL = open(os.devnull, 'w')
child = subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL.close()
child.communicate()[0]
return child.returncode
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册