提交 eff8e974 编写于 作者: B Bo Zhou 提交者: Hongsheng Zeng

Unittest in parallel (#135)

* upgrade unit test to reduce test time

* minor change

* fix the bug in unit tests

* fix the bug in build.sh

* Update build.sh

* Update test.sh

* fix a potential bug that will hung the test

* then

* solve the problem in documentation
上级 f2a1fda6
...@@ -15,12 +15,6 @@ ...@@ -15,12 +15,6 @@
# A dev image based on paddle production image # A dev image based on paddle production image
FROM parl/parl-test:1.1-cuda9.0-cudnn7-docs FROM parl/parl-test:cuda9.0-cudnn7-v1
COPY ./requirements.txt /root/ COPY ./requirements.txt /root/
# Requirements for python2
RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r /root/requirements.txt
# Requirements for python3
RUN pip3.6 install -i https://pypi.tuna.tsinghua.edu.cn/simple -r /root/requirements.txt
...@@ -21,6 +21,8 @@ function init() { ...@@ -21,6 +21,8 @@ function init() {
NONE='\033[0m' NONE='\033[0m'
REPO_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../" && pwd )" REPO_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../" && pwd )"
source ~/.bashrc
export PATH="/root/miniconda3/bin:$PATH"
} }
function print_usage() { function print_usage() {
...@@ -56,7 +58,8 @@ function check_style() { ...@@ -56,7 +58,8 @@ function check_style() {
} }
function run_test_with_gpu() { function run_test_with_gpu() {
export FLAGS_fraction_of_gpu_memory_to_use=0.5 unset CUDA_VISIBLE_DEVICES
export FLAGS_fraction_of_gpu_memory_to_use=0.05
mkdir -p ${REPO_ROOT}/build mkdir -p ${REPO_ROOT}/build
cd ${REPO_ROOT}/build cd ${REPO_ROOT}/build
...@@ -66,7 +69,8 @@ function run_test_with_gpu() { ...@@ -66,7 +69,8 @@ function run_test_with_gpu() {
Running unit tests with GPU... Running unit tests with GPU...
======================================== ========================================
EOF EOF
ctest --output-on-failure ctest --output-on-failure -j10
cd ${REPO_ROOT}
rm -rf ${REPO_ROOT}/build rm -rf ${REPO_ROOT}/build
} }
...@@ -75,13 +79,22 @@ function run_test_with_cpu() { ...@@ -75,13 +79,22 @@ function run_test_with_cpu() {
mkdir -p ${REPO_ROOT}/build mkdir -p ${REPO_ROOT}/build
cd ${REPO_ROOT}/build cd ${REPO_ROOT}/build
cmake .. if [ $# -eq 1 ];then
cmake ..
else
cmake .. -DIS_TESTING_SERIALLY=ON
fi
cat <<EOF cat <<EOF
======================================== =====================================================
Running unit tests with CPU... Running unit tests with CPU in the environment: $1
======================================== =====================================================
EOF EOF
ctest --output-on-failure if [ $# -eq 1 ];then
ctest --output-on-failure -j10
else
ctest --output-on-failure
fi
cd ${REPO_ROOT}
rm -rf ${REPO_ROOT}/build rm -rf ${REPO_ROOT}/build
} }
...@@ -99,6 +112,7 @@ function run_import_test { ...@@ -99,6 +112,7 @@ function run_import_test {
======================================== ========================================
EOF EOF
ctest --output-on-failure ctest --output-on-failure
cd ${REPO_ROOT}
rm -rf ${REPO_ROOT}/build rm -rf ${REPO_ROOT}/build
} }
...@@ -116,6 +130,7 @@ function run_docs_test { ...@@ -116,6 +130,7 @@ function run_docs_test {
======================================== ========================================
EOF EOF
ctest --output-on-failure ctest --output-on-failure
cd ${REPO_ROOT}
rm -rf ${REPO_ROOT}/build rm -rf ${REPO_ROOT}/build
} }
...@@ -129,11 +144,26 @@ function main() { ...@@ -129,11 +144,26 @@ function main() {
check_style check_style
;; ;;
test) test)
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple . # test code compability in environments with various python versions
pip3.6 install -i https://pypi.tuna.tsinghua.edu.cn/simple . declare -a envs=("py27" "py36" "py37")
/root/miniconda3/envs/empty_env/bin/pip install -i https://pypi.tuna.tsinghua.edu.cn/simple . for env in "${envs[@]}";do
cd /work
source ~/.bashrc
export PATH="/root/miniconda3/bin:$PATH"
source activate $env
echo ========================================
echo Running tests in $env ..
echo `which pip`
echo ========================================
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple .
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r .teamcity/requirements.txt
run_test_with_cpu $env
run_test_with_cpu $env "DIS_TESTING_SERIALLY"
done
run_test_with_gpu run_test_with_gpu
run_test_with_cpu
#
/root/miniconda3/envs/empty_env/bin/pip install -i https://pypi.tuna.tsinghua.edu.cn/simple .
run_import_test run_import_test
run_docs_test run_docs_test
;; ;;
......
...@@ -17,34 +17,23 @@ cmake_minimum_required(VERSION 3.0) ...@@ -17,34 +17,23 @@ cmake_minimum_required(VERSION 3.0)
enable_testing() enable_testing()
option(WITH_TESTING "Include unit testing" ON) option(WITH_TESTING "Include unit testing" ON)
option(IS_TESTING_SERIALLY "testing scripts that cannot run in parallel" OFF)
option(IS_TESTING_IMPORT "testing import parl" OFF) option(IS_TESTING_IMPORT "testing import parl" OFF)
option(IS_TESTING_DOCS "testing compling the docs" OFF) option(IS_TESTING_DOCS "testing compling the docs" OFF)
option(IS_TESTING_GPU "testing GPU environment" OFF) option(IS_TESTING_GPU "testing GPU environment" OFF)
set(PADDLE_PYTHON_PATH "" CACHE STRING "Python path to PaddlePaddle Fluid") set(PADDLE_PYTHON_PATH "" CACHE STRING "Python path to PaddlePaddle Fluid")
function(py3_test TARGET_NAME) function(py_test TARGET_NAME)
set(options "") set(options "")
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS ENVS) set(multiValueArgs SRCS DEPS ARGS ENVS)
cmake_parse_arguments(py3_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
#TODO: add real python2 env. add_test(NAME ${TARGET_NAME}
add_test(NAME ${TARGET_NAME}_with_python3 COMMAND python -u ${py_test_SRCS} ${py_test_ARGS}
COMMAND python3.6 ${py3_test_SRCS} ${py3_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endfunction() endfunction()
#function(py2_test TARGET_NAME)
# set(options "")
# set(oneValueArgs "")
# set(multiValueArgs SRCS DEPS ARGS ENVS)
# cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
# #TODO: add real python2 env.
# add_test(NAME ${TARGET_NAME}_with_python2
# COMMAND python ${py_test_SRCS} ${py_test_ARGS}
# WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
#endfunction()
function(import_test TARGET_NAME) function(import_test TARGET_NAME)
set(options "") set(options "")
set(oneValueArgs "") set(oneValueArgs "")
...@@ -67,18 +56,24 @@ if (WITH_TESTING) ...@@ -67,18 +56,24 @@ if (WITH_TESTING)
if (IS_TESTING_IMPORT) if (IS_TESTING_IMPORT)
set(src "parl/tests/import_test") set(src "parl/tests/import_test")
import_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH}) import_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
elseif (IS_TESTING_DOCS) elseif (IS_TESTING_DOCS)
docs_test() docs_test()
elseif (IS_TESTING_SERIALLY)
file(GLOB_RECURSE TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_test_alone.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
endforeach()
else () else ()
file(GLOB_RECURSE TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_test.py") file(GLOB_RECURSE TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_test.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(src ${TEST_OPS}) foreach(src ${TEST_OPS})
if (${src} MATCHES ".*remote.*") if (${src} MATCHES ".*remote.*")
if (NOT IS_TESTING_GPU) if (NOT IS_TESTING_GPU)
py3_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH}) py_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
endif() endif()
else() else()
py3_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH}) py_test(${src} SRCS ${src}.py ENVS ${PADDLE_PYTHON_PATH})
endif() endif()
endforeach() endforeach()
endif() endif()
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
cd "$(dirname "$0")" cd "$(dirname "$0")"
source ~/.bashrc source ~/.bashrc
export PATH="/root/miniconda3/bin:$PATH" export PATH="/root/miniconda3/bin:$PATH"
source deactivate
source activate docs source activate docs
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple /work/ pip install -i https://pypi.tuna.tsinghua.edu.cn/simple /work/
make html make html
...@@ -91,7 +91,6 @@ class Job(object): ...@@ -91,7 +91,6 @@ class Job(object):
target=self._reply_worker_heartbeat, target=self._reply_worker_heartbeat,
args=(worker_heartbeat_socket, )) args=(worker_heartbeat_socket, ))
worker_thread.setDaemon(True) worker_thread.setDaemon(True)
worker_thread.start()
# a thread that reply heartbeat signals from the client # a thread that reply heartbeat signals from the client
client_heartbeat_socket, client_heartbeat_address = self._create_heartbeat_server( client_heartbeat_socket, client_heartbeat_address = self._create_heartbeat_server(
...@@ -110,6 +109,7 @@ class Job(object): ...@@ -110,6 +109,7 @@ class Job(object):
[remote_constants.NORMAL_TAG, [remote_constants.NORMAL_TAG,
cloudpickle.dumps(initialized_job)]) cloudpickle.dumps(initialized_job)])
message = self.job_socket.recv_multipart() message = self.job_socket.recv_multipart()
worker_thread.start()
assert message[0] == remote_constants.NORMAL_TAG assert message[0] == remote_constants.NORMAL_TAG
# create the kill_job_socket # create the kill_job_socket
...@@ -160,7 +160,10 @@ class Job(object): ...@@ -160,7 +160,10 @@ class Job(object):
self.kill_job_socket.send_multipart( self.kill_job_socket.send_multipart(
[remote_constants.KILLJOB_TAG, [remote_constants.KILLJOB_TAG,
to_byte(self.job_address)]) to_byte(self.job_address)])
_ = self.kill_job_socket.recv_multipart() try:
_ = self.kill_job_socket.recv_multipart()
except zmq.error.Again as e:
pass
logger.warning("[Job]lost connection with the client, will exit") logger.warning("[Job]lost connection with the client, will exit")
os._exit(1) os._exit(1)
...@@ -262,6 +265,7 @@ class Job(object): ...@@ -262,6 +265,7 @@ class Job(object):
# receive source code from the actor and append them to the environment variables. # receive source code from the actor and append them to the environment variables.
envdir = self.wait_for_files() envdir = self.wait_for_files()
sys.path.append(envdir) sys.path.append(envdir)
self.client_is_alive = True
self.client_thread.start() self.client_thread.start()
try: try:
...@@ -278,7 +282,11 @@ class Job(object): ...@@ -278,7 +282,11 @@ class Job(object):
self.kill_job_socket.send_multipart( self.kill_job_socket.send_multipart(
[remote_constants.KILLJOB_TAG, [remote_constants.KILLJOB_TAG,
to_byte(self.job_address)]) to_byte(self.job_address)])
_ = self.kill_job_socket.recv_multipart() try:
_ = self.kill_job_socket.recv_multipart()
except zmq.error.Again as e:
pass
os._exit(1)
def single_task(self, obj): def single_task(self, obj):
"""An infinite loop waiting for commands from the remote object. """An infinite loop waiting for commands from the remote object.
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.monitor import ClusterMonitor
import time
import threading
from parl.remote.client import disconnect
from parl.remote import exceptions
import timeout_decorator
import subprocess
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def get_arg1(self):
return self.arg1
def get_arg2(self):
return self.arg2
def set_arg1(self, value):
self.arg1 = value
def set_arg2(self, value):
self.arg2 = value
def get_unable_serialize_object(self):
return UnableSerializeObject()
def add_one(self, value):
value += 1
return value
def add(self, x, y):
time.sleep(3)
return x + y
def will_raise_exception_func(self):
x = 1 / 0
class TestClusterMonitor(unittest.TestCase):
def tearDown(self):
disconnect()
def test_add_actor(self):
port = 1441
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
cluster_monitor = ClusterMonitor('localhost:{}'.format(port))
time.sleep(1)
self.assertEqual(0, len(cluster_monitor.data['clients']))
parl.connect('localhost:{}'.format(port))
time.sleep(10)
self.assertEqual(1, len(cluster_monitor.data['clients']))
self.assertEqual(1, cluster_monitor.data['workers'][0]['vacant_cpus'])
actor = Actor()
time.sleep(20)
self.assertEqual(0, cluster_monitor.data['workers'][0]['vacant_cpus'])
self.assertEqual(1, cluster_monitor.data['workers'][0]['used_cpus'])
self.assertEqual(1, cluster_monitor.data['clients'][0]['actor_num'])
del actor
time.sleep(40)
self.assertEqual(0, cluster_monitor.data['clients'][0]['actor_num'])
self.assertEqual(1, cluster_monitor.data['workers'][0]['vacant_cpus'])
worker.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.monitor import ClusterMonitor
import time
import threading
from parl.remote.client import disconnect
from parl.remote import exceptions
import timeout_decorator
import subprocess
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def get_arg1(self):
return self.arg1
def get_arg2(self):
return self.arg2
def set_arg1(self, value):
self.arg1 = value
def set_arg2(self, value):
self.arg2 = value
def get_unable_serialize_object(self):
return UnableSerializeObject()
def add_one(self, value):
value += 1
return value
def add(self, x, y):
time.sleep(3)
return x + y
def will_raise_exception_func(self):
x = 1 / 0
class TestClusterMonitor(unittest.TestCase):
def tearDown(self):
disconnect()
def test_twenty_worker(self):
port = 1440
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
workers = []
for _ in range(20):
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(1)
workers.append(worker)
cluster_monitor = ClusterMonitor('localhost:{}'.format(port))
time.sleep(1)
self.assertEqual(20, len(cluster_monitor.data['workers']))
for i in range(10):
workers[i].exit()
time.sleep(40)
self.assertEqual(10, len(cluster_monitor.data['workers']))
for i in range(10, 20):
workers[i].exit()
time.sleep(40)
self.assertEqual(0, len(cluster_monitor.data['workers']))
master.exit()
if __name__ == '__main__':
unittest.main()
...@@ -78,60 +78,6 @@ class TestClusterMonitor(unittest.TestCase): ...@@ -78,60 +78,6 @@ class TestClusterMonitor(unittest.TestCase):
self.assertEqual(0, len(cluster_monitor.data['workers'])) self.assertEqual(0, len(cluster_monitor.data['workers']))
master.exit() master.exit()
def test_twenty_worker(self):
port = 1440
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
workers = []
for _ in range(20):
worker = Worker('localhost:{}'.format(port), 1)
workers.append(worker)
cluster_monitor = ClusterMonitor('localhost:{}'.format(port))
time.sleep(1)
self.assertEqual(20, len(cluster_monitor.data['workers']))
for i in range(10):
workers[i].exit()
time.sleep(40)
self.assertEqual(10, len(cluster_monitor.data['workers']))
for i in range(10, 20):
workers[i].exit()
time.sleep(40)
self.assertEqual(0, len(cluster_monitor.data['workers']))
master.exit()
def test_add_actor(self):
port = 1441
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
cluster_monitor = ClusterMonitor('localhost:{}'.format(port))
time.sleep(1)
self.assertEqual(0, len(cluster_monitor.data['clients']))
parl.connect('localhost:{}'.format(port))
time.sleep(10)
self.assertEqual(1, len(cluster_monitor.data['clients']))
self.assertEqual(1, cluster_monitor.data['workers'][0]['vacant_cpus'])
actor = Actor()
time.sleep(20)
self.assertEqual(0, cluster_monitor.data['workers'][0]['vacant_cpus'])
self.assertEqual(1, cluster_monitor.data['workers'][0]['used_cpus'])
self.assertEqual(1, cluster_monitor.data['clients'][0]['actor_num'])
del actor
time.sleep(40)
self.assertEqual(0, cluster_monitor.data['clients'][0]['actor_num'])
self.assertEqual(1, cluster_monitor.data['workers'][0]['vacant_cpus'])
worker.exit()
master.exit()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -61,9 +61,6 @@ class Actor(object): ...@@ -61,9 +61,6 @@ class Actor(object):
class TestCluster(unittest.TestCase): class TestCluster(unittest.TestCase):
def tearDown(self): def tearDown(self):
disconnect() disconnect()
#time.sleep(20)
#command = ("pkill -f remote/job.py")
#subprocess.call([command], shell=True)
def test_actor_exception(self): def test_actor_exception(self):
master = Master(port=1235) master = Master(port=1235)
...@@ -136,7 +133,7 @@ class TestCluster(unittest.TestCase): ...@@ -136,7 +133,7 @@ class TestCluster(unittest.TestCase):
self.assertEqual(master.cpu_num, 8) self.assertEqual(master.cpu_num, 8)
worker2.exit() worker2.exit()
time.sleep(30) time.sleep(50)
self.assertEqual(master.cpu_num, 4) self.assertEqual(master.cpu_num, 4)
master.exit() master.exit()
......
...@@ -45,6 +45,7 @@ class TestCluster(unittest.TestCase): ...@@ -45,6 +45,7 @@ class TestCluster(unittest.TestCase):
actor = Actor() actor = Actor()
ret = actor.add_one(1) ret = actor.add_one(1)
self.assertEqual(ret, 2) self.assertEqual(ret, 2)
disconnect()
def _create_actor(self): def _create_actor(self):
for _ in range(2): for _ in range(2):
...@@ -52,7 +53,7 @@ class TestCluster(unittest.TestCase): ...@@ -52,7 +53,7 @@ class TestCluster(unittest.TestCase):
ret = actor.add_one(1) ret = actor.add_one(1)
self.assertEqual(ret, 2) self.assertEqual(ret, 2)
@timeout_decorator.timeout(seconds=60) @timeout_decorator.timeout(seconds=300)
def test_connect_and_create_actor_in_multiprocessing_with_connected_in_main_process( def test_connect_and_create_actor_in_multiprocessing_with_connected_in_main_process(
self): self):
# start the master # start the master
...@@ -80,57 +81,6 @@ class TestCluster(unittest.TestCase): ...@@ -80,57 +81,6 @@ class TestCluster(unittest.TestCase):
worker1.exit() worker1.exit()
master.exit() master.exit()
@timeout_decorator.timeout(seconds=60)
def test_connect_and_create_actor_in_multiprocessing_without_connected_in_main_process(
self):
# start the master
master = Master(port=8239)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:8239', 4)
proc1 = multiprocessing.Process(
target=self._connect_and_create_actor, args=('localhost:8239', ))
proc2 = multiprocessing.Process(
target=self._connect_and_create_actor, args=('localhost:8239', ))
proc1.start()
proc2.start()
proc1.join()
proc2.join()
self.assertRaises(AssertionError, self._create_actor)
worker1.exit()
master.exit()
@timeout_decorator.timeout(seconds=60)
def test_create_actor_in_multiprocessing(self):
# start the master
master = Master(port=8240)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:8240', 4)
parl.connect('localhost:8240')
proc1 = multiprocessing.Process(target=self._create_actor)
proc2 = multiprocessing.Process(target=self._create_actor)
proc1.start()
proc2.start()
proc1.join()
proc2.join()
# make sure that the client of the main process still works
self._create_actor()
worker1.exit()
master.exit()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
import time
import threading
import timeout_decorator
import multiprocessing
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def add_one(self, value):
value += 1
return value
class TestCluster(unittest.TestCase):
def tearDown(self):
disconnect()
def _connect_and_create_actor(self, cluster_addr):
parl.connect(cluster_addr)
for _ in range(2):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
disconnect()
def _create_actor(self):
for _ in range(2):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
@timeout_decorator.timeout(seconds=300)
def test_connect_and_create_actor_in_multiprocessing_without_connected_in_main_process(
self):
# start the master
master = Master(port=8239)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:8239', 4)
proc1 = multiprocessing.Process(
target=self._connect_and_create_actor, args=('localhost:8239', ))
proc2 = multiprocessing.Process(
target=self._connect_and_create_actor, args=('localhost:8239', ))
proc1.start()
proc2.start()
proc1.join()
proc2.join()
self.assertRaises(AssertionError, self._create_actor)
worker1.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
import time
import threading
import timeout_decorator
import multiprocessing
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def add_one(self, value):
value += 1
return value
class TestCluster(unittest.TestCase):
def tearDown(self):
disconnect()
def _connect_and_create_actor(self, cluster_addr):
parl.connect(cluster_addr)
for _ in range(2):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
disconnect()
def _create_actor(self):
for _ in range(2):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
@timeout_decorator.timeout(seconds=300)
def test_create_actor_in_multiprocessing(self):
# start the master
master = Master(port=8240)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:8240', 4)
parl.connect('localhost:8240')
proc1 = multiprocessing.Process(target=self._create_actor)
proc2 = multiprocessing.Process(target=self._create_actor)
proc1.start()
proc2.start()
proc1.join()
proc2.join()
print("[test_create_actor_in_multiprocessing] Join")
# make sure that the client of the main process still works
self._create_actor()
worker1.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
...@@ -62,27 +62,6 @@ class TestJob(unittest.TestCase): ...@@ -62,27 +62,6 @@ class TestJob(unittest.TestCase):
def tearDown(self): def tearDown(self):
disconnect() disconnect()
def test_job_exit_exceptionally(self):
master = Master(port=1334)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1334', 4)
time.sleep(10)
self.assertEqual(worker1.job_buffer.full(), True)
time.sleep(1)
self.assertEqual(master.cpu_num, 4)
print("We are going to kill all the jobs.")
command = ("pkill -f remote/job.py")
subprocess.call([command], shell=True)
parl.connect('localhost:1334')
actor = Actor()
self.assertEqual(actor.add_one(1), 2)
time.sleep(20)
master.exit()
worker1.exit()
@timeout_decorator.timeout(seconds=300) @timeout_decorator.timeout(seconds=300)
def test_acor_exit_exceptionally(self): def test_acor_exit_exceptionally(self):
master = Master(port=1335) master = Master(port=1335)
...@@ -94,7 +73,7 @@ class TestJob(unittest.TestCase): ...@@ -94,7 +73,7 @@ class TestJob(unittest.TestCase):
file_path = __file__.replace('reset_job_test', 'simulate_client') file_path = __file__.replace('reset_job_test', 'simulate_client')
command = [sys.executable, file_path] command = [sys.executable, file_path]
proc = subprocess.Popen(command) proc = subprocess.Popen(command)
time.sleep(10) time.sleep(20)
self.assertEqual(master.cpu_num, 0) self.assertEqual(master.cpu_num, 0)
proc.kill() proc.kill()
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
from parl.utils import logger
import threading
import time
import subprocess
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def get_arg1(self):
return self.arg1
def get_arg2(self):
return self.arg2
def set_arg1(self, value):
self.arg1 = value
def set_arg2(self, value):
self.arg2 = value
def get_unable_serialize_object(self):
return UnableSerializeObject()
def add_one(self, value):
value += 1
return value
def add(self, x, y):
time.sleep(3)
return x + y
def will_raise_exception_func(self):
x = 1 / 0
class TestJobAlone(unittest.TestCase):
def tearDown(self):
disconnect()
def test_job_exit_exceptionally(self):
master = Master(port=1334)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:1334', 4)
time.sleep(10)
self.assertEqual(worker1.job_buffer.full(), True)
time.sleep(1)
self.assertEqual(master.cpu_num, 4)
print("We are going to kill all the jobs.")
command = ("pkill -f remote/job.py")
subprocess.call([command], shell=True)
parl.connect('localhost:1334')
actor = Actor()
self.assertEqual(actor.add_one(1), 2)
time.sleep(20)
master.exit()
worker1.exit()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册