提交 f521532a 编写于 作者: L liyong

fix field_name probelem from tfrecord to mindrecord

上级 b5d8dad4
......@@ -385,9 +385,14 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
}
TensorRow row;
std::unordered_map<std::string, int32_t> column_name_id_map =
iterator_->GetColumnNameMap(); // map of column name, id
bool first_loop = true; // build schema in first loop
std::unordered_map<std::string, int32_t> column_name_id_map;
for (auto el : iterator_->GetColumnNameMap()) {
std::string column_name = el.first;
std::transform(column_name.begin(), column_name.end(), column_name.begin(),
[](unsigned char c) { return ispunct(c) ? '_' : c; });
column_name_id_map[column_name] = el.second;
}
bool first_loop = true; // build schema in first loop
do {
json row_raw_data;
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data;
......@@ -402,7 +407,10 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
std::vector<std::string> index_fields;
s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields);
RETURN_IF_NOT_OK(s);
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id);
if (mindrecord::SUCCESS !=
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) {
RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader.");
}
mr_writer->SetShardHeader(mr_header);
first_loop = false;
}
......@@ -422,7 +430,9 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
}
} while (!row.empty());
mr_writer->Commit();
mindrecord::ShardIndexGenerator::finalize(file_names);
if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::finalize(file_names)) {
RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator.");
}
return Status::OK();
}
......
......@@ -16,6 +16,7 @@
This is the test module for saveOp.
"""
import os
from string import punctuation
import mindspore.dataset as ds
from mindspore import log as logger
from mindspore.mindrecord import FileWriter
......@@ -24,7 +25,7 @@ import pytest
CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord"
CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord"
TFRECORD_FILES = "../data/mindrecord/testTFRecordData/dummy.tfrecord"
FILES_NUM = 1
num_readers = 1
......@@ -388,3 +389,46 @@ def test_case_06(add_and_remove_cv_file):
with pytest.raises(Exception, match="tfrecord dataset format is not supported."):
d1.save(CV_FILE_NAME2, 1, "tfrecord")
def cast_name(key):
"""
Cast schema names which containing special characters to valid names.
"""
special_symbols = set('{}{}'.format(punctuation, ' '))
special_symbols.remove('_')
new_key = ['_' if x in special_symbols else x for x in key]
casted_key = ''.join(new_key)
return casted_key
def test_case_07():
if os.path.exists("{}".format(CV_FILE_NAME2)):
os.remove("{}".format(CV_FILE_NAME2))
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
os.remove("{}.db".format(CV_FILE_NAME2))
d1 = ds.TFRecordDataset(TFRECORD_FILES, shuffle=False)
tf_data = []
for x in d1.create_dict_iterator():
tf_data.append(x)
d1.save(CV_FILE_NAME2, FILES_NUM)
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
num_parallel_workers=num_readers,
shuffle=False)
mr_data = []
for x in d2.create_dict_iterator():
mr_data.append(x)
count = 0
for x in tf_data:
for k, v in x.items():
if isinstance(v, np.ndarray):
assert (v == mr_data[count][cast_name(k)]).all()
else:
assert v == mr_data[count][cast_name(k)]
count += 1
assert count == 10
if os.path.exists("{}".format(CV_FILE_NAME2)):
os.remove("{}".format(CV_FILE_NAME2))
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
os.remove("{}.db".format(CV_FILE_NAME2))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册