提交 28ebd730 编写于 作者: L liyong

fix bug when int or float is numpy type

上级 74f2c89d
...@@ -29,6 +29,7 @@ class ShardWriter: ...@@ -29,6 +29,7 @@ class ShardWriter:
The class would write MindRecord File series. The class would write MindRecord File series.
""" """
def __init__(self): def __init__(self):
self._writer = ms.ShardWriter() self._writer = ms.ShardWriter()
self._header = None self._header = None
...@@ -161,7 +162,7 @@ class ShardWriter: ...@@ -161,7 +162,7 @@ class ShardWriter:
if row_blob: if row_blob:
blob_data.append(list(row_blob)) blob_data.append(list(row_blob))
# filter raw data according to schema # filter raw data according to schema
row_raw = {field: item[field] row_raw = {field: self._convert_np_types(item[field])
for field in self._header.schema.keys() - self._header.blob_fields if field in item} for field in self._header.schema.keys() - self._header.blob_fields if field in item}
if row_raw: if row_raw:
raw_data.append(row_raw) raw_data.append(row_raw)
...@@ -172,6 +173,12 @@ class ShardWriter: ...@@ -172,6 +173,12 @@ class ShardWriter:
raise MRMWriteDatasetError raise MRMWriteDatasetError
return ret return ret
def _convert_np_types(self, val):
"""convert numpy type to python primitive type"""
if isinstance(val, (np.int32, np.int64, np.float32, np.float64)):
return val.item()
return val
def _merge_blob(self, blob_data): def _merge_blob(self, blob_data):
""" """
Merge multiple blob data whose type is bytes or ndarray Merge multiple blob data whose type is bytes or ndarray
......
...@@ -1853,3 +1853,42 @@ def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset( ...@@ -1853,3 +1853,42 @@ def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset(
os.remove("{}".format(mindrecord_file_name)) os.remove("{}".format(mindrecord_file_name))
os.remove("{}.db".format(mindrecord_file_name)) os.remove("{}.db".format(mindrecord_file_name))
def test_numpy_generic():
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
for x in range(FILES_NUM)]
for x in paths:
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
cv_schema_json = {"label1": {"type": "int32"}, "label2": {"type": "int64"},
"label3": {"type": "float32"}, "label4": {"type": "float64"}}
data = []
for idx in range(10):
row = {}
row['label1'] = np.int32(idx)
row['label2'] = np.int64(idx*10)
row['label3'] = np.float32(idx+0.12345)
row['label4'] = np.float64(idx+0.12345789)
data.append(row)
writer.add_schema(cv_schema_json, "img_schema")
writer.write_raw_data(data)
writer.commit()
num_readers = 4
data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers, shuffle=False)
assert data_set.get_dataset_size() == 10
idx = 0
for item in data_set.create_dict_iterator():
assert item['label1'] == item['label1']
assert item['label2'] == item['label2']
assert item['label3'] == item['label3']
assert item['label4'] == item['label4']
idx += 1
assert idx == 10
for x in paths:
os.remove("{}".format(x))
os.remove("{}.db".format(x))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册