diff --git a/imperative/python/megengine/data/dataloader.py b/imperative/python/megengine/data/dataloader.py index 4372b2b753afaf0a7612d8ec794c6755314137bb..39f2f407cb6c43b1e344184b97e066e98fe6e65d 100644 --- a/imperative/python/megengine/data/dataloader.py +++ b/imperative/python/megengine/data/dataloader.py @@ -114,7 +114,7 @@ class DataLoader: self.__initialized = True def __iter__(self): - if platform.system() == "Windows": + if platform.system() == "Windows" and self.num_workers > 0: print( "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero" ) diff --git a/src/core/impl/utils/thread.cpp b/src/core/impl/utils/thread.cpp index 59c6d1f0abdd5f4f9f30a00a3854f3d342c7ce93..43b4d7b2285181ec9b7a4b3fbc27dce470f00892 100644 --- a/src/core/impl/utils/thread.cpp +++ b/src/core/impl/utils/thread.cpp @@ -67,6 +67,9 @@ namespace { /* =============== SCQueueSynchronizer =============== */ size_t SCQueueSynchronizer::cached_max_spin = 0; +#ifdef WIN32 +bool SCQueueSynchronizer::is_into_atexit = false; +#endif size_t SCQueueSynchronizer::max_spin() { if (cached_max_spin) diff --git a/src/core/include/megbrain/utils/thread_impl_1.h b/src/core/include/megbrain/utils/thread_impl_1.h index c2719a5d974a25697ff09baf790511000bd95892..137ab59e4cd0f83b95df35f89161c9306dbcacb4 100644 --- a/src/core/include/megbrain/utils/thread_impl_1.h +++ b/src/core/include/megbrain/utils/thread_impl_1.h @@ -72,6 +72,13 @@ namespace mgb { return m_worker_started; } +#ifdef WIN32 + static bool is_into_atexit; + void set_finish_called(bool status) { + m_wait_finish_called = status; + } +#endif + static size_t max_spin(); void start_worker(std::thread thread); @@ -143,14 +150,29 @@ namespace mgb { }; public: - void add_task(const Param ¶m) { +#ifdef WIN32 + bool check_is_into_atexit() { + if (SCQueueSynchronizer::is_into_atexit) { + mgb_log_warn( + "add_task after system call atexit happened! " + "ignore it, workround for windows os force INT " + "some thread before shared_ptr destructor " + "finish!!"); + m_synchronizer.set_finish_called(true); + } + + return SCQueueSynchronizer::is_into_atexit; + } +#endif + + void add_task(const Param& param) { SyncedParam* p = allocate_task(); new (p->get()) Param(param); p->init_done.store(true, std::memory_order_release); m_synchronizer.producer_add(); } - void add_task(Param &¶m) { + void add_task(Param&& param) { SyncedParam* p = allocate_task(); new (p->get()) Param(std::move(param)); p->init_done.store(true, std::memory_order_release); @@ -165,6 +187,10 @@ namespace mgb { void wait_all_task_finish() { auto tgt = m_queue_tail_tid.load(std::memory_order_acquire); do { +#ifdef WIN32 + if (check_is_into_atexit()) + return; +#endif // we need a loop because other threads might be adding new // tasks, and m_queue_tail_tid is increased before // producer_add() @@ -184,6 +210,10 @@ namespace mgb { void wait_task_queue_empty() { size_t tgt, done; do { +#ifdef WIN32 + if (check_is_into_atexit()) + return; +#endif m_synchronizer.producer_wait(); // producer_wait() only waits for tasks that are added upon // entrance of the function, and new tasks might be added @@ -272,6 +302,17 @@ namespace mgb { // reload newest tail tail = m_queue_tail; if (!m_synchronizer.worker_started()) { +#ifdef WIN32 + if (!SCQueueSynchronizer::is_into_atexit) { + auto cb_atexit = [] { + SCQueueSynchronizer::is_into_atexit = true; + }; + auto err = atexit(cb_atexit); + mgb_assert(!err, + "failed to register windows_call_atexit " + "at exit"); + } +#endif m_synchronizer.start_worker(std::thread{ &AsyncQueueSC::worker_thread_impl, this}); }