diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc index 127769bd68a250f7ae659a72c35f47cf8c68c504..c517fda969eca805dbedf728f2e0253bbb788a90 100644 --- a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc +++ b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc @@ -22,6 +22,7 @@ #include "mindspore/ccsrc/mindrecord/include/shard_error.h" #include "dataset/engine/gnn/local_edge.h" #include "dataset/engine/gnn/local_node.h" +#include "dataset/util/task_manager.h" using ShardTuple = std::vector, mindspore::mindrecord::json>>; @@ -80,7 +81,7 @@ Status GraphLoader::InitAndLoad() { n_feature_maps_.resize(num_workers_); e_feature_maps_.resize(num_workers_); default_feature_maps_.resize(num_workers_); - std::vector> r_codes(num_workers_); + TaskGroup vg; shard_reader_ = std::make_unique(); CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS, @@ -97,12 +98,11 @@ Status GraphLoader::InitAndLoad() { // launching worker threads for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { - r_codes[wkr_id] = std::async(std::launch::async, &GraphLoader::WorkerEntry, this, wkr_id); + RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id))); } // wait for threads to finish and check its return code - for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { - RETURN_IF_NOT_OK(r_codes[wkr_id].get()); - } + vg.join_all(Task::WaitFlag::kBlocking); + RETURN_IF_NOT_OK(vg.GetTaskErrorIfAny()); return Status::OK(); } @@ -201,8 +201,11 @@ Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vectorPost(); ShardTuple rows = shard_reader_->GetNextById(row_id_++, worker_id); while (rows.empty() == false) { + RETURN_IF_INTERRUPTED(); for (const auto &tupled_row : rows) { std::vector col_blob = std::get<0>(tupled_row); mindrecord::json col_jsn = std::get<1>(tupled_row); diff --git a/mindspore/ccsrc/dataset/util/task.cc b/mindspore/ccsrc/dataset/util/task.cc index b6be2fa9ee87519ea4332a49b69b5d18454dca65..ddedbd9221889adaa684087342e1f6a7c991f914 100644 --- a/mindspore/ccsrc/dataset/util/task.cc +++ b/mindspore/ccsrc/dataset/util/task.cc @@ -108,20 +108,27 @@ Status Task::Run() { return rc; } -Status Task::Join() { +Status Task::Join(WaitFlag blocking) { if (running_) { RETURN_UNEXPECTED_IF_NULL(MyTaskGroup()); auto interrupt_svc = MyTaskGroup()->GetIntrpService(); try { - // There is a race condition in the global resource tracking such that a thread can miss the - // interrupt and becomes blocked on a conditional variable forever. As a result, calling - // join() will not come back. We need some timeout version of join such that if the thread - // doesn't come back in a reasonable of time, we will send the interrupt again. - while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) { - // We can't tell which conditional_variable this thread is waiting on. So we may need - // to interrupt everything one more time. - MS_LOG(INFO) << "Some threads not responding. Interrupt again"; - interrupt_svc->InterruptAll(); + if (blocking == WaitFlag::kBlocking) { + // If we are asked to wait, then wait + thrd_.get(); + } else if (blocking == WaitFlag::kNonBlocking) { + // There is a race condition in the global resource tracking such that a thread can miss the + // interrupt and becomes blocked on a conditional variable forever. As a result, calling + // join() will not come back. We need some timeout version of join such that if the thread + // doesn't come back in a reasonable of time, we will send the interrupt again. + while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) { + // We can't tell which conditional_variable this thread is waiting on. So we may need + // to interrupt everything one more time. + MS_LOG(INFO) << "Some threads not responding. Interrupt again"; + interrupt_svc->InterruptAll(); + } + } else { + RETURN_STATUS_UNEXPECTED("Unknown WaitFlag"); } std::stringstream ss; ss << get_id(); diff --git a/mindspore/ccsrc/dataset/util/task.h b/mindspore/ccsrc/dataset/util/task.h index 1d544d933d2a49d8fbd1c9a5be9a1c306c324973..b120ed3d7c649c2e983a4bc1f6e08154edaf8dcb 100644 --- a/mindspore/ccsrc/dataset/util/task.h +++ b/mindspore/ccsrc/dataset/util/task.h @@ -42,9 +42,10 @@ class TaskManager; class Task : public IntrpResource { public: friend class TaskManager; - friend class TaskGroup; + enum class WaitFlag : int { kBlocking, kNonBlocking }; + Task(const std::string &myName, const std::function &f); // Future objects are not copyable. @@ -74,7 +75,7 @@ class Task : public IntrpResource { // Run the task Status Run(); - Status Join(); + Status Join(WaitFlag wf = WaitFlag::kBlocking); bool Running() const { return running_; } diff --git a/mindspore/ccsrc/dataset/util/task_manager.cc b/mindspore/ccsrc/dataset/util/task_manager.cc index 36f239a8409e40ac722414865a332b8ca33483fe..ff573f4f44c012585c5a065f4c394fb296c8dbd1 100644 --- a/mindspore/ccsrc/dataset/util/task_manager.cc +++ b/mindspore/ccsrc/dataset/util/task_manager.cc @@ -278,12 +278,12 @@ Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::functio void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); } -Status TaskGroup::join_all() { +Status TaskGroup::join_all(Task::WaitFlag wf) { Status rc; Status rc2; SharedLock lck(&rw_lock_); for (Task &tk : grp_list_) { - rc = tk.Join(); + rc = tk.Join(wf); if (rc.IsError()) { rc2 = rc; } @@ -294,7 +294,7 @@ Status TaskGroup::join_all() { Status TaskGroup::DoServiceStop() { intrp_svc_->ServiceStop(); interrupt_all(); - return (join_all()); + return (join_all(Task::WaitFlag::kNonBlocking)); } TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) { diff --git a/mindspore/ccsrc/dataset/util/task_manager.h b/mindspore/ccsrc/dataset/util/task_manager.h index d49c0fc651e70d1ec481e38c8bffe42730f139d8..5961c9000e848a88dfc136c44a86bc7d8e0ad327 100644 --- a/mindspore/ccsrc/dataset/util/task_manager.h +++ b/mindspore/ccsrc/dataset/util/task_manager.h @@ -122,7 +122,7 @@ class TaskGroup : public Service { void interrupt_all() noexcept; - Status join_all(); + Status join_all(Task::WaitFlag wf = Task::WaitFlag::kBlocking); int size() const noexcept { return grp_list_.count; } diff --git a/tests/ut/cpp/dataset/interrupt_test.cc b/tests/ut/cpp/dataset/interrupt_test.cc index ee2018a050b2a028520ecd8e4156b7e8b4659c2f..bde9351ca6ac35da26212786f099318bd78d7961 100644 --- a/tests/ut/cpp/dataset/interrupt_test.cc +++ b/tests/ut/cpp/dataset/interrupt_test.cc @@ -48,7 +48,7 @@ TEST_F(MindDataTestIntrpService, Test1) { return rc; }); vg_.GetIntrpService()->InterruptAll(); - vg_.join_all(); + vg_.join_all(Task::WaitFlag::kNonBlocking); } TEST_F(MindDataTestIntrpService, Test2) { @@ -64,5 +64,5 @@ TEST_F(MindDataTestIntrpService, Test2) { return rc; }); vg_.GetIntrpService()->InterruptAll(); - vg_.join_all(); -} + vg_.join_all(Task::WaitFlag::kNonBlocking); +} \ No newline at end of file diff --git a/tests/ut/cpp/dataset/task_manager_test.cc b/tests/ut/cpp/dataset/task_manager_test.cc index 79aa85bb6dfbba19b727fc4e68f515743d2e92d2..a28b10a1fe075610338358fa69e983b46c2e02db 100644 --- a/tests/ut/cpp/dataset/task_manager_test.cc +++ b/tests/ut/cpp/dataset/task_manager_test.cc @@ -80,5 +80,5 @@ TEST_F(MindDataTestTaskManager, Test2) { vg.interrupt_all(); EXPECT_TRUE(rc.IsOk()); // Now we test the async Join - ASSERT_TRUE(vg.join_all().IsOk()); + ASSERT_TRUE(vg.join_all(Task::WaitFlag::kNonBlocking).IsOk()); }