提交 eff8e974 编写于 作者: B Bo Zhou 提交者: Hongsheng Zeng

Unittest in parallel (#135)

* upgrade unit test to reduce test time

* minor change

* fix the bug in unit tests

* fix the bug in build.sh

* Update build.sh

* Update test.sh

* fix a potential bug that will hung the test

* then

* solve the problem in documentation
上级 f2a1fda6
......@@ -15,12 +15,6 @@
# A dev image based on paddle production image
FROM parl/parl-test:1.1-cuda9.0-cudnn7-docs
FROM parl/parl-test:cuda9.0-cudnn7-v1
COPY ./requirements.txt /root/
# Requirements for python2
RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r /root/requirements.txt
# Requirements for python3
RUN pip3.6 install -i https://pypi.tuna.tsinghua.edu.cn/simple -r /root/requirements.txt
......@@ -21,6 +21,8 @@ function init() {
NONE='\033[0m'
REPO_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../" && pwd )"
source ~/.bashrc
export PATH="/root/miniconda3/bin:$PATH"
}
function print_usage() {
......@@ -56,7 +58,8 @@ function check_style() {
}
function run_test_with_gpu() {
export FLAGS_fraction_of_gpu_memory_to_use=0.5
unset CUDA_VISIBLE_DEVICES
export FLAGS_fraction_of_gpu_memory_to_use=0.05
mkdir -p ${REPO_ROOT}/build
cd ${REPO_ROOT}/build
......@@ -66,7 +69,8 @@ function run_test_with_gpu() {
Running unit tests with GPU...
========================================
EOF
ctest --output-on-failure
ctest --output-on-failure -j10
cd ${REPO_ROOT}
rm -rf ${REPO_ROOT}/build
}
......@@ -75,13 +79,22 @@ function run_test_with_cpu() {
mkdir -p ${REPO_ROOT}/build
cd ${REPO_ROOT}/build
cmake ..
if [ $# -eq 1 ];then
cmake ..
else
cmake .. -DIS_TESTING_SERIALLY=ON
fi
cat <<EOF
========================================
Running unit tests with CPU...
========================================
=====================================================
Running unit tests with CPU in the environment: $1
=====================================================
EOF
ctest --output-on-failure
if [ $# -eq 1 ];then
ctest --output-on-failure -j10
else
ctest --output-on-failure
fi
cd ${REPO_ROOT}
rm -rf ${REPO_ROOT}/build
}
......@@ -99,6 +112,7 @@ function run_import_test {
========================================
EOF
ctest --output-on-failure
cd ${REPO_ROOT}
rm -rf ${REPO_ROOT}/build
}
......@@ -116,6 +130,7 @@ function run_docs_test {
========================================
EOF
ctest --output-on-failure
cd ${REPO_ROOT}
rm -rf ${REPO_ROOT}/build
}
......@@ -129,11 +144,26 @@ function main() {
check_style
;;
test)
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple .
pip3.6 install -i https://pypi.tuna.tsinghua.edu.cn/simple .
/root/miniconda3/envs/empty_env/bin/pip install -i https://pypi.tuna.tsinghua.edu.cn/simple .
# test code compability in environments with various python versions
declare -a envs=("py27" "py36" "py37")
for env in "${envs[@]}";do
cd /work
source ~/.bashrc
export PATH="/root/miniconda3/bin:$PATH"
source activate $env
echo ========================================
echo Running tests in $env ..
echo `which pip`
echo ========================================
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple .
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r .teamcity/requirements.txt
run_test_with_cpu $env
run_test_with_cpu $env "DIS_TESTING_SERIALLY"
done
run_test_with_gpu
run_test_with_cpu
#
/root/miniconda3/envs/empty_env/bin/pip install -i https://pypi.tuna.tsinghua.edu.cn/simple .
run_import_test
run_docs_test
;;
......
......@@ -17,34 +17,23 @@ cmake_minimum_required(VERSION 3.0)
enable_testing()
option(WITH_TESTING "Include unit testing" ON)
option(IS_TESTING_SERIALLY "testing scripts that cannot run in parallel" OFF)
option(IS_TESTING_IMPORT "testing import parl" OFF)
option(IS_TESTING_DOCS "testing compling the docs" OFF)
option(IS_TESTING_GPU "testing GPU environment" OFF)
set(PADDLE_PYTHON_PATH "" CACHE STRING "Python path to PaddlePaddle Fluid")
function(py3_test TARGET_NAME)
function(py_test TARGET_NAME)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS ENVS)
cmake_parse_arguments(py3_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
#TODO: add real python2 env.
add_test(NAME ${TARGET_NAME}_with_python3
COMMAND python3.6 ${py3_test_SRCS} ${py3_test_ARGS}
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}
COMMAND python -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endfunction()
#function(py2_test TARGET_NAME)
# set(options "")
# set(oneValueArgs "")
# set(multiValueArgs SRCS DEPS ARGS ENVS)
# cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
# #TODO: add real python2 env.
# add_test(NAME ${TARGET_NAME}_with_python2
# COMMAND python ${py_test_SRCS} ${py_test_ARGS}
# WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
#endfunction()
function(import_test TARGET_NAME)
set(options "")
set(oneValueArgs "")
......@@ -67,18 +56,24 @@ if (WITH_TESTING)
if (IS_TESTING_IMPORT)
set(src "parl/tests/import_test")
import_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
elseif (IS_TESTING_DOCS)
elseif (IS_TESTING_DOCS)
docs_test()
elseif (IS_TESTING_SERIALLY)
file(GLOB_RECURSE TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_test_alone.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
endforeach()
else ()
file(GLOB_RECURSE TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_test.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(src ${TEST_OPS})
if (${src} MATCHES ".*remote.*")
if (NOT IS_TESTING_GPU)
py3_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
py_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
endif()
else()
py3_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
py_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
endif()
endforeach()
endif()
......
......@@ -2,6 +2,7 @@
cd "$(dirname "$0")"
source ~/.bashrc
export PATH="/root/miniconda3/bin:$PATH"
source deactivate
source activate docs
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple /work/
make html
......@@ -91,7 +91,6 @@ class Job(object):
target=self._reply_worker_heartbeat,
args=(worker_heartbeat_socket, ))
worker_thread.setDaemon(True)
worker_thread.start()
# a thread that reply heartbeat signals from the client
client_heartbeat_socket, client_heartbeat_address = self._create_heartbeat_server(
......@@ -110,6 +109,7 @@ class Job(object):
[remote_constants.NORMAL_TAG,
cloudpickle.dumps(initialized_job)])
message = self.job_socket.recv_multipart()
worker_thread.start()
assert message[0] == remote_constants.NORMAL_TAG
# create the kill_job_socket
......@@ -160,7 +160,10 @@ class Job(object):
self.kill_job_socket.send_multipart(
[remote_constants.KILLJOB_TAG,
to_byte(self.job_address)])
_ = self.kill_job_socket.recv_multipart()
try:
_ = self.kill_job_socket.recv_multipart()
except zmq.error.Again as e:
pass
logger.warning("[Job]lost connection with the client, will exit")
os._exit(1)
......@@ -262,6 +265,7 @@ class Job(object):
# receive source code from the actor and append them to the environment variables.
envdir = self.wait_for_files()
sys.path.append(envdir)
self.client_is_alive = True
self.client_thread.start()
try:
......@@ -278,7 +282,11 @@ class Job(object):
self.kill_job_socket.send_multipart(
[remote_constants.KILLJOB_TAG,
to_byte(self.job_address)])
_ = self.kill_job_socket.recv_multipart()
try:
_ = self.kill_job_socket.recv_multipart()
except zmq.error.Again as e:
pass
os._exit(1)
def single_task(self, obj):
"""An infinite loop waiting for commands from the remote object.
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.monitor import ClusterMonitor
import time
import threading
from parl.remote.client import disconnect
from parl.remote import exceptions
import timeout_decorator
import subprocess
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def get_arg1(self):
return self.arg1
def get_arg2(self):
return self.arg2
def set_arg1(self, value):
self.arg1 = value
def set_arg2(self, value):
self.arg2 = value
def get_unable_serialize_object(self):
return UnableSerializeObject()
def add_one(self, value):
value += 1
return value
def add(self, x, y):
time.sleep(3)
return x + y
def will_raise_exception_func(self):
x = 1 / 0
class TestClusterMonitor(unittest.TestCase):
def tearDown(self):
disconnect()
def test_add_actor(self):
port = 1441
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
cluster_monitor = ClusterMonitor('localhost:{}'.format(port))
time.sleep(1)
self.assertEqual(0, len(cluster_monitor.data['clients']))
parl.connect('localhost:{}'.format(port))
time.sleep(10)
self.assertEqual(1, len(cluster_monitor.data['clients']))
self.assertEqual(1, cluster_monitor.data['workers'][0]['vacant_cpus'])
actor = Actor()
time.sleep(20)
self.assertEqual(0, cluster_monitor.data['workers'][0]['vacant_cpus'])
self.assertEqual(1, cluster_monitor.data['workers'][0]['used_cpus'])
self.assertEqual(1, cluster_monitor.data['clients'][0]['actor_num'])
del actor
time.sleep(40)
self.assertEqual(0, cluster_monitor.data['clients'][0]['actor_num'])
self.assertEqual(1, cluster_monitor.data['workers'][0]['vacant_cpus'])
worker.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.monitor import ClusterMonitor
import time
import threading
from parl.remote.client import disconnect
from parl.remote import exceptions
import timeout_decorator
import subprocess
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def get_arg1(self):
return self.arg1
def get_arg2(self):
return self.arg2
def set_arg1(self, value):
self.arg1 = value
def set_arg2(self, value):
self.arg2 = value
def get_unable_serialize_object(self):
return UnableSerializeObject()
def add_one(self, value):
value += 1
return value
def add(self, x, y):
time.sleep(3)
return x + y
def will_raise_exception_func(self):
x = 1 / 0
class TestClusterMonitor(unittest.TestCase):
def tearDown(self):
disconnect()
def test_twenty_worker(self):
port = 1440
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
workers = []
for _ in range(20):
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(1)
workers.append(worker)
cluster_monitor = ClusterMonitor('localhost:{}'.format(port))
time.sleep(1)
self.assertEqual(20, len(cluster_monitor.data['workers']))
for i in range(10):
workers[i].exit()
time.sleep(40)
self.assertEqual(10, len(cluster_monitor.data['workers']))
for i in range(10, 20):
workers[i].exit()
time.sleep(40)
self.assertEqual(0, len(cluster_monitor.data['workers']))
master.exit()
if __name__ == '__main__':
unittest.main()
......@@ -78,60 +78,6 @@ class TestClusterMonitor(unittest.TestCase):
self.assertEqual(0, len(cluster_monitor.data['workers']))
master.exit()
def test_twenty_worker(self):
port = 1440
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
workers = []
for _ in range(20):
worker = Worker('localhost:{}'.format(port), 1)
workers.append(worker)
cluster_monitor = ClusterMonitor('localhost:{}'.format(port))
time.sleep(1)
self.assertEqual(20, len(cluster_monitor.data['workers']))
for i in range(10):
workers[i].exit()
time.sleep(40)
self.assertEqual(10, len(cluster_monitor.data['workers']))
for i in range(10, 20):
workers[i].exit()
time.sleep(40)
self.assertEqual(0, len(cluster_monitor.data['workers']))
master.exit()
def test_add_actor(self):
port = 1441
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
cluster_monitor = ClusterMonitor('localhost:{}'.format(port))
time.sleep(1)
self.assertEqual(0, len(cluster_monitor.data['clients']))
parl.connect('localhost:{}'.format(port))
time.sleep(10)
self.assertEqual(1, len(cluster_monitor.data['clients']))
self.assertEqual(1, cluster_monitor.data['workers'][0]['vacant_cpus'])
actor = Actor()
time.sleep(20)
self.assertEqual(0, cluster_monitor.data['workers'][0]['vacant_cpus'])
self.assertEqual(1, cluster_monitor.data['workers'][0]['used_cpus'])
self.assertEqual(1, cluster_monitor.data['clients'][0]['actor_num'])
del actor
time.sleep(40)
self.assertEqual(0, cluster_monitor.data['clients'][0]['actor_num'])
self.assertEqual(1, cluster_monitor.data['workers'][0]['vacant_cpus'])
worker.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
......@@ -61,9 +61,6 @@ class Actor(object):
class TestCluster(unittest.TestCase):
def tearDown(self):
disconnect()
#time.sleep(20)
#command = ("pkill -f remote/job.py")
#subprocess.call([command], shell=True)
def test_actor_exception(self):
master = Master(port=1235)
......@@ -136,7 +133,7 @@ class TestCluster(unittest.TestCase):
self.assertEqual(master.cpu_num, 8)
worker2.exit()
time.sleep(30)
time.sleep(50)
self.assertEqual(master.cpu_num, 4)
master.exit()
......
......@@ -45,6 +45,7 @@ class TestCluster(unittest.TestCase):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
disconnect()
def _create_actor(self):
for _ in range(2):
......@@ -52,7 +53,7 @@ class TestCluster(unittest.TestCase):
ret = actor.add_one(1)
self.assertEqual(ret, 2)
@timeout_decorator.timeout(seconds=60)
@timeout_decorator.timeout(seconds=300)
def test_connect_and_create_actor_in_multiprocessing_with_connected_in_main_process(
self):
# start the master
......@@ -80,57 +81,6 @@ class TestCluster(unittest.TestCase):
worker1.exit()
master.exit()
@timeout_decorator.timeout(seconds=60)
def test_connect_and_create_actor_in_multiprocessing_without_connected_in_main_process(
self):
# start the master
master = Master(port=8239)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:8239', 4)
proc1 = multiprocessing.Process(
target=self._connect_and_create_actor, args=('localhost:8239', ))
proc2 = multiprocessing.Process(
target=self._connect_and_create_actor, args=('localhost:8239', ))
proc1.start()
proc2.start()
proc1.join()
proc2.join()
self.assertRaises(AssertionError, self._create_actor)
worker1.exit()
master.exit()
@timeout_decorator.timeout(seconds=60)
def test_create_actor_in_multiprocessing(self):
# start the master
master = Master(port=8240)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:8240', 4)
parl.connect('localhost:8240')
proc1 = multiprocessing.Process(target=self._create_actor)
proc2 = multiprocessing.Process(target=self._create_actor)
proc1.start()
proc2.start()
proc1.join()
proc2.join()
# make sure that the client of the main process still works
self._create_actor()
worker1.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
import time
import threading
import timeout_decorator
import multiprocessing
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def add_one(self, value):
value += 1
return value
class TestCluster(unittest.TestCase):
def tearDown(self):
disconnect()
def _connect_and_create_actor(self, cluster_addr):
parl.connect(cluster_addr)
for _ in range(2):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
disconnect()
def _create_actor(self):
for _ in range(2):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
@timeout_decorator.timeout(seconds=300)
def test_connect_and_create_actor_in_multiprocessing_without_connected_in_main_process(
self):
# start the master
master = Master(port=8239)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:8239', 4)
proc1 = multiprocessing.Process(
target=self._connect_and_create_actor, args=('localhost:8239', ))
proc2 = multiprocessing.Process(
target=self._connect_and_create_actor, args=('localhost:8239', ))
proc1.start()
proc2.start()
proc1.join()
proc2.join()
self.assertRaises(AssertionError, self._create_actor)
worker1.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
import time
import threading
import timeout_decorator
import multiprocessing
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def add_one(self, value):
value += 1
return value
class TestCluster(unittest.TestCase):
def tearDown(self):
disconnect()
def _connect_and_create_actor(self, cluster_addr):
parl.connect(cluster_addr)
for _ in range(2):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
disconnect()
def _create_actor(self):
for _ in range(2):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
@timeout_decorator.timeout(seconds=300)
def test_create_actor_in_multiprocessing(self):
# start the master
master = Master(port=8240)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:8240', 4)
parl.connect('localhost:8240')
proc1 = multiprocessing.Process(target=self._create_actor)
proc2 = multiprocessing.Process(target=self._create_actor)
proc1.start()
proc2.start()
proc1.join()
proc2.join()
print("[test_create_actor_in_multiprocessing] Join")
# make sure that the client of the main process still works
self._create_actor()
worker1.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
......@@ -62,27 +62,6 @@ class TestJob(unittest.TestCase):
def tearDown(self):
disconnect()
def test_job_exit_exceptionally(self):
master = Master(port=1334)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1334', 4)
time.sleep(10)
self.assertEqual(worker1.job_buffer.full(), True)
time.sleep(1)
self.assertEqual(master.cpu_num, 4)
print("We are going to kill all the jobs.")
command = ("pkill -f remote/job.py")
subprocess.call([command], shell=True)
parl.connect('localhost:1334')
actor = Actor()
self.assertEqual(actor.add_one(1), 2)
time.sleep(20)
master.exit()
worker1.exit()
@timeout_decorator.timeout(seconds=300)
def test_acor_exit_exceptionally(self):
master = Master(port=1335)
......@@ -94,7 +73,7 @@ class TestJob(unittest.TestCase):
file_path = __file__.replace('reset_job_test', 'simulate_client')
command = [sys.executable, file_path]
proc = subprocess.Popen(command)
time.sleep(10)
time.sleep(20)
self.assertEqual(master.cpu_num, 0)
proc.kill()
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
from parl.utils import logger
import threading
import time
import subprocess
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def get_arg1(self):
return self.arg1
def get_arg2(self):
return self.arg2
def set_arg1(self, value):
self.arg1 = value
def set_arg2(self, value):
self.arg2 = value
def get_unable_serialize_object(self):
return UnableSerializeObject()
def add_one(self, value):
value += 1
return value
def add(self, x, y):
time.sleep(3)
return x + y
def will_raise_exception_func(self):
x = 1 / 0
class TestJobAlone(unittest.TestCase):
def tearDown(self):
disconnect()
def test_job_exit_exceptionally(self):
master = Master(port=1334)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1334', 4)
time.sleep(10)
self.assertEqual(worker1.job_buffer.full(), True)
time.sleep(1)
self.assertEqual(master.cpu_num, 4)
print("We are going to kill all the jobs.")
command = ("pkill -f remote/job.py")
subprocess.call([command], shell=True)
parl.connect('localhost:1334')
actor = Actor()
self.assertEqual(actor.add_one(1), 2)
time.sleep(20)
master.exit()
worker1.exit()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册