提交 e8980ed2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1406 Simplify CondVar class

Merge pull request !1406 from JesseKLee/CondVar
......@@ -14,35 +14,34 @@
* limitations under the License.
*/
#include "dataset/util/cond_var.h"
#include <exception>
#include <utility>
#include "dataset/util/services.h"
#include "dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
CondVar::CondVar() : svc_(nullptr), my_name_(std::move(Services::GetUniqueID())) {}
CondVar::CondVar() : svc_(nullptr), my_name_(Services::GetUniqueID()) {}
Status CondVar::Wait(std::unique_lock<std::mutex> *lck, const std::function<bool()> &pred) {
// Append an additional condition on top of the given predicate.
// We will also bail out if this cv got interrupted.
auto f = [this, &pred]() -> bool { return (pred() || (CurState() == State::kInterrupted)); };
// If we have interrupt service, just wait on the cv unconditionally.
// Otherwise fall back to the old way of checking interrupt.
if (svc_) {
cv_.wait(*lck, f);
if (CurState() == State::kInterrupted) {
Task *my_task = TaskManager::FindMe();
if (my_task->IsMasterThread() && my_task->CaughtSevereException()) {
return TaskManager::GetMasterThreadRc();
} else {
return Status(StatusCode::kInterrupted);
try {
if (svc_ != nullptr) {
// If this cv registers with a global resource tracking, then wait unconditionally.
auto f = [this, &pred]() -> bool { return (pred() || this->Interrupted()); };
cv_.wait(*lck, f);
// If we are interrupted, override the return value if this is the master thread.
// Master thread is being interrupted mostly because of some thread is reporting error.
RETURN_IF_NOT_OK(Task::OverrideInterruptRc(this->GetInterruptStatus()));
} else {
// Otherwise we wake up once a while to check for interrupt (for this thread).
auto f = [&pred]() -> bool { return (pred() || this_thread::is_interrupted()); };
while (!f()) {
(void)cv_.wait_for(*lck, std::chrono::milliseconds(1));
}
RETURN_IF_INTERRUPTED();
}
} else {
RETURN_IF_NOT_OK(interruptible_wait(&cv_, lck, pred));
if (CurState() == State::kInterrupted) {
return Status(StatusCode::kInterrupted);
}
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
return Status::OK();
}
......@@ -66,10 +65,9 @@ Status CondVar::Register(std::shared_ptr<IntrpService> svc) {
return rc;
}
Status CondVar::Interrupt() {
RETURN_IF_NOT_OK(IntrpResource::Interrupt());
void CondVar::Interrupt() {
IntrpResource::Interrupt();
cv_.notify_all();
return Status::OK();
}
std::string CondVar::my_name() const { return my_name_; }
......
......@@ -35,7 +35,7 @@ class CondVar : public IntrpResource {
Status Wait(std::unique_lock<std::mutex> *lck, const std::function<bool()> &pred);
Status Interrupt() override;
void Interrupt() override;
void NotifyOne() noexcept;
......
......@@ -29,10 +29,7 @@ class IntrpResource {
virtual ~IntrpResource() = default;
virtual Status Interrupt() {
st_ = State::kInterrupted;
return Status::OK();
}
virtual void Interrupt() { st_ = State::kInterrupted; }
virtual void ResetIntrpState() { st_ = State::kRunning; }
......@@ -40,6 +37,13 @@ class IntrpResource {
bool Interrupted() const { return CurState() == State::kInterrupted; }
virtual Status GetInterruptStatus() const {
if (Interrupted()) {
return Status(StatusCode::kInterrupted);
}
return Status::OK();
}
protected:
std::atomic<State> st_;
};
......
......@@ -27,7 +27,7 @@ IntrpService::~IntrpService() noexcept {
MS_LOG(INFO) << "Number of registered resources is " << high_water_mark_ << ".";
if (!all_intrp_resources_.empty()) {
try {
(void)InterruptAll();
InterruptAll();
} catch (const std::exception &e) {
// Ignore all error as we can't throw in the destructor.
}
......@@ -64,11 +64,9 @@ Status IntrpService::Deregister(const std::string &name) noexcept {
std::ostringstream ss;
ss << this_thread::get_id();
MS_LOG(DEBUG) << "De-register resource with name " << name << ". Thread ID is " << ss.str() << ".";
auto it = all_intrp_resources_.find(name);
if (it != all_intrp_resources_.end()) {
(void)all_intrp_resources_.erase(it);
} else {
MS_LOG(DEBUG) << "Key " << name << " not found.";
auto n = all_intrp_resources_.erase(name);
if (n == 0) {
MS_LOG(INFO) << "Key " << name << " not found.";
}
} catch (std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
......@@ -76,21 +74,16 @@ Status IntrpService::Deregister(const std::string &name) noexcept {
return Status::OK();
}
Status IntrpService::InterruptAll() noexcept {
void IntrpService::InterruptAll() noexcept {
std::lock_guard<std::mutex> lck(mutex_);
Status rc;
for (auto const &it : all_intrp_resources_) {
std::string kName = it.first;
try {
Status rc2 = it.second->Interrupt();
if (rc2.IsError()) {
rc = rc2;
}
it.second->Interrupt();
} catch (const std::exception &e) {
// continue the clean up.
}
}
return rc;
}
} // namespace dataset
} // namespace mindspore
......@@ -47,7 +47,7 @@ class IntrpService : public Service {
Status Deregister(const std::string &name) noexcept;
Status InterruptAll() noexcept;
void InterruptAll() noexcept;
Status DoServiceStart() override { return Status::OK(); }
......
......@@ -110,7 +110,7 @@ class Queue {
empty_cv_.NotifyAll();
_lock.unlock();
} else {
(void)empty_cv_.Interrupt();
empty_cv_.Interrupt();
}
return rc;
}
......@@ -125,7 +125,7 @@ class Queue {
empty_cv_.NotifyAll();
_lock.unlock();
} else {
(void)empty_cv_.Interrupt();
empty_cv_.Interrupt();
}
return rc;
}
......@@ -141,7 +141,7 @@ class Queue {
empty_cv_.NotifyAll();
_lock.unlock();
} else {
(void)empty_cv_.Interrupt();
empty_cv_.Interrupt();
}
return rc;
}
......@@ -160,7 +160,7 @@ class Queue {
full_cv_.NotifyAll();
_lock.unlock();
} else {
(void)full_cv_.Interrupt();
full_cv_.Interrupt();
}
return rc;
}
......
......@@ -20,6 +20,7 @@
#include <mutex>
#include <string>
#include "dataset/util/memory_pool.h"
#include "dataset/util/allocator.h"
#include "dataset/util/service.h"
#define UNIQUEID_LEN 36
......@@ -72,6 +73,11 @@ class Services {
static std::string GetUniqueID();
template <typename T>
static Allocator<T> GetAllocator() {
return Allocator<T>(Services::GetInstance().GetServiceMemPool());
}
private:
static std::once_flag init_instance_flag_;
static std::unique_ptr<Services> instance_;
......
......@@ -72,7 +72,7 @@ void Task::ShutdownGroup() { // Wake up watch dog and shutdown the engine.
}
}
Status Task::GetTaskErrorIfAny() {
Status Task::GetTaskErrorIfAny() const {
std::lock_guard<std::mutex> lk(mux_);
if (caught_severe_exception_) {
return rc_;
......@@ -141,5 +141,13 @@ TaskGroup *Task::MyTaskGroup() { return task_group_; }
void Task::set_task_group(TaskGroup *vg) { task_group_ = vg; }
Task::~Task() { task_group_ = nullptr; }
Status Task::OverrideInterruptRc(const Status &rc) {
if (rc.IsInterrupted() && this_thread::is_master_thread()) {
// If we are interrupted, override the return value if this is the master thread.
// Master thread is being interrupted mostly because of some thread is reporting error.
return TaskManager::GetMasterThreadRc();
}
return rc;
}
} // namespace dataset
} // namespace mindspore
......@@ -60,7 +60,7 @@ class Task : public IntrpResource {
Task &operator=(Task &&) = delete;
Status GetTaskErrorIfAny();
Status GetTaskErrorIfAny() const;
void ChangeName(const std::string &newName) { my_name_ = newName; }
......@@ -95,10 +95,10 @@ class Task : public IntrpResource {
Status Wait() { return (wp_.Wait()); }
void set_task_group(TaskGroup *vg);
static Status OverrideInterruptRc(const Status &rc);
private:
std::mutex mux_;
mutable std::mutex mux_;
std::string my_name_;
Status rc_;
WaitPost wp_;
......@@ -115,6 +115,7 @@ class Task : public IntrpResource {
void ShutdownGroup();
TaskGroup *MyTaskGroup();
void set_task_group(TaskGroup *vg);
};
extern thread_local Task *gMyTask;
......
......@@ -84,7 +84,7 @@ void TaskManager::interrupt_all() noexcept {
svc->InterruptAll();
}
}
(void)master_->Interrupt();
master_->Interrupt();
}
Task *TaskManager::FindMe() { return gMyTask; }
......@@ -94,8 +94,7 @@ TaskManager::TaskManager() try : global_interrupt_(0),
free_lst_(&Task::free),
watchdog_grp_(nullptr),
watchdog_(nullptr) {
std::shared_ptr<MemoryPool> mp = Services::GetInstance().GetServiceMemPool();
Allocator<Task> alloc(mp);
auto alloc = Services::GetAllocator<Task>();
// Create a dummy Task for the master thread (this thread)
master_ = std::allocate_shared<Task>(alloc, "master", []() -> Status { return Status::OK(); });
master_->id_ = this_thread::get_id();
......@@ -185,7 +184,7 @@ void TaskManager::InterruptMaster(const Status &rc) {
TaskManager &tm = TaskManager::GetInstance();
std::shared_ptr<Task> master = tm.master_;
std::lock_guard<std::mutex> lck(master->mux_);
(void)master->Interrupt();
master->Interrupt();
if (rc.IsError() && master->rc_.IsOk()) {
master->rc_ = rc;
master->caught_severe_exception_ = true;
......@@ -277,7 +276,7 @@ Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::functio
return Status::OK();
}
void TaskGroup::interrupt_all() noexcept { (void)intrp_svc_->InterruptAll(); }
void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); }
Status TaskGroup::join_all() {
Status rc;
......@@ -299,8 +298,7 @@ Status TaskGroup::DoServiceStop() {
}
TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) {
std::shared_ptr<MemoryPool> mp = Services::GetInstance().GetServiceMemPool();
Allocator<IntrpService> alloc(mp);
auto alloc = Services::GetAllocator<IntrpService>();
intrp_svc_ = std::allocate_shared<IntrpService>(alloc);
(void)Service::ServiceStart();
}
......
......@@ -154,37 +154,27 @@ inline bool is_interrupted() {
return true;
}
Task *my_task = TaskManager::FindMe();
return (my_task != nullptr) ? my_task->Interrupted() : false;
return my_task->Interrupted();
}
inline bool is_master_thread() {
Task *my_task = TaskManager::FindMe();
return my_task->IsMasterThread();
}
inline Status GetInterruptStatus() {
Task *my_task = TaskManager::FindMe();
return my_task->GetInterruptStatus();
}
} // namespace this_thread
#define RETURN_IF_INTERRUPTED() \
do { \
if (mindspore::dataset::this_thread::is_interrupted()) { \
Task *myTask = TaskManager::FindMe(); \
if (myTask->IsMasterThread() && myTask->CaughtSevereException()) { \
return TaskManager::GetMasterThreadRc(); \
} else { \
return Status(StatusCode::kInterrupted); \
} \
} \
#define RETURN_IF_INTERRUPTED() \
do { \
if (mindspore::dataset::this_thread::is_interrupted()) { \
return Task::OverrideInterruptRc(this_thread::GetInterruptStatus()); \
} \
} while (false)
inline Status interruptible_wait(std::condition_variable *cv, std::unique_lock<std::mutex> *lk,
const std::function<bool()> &pred) noexcept {
if (!pred()) {
do {
RETURN_IF_INTERRUPTED();
try {
(void)cv->wait_for(*lk, std::chrono::milliseconds(1));
} catch (std::exception &e) {
// Anything thrown by wait_for is considered system error.
RETURN_STATUS_UNEXPECTED(e.what());
}
} while (!pred());
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
......
......@@ -139,6 +139,9 @@ Status MindDataTestConnector::Run_test_0() {
10); // capacity of each queue
DS_ASSERT(my_conn != nullptr);
rc = my_conn->Register(tg_.get());
RETURN_IF_NOT_OK(rc);
// Spawn a thread to read input_ vector and put it in my_conn
rc = tg_->CreateAsyncTask("Worker Push",
std::bind(&MindDataTestConnector::FirstWorkerPush,
......@@ -184,6 +187,11 @@ Status MindDataTestConnector::Run_test_1() {
l3_threads,
conn2_qcap);
rc = conn1->Register(tg_.get());
RETURN_IF_NOT_OK(rc);
rc = conn2->Register(tg_.get());
RETURN_IF_NOT_OK(rc);
// Instantiating the threads in the first layer
for (int i = 0; i < l1_threads; i++) {
rc = tg_->CreateAsyncTask("First Worker Push",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册