未验证 提交 500b189f 编写于 作者: B Bo Zhou 提交者: GitHub

stop using picklecloud to transmit the remote class (#281)

* stop using pickle cloud to transmit the remote class
上级 7d6e8719
......@@ -22,6 +22,3 @@ COPY ./requirements.txt /root/
RUN apt-get install -y libgflags-dev libgoogle-glog-dev libomp-dev unzip
RUN apt-get install -y libgtest-dev && cd /usr/src/gtest && mkdir build \
&& cd build && cmake .. && make && cp libgtest*.a /usr/local/lib
RUN apt-get update
RUN apt-get install -y iputils-ping
......@@ -69,7 +69,7 @@ function run_test_with_gpu() {
Running unit tests with GPU...
========================================
EOF
ctest --output-on-failure -j10
ctest --output-on-failure -j20 --verbose
cd ${REPO_ROOT}
rm -rf ${REPO_ROOT}/build
}
......@@ -90,7 +90,7 @@ function run_test_with_cpu() {
=====================================================
EOF
if [ $# -eq 1 ];then
ctest --output-on-failure -j10
ctest --output-on-failure -j20 --verbose
else
ctest --output-on-failure
fi
......@@ -145,7 +145,8 @@ function main() {
;;
test)
# test code compability in environments with various python versions
declare -a envs=("py36_torch" "py37_torch" "py27" "py36" "py37")
#declare -a envs=("py36_torch" "py37_torch" "py27" "py36" "py37")
declare -a envs=("py27" "py36")
for env in "${envs[@]}";do
cd /work
source ~/.bashrc
......@@ -169,6 +170,10 @@ function main() {
pip install -r .teamcity/requirements_torch.txt
run_test_with_cpu $env "DIS_TESTING_TORCH"
fi
# clean env
export LC_ALL=C.UTF-8
export LANG=C.UTF-8
xparl stop
done
run_test_with_gpu
......
......@@ -33,6 +33,7 @@ function(py_test TARGET_NAME)
add_test(NAME ${TARGET_NAME}
COMMAND python -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 300)
endfunction()
function(import_test TARGET_NAME)
......
......@@ -20,7 +20,6 @@ import sys
import threading
import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.utils.communication import ping
from parl.remote import remote_constants
import time
......@@ -109,6 +108,13 @@ class Client(object):
with open(file, 'rb') as f:
content = f.read()
pyfiles['other_files'][file] = content
# append entry file to code list
main_file = sys.argv[0]
with open(main_file, 'rb') as code_file:
code = code_file.read()
# parl/remote/remote_decorator.py -> remote_decorator.py
file_name = main_file.split(os.sep)[-1]
pyfiles['python_files'][file_name] = code
except AssertionError as e:
raise Exception(
'Failed to create the client, the file {} does not exist.'.
......@@ -162,6 +168,7 @@ class Client(object):
self.heartbeat_master_address = "{}:{}".format(get_ip_address(),
heartbeat_master_port)
self.heartbeat_socket_initialized.set()
connected = False
while self.client_is_alive and self.master_is_alive:
try:
message = socket.recv_multipart()
......@@ -173,9 +180,16 @@ class Client(object):
to_byte(str(self.actor_num)),
to_byte(str(elapsed_time))
])
connected = True
except zmq.error.Again as e:
logger.warning("[Client] Cannot connect to the master."
"Please check if it is still alive.")
if connected:
logger.warning("[Client] Cannot connect to the master."
"Please check if it is still alive.")
else:
logger.warning(
"[Client] Cannot connect to the master."
"Please check the firewall between client and master.(e.g., ping the master IP)"
)
self.master_is_alive = False
socket.close(0)
logger.warning("Client exit replying heartbeat for master.")
......@@ -331,10 +345,6 @@ def connect(master_address, distributed_files=[]):
"{ip}:{port} format"
global GLOBAL_CLIENT
addr = master_address.split(":")[0]
assert ping(
addr
) == 0, "Error occurs in connection with {}. PARL failed to ping this IP.".format(
master_address)
cur_process_id = os.getpid()
if GLOBAL_CLIENT is None:
GLOBAL_CLIENT = Client(master_address, cur_process_id,
......
......@@ -69,7 +69,7 @@ class Job(object):
self.run_job_process.start()
"""
NOTE:
In Windows, it will raise errors when creating threading.Lock before starting multiprocess.Process.
In Windows, it will raise errors when creating threading.Lock before starting multiprocess.Process.
"""
self.lock = threading.Lock()
self._create_sockets()
......@@ -87,7 +87,7 @@ class Job(object):
_ = self.kill_job_socket.recv_multipart()
except zmq.error.Again as e:
pass
os._exit(1)
os._exit(0)
def _create_sockets(self):
"""Create five sockets for each job in main process.
......@@ -243,7 +243,7 @@ class Job(object):
the python files to the job. Later, the job will save these files to a
temporary directory and add the temporary diretory to Python's working
directory.
Args:
reply_socket (sockert): main socket to accept commands of remote object.
job_address (String): address of reply_socket.
......@@ -301,8 +301,12 @@ class Job(object):
if tag == remote_constants.INIT_OBJECT_TAG:
try:
cls = cloudpickle.loads(message[1])
file_name, class_name = cloudpickle.loads(message[1])
#/home/nlp-ol/Firework/baidu/nlp/evokit/python_api/es_agent -> es_agent
file_name = file_name.split(os.sep)[-1]
args, kwargs = cloudpickle.loads(message[2])
mod = __import__(file_name)
cls = getattr(mod, class_name)._original
obj = cls(*args, **kwargs)
except Exception as e:
traceback_str = str(traceback.format_exc())
......
......@@ -18,6 +18,7 @@ import threading
import time
import zmq
import numpy as np
import inspect
from parl.utils import get_ip_address, logger, to_str, to_byte
from parl.utils.communication import loads_argument, loads_return,\
......@@ -55,7 +56,7 @@ def remote_class(*args, **kwargs):
actor = Actor()
actor.step()
# Set maximum memory usage to 300 MB for each object.
# Set maximum memory usage to 300 MB for each object.
@parl.remote_class(max_memory=300)
class LimitedActor(object):
...
......@@ -113,10 +114,11 @@ def remote_class(*args, **kwargs):
self.job_shutdown = False
self.send_file(self.job_socket)
file_name = inspect.getfile(cls)[:-3]
class_name = cls.__name__
self.job_socket.send_multipart([
remote_constants.INIT_OBJECT_TAG,
cloudpickle.dumps(cls),
cloudpickle.dumps([file_name, class_name]),
cloudpickle.dumps([args, kwargs]),
])
message = self.job_socket.recv_multipart()
......@@ -128,6 +130,7 @@ def remote_class(*args, **kwargs):
def __del__(self):
"""Delete the remote class object and release remote resources."""
self.job_socket.setsockopt(zmq.RCVTIMEO, 1 * 1000)
if not self.job_shutdown:
try:
self.job_socket.send_multipart(
......@@ -212,6 +215,7 @@ def remote_class(*args, **kwargs):
return wrapper
RemoteWrapper._original = cls
return RemoteWrapper
max_memory = kwargs.get('max_memory')
......
......@@ -22,6 +22,7 @@ import threading
from parl.remote.client import disconnect
from parl.remote import exceptions
import subprocess
from parl.utils import logger
@parl.remote_class
......@@ -62,20 +63,24 @@ class TestCluster(unittest.TestCase):
disconnect()
def test_actor_exception(self):
master = Master(port=1235)
logger.info("running:test_actor_exception")
master = Master(port=8235)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:1235', 1)
worker1 = Worker('localhost:8235', 1)
for _ in range(3):
if master.cpu_num == 1:
break
time.sleep(10)
self.assertEqual(1, master.cpu_num)
parl.connect('localhost:1235')
logger.info("running:test_actor_exception: 0")
parl.connect('localhost:8235')
logger.info("running:test_actor_exception: 1")
with self.assertRaises(exceptions.RemoteError):
actor = Actor(abcd='a bug')
logger.info("running:test_actor_exception: 2")
actor2 = Actor()
for _ in range(3):
......@@ -88,14 +93,15 @@ class TestCluster(unittest.TestCase):
master.exit()
worker1.exit()
def test_actor_exception(self):
master = Master(port=1236)
def test_actor_exception_2(self):
logger.info("running: test_actor_exception_2")
master = Master(port=8236)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:1236', 1)
worker1 = Worker('localhost:8236', 1)
self.assertEqual(1, master.cpu_num)
parl.connect('localhost:1236')
parl.connect('localhost:8236')
actor = Actor()
try:
actor.will_raise_exception_func()
......@@ -114,14 +120,15 @@ class TestCluster(unittest.TestCase):
master.exit()
def test_reset_actor(self):
logger.info("running: test_reset_actor")
# start the master
master = Master(port=1237)
master = Master(port=8237)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:1237', 4)
parl.connect('localhost:1237')
worker1 = Worker('localhost:8237', 4)
parl.connect('localhost:8237')
for _ in range(10):
actor = Actor()
ret = actor.add_one(1)
......@@ -138,19 +145,20 @@ class TestCluster(unittest.TestCase):
master.exit()
def test_add_worker(self):
master = Master(port=1234)
logger.info("running: test_add_worker")
master = Master(port=8234)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1234', 4)
worker1 = Worker('localhost:8234', 4)
for _ in range(3):
if master.cpu_num == 4:
break
time.sleep(10)
self.assertEqual(master.cpu_num, 4)
worker2 = Worker('localhost:1234', 4)
worker2 = Worker('localhost:8234', 4)
for _ in range(3):
if master.cpu_num == 8:
break
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
import unittest
from parl.remote.client import disconnect
class TestPingMaster(unittest.TestCase):
def tearDown(self):
disconnect()
def test_throw_exception(self):
with self.assertRaises(AssertionError):
parl.connect("176.2.3.4:8080")
if __name__ == '__main__':
unittest.main()
......@@ -78,6 +78,7 @@ class Worker(object):
# create a thread that waits commands from the job to kill the job.
self.kill_job_thread = threading.Thread(target=self._reply_kill_job)
self.kill_job_thread.setDaemon(True)
self.kill_job_thread.start()
self._create_jobs()
......@@ -169,6 +170,7 @@ class Worker(object):
def _fill_job_buffer(self):
"""An endless loop that adds initialized job into the job buffer"""
initialized_jobs = []
while self.worker_is_alive:
if self.job_buffer.full() is False:
job_num = self.cpu_num - self.job_buffer.qsize()
......@@ -178,13 +180,7 @@ class Worker(object):
self.job_buffer.put(job)
time.sleep(0.02)
# release jobs if the worker is not alive
for job in initialized_jobs:
try:
os.kill(job.pid, signal.SIGTERM)
except OSError:
pass
self.exit()
def _init_jobs(self, job_num):
"""Create jobs.
......@@ -223,6 +219,7 @@ class Worker(object):
# a thread for sending heartbeat signals to job
thread = threading.Thread(
target=self._create_job_monitor, args=(initialized_job, ))
thread.setDaemon(True)
thread.start()
self.lock.release()
assert len(new_jobs) > 0, "init jobs failed"
......@@ -354,15 +351,19 @@ class Worker(object):
break
socket.close(0)
logger.warning(
"[Worker] lost connection with the master, will exit replying heartbeat for master."
"[Worker] lost connection with the master, will exit reply heartbeat for master."
)
self.worker_status.clear()
# exit the worker
self.worker_is_alive = False
self.exit()
def exit(self):
"""close the worker"""
self.worker_is_alive = False
command = ('ps aux | grep "remote/job.py.*{}"'.format(
self.reply_job_address) + " | awk '{print $2}' | xargs kill -9")
subprocess.call([command], shell=True)
def run(self):
"""Keep running until it lost connection with the master.
......
......@@ -16,12 +16,9 @@ import cloudpickle
import pyarrow
import subprocess
import os
from parl.utils import _IS_WINDOWS
from parl.utils import SerializeError, DeserializeError
__all__ = [
'dumps_argument', 'loads_argument', 'dumps_return', 'loads_return', 'ping'
]
__all__ = ['dumps_argument', 'loads_argument', 'dumps_return', 'loads_return']
# Reference: https://github.com/apache/arrow/blob/f88474c84e7f02e226eb4cc32afef5e2bbc6e5b4/python/pyarrow/tests/test_serialization.py#L658-L682
......@@ -120,23 +117,3 @@ def loads_return(data):
raise DeserializeError(e)
return ret
#Reference: https://stackoverflow.com/questions/2953462/pinging-servers-in-python
def ping(host):
"""
Returns True if host (str) responds to a ping request.
Remember that a host may not respond to a ping (ICMP) request even if the host name is valid.
"""
# Option for the number of packets as a function of
param = '-n' if _IS_WINDOWS else '-c'
# Building the command. Ex: "ping -c 1 google.com"
command = ['ping', param, '1', host]
FNULL = open(os.devnull, 'w')
child = subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL.close()
child.communicate()[0]
return child.returncode
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册