提交 ed70de80 编写于 作者: L liyong

fix coredump when number of file list more than 1000.

上级 e4a7ca7f
......@@ -133,6 +133,7 @@ void BindGlobalParams(py::module *m) {
(*m).attr("MAX_PAGE_SIZE") = kMaxPageSize;
(*m).attr("MIN_SHARD_COUNT") = kMinShardCount;
(*m).attr("MAX_SHARD_COUNT") = kMaxShardCount;
(*m).attr("MAX_FILE_COUNT") = kMaxFileCount;
(*m).attr("MIN_CONSUMER_COUNT") = kMinConsumerCount;
(void)(*m).def("get_max_thread_num", &GetMaxThreadNum);
}
......
......@@ -104,7 +104,8 @@ const uint64_t kInt64Len = 8;
const uint64_t kMinFileSize = kInt64Len;
const int kMinShardCount = 1;
const int kMaxShardCount = 1000;
const int kMaxShardCount = 1000; // write
const int kMaxFileCount = 4096; // read
const int kMinConsumerCount = 1;
const int kMaxConsumerCount = 128;
......
......@@ -152,7 +152,7 @@ class ShardHeader {
MSRStatus CheckIndexField(const std::string &field, const json &schema);
void ParsePage(const json &page, int shard_index, bool load_dataset);
MSRStatus ParsePage(const json &page, int shard_index, bool load_dataset);
MSRStatus ParseStatistics(const json &statistics);
......
......@@ -252,7 +252,7 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar
if (shard_count <= 0) {
return row_group_summary;
}
if (shard_count <= kMaxShardCount) {
if (shard_count <= kMaxFileCount) {
for (int shard_id = 0; shard_id < shard_count; ++shard_id) {
// return -1 when page's size equals to 0.
auto last_page_id = shard_header_->GetLastPageId(shard_id);
......@@ -1054,7 +1054,7 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, i
}
auto offsets = std::get<1>(ret);
auto local_columns = std::get<2>(ret);
if (shard_count_ <= kMaxShardCount) {
if (shard_count_ <= kMaxFileCount) {
for (int shard_id = 0; shard_id < shard_count_; shard_id++) {
for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) {
tasks_.InsertTask(TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1],
......
......@@ -55,7 +55,9 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool l
header_size_ = header["header_size"].get<uint64_t>();
page_size_ = header["page_size"].get<uint64_t>();
}
ParsePage(header["page"], shard_index, load_dataset);
if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) {
return FAILED;
}
shard_index++;
}
return SUCCESS;
......@@ -248,11 +250,16 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) {
return SUCCESS;
}
void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
MSRStatus ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
// set shard_index when load_dataset is false
if (pages_.empty() && shard_count_ <= kMaxShardCount) {
if (shard_count_ > kMaxFileCount) {
MS_LOG(ERROR) << "The number of mindrecord files is greater than max value: " << kMaxFileCount;
return FAILED;
}
if (pages_.empty() && shard_count_ <= kMaxFileCount) {
pages_.resize(shard_count_);
}
for (auto &page : pages) {
int page_id = page["page_id"];
int shard_id = page["shard_id"];
......@@ -275,6 +282,7 @@ void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_datase
pages_[shard_index].push_back(std::move(parsed_page));
}
}
return SUCCESS;
}
MSRStatus ShardHeader::ParseStatistics(const json &statistics) {
......@@ -715,7 +723,9 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
std::string line;
while (std::getline(page_in_handle, line)) {
ParsePage(json::parse(line), -1, true);
if (SUCCESS != ParsePage(json::parse(line), -1, true)) {
return FAILED;
}
}
page_in_handle.close();
......
......@@ -1054,45 +1054,45 @@ class Dataset:
* - type in 'dataset'
- type in 'mindrecord'
- detail
* - DE_BOOL
* - bool
- None
- Not support
* - DE_INT8
* - int8
- int32
-
* - DE_UINT8
* - uint8
- bytes(1D uint8)
- Drop dimension
* - DE_INT16
* - int16
- int32
-
* - DE_UINT16
* - uint16
- int32
-
* - DE_INT32
* - int32
- int32
-
* - DE_UINT32
* - uint32
- int64
-
* - DE_INT64
* - int64
- int64
-
* - DE_UINT64
* - uint64
- None
- Not support
* - DE_FLOAT16
- Not support
* - float16
- float32
-
* - DE_FLOAT32
* - float32
- float32
-
* - DE_FLOAT64
* - float64
- float64
-
* - DE_STRING
* - string
- string
- Not support multi-dimensional DE_STRING
- Not support multi-dimensional string
Note:
1. To save the samples in order, should set dataset's shuffle false and num_files 1.
......
......@@ -278,6 +278,8 @@ def check_minddataset(method):
dataset_file = param_dict.get('dataset_file')
if isinstance(dataset_file, list):
if len(dataset_file) > 4096:
raise ValueError("length of dataset_file should less than or equal to {}.".format(4096))
for f in dataset_file:
check_file(f)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册