未验证 提交 a5807b77 编写于 作者: Webbley's avatar Webbley 提交者: GitHub

Merge pull request #124 from Yelrose/master

fixed mp_reader for list
......@@ -27,8 +27,11 @@ import time
import paddle.fluid as fluid
from multiprocessing import Queue
import threading
from collections import namedtuple
_np_serialized_data = namedtuple("_np_serialized_data", ["value", "shape", "dtype"])
def serialize_data(data):
"""serialize_data"""
if data is None:
......@@ -38,13 +41,17 @@ def serialize_data(data):
def numpy_serialize_data(data):
"""serialize_data"""
ret_data = {}
for key in data:
if isinstance(data[key], np.ndarray):
ret_data[key] = (data[key].tobytes(), list(data[key].shape),
"%s" % data[key].dtype)
else:
ret_data[key] = data[key]
ret_data = copy.deepcopy(data)
if isinstance(ret_data, list):
for key, value in enumerate(ret_data):
if isinstance(ret_data[key], np.ndarray):
ret_data[key] = _np_serialized_data(value=ret_data[key].tobytes(),
shape=list(ret_data[key].shape), dtype="%s" % ret_data[key].dtype)
elif isinstance(ret_data, dict):
for key in ret_data:
if isinstance(ret_data[key], np.ndarray):
ret_data[key] = _np_serialized_data(value=ret_data[key].tobytes(),
shape=list(ret_data[key].shape), dtype="%s" % ret_data[key].dtype)
return ret_data
......@@ -52,11 +59,18 @@ def numpy_deserialize_data(data):
"""deserialize_data"""
if data is None:
return None
if isinstance(data, list):
for key, value in enumerate(data):
if isinstance(value, _np_serialized_data):
data[key] = np.frombuffer(buffer=data[key].value,
dtype=data[key].dtype).reshape(data[key].shape)
elif isinstance(data, dict):
for key in data:
if isinstance(data[key], tuple):
value = np.frombuffer(
data[key][0], dtype=data[key][2]).reshape(data[key][1])
data[key] = value
if isinstance(data[key], _np_serialized_data):
data[key] = np.frombuffer(buffer=data[key].value,
dtype=data[key].dtype).reshape(data[key].shape)
return data
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册