提交 578cdb42 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1454 Allow TaskGroup::join_all() to block for GNN

Merge pull request !1454 from JesseKLee/zirui
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "mindspore/ccsrc/mindrecord/include/shard_error.h" #include "mindspore/ccsrc/mindrecord/include/shard_error.h"
#include "dataset/engine/gnn/local_edge.h" #include "dataset/engine/gnn/local_edge.h"
#include "dataset/engine/gnn/local_node.h" #include "dataset/engine/gnn/local_node.h"
#include "dataset/util/task_manager.h"
using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>; using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>;
...@@ -80,7 +81,7 @@ Status GraphLoader::InitAndLoad() { ...@@ -80,7 +81,7 @@ Status GraphLoader::InitAndLoad() {
n_feature_maps_.resize(num_workers_); n_feature_maps_.resize(num_workers_);
e_feature_maps_.resize(num_workers_); e_feature_maps_.resize(num_workers_);
default_feature_maps_.resize(num_workers_); default_feature_maps_.resize(num_workers_);
std::vector<std::future<Status>> r_codes(num_workers_); TaskGroup vg;
shard_reader_ = std::make_unique<ShardReader>(); shard_reader_ = std::make_unique<ShardReader>();
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS, CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS,
...@@ -97,12 +98,11 @@ Status GraphLoader::InitAndLoad() { ...@@ -97,12 +98,11 @@ Status GraphLoader::InitAndLoad() {
// launching worker threads // launching worker threads
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
r_codes[wkr_id] = std::async(std::launch::async, &GraphLoader::WorkerEntry, this, wkr_id); RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id)));
} }
// wait for threads to finish and check its return code // wait for threads to finish and check its return code
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { vg.join_all(Task::WaitFlag::kBlocking);
RETURN_IF_NOT_OK(r_codes[wkr_id].get()); RETURN_IF_NOT_OK(vg.GetTaskErrorIfAny());
}
return Status::OK(); return Status::OK();
} }
...@@ -201,8 +201,11 @@ Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<u ...@@ -201,8 +201,11 @@ Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<u
} }
Status GraphLoader::WorkerEntry(int32_t worker_id) { Status GraphLoader::WorkerEntry(int32_t worker_id) {
// Handshake
TaskManager::FindMe()->Post();
ShardTuple rows = shard_reader_->GetNextById(row_id_++, worker_id); ShardTuple rows = shard_reader_->GetNextById(row_id_++, worker_id);
while (rows.empty() == false) { while (rows.empty() == false) {
RETURN_IF_INTERRUPTED();
for (const auto &tupled_row : rows) { for (const auto &tupled_row : rows) {
std::vector<uint8_t> col_blob = std::get<0>(tupled_row); std::vector<uint8_t> col_blob = std::get<0>(tupled_row);
mindrecord::json col_jsn = std::get<1>(tupled_row); mindrecord::json col_jsn = std::get<1>(tupled_row);
......
...@@ -108,11 +108,15 @@ Status Task::Run() { ...@@ -108,11 +108,15 @@ Status Task::Run() {
return rc; return rc;
} }
Status Task::Join() { Status Task::Join(WaitFlag blocking) {
if (running_) { if (running_) {
RETURN_UNEXPECTED_IF_NULL(MyTaskGroup()); RETURN_UNEXPECTED_IF_NULL(MyTaskGroup());
auto interrupt_svc = MyTaskGroup()->GetIntrpService(); auto interrupt_svc = MyTaskGroup()->GetIntrpService();
try { try {
if (blocking == WaitFlag::kBlocking) {
// If we are asked to wait, then wait
thrd_.get();
} else if (blocking == WaitFlag::kNonBlocking) {
// There is a race condition in the global resource tracking such that a thread can miss the // There is a race condition in the global resource tracking such that a thread can miss the
// interrupt and becomes blocked on a conditional variable forever. As a result, calling // interrupt and becomes blocked on a conditional variable forever. As a result, calling
// join() will not come back. We need some timeout version of join such that if the thread // join() will not come back. We need some timeout version of join such that if the thread
...@@ -123,6 +127,9 @@ Status Task::Join() { ...@@ -123,6 +127,9 @@ Status Task::Join() {
MS_LOG(INFO) << "Some threads not responding. Interrupt again"; MS_LOG(INFO) << "Some threads not responding. Interrupt again";
interrupt_svc->InterruptAll(); interrupt_svc->InterruptAll();
} }
} else {
RETURN_STATUS_UNEXPECTED("Unknown WaitFlag");
}
std::stringstream ss; std::stringstream ss;
ss << get_id(); ss << get_id();
MS_LOG(DEBUG) << MyName() << " Thread ID " << ss.str() << " Stopped."; MS_LOG(DEBUG) << MyName() << " Thread ID " << ss.str() << " Stopped.";
......
...@@ -42,9 +42,10 @@ class TaskManager; ...@@ -42,9 +42,10 @@ class TaskManager;
class Task : public IntrpResource { class Task : public IntrpResource {
public: public:
friend class TaskManager; friend class TaskManager;
friend class TaskGroup; friend class TaskGroup;
enum class WaitFlag : int { kBlocking, kNonBlocking };
Task(const std::string &myName, const std::function<Status()> &f); Task(const std::string &myName, const std::function<Status()> &f);
// Future objects are not copyable. // Future objects are not copyable.
...@@ -74,7 +75,7 @@ class Task : public IntrpResource { ...@@ -74,7 +75,7 @@ class Task : public IntrpResource {
// Run the task // Run the task
Status Run(); Status Run();
Status Join(); Status Join(WaitFlag wf = WaitFlag::kBlocking);
bool Running() const { return running_; } bool Running() const { return running_; }
......
...@@ -278,12 +278,12 @@ Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::functio ...@@ -278,12 +278,12 @@ Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::functio
void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); } void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); }
Status TaskGroup::join_all() { Status TaskGroup::join_all(Task::WaitFlag wf) {
Status rc; Status rc;
Status rc2; Status rc2;
SharedLock lck(&rw_lock_); SharedLock lck(&rw_lock_);
for (Task &tk : grp_list_) { for (Task &tk : grp_list_) {
rc = tk.Join(); rc = tk.Join(wf);
if (rc.IsError()) { if (rc.IsError()) {
rc2 = rc; rc2 = rc;
} }
...@@ -294,7 +294,7 @@ Status TaskGroup::join_all() { ...@@ -294,7 +294,7 @@ Status TaskGroup::join_all() {
Status TaskGroup::DoServiceStop() { Status TaskGroup::DoServiceStop() {
intrp_svc_->ServiceStop(); intrp_svc_->ServiceStop();
interrupt_all(); interrupt_all();
return (join_all()); return (join_all(Task::WaitFlag::kNonBlocking));
} }
TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) { TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) {
......
...@@ -122,7 +122,7 @@ class TaskGroup : public Service { ...@@ -122,7 +122,7 @@ class TaskGroup : public Service {
void interrupt_all() noexcept; void interrupt_all() noexcept;
Status join_all(); Status join_all(Task::WaitFlag wf = Task::WaitFlag::kBlocking);
int size() const noexcept { return grp_list_.count; } int size() const noexcept { return grp_list_.count; }
......
...@@ -48,7 +48,7 @@ TEST_F(MindDataTestIntrpService, Test1) { ...@@ -48,7 +48,7 @@ TEST_F(MindDataTestIntrpService, Test1) {
return rc; return rc;
}); });
vg_.GetIntrpService()->InterruptAll(); vg_.GetIntrpService()->InterruptAll();
vg_.join_all(); vg_.join_all(Task::WaitFlag::kNonBlocking);
} }
TEST_F(MindDataTestIntrpService, Test2) { TEST_F(MindDataTestIntrpService, Test2) {
...@@ -64,5 +64,5 @@ TEST_F(MindDataTestIntrpService, Test2) { ...@@ -64,5 +64,5 @@ TEST_F(MindDataTestIntrpService, Test2) {
return rc; return rc;
}); });
vg_.GetIntrpService()->InterruptAll(); vg_.GetIntrpService()->InterruptAll();
vg_.join_all(); vg_.join_all(Task::WaitFlag::kNonBlocking);
} }
\ No newline at end of file
...@@ -80,5 +80,5 @@ TEST_F(MindDataTestTaskManager, Test2) { ...@@ -80,5 +80,5 @@ TEST_F(MindDataTestTaskManager, Test2) {
vg.interrupt_all(); vg.interrupt_all();
EXPECT_TRUE(rc.IsOk()); EXPECT_TRUE(rc.IsOk());
// Now we test the async Join // Now we test the async Join
ASSERT_TRUE(vg.join_all().IsOk()); ASSERT_TRUE(vg.join_all(Task::WaitFlag::kNonBlocking).IsOk());
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册