提交 9bc2134c 编写于 作者: P Peilin Wang

added checking of first row crc to find invalid tfrecord files

addressed code review comments. added check in python layer to exclude directories and to raise an error if a pattern does not match any file

fixed clang format

fixed cppcheck

fixed cppcheck (used std::accumulate and std::copy_if). regenerated tfrecord file to contain correct header, it was a dummy header before

fixed cppcheck: added const reference for string parameter for lambdas, fixed clang format: whitespace adjustments

more clang whitespace fixes...

changed print to logger.info
上级 d8176a77
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
#include "dataset/util/status.h" #include "dataset/util/status.h"
#include "dataset/util/task_manager.h" #include "dataset/util/task_manager.h"
#include "dataset/util/wait_post.h" #include "dataset/util/wait_post.h"
#include "utils/system/crc32c.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
...@@ -56,15 +57,58 @@ TFReaderOp::Builder::Builder() ...@@ -56,15 +57,58 @@ TFReaderOp::Builder::Builder()
builder_data_schema_ = std::make_unique<DataSchema>(); builder_data_schema_ = std::make_unique<DataSchema>();
} }
bool ValidateFirstRowCrc(const std::string &filename) {
std::ifstream reader;
reader.open(filename);
if (!reader) {
return false;
}
// read data
int64_t record_length = 0;
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
// read crc from file
uint32_t masked_crc = 0;
(void)reader.read(reinterpret_cast<char *>(&masked_crc), static_cast<std::streamsize>(sizeof(uint32_t)));
// generate crc from data
uint32_t generated_crc =
system::Crc32c::GetMaskCrc32cValue(reinterpret_cast<char *>(&record_length), sizeof(int64_t));
return masked_crc == generated_crc;
}
Status TFReaderOp::Builder::ValidateInputs() const { Status TFReaderOp::Builder::ValidateInputs() const {
std::string err_msg; std::string err_msg;
err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is smaller or equal to 0\n" : "";
if (!builder_equal_rows_per_shard_) { if (builder_num_workers_ <= 0) {
err_msg += builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_) err_msg += "Number of parallel workers is smaller or equal to 0\n";
? "No enough tf_file files provided\n" }
: "";
if (!builder_equal_rows_per_shard_ &&
builder_dataset_files_list_.size() < static_cast<uint32_t>(builder_num_devices_)) {
err_msg += "Not enough tfrecord files provided\n";
}
if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) {
err_msg += "Wrong sharding configs\n";
} }
err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : "";
std::vector<std::string> invalid_files(builder_dataset_files_list_.size());
auto it = std::copy_if(builder_dataset_files_list_.begin(), builder_dataset_files_list_.end(), invalid_files.begin(),
[](const std::string &filename) { return !ValidateFirstRowCrc(filename); });
invalid_files.resize(std::distance(invalid_files.begin(), it));
if (!invalid_files.empty()) {
err_msg += "The following files either cannot be opened, or are not valid tfrecord files:\n";
std::string accumulated_filenames = std::accumulate(
invalid_files.begin(), invalid_files.end(), std::string(""),
[](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; });
err_msg += accumulated_filenames;
}
return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg);
} }
...@@ -523,6 +567,7 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off ...@@ -523,6 +567,7 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off
RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read)); RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read));
rows_read++; rows_read++;
} }
// ignore crc footer // ignore crc footer
(void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t))); (void)reader.ignore(static_cast<std::streamsize>(sizeof(int32_t)));
rows_total++; rows_total++;
......
...@@ -900,13 +900,22 @@ class SourceDataset(Dataset): ...@@ -900,13 +900,22 @@ class SourceDataset(Dataset):
List, files. List, files.
""" """
def flat(lists):
return list(np.array(lists).flatten())
if not isinstance(patterns, list): if not isinstance(patterns, list):
patterns = [patterns] patterns = [patterns]
file_list = flat([glob.glob(file, recursive=True) for file in patterns]) file_list = []
unmatched_patterns = []
for pattern in patterns:
matches = [match for match in glob.glob(pattern, recursive=True) if os.path.isfile(match)]
if matches:
file_list.extend(matches)
else:
unmatched_patterns.append(pattern)
if unmatched_patterns:
raise ValueError("The following patterns did not match any files: ", unmatched_patterns)
if file_list: # not empty if file_list: # not empty
return file_list return file_list
raise ValueError("The list of path names matching the patterns is empty.") raise ValueError("The list of path names matching the patterns is empty.")
......
...@@ -697,3 +697,37 @@ TEST_F(MindDataTestTFReaderOp, TestTotalRowsBasic) { ...@@ -697,3 +697,37 @@ TEST_F(MindDataTestTFReaderOp, TestTotalRowsBasic) {
TFReaderOp::CountTotalRows(&total_rows, filenames, 729, true); TFReaderOp::CountTotalRows(&total_rows, filenames, 729, true);
ASSERT_EQ(total_rows, 60); ASSERT_EQ(total_rows, 60);
} }
TEST_F(MindDataTestTFReaderOp, TestTFReaderInvalidFiles) {
// Start with an empty execution tree
auto my_tree = std::make_shared<ExecutionTree>();
std::string valid_file = datasets_root_path_ + "/testTFTestAllTypes/test.data";
std::string schema_file = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json";
std::string invalid_file = datasets_root_path_ + "/testTFTestAllTypes/invalidFile.txt";
std::string nonexistent_file = "this/file/doesnt/exist";
std::shared_ptr<TFReaderOp> my_tfreader_op;
TFReaderOp::Builder builder;
builder.SetDatasetFilesList({invalid_file, valid_file, schema_file})
.SetRowsPerBuffer(16)
.SetNumWorkers(16);
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
schema->LoadSchemaFile(schema_file, {});
builder.SetDataSchema(std::move(schema));
Status rc = builder.Build(&my_tfreader_op);
ASSERT_TRUE(!rc.IsOk());
builder.SetDatasetFilesList({invalid_file, valid_file, schema_file, nonexistent_file})
.SetRowsPerBuffer(16)
.SetNumWorkers(16);
schema = std::make_unique<DataSchema>();
schema->LoadSchemaFile(schema_file, {});
builder.SetDataSchema(std::move(schema));
rc = builder.Build(&my_tfreader_op);
ASSERT_TRUE(!rc.IsOk());
}
this is just a text file, not a valid tfrecord file.
...@@ -32,7 +32,7 @@ def test_case_tf_shape(): ...@@ -32,7 +32,7 @@ def test_case_tf_shape():
ds1 = ds.TFRecordDataset(FILES, schema_file) ds1 = ds.TFRecordDataset(FILES, schema_file)
ds1 = ds1.batch(2) ds1 = ds1.batch(2)
for data in ds1.create_dict_iterator(): for data in ds1.create_dict_iterator():
print(data) logger.info(data)
output_shape = ds1.output_shapes() output_shape = ds1.output_shapes()
assert (len(output_shape[-1]) == 1) assert (len(output_shape[-1]) == 1)
...@@ -203,6 +203,32 @@ def test_tf_record_schema_columns_list(): ...@@ -203,6 +203,32 @@ def test_tf_record_schema_columns_list():
a = row["col_sint32"] a = row["col_sint32"]
assert "col_sint32" in str(info.value) assert "col_sint32" in str(info.value)
def test_case_invalid_files():
valid_file = "../data/dataset/testTFTestAllTypes/test.data"
invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt"
files = [invalid_file, valid_file, SCHEMA_FILE]
data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
with pytest.raises(RuntimeError) as info:
row = data.create_dict_iterator().get_next()
assert "cannot be opened" in str(info.value)
assert "not valid tfrecord files" in str(info.value)
assert valid_file not in str(info.value)
assert invalid_file in str(info.value)
assert SCHEMA_FILE in str(info.value)
nonexistent_file = "this/file/does/not/exist"
files = [invalid_file, valid_file, SCHEMA_FILE, nonexistent_file]
with pytest.raises(ValueError) as info:
data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
assert "did not match any files" in str(info.value)
assert valid_file not in str(info.value)
assert invalid_file not in str(info.value)
assert SCHEMA_FILE not in str(info.value)
assert nonexistent_file in str(info.value)
if __name__ == '__main__': if __name__ == '__main__':
test_case_tf_shape() test_case_tf_shape()
test_case_tf_file() test_case_tf_file()
...@@ -212,3 +238,4 @@ if __name__ == '__main__': ...@@ -212,3 +238,4 @@ if __name__ == '__main__':
test_tf_record_schema() test_tf_record_schema()
test_tf_record_shuffle() test_tf_record_shuffle()
test_tf_shard_equal_rows() test_tf_shard_equal_rows()
test_case_invalid_files()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册