提交 03ab8136 编写于 作者: M Megvii Engine Team

fix(core): fix asan error cause by wild thread_pool ptr

GitOrigin-RevId: b1c1b452cd78b3db0ca778c1b31c05593dbe9e96
上级 24a38781
......@@ -51,7 +51,7 @@ void CpuCompNode::CpuDispatchableBase::add_callback(Task&& task) {
class CpuCompNode::WorkerQueue final
: public AsyncQueueSC<TaskElem, WorkerQueue> {
const Locator m_locator;
ThreadPool* m_thread_pool = nullptr;
std::shared_ptr<ThreadPool> m_thread_pool = nullptr;
void on_async_queue_worker_thread_start() override {
mgb_assert(m_locator.device >= 0);
......@@ -74,7 +74,7 @@ public:
explicit WorkerQueue(Locator locator) : m_locator(locator) {}
void attach_thread_pool(ThreadPool* thread_pool) {
void attach_thread_pool(std::shared_ptr<ThreadPool> thread_pool) {
m_thread_pool = thread_pool;
}
......@@ -92,7 +92,7 @@ public:
return m_thread_pool ? m_thread_pool->nr_threads() : 1_z;
}
ThreadPool* get_thread_pool() { return m_thread_pool; }
ThreadPool* get_thread_pool() { return m_thread_pool.get(); }
};
class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder {
......@@ -102,7 +102,7 @@ class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder {
SeqRecorderImpl** const m_self_pointer;
std::vector<TaskElem> m_tasks;
ThreadPool* m_thread_pool = nullptr;
std::shared_ptr<ThreadPool> m_thread_pool = nullptr;
const CompNode m_record_compnode;
/*!
* \brief use to check the all ther recording tasks are its self CompNode
......@@ -118,7 +118,8 @@ class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder {
}
public:
SeqRecorderImpl(SeqRecorderImpl** self_pointer, ThreadPool* thread_pool,
SeqRecorderImpl(SeqRecorderImpl** self_pointer,
std::shared_ptr<ThreadPool> thread_pool,
const CompNode& comp_node)
: m_self_pointer{self_pointer},
m_thread_pool{thread_pool},
......@@ -239,7 +240,7 @@ public:
return m_thread_pool ? m_thread_pool->nr_threads() : 1_z;
}
ThreadPool* get_thread_pool() { return m_thread_pool; }
ThreadPool* get_thread_pool() { return m_thread_pool.get(); }
};
using CompNodeBaseImpl = CpuCompNode::CompNodeBaseImpl;
......@@ -404,14 +405,14 @@ public:
//! implementation of InplaceCPUDispatcher
class InplaceCPUDispatcher final : public CPUDispatcher {
std::atomic_size_t m_nr_task{0};
ThreadPool* m_thread_pool = nullptr;
std::shared_ptr<ThreadPool> m_thread_pool = nullptr;
//! InplaceCPUDispatcher may used by both type of compnodes, so
//! m_comp_node's type should be base class.
CompNodeBaseImpl* const m_comp_node;
public:
InplaceCPUDispatcher(CompNodeBaseImpl* comp_node,
ThreadPool* thread_pool = nullptr)
std::shared_ptr<ThreadPool> thread_pool = nullptr)
: m_thread_pool(thread_pool), m_comp_node(comp_node) {}
void dispatch(Task&& task) override {
......@@ -558,7 +559,7 @@ CompNodeDefaultImpl* CompNodeDefaultImpl::sm_default_cpu_comp_node_ptr =
//! ==================== CompNodeRecorderImpl ======================
class CpuCompNode::CompNodeRecorderImpl final : public CompNodeBaseImpl {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
std::unique_ptr<ThreadPool> m_thread_pool;
std::shared_ptr<ThreadPool> m_thread_pool;
std::shared_ptr<WorkerQueue> m_worker_queue;
//! used during comp node seq rec
......@@ -629,7 +630,7 @@ public:
m_worker_queue(worker_queue) {
auto cn = make_comp_node_from_impl(this);
if (locator.type == DeviceType::MULTITHREAD) {
m_thread_pool = std::unique_ptr<ThreadPool>(
m_thread_pool = std::shared_ptr<ThreadPool>(
new ThreadPool(static_cast<size_t>(locator.nr_threads)));
mgb_assert(m_thread_pool, "ThradPool create failed");
}
......@@ -645,10 +646,10 @@ public:
} else if (locator.type == DeviceType::MULTITHREAD) {
if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) {
m_env.init_cpu({std::make_shared<InplaceCPUDispatcher>(
this, m_thread_pool.get())},
this, m_thread_pool)},
cn);
} else {
m_worker_queue->attach_thread_pool(m_thread_pool.get());
m_worker_queue->attach_thread_pool(m_thread_pool);
m_env.init_cpu({std::make_shared<WorkerQueue::DispatcherImpl>(
m_worker_queue, this)},
cn);
......@@ -807,7 +808,7 @@ public:
std::unique_ptr<CompNodeSeqRecorder> create_seq_recorder(
cg::ComputingGraph*) override {
return std::make_unique<SeqRecorderImpl>(&sm_cur_recorder,
m_thread_pool.get(), this);
m_thread_pool, this);
}
SeqRecorderImpl* cur_recorder() const override { return sm_cur_recorder; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册