From 03ab8136e7aea3bd13757f09df9dc2e38e95c4da Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 4 Jun 2021 18:24:54 +0800 Subject: [PATCH] fix(core): fix asan error cause by wild thread_pool ptr GitOrigin-RevId: b1c1b452cd78b3db0ca778c1b31c05593dbe9e96 --- src/core/impl/comp_node/cpu/comp_node.cpp | 27 ++++++++++++----------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/core/impl/comp_node/cpu/comp_node.cpp b/src/core/impl/comp_node/cpu/comp_node.cpp index e5d96cce9..06867df5a 100644 --- a/src/core/impl/comp_node/cpu/comp_node.cpp +++ b/src/core/impl/comp_node/cpu/comp_node.cpp @@ -51,7 +51,7 @@ void CpuCompNode::CpuDispatchableBase::add_callback(Task&& task) { class CpuCompNode::WorkerQueue final : public AsyncQueueSC { const Locator m_locator; - ThreadPool* m_thread_pool = nullptr; + std::shared_ptr m_thread_pool = nullptr; void on_async_queue_worker_thread_start() override { mgb_assert(m_locator.device >= 0); @@ -74,7 +74,7 @@ public: explicit WorkerQueue(Locator locator) : m_locator(locator) {} - void attach_thread_pool(ThreadPool* thread_pool) { + void attach_thread_pool(std::shared_ptr thread_pool) { m_thread_pool = thread_pool; } @@ -92,7 +92,7 @@ public: return m_thread_pool ? m_thread_pool->nr_threads() : 1_z; } - ThreadPool* get_thread_pool() { return m_thread_pool; } + ThreadPool* get_thread_pool() { return m_thread_pool.get(); } }; class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder { @@ -102,7 +102,7 @@ class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder { SeqRecorderImpl** const m_self_pointer; std::vector m_tasks; - ThreadPool* m_thread_pool = nullptr; + std::shared_ptr m_thread_pool = nullptr; const CompNode m_record_compnode; /*! * \brief use to check the all ther recording tasks are its self CompNode @@ -118,7 +118,8 @@ class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder { } public: - SeqRecorderImpl(SeqRecorderImpl** self_pointer, ThreadPool* thread_pool, + SeqRecorderImpl(SeqRecorderImpl** self_pointer, + std::shared_ptr thread_pool, const CompNode& comp_node) : m_self_pointer{self_pointer}, m_thread_pool{thread_pool}, @@ -239,7 +240,7 @@ public: return m_thread_pool ? m_thread_pool->nr_threads() : 1_z; } - ThreadPool* get_thread_pool() { return m_thread_pool; } + ThreadPool* get_thread_pool() { return m_thread_pool.get(); } }; using CompNodeBaseImpl = CpuCompNode::CompNodeBaseImpl; @@ -404,14 +405,14 @@ public: //! implementation of InplaceCPUDispatcher class InplaceCPUDispatcher final : public CPUDispatcher { std::atomic_size_t m_nr_task{0}; - ThreadPool* m_thread_pool = nullptr; + std::shared_ptr m_thread_pool = nullptr; //! InplaceCPUDispatcher may used by both type of compnodes, so //! m_comp_node's type should be base class. CompNodeBaseImpl* const m_comp_node; public: InplaceCPUDispatcher(CompNodeBaseImpl* comp_node, - ThreadPool* thread_pool = nullptr) + std::shared_ptr thread_pool = nullptr) : m_thread_pool(thread_pool), m_comp_node(comp_node) {} void dispatch(Task&& task) override { @@ -558,7 +559,7 @@ CompNodeDefaultImpl* CompNodeDefaultImpl::sm_default_cpu_comp_node_ptr = //! ==================== CompNodeRecorderImpl ====================== class CpuCompNode::CompNodeRecorderImpl final : public CompNodeBaseImpl { MGB_DYN_TYPE_OBJ_FINAL_DECL; - std::unique_ptr m_thread_pool; + std::shared_ptr m_thread_pool; std::shared_ptr m_worker_queue; //! used during comp node seq rec @@ -629,7 +630,7 @@ public: m_worker_queue(worker_queue) { auto cn = make_comp_node_from_impl(this); if (locator.type == DeviceType::MULTITHREAD) { - m_thread_pool = std::unique_ptr( + m_thread_pool = std::shared_ptr( new ThreadPool(static_cast(locator.nr_threads))); mgb_assert(m_thread_pool, "ThradPool create failed"); } @@ -645,10 +646,10 @@ public: } else if (locator.type == DeviceType::MULTITHREAD) { if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) { m_env.init_cpu({std::make_shared( - this, m_thread_pool.get())}, + this, m_thread_pool)}, cn); } else { - m_worker_queue->attach_thread_pool(m_thread_pool.get()); + m_worker_queue->attach_thread_pool(m_thread_pool); m_env.init_cpu({std::make_shared( m_worker_queue, this)}, cn); @@ -807,7 +808,7 @@ public: std::unique_ptr create_seq_recorder( cg::ComputingGraph*) override { return std::make_unique(&sm_cur_recorder, - m_thread_pool.get(), this); + m_thread_pool, this); } SeqRecorderImpl* cur_recorder() const override { return sm_cur_recorder; } -- GitLab