diff --git a/mindspore/ccsrc/dataset/util/cond_var.cc b/mindspore/ccsrc/dataset/util/cond_var.cc index ada401f2161ea2c452bf728cbfa4a00eed727bad..8b1099fb717d130934cb7e369a433ee4d490912b 100644 --- a/mindspore/ccsrc/dataset/util/cond_var.cc +++ b/mindspore/ccsrc/dataset/util/cond_var.cc @@ -14,35 +14,34 @@ * limitations under the License. */ #include "dataset/util/cond_var.h" +#include #include #include "dataset/util/services.h" #include "dataset/util/task_manager.h" namespace mindspore { namespace dataset { -CondVar::CondVar() : svc_(nullptr), my_name_(std::move(Services::GetUniqueID())) {} +CondVar::CondVar() : svc_(nullptr), my_name_(Services::GetUniqueID()) {} Status CondVar::Wait(std::unique_lock *lck, const std::function &pred) { - // Append an additional condition on top of the given predicate. - // We will also bail out if this cv got interrupted. - auto f = [this, &pred]() -> bool { return (pred() || (CurState() == State::kInterrupted)); }; - // If we have interrupt service, just wait on the cv unconditionally. - // Otherwise fall back to the old way of checking interrupt. - if (svc_) { - cv_.wait(*lck, f); - if (CurState() == State::kInterrupted) { - Task *my_task = TaskManager::FindMe(); - if (my_task->IsMasterThread() && my_task->CaughtSevereException()) { - return TaskManager::GetMasterThreadRc(); - } else { - return Status(StatusCode::kInterrupted); + try { + if (svc_ != nullptr) { + // If this cv registers with a global resource tracking, then wait unconditionally. + auto f = [this, &pred]() -> bool { return (pred() || this->Interrupted()); }; + cv_.wait(*lck, f); + // If we are interrupted, override the return value if this is the master thread. + // Master thread is being interrupted mostly because of some thread is reporting error. + RETURN_IF_NOT_OK(Task::OverrideInterruptRc(this->GetInterruptStatus())); + } else { + // Otherwise we wake up once a while to check for interrupt (for this thread). + auto f = [&pred]() -> bool { return (pred() || this_thread::is_interrupted()); }; + while (!f()) { + (void)cv_.wait_for(*lck, std::chrono::milliseconds(1)); } + RETURN_IF_INTERRUPTED(); } - } else { - RETURN_IF_NOT_OK(interruptible_wait(&cv_, lck, pred)); - if (CurState() == State::kInterrupted) { - return Status(StatusCode::kInterrupted); - } + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); } return Status::OK(); } @@ -66,10 +65,9 @@ Status CondVar::Register(std::shared_ptr svc) { return rc; } -Status CondVar::Interrupt() { - RETURN_IF_NOT_OK(IntrpResource::Interrupt()); +void CondVar::Interrupt() { + IntrpResource::Interrupt(); cv_.notify_all(); - return Status::OK(); } std::string CondVar::my_name() const { return my_name_; } diff --git a/mindspore/ccsrc/dataset/util/cond_var.h b/mindspore/ccsrc/dataset/util/cond_var.h index fc29b3315da1b24ca073dbabac9e9ef126963819..b23dcd566ef0861032da69d5c6400487a1e39b41 100644 --- a/mindspore/ccsrc/dataset/util/cond_var.h +++ b/mindspore/ccsrc/dataset/util/cond_var.h @@ -35,7 +35,7 @@ class CondVar : public IntrpResource { Status Wait(std::unique_lock *lck, const std::function &pred); - Status Interrupt() override; + void Interrupt() override; void NotifyOne() noexcept; diff --git a/mindspore/ccsrc/dataset/util/intrp_resource.h b/mindspore/ccsrc/dataset/util/intrp_resource.h index 462a0bd7ef6dc76c09c6076fd65da0412808df85..52024cb90a138b1429c62baeeb8a3797bb50a0ff 100644 --- a/mindspore/ccsrc/dataset/util/intrp_resource.h +++ b/mindspore/ccsrc/dataset/util/intrp_resource.h @@ -29,10 +29,7 @@ class IntrpResource { virtual ~IntrpResource() = default; - virtual Status Interrupt() { - st_ = State::kInterrupted; - return Status::OK(); - } + virtual void Interrupt() { st_ = State::kInterrupted; } virtual void ResetIntrpState() { st_ = State::kRunning; } @@ -40,6 +37,13 @@ class IntrpResource { bool Interrupted() const { return CurState() == State::kInterrupted; } + virtual Status GetInterruptStatus() const { + if (Interrupted()) { + return Status(StatusCode::kInterrupted); + } + return Status::OK(); + } + protected: std::atomic st_; }; diff --git a/mindspore/ccsrc/dataset/util/intrp_service.cc b/mindspore/ccsrc/dataset/util/intrp_service.cc index 6dcafa8e70aaee65e72a84363dca5b5d57316cfc..da8dde992c7ca74156f699dd8e626bbb4e6174aa 100644 --- a/mindspore/ccsrc/dataset/util/intrp_service.cc +++ b/mindspore/ccsrc/dataset/util/intrp_service.cc @@ -27,7 +27,7 @@ IntrpService::~IntrpService() noexcept { MS_LOG(INFO) << "Number of registered resources is " << high_water_mark_ << "."; if (!all_intrp_resources_.empty()) { try { - (void)InterruptAll(); + InterruptAll(); } catch (const std::exception &e) { // Ignore all error as we can't throw in the destructor. } @@ -64,11 +64,9 @@ Status IntrpService::Deregister(const std::string &name) noexcept { std::ostringstream ss; ss << this_thread::get_id(); MS_LOG(DEBUG) << "De-register resource with name " << name << ". Thread ID is " << ss.str() << "."; - auto it = all_intrp_resources_.find(name); - if (it != all_intrp_resources_.end()) { - (void)all_intrp_resources_.erase(it); - } else { - MS_LOG(DEBUG) << "Key " << name << " not found."; + auto n = all_intrp_resources_.erase(name); + if (n == 0) { + MS_LOG(INFO) << "Key " << name << " not found."; } } catch (std::exception &e) { RETURN_STATUS_UNEXPECTED(e.what()); @@ -76,21 +74,16 @@ Status IntrpService::Deregister(const std::string &name) noexcept { return Status::OK(); } -Status IntrpService::InterruptAll() noexcept { +void IntrpService::InterruptAll() noexcept { std::lock_guard lck(mutex_); - Status rc; for (auto const &it : all_intrp_resources_) { std::string kName = it.first; try { - Status rc2 = it.second->Interrupt(); - if (rc2.IsError()) { - rc = rc2; - } + it.second->Interrupt(); } catch (const std::exception &e) { // continue the clean up. } } - return rc; } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/intrp_service.h b/mindspore/ccsrc/dataset/util/intrp_service.h index 47a7d8d9192408a0a81753b012dec153034d31be..de1d5eb753a70df71f72ec663089934191cdcda3 100644 --- a/mindspore/ccsrc/dataset/util/intrp_service.h +++ b/mindspore/ccsrc/dataset/util/intrp_service.h @@ -47,7 +47,7 @@ class IntrpService : public Service { Status Deregister(const std::string &name) noexcept; - Status InterruptAll() noexcept; + void InterruptAll() noexcept; Status DoServiceStart() override { return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/util/queue.h b/mindspore/ccsrc/dataset/util/queue.h index f0b087cf6d5e3dc63c0c918f6be690208b28cbe1..b97e6a5c28e27c49a08da50905e20ec2816ba6b1 100644 --- a/mindspore/ccsrc/dataset/util/queue.h +++ b/mindspore/ccsrc/dataset/util/queue.h @@ -110,7 +110,7 @@ class Queue { empty_cv_.NotifyAll(); _lock.unlock(); } else { - (void)empty_cv_.Interrupt(); + empty_cv_.Interrupt(); } return rc; } @@ -125,7 +125,7 @@ class Queue { empty_cv_.NotifyAll(); _lock.unlock(); } else { - (void)empty_cv_.Interrupt(); + empty_cv_.Interrupt(); } return rc; } @@ -141,7 +141,7 @@ class Queue { empty_cv_.NotifyAll(); _lock.unlock(); } else { - (void)empty_cv_.Interrupt(); + empty_cv_.Interrupt(); } return rc; } @@ -160,7 +160,7 @@ class Queue { full_cv_.NotifyAll(); _lock.unlock(); } else { - (void)full_cv_.Interrupt(); + full_cv_.Interrupt(); } return rc; } diff --git a/mindspore/ccsrc/dataset/util/services.h b/mindspore/ccsrc/dataset/util/services.h index 5e81c4816edc206e1d154195dcb0b93144ec316c..e19f44dccc47d792455ee9b10acc3b132304b05a 100644 --- a/mindspore/ccsrc/dataset/util/services.h +++ b/mindspore/ccsrc/dataset/util/services.h @@ -20,6 +20,7 @@ #include #include #include "dataset/util/memory_pool.h" +#include "dataset/util/allocator.h" #include "dataset/util/service.h" #define UNIQUEID_LEN 36 @@ -72,6 +73,11 @@ class Services { static std::string GetUniqueID(); + template + static Allocator GetAllocator() { + return Allocator(Services::GetInstance().GetServiceMemPool()); + } + private: static std::once_flag init_instance_flag_; static std::unique_ptr instance_; diff --git a/mindspore/ccsrc/dataset/util/task.cc b/mindspore/ccsrc/dataset/util/task.cc index 0d02ad8317ba2ca376fd57cfa883e5bc3d812a8b..b6be2fa9ee87519ea4332a49b69b5d18454dca65 100644 --- a/mindspore/ccsrc/dataset/util/task.cc +++ b/mindspore/ccsrc/dataset/util/task.cc @@ -72,7 +72,7 @@ void Task::ShutdownGroup() { // Wake up watch dog and shutdown the engine. } } -Status Task::GetTaskErrorIfAny() { +Status Task::GetTaskErrorIfAny() const { std::lock_guard lk(mux_); if (caught_severe_exception_) { return rc_; @@ -141,5 +141,13 @@ TaskGroup *Task::MyTaskGroup() { return task_group_; } void Task::set_task_group(TaskGroup *vg) { task_group_ = vg; } Task::~Task() { task_group_ = nullptr; } +Status Task::OverrideInterruptRc(const Status &rc) { + if (rc.IsInterrupted() && this_thread::is_master_thread()) { + // If we are interrupted, override the return value if this is the master thread. + // Master thread is being interrupted mostly because of some thread is reporting error. + return TaskManager::GetMasterThreadRc(); + } + return rc; +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/task.h b/mindspore/ccsrc/dataset/util/task.h index e9a3e44c5bea928419f9830011756164297ad070..1d544d933d2a49d8fbd1c9a5be9a1c306c324973 100644 --- a/mindspore/ccsrc/dataset/util/task.h +++ b/mindspore/ccsrc/dataset/util/task.h @@ -60,7 +60,7 @@ class Task : public IntrpResource { Task &operator=(Task &&) = delete; - Status GetTaskErrorIfAny(); + Status GetTaskErrorIfAny() const; void ChangeName(const std::string &newName) { my_name_ = newName; } @@ -95,10 +95,10 @@ class Task : public IntrpResource { Status Wait() { return (wp_.Wait()); } - void set_task_group(TaskGroup *vg); + static Status OverrideInterruptRc(const Status &rc); private: - std::mutex mux_; + mutable std::mutex mux_; std::string my_name_; Status rc_; WaitPost wp_; @@ -115,6 +115,7 @@ class Task : public IntrpResource { void ShutdownGroup(); TaskGroup *MyTaskGroup(); + void set_task_group(TaskGroup *vg); }; extern thread_local Task *gMyTask; diff --git a/mindspore/ccsrc/dataset/util/task_manager.cc b/mindspore/ccsrc/dataset/util/task_manager.cc index 31b0eedd7040ffaf3cc18c0d057d868b15e9f98b..36f239a8409e40ac722414865a332b8ca33483fe 100644 --- a/mindspore/ccsrc/dataset/util/task_manager.cc +++ b/mindspore/ccsrc/dataset/util/task_manager.cc @@ -84,7 +84,7 @@ void TaskManager::interrupt_all() noexcept { svc->InterruptAll(); } } - (void)master_->Interrupt(); + master_->Interrupt(); } Task *TaskManager::FindMe() { return gMyTask; } @@ -94,8 +94,7 @@ TaskManager::TaskManager() try : global_interrupt_(0), free_lst_(&Task::free), watchdog_grp_(nullptr), watchdog_(nullptr) { - std::shared_ptr mp = Services::GetInstance().GetServiceMemPool(); - Allocator alloc(mp); + auto alloc = Services::GetAllocator(); // Create a dummy Task for the master thread (this thread) master_ = std::allocate_shared(alloc, "master", []() -> Status { return Status::OK(); }); master_->id_ = this_thread::get_id(); @@ -185,7 +184,7 @@ void TaskManager::InterruptMaster(const Status &rc) { TaskManager &tm = TaskManager::GetInstance(); std::shared_ptr master = tm.master_; std::lock_guard lck(master->mux_); - (void)master->Interrupt(); + master->Interrupt(); if (rc.IsError() && master->rc_.IsOk()) { master->rc_ = rc; master->caught_severe_exception_ = true; @@ -277,7 +276,7 @@ Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::functio return Status::OK(); } -void TaskGroup::interrupt_all() noexcept { (void)intrp_svc_->InterruptAll(); } +void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); } Status TaskGroup::join_all() { Status rc; @@ -299,8 +298,7 @@ Status TaskGroup::DoServiceStop() { } TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) { - std::shared_ptr mp = Services::GetInstance().GetServiceMemPool(); - Allocator alloc(mp); + auto alloc = Services::GetAllocator(); intrp_svc_ = std::allocate_shared(alloc); (void)Service::ServiceStart(); } diff --git a/mindspore/ccsrc/dataset/util/task_manager.h b/mindspore/ccsrc/dataset/util/task_manager.h index 5f5c1eb806925a80a036c8ee8757dc47733be8aa..d49c0fc651e70d1ec481e38c8bffe42730f139d8 100644 --- a/mindspore/ccsrc/dataset/util/task_manager.h +++ b/mindspore/ccsrc/dataset/util/task_manager.h @@ -154,37 +154,27 @@ inline bool is_interrupted() { return true; } Task *my_task = TaskManager::FindMe(); - return (my_task != nullptr) ? my_task->Interrupted() : false; + return my_task->Interrupted(); +} + +inline bool is_master_thread() { + Task *my_task = TaskManager::FindMe(); + return my_task->IsMasterThread(); +} + +inline Status GetInterruptStatus() { + Task *my_task = TaskManager::FindMe(); + return my_task->GetInterruptStatus(); } } // namespace this_thread -#define RETURN_IF_INTERRUPTED() \ - do { \ - if (mindspore::dataset::this_thread::is_interrupted()) { \ - Task *myTask = TaskManager::FindMe(); \ - if (myTask->IsMasterThread() && myTask->CaughtSevereException()) { \ - return TaskManager::GetMasterThreadRc(); \ - } else { \ - return Status(StatusCode::kInterrupted); \ - } \ - } \ +#define RETURN_IF_INTERRUPTED() \ + do { \ + if (mindspore::dataset::this_thread::is_interrupted()) { \ + return Task::OverrideInterruptRc(this_thread::GetInterruptStatus()); \ + } \ } while (false) -inline Status interruptible_wait(std::condition_variable *cv, std::unique_lock *lk, - const std::function &pred) noexcept { - if (!pred()) { - do { - RETURN_IF_INTERRUPTED(); - try { - (void)cv->wait_for(*lk, std::chrono::milliseconds(1)); - } catch (std::exception &e) { - // Anything thrown by wait_for is considered system error. - RETURN_STATUS_UNEXPECTED(e.what()); - } - } while (!pred()); - } - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/tests/ut/cpp/dataset/connector_test.cc b/tests/ut/cpp/dataset/connector_test.cc index f4479f52eb2253392755b477281dc4c8f0d96da6..eb33cc3e8d453b972af7d2bde82807a24d8b1daf 100644 --- a/tests/ut/cpp/dataset/connector_test.cc +++ b/tests/ut/cpp/dataset/connector_test.cc @@ -139,6 +139,9 @@ Status MindDataTestConnector::Run_test_0() { 10); // capacity of each queue DS_ASSERT(my_conn != nullptr); + rc = my_conn->Register(tg_.get()); + RETURN_IF_NOT_OK(rc); + // Spawn a thread to read input_ vector and put it in my_conn rc = tg_->CreateAsyncTask("Worker Push", std::bind(&MindDataTestConnector::FirstWorkerPush, @@ -184,6 +187,11 @@ Status MindDataTestConnector::Run_test_1() { l3_threads, conn2_qcap); + rc = conn1->Register(tg_.get()); + RETURN_IF_NOT_OK(rc); + rc = conn2->Register(tg_.get()); + RETURN_IF_NOT_OK(rc); + // Instantiating the threads in the first layer for (int i = 0; i < l1_threads; i++) { rc = tg_->CreateAsyncTask("First Worker Push",