提交 ed70de80 编写于 作者: L liyong

fix coredump when number of file list more than 1000.

上级 e4a7ca7f
...@@ -133,6 +133,7 @@ void BindGlobalParams(py::module *m) { ...@@ -133,6 +133,7 @@ void BindGlobalParams(py::module *m) {
(*m).attr("MAX_PAGE_SIZE") = kMaxPageSize; (*m).attr("MAX_PAGE_SIZE") = kMaxPageSize;
(*m).attr("MIN_SHARD_COUNT") = kMinShardCount; (*m).attr("MIN_SHARD_COUNT") = kMinShardCount;
(*m).attr("MAX_SHARD_COUNT") = kMaxShardCount; (*m).attr("MAX_SHARD_COUNT") = kMaxShardCount;
(*m).attr("MAX_FILE_COUNT") = kMaxFileCount;
(*m).attr("MIN_CONSUMER_COUNT") = kMinConsumerCount; (*m).attr("MIN_CONSUMER_COUNT") = kMinConsumerCount;
(void)(*m).def("get_max_thread_num", &GetMaxThreadNum); (void)(*m).def("get_max_thread_num", &GetMaxThreadNum);
} }
......
...@@ -104,7 +104,8 @@ const uint64_t kInt64Len = 8; ...@@ -104,7 +104,8 @@ const uint64_t kInt64Len = 8;
const uint64_t kMinFileSize = kInt64Len; const uint64_t kMinFileSize = kInt64Len;
const int kMinShardCount = 1; const int kMinShardCount = 1;
const int kMaxShardCount = 1000; const int kMaxShardCount = 1000; // write
const int kMaxFileCount = 4096; // read
const int kMinConsumerCount = 1; const int kMinConsumerCount = 1;
const int kMaxConsumerCount = 128; const int kMaxConsumerCount = 128;
......
...@@ -152,7 +152,7 @@ class ShardHeader { ...@@ -152,7 +152,7 @@ class ShardHeader {
MSRStatus CheckIndexField(const std::string &field, const json &schema); MSRStatus CheckIndexField(const std::string &field, const json &schema);
void ParsePage(const json &page, int shard_index, bool load_dataset); MSRStatus ParsePage(const json &page, int shard_index, bool load_dataset);
MSRStatus ParseStatistics(const json &statistics); MSRStatus ParseStatistics(const json &statistics);
......
...@@ -252,7 +252,7 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar ...@@ -252,7 +252,7 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar
if (shard_count <= 0) { if (shard_count <= 0) {
return row_group_summary; return row_group_summary;
} }
if (shard_count <= kMaxShardCount) { if (shard_count <= kMaxFileCount) {
for (int shard_id = 0; shard_id < shard_count; ++shard_id) { for (int shard_id = 0; shard_id < shard_count; ++shard_id) {
// return -1 when page's size equals to 0. // return -1 when page's size equals to 0.
auto last_page_id = shard_header_->GetLastPageId(shard_id); auto last_page_id = shard_header_->GetLastPageId(shard_id);
...@@ -1054,7 +1054,7 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, i ...@@ -1054,7 +1054,7 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, i
} }
auto offsets = std::get<1>(ret); auto offsets = std::get<1>(ret);
auto local_columns = std::get<2>(ret); auto local_columns = std::get<2>(ret);
if (shard_count_ <= kMaxShardCount) { if (shard_count_ <= kMaxFileCount) {
for (int shard_id = 0; shard_id < shard_count_; shard_id++) { for (int shard_id = 0; shard_id < shard_count_; shard_id++) {
for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) { for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) {
tasks_.InsertTask(TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1], tasks_.InsertTask(TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1],
......
...@@ -55,7 +55,9 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool l ...@@ -55,7 +55,9 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool l
header_size_ = header["header_size"].get<uint64_t>(); header_size_ = header["header_size"].get<uint64_t>();
page_size_ = header["page_size"].get<uint64_t>(); page_size_ = header["page_size"].get<uint64_t>();
} }
ParsePage(header["page"], shard_index, load_dataset); if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) {
return FAILED;
}
shard_index++; shard_index++;
} }
return SUCCESS; return SUCCESS;
...@@ -248,11 +250,16 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { ...@@ -248,11 +250,16 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) {
return SUCCESS; return SUCCESS;
} }
void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { MSRStatus ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
// set shard_index when load_dataset is false // set shard_index when load_dataset is false
if (pages_.empty() && shard_count_ <= kMaxShardCount) { if (shard_count_ > kMaxFileCount) {
MS_LOG(ERROR) << "The number of mindrecord files is greater than max value: " << kMaxFileCount;
return FAILED;
}
if (pages_.empty() && shard_count_ <= kMaxFileCount) {
pages_.resize(shard_count_); pages_.resize(shard_count_);
} }
for (auto &page : pages) { for (auto &page : pages) {
int page_id = page["page_id"]; int page_id = page["page_id"];
int shard_id = page["shard_id"]; int shard_id = page["shard_id"];
...@@ -275,6 +282,7 @@ void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_datase ...@@ -275,6 +282,7 @@ void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_datase
pages_[shard_index].push_back(std::move(parsed_page)); pages_[shard_index].push_back(std::move(parsed_page));
} }
} }
return SUCCESS;
} }
MSRStatus ShardHeader::ParseStatistics(const json &statistics) { MSRStatus ShardHeader::ParseStatistics(const json &statistics) {
...@@ -715,7 +723,9 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { ...@@ -715,7 +723,9 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
std::string line; std::string line;
while (std::getline(page_in_handle, line)) { while (std::getline(page_in_handle, line)) {
ParsePage(json::parse(line), -1, true); if (SUCCESS != ParsePage(json::parse(line), -1, true)) {
return FAILED;
}
} }
page_in_handle.close(); page_in_handle.close();
......
...@@ -1054,45 +1054,45 @@ class Dataset: ...@@ -1054,45 +1054,45 @@ class Dataset:
* - type in 'dataset' * - type in 'dataset'
- type in 'mindrecord' - type in 'mindrecord'
- detail - detail
* - DE_BOOL * - bool
- None - None
- Not support - Not support
* - DE_INT8 * - int8
- int32 - int32
- -
* - DE_UINT8 * - uint8
- bytes(1D uint8) - bytes(1D uint8)
- Drop dimension - Drop dimension
* - DE_INT16 * - int16
- int32 - int32
- -
* - DE_UINT16 * - uint16
- int32 - int32
- -
* - DE_INT32 * - int32
- int32 - int32
- -
* - DE_UINT32 * - uint32
- int64 - int64
- -
* - DE_INT64 * - int64
- int64 - int64
- -
* - DE_UINT64 * - uint64
- None - None
- Not support - Not support
* - DE_FLOAT16 * - float16
- Not support - float32
- -
* - DE_FLOAT32 * - float32
- float32 - float32
- -
* - DE_FLOAT64 * - float64
- float64 - float64
- -
* - DE_STRING * - string
- string - string
- Not support multi-dimensional DE_STRING - Not support multi-dimensional string
Note: Note:
1. To save the samples in order, should set dataset's shuffle false and num_files 1. 1. To save the samples in order, should set dataset's shuffle false and num_files 1.
......
...@@ -278,6 +278,8 @@ def check_minddataset(method): ...@@ -278,6 +278,8 @@ def check_minddataset(method):
dataset_file = param_dict.get('dataset_file') dataset_file = param_dict.get('dataset_file')
if isinstance(dataset_file, list): if isinstance(dataset_file, list):
if len(dataset_file) > 4096:
raise ValueError("length of dataset_file should less than or equal to {}.".format(4096))
for f in dataset_file: for f in dataset_file:
check_file(f) check_file(f)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册