diff --git a/pgl/utils/mp_reader.py b/pgl/utils/mp_reader.py index fd3b97295e1373942b97afd69ab4310277b4aa4b..55c8b640fc5d460a24e87baeae66dfe1f537bd93 100644 --- a/pgl/utils/mp_reader.py +++ b/pgl/utils/mp_reader.py @@ -38,17 +38,20 @@ def serialize_data(data): return None return numpy_serialize_data(data) #, ensure_ascii=False) +def index_iter(data): + """return indexing iter""" + if isinstance(data, list): + return range(len(data)) + elif isinstance(data, dict): + return data.keys() + def numpy_serialize_data(data): """serialize_data""" ret_data = copy.deepcopy(data) - if isinstance(ret_data, list): - for key, value in enumerate(ret_data): - if isinstance(ret_data[key], np.ndarray): - ret_data[key] = _np_serialized_data(value=ret_data[key].tobytes(), - shape=list(ret_data[key].shape), dtype="%s" % ret_data[key].dtype) - elif isinstance(ret_data, dict): - for key in ret_data: + + if isinstance(ret_data, (dict, list)): + for key in index_iter(ret_data): if isinstance(ret_data[key], np.ndarray): ret_data[key] = _np_serialized_data(value=ret_data[key].tobytes(), shape=list(ret_data[key].shape), dtype="%s" % ret_data[key].dtype) @@ -60,14 +63,8 @@ def numpy_deserialize_data(data): if data is None: return None - if isinstance(data, list): - for key, value in enumerate(data): - if isinstance(value, _np_serialized_data): - data[key] = np.frombuffer(buffer=data[key].value, - dtype=data[key].dtype).reshape(data[key].shape) - - elif isinstance(data, dict): - for key in data: + if isinstance(data, (dict, list)): + for key in index_iter(data): if isinstance(data[key], _np_serialized_data): data[key] = np.frombuffer(buffer=data[key].value, dtype=data[key].dtype).reshape(data[key].shape)