提交 e8980ed2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1406 Simplify CondVar class

Merge pull request !1406 from JesseKLee/CondVar
...@@ -14,35 +14,34 @@ ...@@ -14,35 +14,34 @@
* limitations under the License. * limitations under the License.
*/ */
#include "dataset/util/cond_var.h" #include "dataset/util/cond_var.h"
#include <exception>
#include <utility> #include <utility>
#include "dataset/util/services.h" #include "dataset/util/services.h"
#include "dataset/util/task_manager.h" #include "dataset/util/task_manager.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
CondVar::CondVar() : svc_(nullptr), my_name_(std::move(Services::GetUniqueID())) {} CondVar::CondVar() : svc_(nullptr), my_name_(Services::GetUniqueID()) {}
Status CondVar::Wait(std::unique_lock<std::mutex> *lck, const std::function<bool()> &pred) { Status CondVar::Wait(std::unique_lock<std::mutex> *lck, const std::function<bool()> &pred) {
// Append an additional condition on top of the given predicate. try {
// We will also bail out if this cv got interrupted. if (svc_ != nullptr) {
auto f = [this, &pred]() -> bool { return (pred() || (CurState() == State::kInterrupted)); }; // If this cv registers with a global resource tracking, then wait unconditionally.
// If we have interrupt service, just wait on the cv unconditionally. auto f = [this, &pred]() -> bool { return (pred() || this->Interrupted()); };
// Otherwise fall back to the old way of checking interrupt. cv_.wait(*lck, f);
if (svc_) { // If we are interrupted, override the return value if this is the master thread.
cv_.wait(*lck, f); // Master thread is being interrupted mostly because of some thread is reporting error.
if (CurState() == State::kInterrupted) { RETURN_IF_NOT_OK(Task::OverrideInterruptRc(this->GetInterruptStatus()));
Task *my_task = TaskManager::FindMe(); } else {
if (my_task->IsMasterThread() && my_task->CaughtSevereException()) { // Otherwise we wake up once a while to check for interrupt (for this thread).
return TaskManager::GetMasterThreadRc(); auto f = [&pred]() -> bool { return (pred() || this_thread::is_interrupted()); };
} else { while (!f()) {
return Status(StatusCode::kInterrupted); (void)cv_.wait_for(*lck, std::chrono::milliseconds(1));
} }
RETURN_IF_INTERRUPTED();
} }
} else { } catch (const std::exception &e) {
RETURN_IF_NOT_OK(interruptible_wait(&cv_, lck, pred)); RETURN_STATUS_UNEXPECTED(e.what());
if (CurState() == State::kInterrupted) {
return Status(StatusCode::kInterrupted);
}
} }
return Status::OK(); return Status::OK();
} }
...@@ -66,10 +65,9 @@ Status CondVar::Register(std::shared_ptr<IntrpService> svc) { ...@@ -66,10 +65,9 @@ Status CondVar::Register(std::shared_ptr<IntrpService> svc) {
return rc; return rc;
} }
Status CondVar::Interrupt() { void CondVar::Interrupt() {
RETURN_IF_NOT_OK(IntrpResource::Interrupt()); IntrpResource::Interrupt();
cv_.notify_all(); cv_.notify_all();
return Status::OK();
} }
std::string CondVar::my_name() const { return my_name_; } std::string CondVar::my_name() const { return my_name_; }
......
...@@ -35,7 +35,7 @@ class CondVar : public IntrpResource { ...@@ -35,7 +35,7 @@ class CondVar : public IntrpResource {
Status Wait(std::unique_lock<std::mutex> *lck, const std::function<bool()> &pred); Status Wait(std::unique_lock<std::mutex> *lck, const std::function<bool()> &pred);
Status Interrupt() override; void Interrupt() override;
void NotifyOne() noexcept; void NotifyOne() noexcept;
......
...@@ -29,10 +29,7 @@ class IntrpResource { ...@@ -29,10 +29,7 @@ class IntrpResource {
virtual ~IntrpResource() = default; virtual ~IntrpResource() = default;
virtual Status Interrupt() { virtual void Interrupt() { st_ = State::kInterrupted; }
st_ = State::kInterrupted;
return Status::OK();
}
virtual void ResetIntrpState() { st_ = State::kRunning; } virtual void ResetIntrpState() { st_ = State::kRunning; }
...@@ -40,6 +37,13 @@ class IntrpResource { ...@@ -40,6 +37,13 @@ class IntrpResource {
bool Interrupted() const { return CurState() == State::kInterrupted; } bool Interrupted() const { return CurState() == State::kInterrupted; }
virtual Status GetInterruptStatus() const {
if (Interrupted()) {
return Status(StatusCode::kInterrupted);
}
return Status::OK();
}
protected: protected:
std::atomic<State> st_; std::atomic<State> st_;
}; };
......
...@@ -27,7 +27,7 @@ IntrpService::~IntrpService() noexcept { ...@@ -27,7 +27,7 @@ IntrpService::~IntrpService() noexcept {
MS_LOG(INFO) << "Number of registered resources is " << high_water_mark_ << "."; MS_LOG(INFO) << "Number of registered resources is " << high_water_mark_ << ".";
if (!all_intrp_resources_.empty()) { if (!all_intrp_resources_.empty()) {
try { try {
(void)InterruptAll(); InterruptAll();
} catch (const std::exception &e) { } catch (const std::exception &e) {
// Ignore all error as we can't throw in the destructor. // Ignore all error as we can't throw in the destructor.
} }
...@@ -64,11 +64,9 @@ Status IntrpService::Deregister(const std::string &name) noexcept { ...@@ -64,11 +64,9 @@ Status IntrpService::Deregister(const std::string &name) noexcept {
std::ostringstream ss; std::ostringstream ss;
ss << this_thread::get_id(); ss << this_thread::get_id();
MS_LOG(DEBUG) << "De-register resource with name " << name << ". Thread ID is " << ss.str() << "."; MS_LOG(DEBUG) << "De-register resource with name " << name << ". Thread ID is " << ss.str() << ".";
auto it = all_intrp_resources_.find(name); auto n = all_intrp_resources_.erase(name);
if (it != all_intrp_resources_.end()) { if (n == 0) {
(void)all_intrp_resources_.erase(it); MS_LOG(INFO) << "Key " << name << " not found.";
} else {
MS_LOG(DEBUG) << "Key " << name << " not found.";
} }
} catch (std::exception &e) { } catch (std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what()); RETURN_STATUS_UNEXPECTED(e.what());
...@@ -76,21 +74,16 @@ Status IntrpService::Deregister(const std::string &name) noexcept { ...@@ -76,21 +74,16 @@ Status IntrpService::Deregister(const std::string &name) noexcept {
return Status::OK(); return Status::OK();
} }
Status IntrpService::InterruptAll() noexcept { void IntrpService::InterruptAll() noexcept {
std::lock_guard<std::mutex> lck(mutex_); std::lock_guard<std::mutex> lck(mutex_);
Status rc;
for (auto const &it : all_intrp_resources_) { for (auto const &it : all_intrp_resources_) {
std::string kName = it.first; std::string kName = it.first;
try { try {
Status rc2 = it.second->Interrupt(); it.second->Interrupt();
if (rc2.IsError()) {
rc = rc2;
}
} catch (const std::exception &e) { } catch (const std::exception &e) {
// continue the clean up. // continue the clean up.
} }
} }
return rc;
} }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
...@@ -47,7 +47,7 @@ class IntrpService : public Service { ...@@ -47,7 +47,7 @@ class IntrpService : public Service {
Status Deregister(const std::string &name) noexcept; Status Deregister(const std::string &name) noexcept;
Status InterruptAll() noexcept; void InterruptAll() noexcept;
Status DoServiceStart() override { return Status::OK(); } Status DoServiceStart() override { return Status::OK(); }
......
...@@ -110,7 +110,7 @@ class Queue { ...@@ -110,7 +110,7 @@ class Queue {
empty_cv_.NotifyAll(); empty_cv_.NotifyAll();
_lock.unlock(); _lock.unlock();
} else { } else {
(void)empty_cv_.Interrupt(); empty_cv_.Interrupt();
} }
return rc; return rc;
} }
...@@ -125,7 +125,7 @@ class Queue { ...@@ -125,7 +125,7 @@ class Queue {
empty_cv_.NotifyAll(); empty_cv_.NotifyAll();
_lock.unlock(); _lock.unlock();
} else { } else {
(void)empty_cv_.Interrupt(); empty_cv_.Interrupt();
} }
return rc; return rc;
} }
...@@ -141,7 +141,7 @@ class Queue { ...@@ -141,7 +141,7 @@ class Queue {
empty_cv_.NotifyAll(); empty_cv_.NotifyAll();
_lock.unlock(); _lock.unlock();
} else { } else {
(void)empty_cv_.Interrupt(); empty_cv_.Interrupt();
} }
return rc; return rc;
} }
...@@ -160,7 +160,7 @@ class Queue { ...@@ -160,7 +160,7 @@ class Queue {
full_cv_.NotifyAll(); full_cv_.NotifyAll();
_lock.unlock(); _lock.unlock();
} else { } else {
(void)full_cv_.Interrupt(); full_cv_.Interrupt();
} }
return rc; return rc;
} }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <mutex> #include <mutex>
#include <string> #include <string>
#include "dataset/util/memory_pool.h" #include "dataset/util/memory_pool.h"
#include "dataset/util/allocator.h"
#include "dataset/util/service.h" #include "dataset/util/service.h"
#define UNIQUEID_LEN 36 #define UNIQUEID_LEN 36
...@@ -72,6 +73,11 @@ class Services { ...@@ -72,6 +73,11 @@ class Services {
static std::string GetUniqueID(); static std::string GetUniqueID();
template <typename T>
static Allocator<T> GetAllocator() {
return Allocator<T>(Services::GetInstance().GetServiceMemPool());
}
private: private:
static std::once_flag init_instance_flag_; static std::once_flag init_instance_flag_;
static std::unique_ptr<Services> instance_; static std::unique_ptr<Services> instance_;
......
...@@ -72,7 +72,7 @@ void Task::ShutdownGroup() { // Wake up watch dog and shutdown the engine. ...@@ -72,7 +72,7 @@ void Task::ShutdownGroup() { // Wake up watch dog and shutdown the engine.
} }
} }
Status Task::GetTaskErrorIfAny() { Status Task::GetTaskErrorIfAny() const {
std::lock_guard<std::mutex> lk(mux_); std::lock_guard<std::mutex> lk(mux_);
if (caught_severe_exception_) { if (caught_severe_exception_) {
return rc_; return rc_;
...@@ -141,5 +141,13 @@ TaskGroup *Task::MyTaskGroup() { return task_group_; } ...@@ -141,5 +141,13 @@ TaskGroup *Task::MyTaskGroup() { return task_group_; }
void Task::set_task_group(TaskGroup *vg) { task_group_ = vg; } void Task::set_task_group(TaskGroup *vg) { task_group_ = vg; }
Task::~Task() { task_group_ = nullptr; } Task::~Task() { task_group_ = nullptr; }
Status Task::OverrideInterruptRc(const Status &rc) {
if (rc.IsInterrupted() && this_thread::is_master_thread()) {
// If we are interrupted, override the return value if this is the master thread.
// Master thread is being interrupted mostly because of some thread is reporting error.
return TaskManager::GetMasterThreadRc();
}
return rc;
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
...@@ -60,7 +60,7 @@ class Task : public IntrpResource { ...@@ -60,7 +60,7 @@ class Task : public IntrpResource {
Task &operator=(Task &&) = delete; Task &operator=(Task &&) = delete;
Status GetTaskErrorIfAny(); Status GetTaskErrorIfAny() const;
void ChangeName(const std::string &newName) { my_name_ = newName; } void ChangeName(const std::string &newName) { my_name_ = newName; }
...@@ -95,10 +95,10 @@ class Task : public IntrpResource { ...@@ -95,10 +95,10 @@ class Task : public IntrpResource {
Status Wait() { return (wp_.Wait()); } Status Wait() { return (wp_.Wait()); }
void set_task_group(TaskGroup *vg); static Status OverrideInterruptRc(const Status &rc);
private: private:
std::mutex mux_; mutable std::mutex mux_;
std::string my_name_; std::string my_name_;
Status rc_; Status rc_;
WaitPost wp_; WaitPost wp_;
...@@ -115,6 +115,7 @@ class Task : public IntrpResource { ...@@ -115,6 +115,7 @@ class Task : public IntrpResource {
void ShutdownGroup(); void ShutdownGroup();
TaskGroup *MyTaskGroup(); TaskGroup *MyTaskGroup();
void set_task_group(TaskGroup *vg);
}; };
extern thread_local Task *gMyTask; extern thread_local Task *gMyTask;
......
...@@ -84,7 +84,7 @@ void TaskManager::interrupt_all() noexcept { ...@@ -84,7 +84,7 @@ void TaskManager::interrupt_all() noexcept {
svc->InterruptAll(); svc->InterruptAll();
} }
} }
(void)master_->Interrupt(); master_->Interrupt();
} }
Task *TaskManager::FindMe() { return gMyTask; } Task *TaskManager::FindMe() { return gMyTask; }
...@@ -94,8 +94,7 @@ TaskManager::TaskManager() try : global_interrupt_(0), ...@@ -94,8 +94,7 @@ TaskManager::TaskManager() try : global_interrupt_(0),
free_lst_(&Task::free), free_lst_(&Task::free),
watchdog_grp_(nullptr), watchdog_grp_(nullptr),
watchdog_(nullptr) { watchdog_(nullptr) {
std::shared_ptr<MemoryPool> mp = Services::GetInstance().GetServiceMemPool(); auto alloc = Services::GetAllocator<Task>();
Allocator<Task> alloc(mp);
// Create a dummy Task for the master thread (this thread) // Create a dummy Task for the master thread (this thread)
master_ = std::allocate_shared<Task>(alloc, "master", []() -> Status { return Status::OK(); }); master_ = std::allocate_shared<Task>(alloc, "master", []() -> Status { return Status::OK(); });
master_->id_ = this_thread::get_id(); master_->id_ = this_thread::get_id();
...@@ -185,7 +184,7 @@ void TaskManager::InterruptMaster(const Status &rc) { ...@@ -185,7 +184,7 @@ void TaskManager::InterruptMaster(const Status &rc) {
TaskManager &tm = TaskManager::GetInstance(); TaskManager &tm = TaskManager::GetInstance();
std::shared_ptr<Task> master = tm.master_; std::shared_ptr<Task> master = tm.master_;
std::lock_guard<std::mutex> lck(master->mux_); std::lock_guard<std::mutex> lck(master->mux_);
(void)master->Interrupt(); master->Interrupt();
if (rc.IsError() && master->rc_.IsOk()) { if (rc.IsError() && master->rc_.IsOk()) {
master->rc_ = rc; master->rc_ = rc;
master->caught_severe_exception_ = true; master->caught_severe_exception_ = true;
...@@ -277,7 +276,7 @@ Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::functio ...@@ -277,7 +276,7 @@ Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::functio
return Status::OK(); return Status::OK();
} }
void TaskGroup::interrupt_all() noexcept { (void)intrp_svc_->InterruptAll(); } void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); }
Status TaskGroup::join_all() { Status TaskGroup::join_all() {
Status rc; Status rc;
...@@ -299,8 +298,7 @@ Status TaskGroup::DoServiceStop() { ...@@ -299,8 +298,7 @@ Status TaskGroup::DoServiceStop() {
} }
TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) { TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) {
std::shared_ptr<MemoryPool> mp = Services::GetInstance().GetServiceMemPool(); auto alloc = Services::GetAllocator<IntrpService>();
Allocator<IntrpService> alloc(mp);
intrp_svc_ = std::allocate_shared<IntrpService>(alloc); intrp_svc_ = std::allocate_shared<IntrpService>(alloc);
(void)Service::ServiceStart(); (void)Service::ServiceStart();
} }
......
...@@ -154,37 +154,27 @@ inline bool is_interrupted() { ...@@ -154,37 +154,27 @@ inline bool is_interrupted() {
return true; return true;
} }
Task *my_task = TaskManager::FindMe(); Task *my_task = TaskManager::FindMe();
return (my_task != nullptr) ? my_task->Interrupted() : false; return my_task->Interrupted();
}
inline bool is_master_thread() {
Task *my_task = TaskManager::FindMe();
return my_task->IsMasterThread();
}
inline Status GetInterruptStatus() {
Task *my_task = TaskManager::FindMe();
return my_task->GetInterruptStatus();
} }
} // namespace this_thread } // namespace this_thread
#define RETURN_IF_INTERRUPTED() \ #define RETURN_IF_INTERRUPTED() \
do { \ do { \
if (mindspore::dataset::this_thread::is_interrupted()) { \ if (mindspore::dataset::this_thread::is_interrupted()) { \
Task *myTask = TaskManager::FindMe(); \ return Task::OverrideInterruptRc(this_thread::GetInterruptStatus()); \
if (myTask->IsMasterThread() && myTask->CaughtSevereException()) { \ } \
return TaskManager::GetMasterThreadRc(); \
} else { \
return Status(StatusCode::kInterrupted); \
} \
} \
} while (false) } while (false)
inline Status interruptible_wait(std::condition_variable *cv, std::unique_lock<std::mutex> *lk,
const std::function<bool()> &pred) noexcept {
if (!pred()) {
do {
RETURN_IF_INTERRUPTED();
try {
(void)cv->wait_for(*lk, std::chrono::milliseconds(1));
} catch (std::exception &e) {
// Anything thrown by wait_for is considered system error.
RETURN_STATUS_UNEXPECTED(e.what());
}
} while (!pred());
}
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
......
...@@ -139,6 +139,9 @@ Status MindDataTestConnector::Run_test_0() { ...@@ -139,6 +139,9 @@ Status MindDataTestConnector::Run_test_0() {
10); // capacity of each queue 10); // capacity of each queue
DS_ASSERT(my_conn != nullptr); DS_ASSERT(my_conn != nullptr);
rc = my_conn->Register(tg_.get());
RETURN_IF_NOT_OK(rc);
// Spawn a thread to read input_ vector and put it in my_conn // Spawn a thread to read input_ vector and put it in my_conn
rc = tg_->CreateAsyncTask("Worker Push", rc = tg_->CreateAsyncTask("Worker Push",
std::bind(&MindDataTestConnector::FirstWorkerPush, std::bind(&MindDataTestConnector::FirstWorkerPush,
...@@ -184,6 +187,11 @@ Status MindDataTestConnector::Run_test_1() { ...@@ -184,6 +187,11 @@ Status MindDataTestConnector::Run_test_1() {
l3_threads, l3_threads,
conn2_qcap); conn2_qcap);
rc = conn1->Register(tg_.get());
RETURN_IF_NOT_OK(rc);
rc = conn2->Register(tg_.get());
RETURN_IF_NOT_OK(rc);
// Instantiating the threads in the first layer // Instantiating the threads in the first layer
for (int i = 0; i < l1_threads; i++) { for (int i = 0; i < l1_threads; i++) {
rc = tg_->CreateAsyncTask("First Worker Push", rc = tg_->CreateAsyncTask("First Worker Push",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册