提交 48fc1de8 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

support serialize/deserialize instance of custom class (#77)

* support serialize/deserialize instance of custom class

* update version requirement of dependences

* remove requirements of unittest which are included in parl

* use fashion style of pyarrow serialization context; add thread safe unittest of serialize/deserialize

* add thread-safe test of remote module; add more exception tips of serialize/deserialize

* refine comment
上级 6c6ec6c7
......@@ -112,6 +112,8 @@ function main() {
check_style
;;
test)
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple .
pip3.6 install -i https://pypi.tuna.tsinghua.edu.cn/simple .
/root/miniconda3/envs/empty_env/bin/pip install -i https://pypi.tuna.tsinghua.edu.cn/simple .
/root/miniconda3/envs/paddle1.4.0/bin/pip install -i https://pypi.tuna.tsinghua.edu.cn/simple .
run_test_with_gpu
......
# requirements for unittest
paddlepaddle-gpu==1.3.0.post97
gym
details
termcolor
pyarrow
zmq
parameterized
......@@ -19,6 +19,12 @@ import unittest
from parl.remote import *
class UnableSerializeObject(object):
def __init__(self):
# threading.Lock() can not be serialized
self.lock = threading.Lock()
@parl.remote_class
class Simulator:
def __init__(self, arg1, arg2=None):
......@@ -38,7 +44,11 @@ class Simulator:
self.arg2 = value
def get_unable_serialize_object(self):
return self
return UnableSerializeObject()
def add_one(self, value):
value += 1
return value
class TestRemote(unittest.TestCase):
......@@ -122,8 +132,9 @@ class TestRemote(unittest.TestCase):
remote_sim = self.remote_manager.get_remote()
try:
remote_sim.set_arg1(wrong_arg=remote_sim)
except SerializeError:
remote_sim.set_arg1(UnableSerializeObject())
except SerializeError as e:
logger.info('Expected exception: {}'.format(e))
# expected
return
......@@ -137,13 +148,14 @@ class TestRemote(unittest.TestCase):
try:
remote_sim.get_unable_serialize_object()
except RemoteSerializeError:
except RemoteSerializeError as e:
# expected
logger.info('Expected exception: {}'.format(e))
return
assert False
def test_mutli_remote_object(self):
def test_multi_remote_object(self):
server_port = 17776
self._setUp(server_port)
......@@ -165,7 +177,7 @@ class TestRemote(unittest.TestCase):
self.assertEqual(remote_sim1.get_arg1(), 1)
self.assertEqual(remote_sim2.get_arg1(), 11)
def test_mutli_remote_object_with_one_failed(self):
def test_multi_remote_object_with_one_failed(self):
server_port = 17777
self._setUp(server_port)
......@@ -235,6 +247,42 @@ class TestRemote(unittest.TestCase):
self.assertEqual(remote_sim1.get_arg1(), 1)
self.assertEqual(remote_sim2.get_arg1(), 11)
def test_thread_safe_of_remote_module(self):
server_port = 17780
self._setUp(server_port)
time.sleep(1)
thread_num = 10
for _ in range(thread_num):
# run clients in backend
sim = Simulator(11, arg2=22)
client_thread = threading.Thread(
target=sim.as_remote, args=(
'localhost',
server_port,
))
client_thread.setDaemon(True)
client_thread.start()
time.sleep(1)
threads = []
for _ in range(thread_num):
remote_sim = self.remote_manager.get_remote()
t = threading.Thread(
target=self._run_remote_add, args=(remote_sim, ))
t.start()
threads.append(t)
for t in threads:
t.join()
def _run_remote_add(self, remote_sim):
value = 0
for i in range(1000):
value = remote_sim.add_one(value)
assert value == i + 1
if __name__ == '__main__':
unittest.main()
......@@ -12,12 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import cloudpickle
import pyarrow
from parl.utils import SerializeError, DeserializeError
__all__ = ['dumps_argument', 'loads_argument', 'dumps_return', 'loads_return']
# Reference: https://github.com/apache/arrow/blob/f88474c84e7f02e226eb4cc32afef5e2bbc6e5b4/python/pyarrow/tests/test_serialization.py#L658-L682
def _serialize_serializable(obj):
return {"type": type(obj), "data": obj.__dict__}
def _deserialize_serializable(obj):
val = obj["type"].__new__(obj["type"])
val.__dict__.update(obj["data"])
return val
context = pyarrow.default_serialization_context()
# support deserialize in another environment
context.set_pickle(cloudpickle.dumps, cloudpickle.loads)
# support serialize and deserialize custom class
context.register_type(
object,
"object",
custom_serializer=_serialize_serializable,
custom_deserializer=_deserialize_serializable)
def dumps_argument(*args, **kwargs):
"""
......@@ -30,7 +55,7 @@ def dumps_argument(*args, **kwargs):
Implementation-dependent object in bytes.
"""
try:
ret = pyarrow.serialize([args, kwargs]).to_buffer()
ret = pyarrow.serialize([args, kwargs], context=context).to_buffer()
except Exception as e:
raise SerializeError(e)
......@@ -49,7 +74,7 @@ def loads_argument(data):
like the input of `dumps_argument`, args is a tuple, and kwargs is a dict
"""
try:
ret = pyarrow.deserialize(data)
ret = pyarrow.deserialize(data, context=context)
except Exception as e:
raise DeserializeError(e)
......@@ -67,7 +92,7 @@ def dumps_return(data):
Implementation-dependent object in bytes.
"""
try:
ret = pyarrow.serialize(data).to_buffer()
ret = pyarrow.serialize(data, context=context).to_buffer()
except Exception as e:
raise SerializeError(e)
......@@ -85,7 +110,7 @@ def loads_return(data):
deserialized data
"""
try:
ret = pyarrow.deserialize(data)
ret = pyarrow.deserialize(data, context=context)
except Exception as e:
raise DeserializeError(e)
......
......@@ -19,7 +19,7 @@ class UtilsError(Exception):
"""
def __init__(self, error_info):
self.error_info = '[PARL Utils Error]:\n{}'.format(error_info)
self.error_info = '[PARL Utils Error]: {}'.format(error_info)
class SerializeError(UtilsError):
......@@ -28,6 +28,9 @@ class SerializeError(UtilsError):
"""
def __init__(self, error_info):
error_info = (
'Serialize error, you may have provided an object that cannot be '
+ 'serialized by pyarrow. Detailed error:\n{}'.format(error_info))
super(SerializeError, self).__init__(error_info)
def __str__(self):
......@@ -40,6 +43,10 @@ class DeserializeError(UtilsError):
"""
def __init__(self, error_info):
error_info = (
'Deserialize error, you may have provided an object that cannot be '
+
'deserialized by pyarrow. Detailed error:\n{}'.format(error_info))
super(DeserializeError, self).__init__(error_info)
def __str__(self):
......
......@@ -14,8 +14,10 @@
import numpy as np
import time
import threading
import unittest
from parl.utils.communication import dumps_return, loads_return
from parl.utils.communication import dumps_return, loads_return, \
dumps_argument, loads_argument
class TestCommunication(unittest.TestCase):
......@@ -56,11 +58,73 @@ class TestCommunication(unittest.TestCase):
for i, data in enumerate([data1, data2, data3]):
start = time.time()
for _ in range(10):
serialize_bytes = dumps_return(data)
deserialize_result = loads_return(serialize_bytes)
serialize_bytes = dumps_argument(data)
deserialize_result = loads_argument(serialize_bytes)
print('Case {}, Average dump and load argument time:'.format(i),
(time.time() - start) / 10)
def test_dumps_loads_return_with_custom_class(self):
class A(object):
def __init__(self):
self.a = 3
a = A()
serialize_bytes = dumps_return(a)
deserialize_result = loads_return(serialize_bytes)
assert deserialize_result.a == 3
def test_dumps_loads_argument_with_custom_class(self):
class A(object):
def __init__(self):
self.a = 3
a = A()
serialize_bytes = dumps_argument(a)
deserialize_result = loads_argument(serialize_bytes)
assert deserialize_result[0][0].a == 3
def test_dumps_loads_return_with_multi_thread(self):
class A(object):
def __init__(self, a):
self.a = a
def run(i):
a = A(i)
serialize_bytes = dumps_return(a)
deserialize_result = loads_return(serialize_bytes)
assert deserialize_result.a == i
threads = []
for i in range(50):
t = threading.Thread(target=run, args=(i, ))
t.start()
threads.append(t)
for t in threads:
t.join()
def test_dumps_loads_argument_with_multi_thread(self):
class A(object):
def __init__(self, a):
self.a = a
def run(i):
a = A(i)
serialize_bytes = dumps_argument(a)
deserialize_result = loads_argument(serialize_bytes)
assert deserialize_result[0][0].a == i
threads = []
for i in range(50):
t = threading.Thread(target=run, args=(i, ))
t.start()
threads.append(t)
for t in threads:
t.join()
if __name__ == '__main__':
unittest.main()
......@@ -64,7 +64,11 @@ setup(
packages=_find_packages(),
package_data={'': ['*.so']},
install_requires=[
"termcolor>=1.1.0", "pyzmq>=17.1.2", "pyarrow>=0.12.0", "scipy>=1.0.0"
"termcolor>=1.1.0",
"pyzmq==18.0.1",
"pyarrow==0.13.0",
"scipy>=1.0.0",
"cloudpickle==1.0.0",
],
classifiers=[
'Intended Audience :: Developers',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册