提交 f521532a 编写于 作者: L liyong

fix field_name probelem from tfrecord to mindrecord

上级 b5d8dad4
...@@ -385,9 +385,14 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const ...@@ -385,9 +385,14 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
} }
TensorRow row; TensorRow row;
std::unordered_map<std::string, int32_t> column_name_id_map = std::unordered_map<std::string, int32_t> column_name_id_map;
iterator_->GetColumnNameMap(); // map of column name, id for (auto el : iterator_->GetColumnNameMap()) {
bool first_loop = true; // build schema in first loop std::string column_name = el.first;
std::transform(column_name.begin(), column_name.end(), column_name.begin(),
[](unsigned char c) { return ispunct(c) ? '_' : c; });
column_name_id_map[column_name] = el.second;
}
bool first_loop = true; // build schema in first loop
do { do {
json row_raw_data; json row_raw_data;
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data; std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data;
...@@ -402,7 +407,10 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const ...@@ -402,7 +407,10 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
std::vector<std::string> index_fields; std::vector<std::string> index_fields;
s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields); s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields);
RETURN_IF_NOT_OK(s); RETURN_IF_NOT_OK(s);
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id); if (mindrecord::SUCCESS !=
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) {
RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader.");
}
mr_writer->SetShardHeader(mr_header); mr_writer->SetShardHeader(mr_header);
first_loop = false; first_loop = false;
} }
...@@ -422,7 +430,9 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const ...@@ -422,7 +430,9 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
} }
} while (!row.empty()); } while (!row.empty());
mr_writer->Commit(); mr_writer->Commit();
mindrecord::ShardIndexGenerator::finalize(file_names); if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::finalize(file_names)) {
RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator.");
}
return Status::OK(); return Status::OK();
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
This is the test module for saveOp. This is the test module for saveOp.
""" """
import os import os
from string import punctuation
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
from mindspore.mindrecord import FileWriter from mindspore.mindrecord import FileWriter
...@@ -24,7 +25,7 @@ import pytest ...@@ -24,7 +25,7 @@ import pytest
CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord" CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord"
CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord" CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord"
TFRECORD_FILES = "../data/mindrecord/testTFRecordData/dummy.tfrecord"
FILES_NUM = 1 FILES_NUM = 1
num_readers = 1 num_readers = 1
...@@ -388,3 +389,46 @@ def test_case_06(add_and_remove_cv_file): ...@@ -388,3 +389,46 @@ def test_case_06(add_and_remove_cv_file):
with pytest.raises(Exception, match="tfrecord dataset format is not supported."): with pytest.raises(Exception, match="tfrecord dataset format is not supported."):
d1.save(CV_FILE_NAME2, 1, "tfrecord") d1.save(CV_FILE_NAME2, 1, "tfrecord")
def cast_name(key):
"""
Cast schema names which containing special characters to valid names.
"""
special_symbols = set('{}{}'.format(punctuation, ' '))
special_symbols.remove('_')
new_key = ['_' if x in special_symbols else x for x in key]
casted_key = ''.join(new_key)
return casted_key
def test_case_07():
if os.path.exists("{}".format(CV_FILE_NAME2)):
os.remove("{}".format(CV_FILE_NAME2))
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
os.remove("{}.db".format(CV_FILE_NAME2))
d1 = ds.TFRecordDataset(TFRECORD_FILES, shuffle=False)
tf_data = []
for x in d1.create_dict_iterator():
tf_data.append(x)
d1.save(CV_FILE_NAME2, FILES_NUM)
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
num_parallel_workers=num_readers,
shuffle=False)
mr_data = []
for x in d2.create_dict_iterator():
mr_data.append(x)
count = 0
for x in tf_data:
for k, v in x.items():
if isinstance(v, np.ndarray):
assert (v == mr_data[count][cast_name(k)]).all()
else:
assert v == mr_data[count][cast_name(k)]
count += 1
assert count == 10
if os.path.exists("{}".format(CV_FILE_NAME2)):
os.remove("{}".format(CV_FILE_NAME2))
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
os.remove("{}.db".format(CV_FILE_NAME2))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册