From 9b645a7a08cb2c378ea5d3e56cc615b0e11a92b6 Mon Sep 17 00:00:00 2001 From: Yelrose <270018958@qq.com> Date: Wed, 9 Sep 2020 11:56:27 +0800 Subject: [PATCH] fixed mp_reader for list --- pgl/utils/mp_reader.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/pgl/utils/mp_reader.py b/pgl/utils/mp_reader.py index a796283..fd3b972 100644 --- a/pgl/utils/mp_reader.py +++ b/pgl/utils/mp_reader.py @@ -27,8 +27,11 @@ import time import paddle.fluid as fluid from multiprocessing import Queue import threading +from collections import namedtuple +_np_serialized_data = namedtuple("_np_serialized_data", ["value", "shape", "dtype"]) + def serialize_data(data): """serialize_data""" if data is None: @@ -38,13 +41,17 @@ def serialize_data(data): def numpy_serialize_data(data): """serialize_data""" - ret_data = {} - for key in data: - if isinstance(data[key], np.ndarray): - ret_data[key] = (data[key].tobytes(), list(data[key].shape), - "%s" % data[key].dtype) - else: - ret_data[key] = data[key] + ret_data = copy.deepcopy(data) + if isinstance(ret_data, list): + for key, value in enumerate(ret_data): + if isinstance(ret_data[key], np.ndarray): + ret_data[key] = _np_serialized_data(value=ret_data[key].tobytes(), + shape=list(ret_data[key].shape), dtype="%s" % ret_data[key].dtype) + elif isinstance(ret_data, dict): + for key in ret_data: + if isinstance(ret_data[key], np.ndarray): + ret_data[key] = _np_serialized_data(value=ret_data[key].tobytes(), + shape=list(ret_data[key].shape), dtype="%s" % ret_data[key].dtype) return ret_data @@ -52,11 +59,18 @@ def numpy_deserialize_data(data): """deserialize_data""" if data is None: return None - for key in data: - if isinstance(data[key], tuple): - value = np.frombuffer( - data[key][0], dtype=data[key][2]).reshape(data[key][1]) - data[key] = value + + if isinstance(data, list): + for key, value in enumerate(data): + if isinstance(value, _np_serialized_data): + data[key] = np.frombuffer(buffer=data[key].value, + dtype=data[key].dtype).reshape(data[key].shape) + + elif isinstance(data, dict): + for key in data: + if isinstance(data[key], _np_serialized_data): + data[key] = np.frombuffer(buffer=data[key].value, + dtype=data[key].dtype).reshape(data[key].shape) return data -- GitLab