提交 578cdb42 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1454 Allow TaskGroup::join_all() to block for GNN

Merge pull request !1454 from JesseKLee/zirui
......@@ -22,6 +22,7 @@
#include "mindspore/ccsrc/mindrecord/include/shard_error.h"
#include "dataset/engine/gnn/local_edge.h"
#include "dataset/engine/gnn/local_node.h"
#include "dataset/util/task_manager.h"
using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>;
......@@ -80,7 +81,7 @@ Status GraphLoader::InitAndLoad() {
n_feature_maps_.resize(num_workers_);
e_feature_maps_.resize(num_workers_);
default_feature_maps_.resize(num_workers_);
std::vector<std::future<Status>> r_codes(num_workers_);
TaskGroup vg;
shard_reader_ = std::make_unique<ShardReader>();
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS,
......@@ -97,12 +98,11 @@ Status GraphLoader::InitAndLoad() {
// launching worker threads
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
r_codes[wkr_id] = std::async(std::launch::async, &GraphLoader::WorkerEntry, this, wkr_id);
RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id)));
}
// wait for threads to finish and check its return code
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
RETURN_IF_NOT_OK(r_codes[wkr_id].get());
}
vg.join_all(Task::WaitFlag::kBlocking);
RETURN_IF_NOT_OK(vg.GetTaskErrorIfAny());
return Status::OK();
}
......@@ -201,8 +201,11 @@ Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<u
}
Status GraphLoader::WorkerEntry(int32_t worker_id) {
// Handshake
TaskManager::FindMe()->Post();
ShardTuple rows = shard_reader_->GetNextById(row_id_++, worker_id);
while (rows.empty() == false) {
RETURN_IF_INTERRUPTED();
for (const auto &tupled_row : rows) {
std::vector<uint8_t> col_blob = std::get<0>(tupled_row);
mindrecord::json col_jsn = std::get<1>(tupled_row);
......
......@@ -108,20 +108,27 @@ Status Task::Run() {
return rc;
}
Status Task::Join() {
Status Task::Join(WaitFlag blocking) {
if (running_) {
RETURN_UNEXPECTED_IF_NULL(MyTaskGroup());
auto interrupt_svc = MyTaskGroup()->GetIntrpService();
try {
// There is a race condition in the global resource tracking such that a thread can miss the
// interrupt and becomes blocked on a conditional variable forever. As a result, calling
// join() will not come back. We need some timeout version of join such that if the thread
// doesn't come back in a reasonable of time, we will send the interrupt again.
while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) {
// We can't tell which conditional_variable this thread is waiting on. So we may need
// to interrupt everything one more time.
MS_LOG(INFO) << "Some threads not responding. Interrupt again";
interrupt_svc->InterruptAll();
if (blocking == WaitFlag::kBlocking) {
// If we are asked to wait, then wait
thrd_.get();
} else if (blocking == WaitFlag::kNonBlocking) {
// There is a race condition in the global resource tracking such that a thread can miss the
// interrupt and becomes blocked on a conditional variable forever. As a result, calling
// join() will not come back. We need some timeout version of join such that if the thread
// doesn't come back in a reasonable of time, we will send the interrupt again.
while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) {
// We can't tell which conditional_variable this thread is waiting on. So we may need
// to interrupt everything one more time.
MS_LOG(INFO) << "Some threads not responding. Interrupt again";
interrupt_svc->InterruptAll();
}
} else {
RETURN_STATUS_UNEXPECTED("Unknown WaitFlag");
}
std::stringstream ss;
ss << get_id();
......
......@@ -42,9 +42,10 @@ class TaskManager;
class Task : public IntrpResource {
public:
friend class TaskManager;
friend class TaskGroup;
enum class WaitFlag : int { kBlocking, kNonBlocking };
Task(const std::string &myName, const std::function<Status()> &f);
// Future objects are not copyable.
......@@ -74,7 +75,7 @@ class Task : public IntrpResource {
// Run the task
Status Run();
Status Join();
Status Join(WaitFlag wf = WaitFlag::kBlocking);
bool Running() const { return running_; }
......
......@@ -278,12 +278,12 @@ Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::functio
void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); }
Status TaskGroup::join_all() {
Status TaskGroup::join_all(Task::WaitFlag wf) {
Status rc;
Status rc2;
SharedLock lck(&rw_lock_);
for (Task &tk : grp_list_) {
rc = tk.Join();
rc = tk.Join(wf);
if (rc.IsError()) {
rc2 = rc;
}
......@@ -294,7 +294,7 @@ Status TaskGroup::join_all() {
Status TaskGroup::DoServiceStop() {
intrp_svc_->ServiceStop();
interrupt_all();
return (join_all());
return (join_all(Task::WaitFlag::kNonBlocking));
}
TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) {
......
......@@ -122,7 +122,7 @@ class TaskGroup : public Service {
void interrupt_all() noexcept;
Status join_all();
Status join_all(Task::WaitFlag wf = Task::WaitFlag::kBlocking);
int size() const noexcept { return grp_list_.count; }
......
......@@ -48,7 +48,7 @@ TEST_F(MindDataTestIntrpService, Test1) {
return rc;
});
vg_.GetIntrpService()->InterruptAll();
vg_.join_all();
vg_.join_all(Task::WaitFlag::kNonBlocking);
}
TEST_F(MindDataTestIntrpService, Test2) {
......@@ -64,5 +64,5 @@ TEST_F(MindDataTestIntrpService, Test2) {
return rc;
});
vg_.GetIntrpService()->InterruptAll();
vg_.join_all();
}
vg_.join_all(Task::WaitFlag::kNonBlocking);
}
\ No newline at end of file
......@@ -80,5 +80,5 @@ TEST_F(MindDataTestTaskManager, Test2) {
vg.interrupt_all();
EXPECT_TRUE(rc.IsOk());
// Now we test the async Join
ASSERT_TRUE(vg.join_all().IsOk());
ASSERT_TRUE(vg.join_all(Task::WaitFlag::kNonBlocking).IsOk());
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册