diff --git a/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc index d9e51efc4e9df445368e0c3b73ed41ac48b26bae..f36182027786e72ae6bfac2920e0db0bb19ea80a 100644 --- a/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc +++ b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc @@ -133,6 +133,7 @@ void BindGlobalParams(py::module *m) { (*m).attr("MAX_PAGE_SIZE") = kMaxPageSize; (*m).attr("MIN_SHARD_COUNT") = kMinShardCount; (*m).attr("MAX_SHARD_COUNT") = kMaxShardCount; + (*m).attr("MAX_FILE_COUNT") = kMaxFileCount; (*m).attr("MIN_CONSUMER_COUNT") = kMinConsumerCount; (void)(*m).def("get_max_thread_num", &GetMaxThreadNum); } diff --git a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h index 198cd46cc8000c08395b8b7a9ff022a5a67d1098..5ee6f70a82f648b502dcd436f0fe2fe28b7ebeb7 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h @@ -104,7 +104,8 @@ const uint64_t kInt64Len = 8; const uint64_t kMinFileSize = kInt64Len; const int kMinShardCount = 1; -const int kMaxShardCount = 1000; +const int kMaxShardCount = 1000; // write +const int kMaxFileCount = 4096; // read const int kMinConsumerCount = 1; const int kMaxConsumerCount = 128; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h index 4aad0d117ed1c5663a6d7944ab8c2fd72eca2b72..51928d7874edef57528de02336c5dcffb5d970ab 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h @@ -152,7 +152,7 @@ class ShardHeader { MSRStatus CheckIndexField(const std::string &field, const json &schema); - void ParsePage(const json &page, int shard_index, bool load_dataset); + MSRStatus ParsePage(const json &page, int shard_index, bool load_dataset); MSRStatus ParseStatistics(const json &statistics); diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc index 84d7fddb6f1183280cca79736eb1301f8e4624ed..37d7f19b625d32d0fad5ab1ccaa256737a6333b6 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc @@ -252,7 +252,7 @@ std::vector> ShardReader::ReadRowGroupSummar if (shard_count <= 0) { return row_group_summary; } - if (shard_count <= kMaxShardCount) { + if (shard_count <= kMaxFileCount) { for (int shard_id = 0; shard_id < shard_count; ++shard_id) { // return -1 when page's size equals to 0. auto last_page_id = shard_header_->GetLastPageId(shard_id); @@ -1054,7 +1054,7 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector(ret); auto local_columns = std::get<2>(ret); - if (shard_count_ <= kMaxShardCount) { + if (shard_count_ <= kMaxFileCount) { for (int shard_id = 0; shard_id < shard_count_; shard_id++) { for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) { tasks_.InsertTask(TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1], diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc index 843b412a31c9afc0859dbfb9288f1c1449285a68..f94f92d939abefb499b22505a3bc42a1e6321d55 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc @@ -55,7 +55,9 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector &headers, bool l header_size_ = header["header_size"].get(); page_size_ = header["page_size"].get(); } - ParsePage(header["page"], shard_index, load_dataset); + if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) { + return FAILED; + } shard_index++; } return SUCCESS; @@ -248,11 +250,16 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { return SUCCESS; } -void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { +MSRStatus ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { // set shard_index when load_dataset is false - if (pages_.empty() && shard_count_ <= kMaxShardCount) { + if (shard_count_ > kMaxFileCount) { + MS_LOG(ERROR) << "The number of mindrecord files is greater than max value: " << kMaxFileCount; + return FAILED; + } + if (pages_.empty() && shard_count_ <= kMaxFileCount) { pages_.resize(shard_count_); } + for (auto &page : pages) { int page_id = page["page_id"]; int shard_id = page["shard_id"]; @@ -275,6 +282,7 @@ void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_datase pages_[shard_index].push_back(std::move(parsed_page)); } } + return SUCCESS; } MSRStatus ShardHeader::ParseStatistics(const json &statistics) { @@ -715,7 +723,9 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { std::string line; while (std::getline(page_in_handle, line)) { - ParsePage(json::parse(line), -1, true); + if (SUCCESS != ParsePage(json::parse(line), -1, true)) { + return FAILED; + } } page_in_handle.close(); diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index caf3857e2060e8e863faf9ebff987544fe539d01..d537fc3fb674cd40574eefb2bdd83a45bfeef448 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1054,45 +1054,45 @@ class Dataset: * - type in 'dataset' - type in 'mindrecord' - detail - * - DE_BOOL + * - bool - None - Not support - * - DE_INT8 + * - int8 - int32 - - * - DE_UINT8 + * - uint8 - bytes(1D uint8) - Drop dimension - * - DE_INT16 + * - int16 - int32 - - * - DE_UINT16 + * - uint16 - int32 - - * - DE_INT32 + * - int32 - int32 - - * - DE_UINT32 + * - uint32 - int64 - - * - DE_INT64 + * - int64 - int64 - - * - DE_UINT64 + * - uint64 - None - Not support - * - DE_FLOAT16 - - Not support + * - float16 + - float32 - - * - DE_FLOAT32 + * - float32 - float32 - - * - DE_FLOAT64 + * - float64 - float64 - - * - DE_STRING + * - string - string - - Not support multi-dimensional DE_STRING + - Not support multi-dimensional string Note: 1. To save the samples in order, should set dataset's shuffle false and num_files 1. diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index a9a61c113cd9a9e3fe724dc3f1c7a07f3e21ecbe..03f52380c0e88326a9e08cb7895682ed441feff9 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -278,6 +278,8 @@ def check_minddataset(method): dataset_file = param_dict.get('dataset_file') if isinstance(dataset_file, list): + if len(dataset_file) > 4096: + raise ValueError("length of dataset_file should less than or equal to {}.".format(4096)) for f in dataset_file: check_file(f) else: