diff --git a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc index b31fdcf63b8585ea5a24d48dfd689d890bd1ab6c..4302e12954d06f1d10a443fd337a4f31e26ed69b 100644 --- a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc @@ -385,9 +385,14 @@ Status DEPipeline::SaveDataset(const std::vector &file_names, const } TensorRow row; - std::unordered_map column_name_id_map = - iterator_->GetColumnNameMap(); // map of column name, id - bool first_loop = true; // build schema in first loop + std::unordered_map column_name_id_map; + for (auto el : iterator_->GetColumnNameMap()) { + std::string column_name = el.first; + std::transform(column_name.begin(), column_name.end(), column_name.begin(), + [](unsigned char c) { return ispunct(c) ? '_' : c; }); + column_name_id_map[column_name] = el.second; + } + bool first_loop = true; // build schema in first loop do { json row_raw_data; std::map>> row_bin_data; @@ -402,7 +407,10 @@ Status DEPipeline::SaveDataset(const std::vector &file_names, const std::vector index_fields; s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields); RETURN_IF_NOT_OK(s); - mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id); + if (mindrecord::SUCCESS != + mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) { + RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader."); + } mr_writer->SetShardHeader(mr_header); first_loop = false; } @@ -422,7 +430,9 @@ Status DEPipeline::SaveDataset(const std::vector &file_names, const } } while (!row.empty()); mr_writer->Commit(); - mindrecord::ShardIndexGenerator::finalize(file_names); + if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::finalize(file_names)) { + RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator."); + } return Status::OK(); } diff --git a/tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord b/tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord new file mode 100644 index 0000000000000000000000000000000000000000..da4f853e2d7c948921750cdf31539bed30543566 Binary files /dev/null and b/tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord differ diff --git a/tests/ut/python/dataset/test_save_op.py b/tests/ut/python/dataset/test_save_op.py index 2ed326276b3f04e816fcbd753ce2d3c35ae6562b..2af14aec1ce97f2fbaa16296dda41541fe44a93e 100644 --- a/tests/ut/python/dataset/test_save_op.py +++ b/tests/ut/python/dataset/test_save_op.py @@ -16,6 +16,7 @@ This is the test module for saveOp. """ import os +from string import punctuation import mindspore.dataset as ds from mindspore import log as logger from mindspore.mindrecord import FileWriter @@ -24,7 +25,7 @@ import pytest CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord" CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord" - +TFRECORD_FILES = "../data/mindrecord/testTFRecordData/dummy.tfrecord" FILES_NUM = 1 num_readers = 1 @@ -388,3 +389,46 @@ def test_case_06(add_and_remove_cv_file): with pytest.raises(Exception, match="tfrecord dataset format is not supported."): d1.save(CV_FILE_NAME2, 1, "tfrecord") + + +def cast_name(key): + """ + Cast schema names which containing special characters to valid names. + """ + special_symbols = set('{}{}'.format(punctuation, ' ')) + special_symbols.remove('_') + new_key = ['_' if x in special_symbols else x for x in key] + casted_key = ''.join(new_key) + return casted_key + + +def test_case_07(): + if os.path.exists("{}".format(CV_FILE_NAME2)): + os.remove("{}".format(CV_FILE_NAME2)) + if os.path.exists("{}.db".format(CV_FILE_NAME2)): + os.remove("{}.db".format(CV_FILE_NAME2)) + d1 = ds.TFRecordDataset(TFRECORD_FILES, shuffle=False) + tf_data = [] + for x in d1.create_dict_iterator(): + tf_data.append(x) + d1.save(CV_FILE_NAME2, FILES_NUM) + d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, + num_parallel_workers=num_readers, + shuffle=False) + mr_data = [] + for x in d2.create_dict_iterator(): + mr_data.append(x) + count = 0 + for x in tf_data: + for k, v in x.items(): + if isinstance(v, np.ndarray): + assert (v == mr_data[count][cast_name(k)]).all() + else: + assert v == mr_data[count][cast_name(k)] + count += 1 + assert count == 10 + + if os.path.exists("{}".format(CV_FILE_NAME2)): + os.remove("{}".format(CV_FILE_NAME2)) + if os.path.exists("{}.db".format(CV_FILE_NAME2)): + os.remove("{}.db".format(CV_FILE_NAME2))