提交 63479f8e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1574 fix tfreadop hang

Merge pull request !1574 from yanghaitao/yht_tfreadop_equal_rows_hang_r0.3
...@@ -433,11 +433,13 @@ Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) { ...@@ -433,11 +433,13 @@ Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) {
int64_t start_offset = 0; int64_t start_offset = 0;
int64_t end_offset = 0; int64_t end_offset = 0;
bool finish = false; bool finish = false;
bool end_of_epoch = false;
while (!finish) { while (!finish) {
for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { for (auto it = i_keys.begin(); it != i_keys.end(); ++it) {
{ {
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
if (load_io_block_queue_ == false) { if (load_io_block_queue_ == false) {
end_of_epoch = true;
break; break;
} }
} }
...@@ -461,7 +463,8 @@ Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) { ...@@ -461,7 +463,8 @@ Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) {
pre_count += filename_numrows_[file_name]; pre_count += filename_numrows_[file_name];
} }
} }
if (equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) { if (equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_ &&
!end_of_epoch) {
finish = false; finish = false;
} else { } else {
finish = true; finish = true;
...@@ -478,12 +481,14 @@ Status TFReaderOp::FillIOBlockNoShuffle() { ...@@ -478,12 +481,14 @@ Status TFReaderOp::FillIOBlockNoShuffle() {
int64_t start_offset = 0; int64_t start_offset = 0;
int64_t end_offset = 0; int64_t end_offset = 0;
bool finish = false; bool finish = false;
bool end_of_epoch = false;
while (!finish) { while (!finish) {
// Iterate over all the keys and add one key to each block. // Iterate over all the keys and add one key to each block.
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
{ {
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
if (load_io_block_queue_ == false) { if (load_io_block_queue_ == false) {
end_of_epoch = true;
break; break;
} }
} }
...@@ -505,7 +510,8 @@ Status TFReaderOp::FillIOBlockNoShuffle() { ...@@ -505,7 +510,8 @@ Status TFReaderOp::FillIOBlockNoShuffle() {
pre_count += filename_numrows_[file_name]; pre_count += filename_numrows_[file_name];
} }
} }
if (equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) { if (equal_rows_per_shard_ && pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_ &&
!end_of_epoch) {
finish = false; finish = false;
} else { } else {
finish = true; finish = true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册