From 8ea455503a2ab77ab880537ffd8810b8c3a56c1f Mon Sep 17 00:00:00 2001 From: Yelrose <270018958@qq.com> Date: Wed, 9 Sep 2020 13:26:47 +0800 Subject: [PATCH] fixed mp_reader for list --- pgl/utils/mp_reader.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/pgl/utils/mp_reader.py b/pgl/utils/mp_reader.py index fd3b972..55c8b64 100644 --- a/pgl/utils/mp_reader.py +++ b/pgl/utils/mp_reader.py @@ -38,17 +38,20 @@ def serialize_data(data): return None return numpy_serialize_data(data) #, ensure_ascii=False) +def index_iter(data): + """return indexing iter""" + if isinstance(data, list): + return range(len(data)) + elif isinstance(data, dict): + return data.keys() + def numpy_serialize_data(data): """serialize_data""" ret_data = copy.deepcopy(data) - if isinstance(ret_data, list): - for key, value in enumerate(ret_data): - if isinstance(ret_data[key], np.ndarray): - ret_data[key] = _np_serialized_data(value=ret_data[key].tobytes(), - shape=list(ret_data[key].shape), dtype="%s" % ret_data[key].dtype) - elif isinstance(ret_data, dict): - for key in ret_data: + + if isinstance(ret_data, (dict, list)): + for key in index_iter(ret_data): if isinstance(ret_data[key], np.ndarray): ret_data[key] = _np_serialized_data(value=ret_data[key].tobytes(), shape=list(ret_data[key].shape), dtype="%s" % ret_data[key].dtype) @@ -60,14 +63,8 @@ def numpy_deserialize_data(data): if data is None: return None - if isinstance(data, list): - for key, value in enumerate(data): - if isinstance(value, _np_serialized_data): - data[key] = np.frombuffer(buffer=data[key].value, - dtype=data[key].dtype).reshape(data[key].shape) - - elif isinstance(data, dict): - for key in data: + if isinstance(data, (dict, list)): + for key in index_iter(data): if isinstance(data[key], _np_serialized_data): data[key] = np.frombuffer(buffer=data[key].value, dtype=data[key].dtype).reshape(data[key].shape) -- GitLab