提交 9b645a7a 编写于 作者: Y Yelrose

fixed mp_reader for list

上级 1577fb29
...@@ -27,8 +27,11 @@ import time ...@@ -27,8 +27,11 @@ import time
import paddle.fluid as fluid import paddle.fluid as fluid
from multiprocessing import Queue from multiprocessing import Queue
import threading import threading
from collections import namedtuple
_np_serialized_data = namedtuple("_np_serialized_data", ["value", "shape", "dtype"])
def serialize_data(data): def serialize_data(data):
"""serialize_data""" """serialize_data"""
if data is None: if data is None:
...@@ -38,13 +41,17 @@ def serialize_data(data): ...@@ -38,13 +41,17 @@ def serialize_data(data):
def numpy_serialize_data(data): def numpy_serialize_data(data):
"""serialize_data""" """serialize_data"""
ret_data = {} ret_data = copy.deepcopy(data)
for key in data: if isinstance(ret_data, list):
if isinstance(data[key], np.ndarray): for key, value in enumerate(ret_data):
ret_data[key] = (data[key].tobytes(), list(data[key].shape), if isinstance(ret_data[key], np.ndarray):
"%s" % data[key].dtype) ret_data[key] = _np_serialized_data(value=ret_data[key].tobytes(),
else: shape=list(ret_data[key].shape), dtype="%s" % ret_data[key].dtype)
ret_data[key] = data[key] elif isinstance(ret_data, dict):
for key in ret_data:
if isinstance(ret_data[key], np.ndarray):
ret_data[key] = _np_serialized_data(value=ret_data[key].tobytes(),
shape=list(ret_data[key].shape), dtype="%s" % ret_data[key].dtype)
return ret_data return ret_data
...@@ -52,11 +59,18 @@ def numpy_deserialize_data(data): ...@@ -52,11 +59,18 @@ def numpy_deserialize_data(data):
"""deserialize_data""" """deserialize_data"""
if data is None: if data is None:
return None return None
for key in data:
if isinstance(data[key], tuple): if isinstance(data, list):
value = np.frombuffer( for key, value in enumerate(data):
data[key][0], dtype=data[key][2]).reshape(data[key][1]) if isinstance(value, _np_serialized_data):
data[key] = value data[key] = np.frombuffer(buffer=data[key].value,
dtype=data[key].dtype).reshape(data[key].shape)
elif isinstance(data, dict):
for key in data:
if isinstance(data[key], _np_serialized_data):
data[key] = np.frombuffer(buffer=data[key].value,
dtype=data[key].dtype).reshape(data[key].shape)
return data return data
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册