diff --git a/.teamcity/build.sh b/.teamcity/build.sh index 298812ebd8293525b2d72cbe73ffcfcd6a0332c2..9e55dcb4d8c0f3f0f7c5e99c2005059281f84e17 100755 --- a/.teamcity/build.sh +++ b/.teamcity/build.sh @@ -112,6 +112,8 @@ function main() { check_style ;; test) + pip install -i https://pypi.tuna.tsinghua.edu.cn/simple . + pip3.6 install -i https://pypi.tuna.tsinghua.edu.cn/simple . /root/miniconda3/envs/empty_env/bin/pip install -i https://pypi.tuna.tsinghua.edu.cn/simple . /root/miniconda3/envs/paddle1.4.0/bin/pip install -i https://pypi.tuna.tsinghua.edu.cn/simple . run_test_with_gpu diff --git a/.teamcity/requirements.txt b/.teamcity/requirements.txt index 5e3600128b668d7a1e0fce7480a4cbf1d614e665..685076eaa042b423bc15cbcbee389a28b4e632f1 100644 --- a/.teamcity/requirements.txt +++ b/.teamcity/requirements.txt @@ -1,7 +1,5 @@ +# requirements for unittest paddlepaddle-gpu==1.3.0.post97 gym details -termcolor -pyarrow -zmq parameterized diff --git a/parl/remote/tests/remote_test.py b/parl/remote/tests/remote_test.py index 3bf21d04829f5adb48d50bbbc41615de59ee5dd1..f36ab80feee8805f0b026dfffa67b20afe3faae5 100644 --- a/parl/remote/tests/remote_test.py +++ b/parl/remote/tests/remote_test.py @@ -19,6 +19,12 @@ import unittest from parl.remote import * +class UnableSerializeObject(object): + def __init__(self): + # threading.Lock() can not be serialized + self.lock = threading.Lock() + + @parl.remote_class class Simulator: def __init__(self, arg1, arg2=None): @@ -38,7 +44,11 @@ class Simulator: self.arg2 = value def get_unable_serialize_object(self): - return self + return UnableSerializeObject() + + def add_one(self, value): + value += 1 + return value class TestRemote(unittest.TestCase): @@ -122,8 +132,9 @@ class TestRemote(unittest.TestCase): remote_sim = self.remote_manager.get_remote() try: - remote_sim.set_arg1(wrong_arg=remote_sim) - except SerializeError: + remote_sim.set_arg1(UnableSerializeObject()) + except SerializeError as e: + logger.info('Expected exception: {}'.format(e)) # expected return @@ -137,13 +148,14 @@ class TestRemote(unittest.TestCase): try: remote_sim.get_unable_serialize_object() - except RemoteSerializeError: + except RemoteSerializeError as e: # expected + logger.info('Expected exception: {}'.format(e)) return assert False - def test_mutli_remote_object(self): + def test_multi_remote_object(self): server_port = 17776 self._setUp(server_port) @@ -165,7 +177,7 @@ class TestRemote(unittest.TestCase): self.assertEqual(remote_sim1.get_arg1(), 1) self.assertEqual(remote_sim2.get_arg1(), 11) - def test_mutli_remote_object_with_one_failed(self): + def test_multi_remote_object_with_one_failed(self): server_port = 17777 self._setUp(server_port) @@ -235,6 +247,42 @@ class TestRemote(unittest.TestCase): self.assertEqual(remote_sim1.get_arg1(), 1) self.assertEqual(remote_sim2.get_arg1(), 11) + def test_thread_safe_of_remote_module(self): + server_port = 17780 + self._setUp(server_port) + + time.sleep(1) + + thread_num = 10 + for _ in range(thread_num): + # run clients in backend + sim = Simulator(11, arg2=22) + client_thread = threading.Thread( + target=sim.as_remote, args=( + 'localhost', + server_port, + )) + client_thread.setDaemon(True) + client_thread.start() + + time.sleep(1) + threads = [] + for _ in range(thread_num): + remote_sim = self.remote_manager.get_remote() + t = threading.Thread( + target=self._run_remote_add, args=(remote_sim, )) + t.start() + threads.append(t) + + for t in threads: + t.join() + + def _run_remote_add(self, remote_sim): + value = 0 + for i in range(1000): + value = remote_sim.add_one(value) + assert value == i + 1 + if __name__ == '__main__': unittest.main() diff --git a/parl/utils/communication.py b/parl/utils/communication.py index d6ec3e7e3f3efe7f4e89e5d8caf1de15b3e214bc..ea201bae16e571ab429ef8f194228fc5b7fa4432 100644 --- a/parl/utils/communication.py +++ b/parl/utils/communication.py @@ -12,12 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. +import cloudpickle import pyarrow from parl.utils import SerializeError, DeserializeError __all__ = ['dumps_argument', 'loads_argument', 'dumps_return', 'loads_return'] +# Reference: https://github.com/apache/arrow/blob/f88474c84e7f02e226eb4cc32afef5e2bbc6e5b4/python/pyarrow/tests/test_serialization.py#L658-L682 +def _serialize_serializable(obj): + return {"type": type(obj), "data": obj.__dict__} + + +def _deserialize_serializable(obj): + val = obj["type"].__new__(obj["type"]) + val.__dict__.update(obj["data"]) + return val + + +context = pyarrow.default_serialization_context() + +# support deserialize in another environment +context.set_pickle(cloudpickle.dumps, cloudpickle.loads) + +# support serialize and deserialize custom class +context.register_type( + object, + "object", + custom_serializer=_serialize_serializable, + custom_deserializer=_deserialize_serializable) + + def dumps_argument(*args, **kwargs): """ @@ -30,7 +55,7 @@ def dumps_argument(*args, **kwargs): Implementation-dependent object in bytes. """ try: - ret = pyarrow.serialize([args, kwargs]).to_buffer() + ret = pyarrow.serialize([args, kwargs], context=context).to_buffer() except Exception as e: raise SerializeError(e) @@ -49,7 +74,7 @@ def loads_argument(data): like the input of `dumps_argument`, args is a tuple, and kwargs is a dict """ try: - ret = pyarrow.deserialize(data) + ret = pyarrow.deserialize(data, context=context) except Exception as e: raise DeserializeError(e) @@ -67,7 +92,7 @@ def dumps_return(data): Implementation-dependent object in bytes. """ try: - ret = pyarrow.serialize(data).to_buffer() + ret = pyarrow.serialize(data, context=context).to_buffer() except Exception as e: raise SerializeError(e) @@ -85,7 +110,7 @@ def loads_return(data): deserialized data """ try: - ret = pyarrow.deserialize(data) + ret = pyarrow.deserialize(data, context=context) except Exception as e: raise DeserializeError(e) diff --git a/parl/utils/exceptions.py b/parl/utils/exceptions.py index 023cc5c1073a459b853a1652f43eccd640f70347..eb4b4d2139984b2dab578618de667bad8f40f0df 100644 --- a/parl/utils/exceptions.py +++ b/parl/utils/exceptions.py @@ -19,7 +19,7 @@ class UtilsError(Exception): """ def __init__(self, error_info): - self.error_info = '[PARL Utils Error]:\n{}'.format(error_info) + self.error_info = '[PARL Utils Error]: {}'.format(error_info) class SerializeError(UtilsError): @@ -28,6 +28,9 @@ class SerializeError(UtilsError): """ def __init__(self, error_info): + error_info = ( + 'Serialize error, you may have provided an object that cannot be ' + + 'serialized by pyarrow. Detailed error:\n{}'.format(error_info)) super(SerializeError, self).__init__(error_info) def __str__(self): @@ -40,6 +43,10 @@ class DeserializeError(UtilsError): """ def __init__(self, error_info): + error_info = ( + 'Deserialize error, you may have provided an object that cannot be ' + + + 'deserialized by pyarrow. Detailed error:\n{}'.format(error_info)) super(DeserializeError, self).__init__(error_info) def __str__(self): diff --git a/parl/utils/tests/communication_test.py b/parl/utils/tests/communication_test.py index 5bb0b5f8026aea38cf6268617581a9048c164512..06aee9a6879e249f10a07ffe97e255aebd160498 100644 --- a/parl/utils/tests/communication_test.py +++ b/parl/utils/tests/communication_test.py @@ -14,8 +14,10 @@ import numpy as np import time +import threading import unittest -from parl.utils.communication import dumps_return, loads_return +from parl.utils.communication import dumps_return, loads_return, \ + dumps_argument, loads_argument class TestCommunication(unittest.TestCase): @@ -56,11 +58,73 @@ class TestCommunication(unittest.TestCase): for i, data in enumerate([data1, data2, data3]): start = time.time() for _ in range(10): - serialize_bytes = dumps_return(data) - deserialize_result = loads_return(serialize_bytes) + serialize_bytes = dumps_argument(data) + deserialize_result = loads_argument(serialize_bytes) print('Case {}, Average dump and load argument time:'.format(i), (time.time() - start) / 10) + def test_dumps_loads_return_with_custom_class(self): + class A(object): + def __init__(self): + self.a = 3 + + a = A() + serialize_bytes = dumps_return(a) + deserialize_result = loads_return(serialize_bytes) + + assert deserialize_result.a == 3 + + def test_dumps_loads_argument_with_custom_class(self): + class A(object): + def __init__(self): + self.a = 3 + + a = A() + serialize_bytes = dumps_argument(a) + deserialize_result = loads_argument(serialize_bytes) + + assert deserialize_result[0][0].a == 3 + + def test_dumps_loads_return_with_multi_thread(self): + class A(object): + def __init__(self, a): + self.a = a + + def run(i): + a = A(i) + serialize_bytes = dumps_return(a) + deserialize_result = loads_return(serialize_bytes) + assert deserialize_result.a == i + + threads = [] + for i in range(50): + t = threading.Thread(target=run, args=(i, )) + t.start() + threads.append(t) + + for t in threads: + t.join() + + def test_dumps_loads_argument_with_multi_thread(self): + class A(object): + def __init__(self, a): + self.a = a + + def run(i): + a = A(i) + serialize_bytes = dumps_argument(a) + deserialize_result = loads_argument(serialize_bytes) + assert deserialize_result[0][0].a == i + + threads = [] + for i in range(50): + t = threading.Thread(target=run, args=(i, )) + t.start() + threads.append(t) + + for t in threads: + t.join() + if __name__ == '__main__': unittest.main() diff --git a/setup.py b/setup.py index 43165126c001be3a3edd3ee1a952f91b90abb920..7def3f12393e3c534841a06745cc8857b87d5728 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,11 @@ setup( packages=_find_packages(), package_data={'': ['*.so']}, install_requires=[ - "termcolor>=1.1.0", "pyzmq>=17.1.2", "pyarrow>=0.12.0", "scipy>=1.0.0" + "termcolor>=1.1.0", + "pyzmq==18.0.1", + "pyarrow==0.13.0", + "scipy>=1.0.0", + "cloudpickle==1.0.0", ], classifiers=[ 'Intended Audience :: Developers',