diff --git a/mindspore/mindrecord/shardwriter.py b/mindspore/mindrecord/shardwriter.py index 0913201861ce8d77c2e44ee8a2e4169faa626b69..49391d13ce1c4fde42100fc2d71bbf97b1646d39 100644 --- a/mindspore/mindrecord/shardwriter.py +++ b/mindspore/mindrecord/shardwriter.py @@ -29,6 +29,7 @@ class ShardWriter: The class would write MindRecord File series. """ + def __init__(self): self._writer = ms.ShardWriter() self._header = None @@ -161,7 +162,7 @@ class ShardWriter: if row_blob: blob_data.append(list(row_blob)) # filter raw data according to schema - row_raw = {field: item[field] + row_raw = {field: self._convert_np_types(item[field]) for field in self._header.schema.keys() - self._header.blob_fields if field in item} if row_raw: raw_data.append(row_raw) @@ -172,6 +173,12 @@ class ShardWriter: raise MRMWriteDatasetError return ret + def _convert_np_types(self, val): + """convert numpy type to python primitive type""" + if isinstance(val, (np.int32, np.int64, np.float32, np.float64)): + return val.item() + return val + def _merge_blob(self, blob_data): """ Merge multiple blob data whose type is bytes or ndarray diff --git a/tests/ut/python/dataset/test_minddataset.py b/tests/ut/python/dataset/test_minddataset.py index 8d22bd6c50f6bf3b0d9c6e6c1c69a5599275422a..5791ea9618d51a484bf8270bf3d8d2cd51e77246 100644 --- a/tests/ut/python/dataset/test_minddataset.py +++ b/tests/ut/python/dataset/test_minddataset.py @@ -1853,3 +1853,42 @@ def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset( os.remove("{}".format(mindrecord_file_name)) os.remove("{}.db".format(mindrecord_file_name)) + +def test_numpy_generic(): + + paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) + for x in range(FILES_NUM)] + for x in paths: + if os.path.exists("{}".format(x)): + os.remove("{}".format(x)) + if os.path.exists("{}.db".format(x)): + os.remove("{}.db".format(x)) + writer = FileWriter(CV_FILE_NAME, FILES_NUM) + cv_schema_json = {"label1": {"type": "int32"}, "label2": {"type": "int64"}, + "label3": {"type": "float32"}, "label4": {"type": "float64"}} + data = [] + for idx in range(10): + row = {} + row['label1'] = np.int32(idx) + row['label2'] = np.int64(idx*10) + row['label3'] = np.float32(idx+0.12345) + row['label4'] = np.float64(idx+0.12345789) + data.append(row) + writer.add_schema(cv_schema_json, "img_schema") + writer.write_raw_data(data) + writer.commit() + + num_readers = 4 + data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers, shuffle=False) + assert data_set.get_dataset_size() == 10 + idx = 0 + for item in data_set.create_dict_iterator(): + assert item['label1'] == item['label1'] + assert item['label2'] == item['label2'] + assert item['label3'] == item['label3'] + assert item['label4'] == item['label4'] + idx += 1 + assert idx == 10 + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x))