Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
03ab8136
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
03ab8136
编写于
6月 04, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(core): fix asan error cause by wild thread_pool ptr
GitOrigin-RevId: b1c1b452cd78b3db0ca778c1b31c05593dbe9e96
上级
24a38781
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
14 addition
and
13 deletion
+14
-13
src/core/impl/comp_node/cpu/comp_node.cpp
src/core/impl/comp_node/cpu/comp_node.cpp
+14
-13
未找到文件。
src/core/impl/comp_node/cpu/comp_node.cpp
浏览文件 @
03ab8136
...
...
@@ -51,7 +51,7 @@ void CpuCompNode::CpuDispatchableBase::add_callback(Task&& task) {
class
CpuCompNode
::
WorkerQueue
final
:
public
AsyncQueueSC
<
TaskElem
,
WorkerQueue
>
{
const
Locator
m_locator
;
ThreadPool
*
m_thread_pool
=
nullptr
;
std
::
shared_ptr
<
ThreadPool
>
m_thread_pool
=
nullptr
;
void
on_async_queue_worker_thread_start
()
override
{
mgb_assert
(
m_locator
.
device
>=
0
);
...
...
@@ -74,7 +74,7 @@ public:
explicit
WorkerQueue
(
Locator
locator
)
:
m_locator
(
locator
)
{}
void
attach_thread_pool
(
ThreadPool
*
thread_pool
)
{
void
attach_thread_pool
(
std
::
shared_ptr
<
ThreadPool
>
thread_pool
)
{
m_thread_pool
=
thread_pool
;
}
...
...
@@ -92,7 +92,7 @@ public:
return
m_thread_pool
?
m_thread_pool
->
nr_threads
()
:
1
_z
;
}
ThreadPool
*
get_thread_pool
()
{
return
m_thread_pool
;
}
ThreadPool
*
get_thread_pool
()
{
return
m_thread_pool
.
get
()
;
}
};
class
CpuCompNode
::
SeqRecorderImpl
final
:
public
CompNodeSeqRecorder
{
...
...
@@ -102,7 +102,7 @@ class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder {
SeqRecorderImpl
**
const
m_self_pointer
;
std
::
vector
<
TaskElem
>
m_tasks
;
ThreadPool
*
m_thread_pool
=
nullptr
;
std
::
shared_ptr
<
ThreadPool
>
m_thread_pool
=
nullptr
;
const
CompNode
m_record_compnode
;
/*!
* \brief use to check the all ther recording tasks are its self CompNode
...
...
@@ -118,7 +118,8 @@ class CpuCompNode::SeqRecorderImpl final : public CompNodeSeqRecorder {
}
public:
SeqRecorderImpl
(
SeqRecorderImpl
**
self_pointer
,
ThreadPool
*
thread_pool
,
SeqRecorderImpl
(
SeqRecorderImpl
**
self_pointer
,
std
::
shared_ptr
<
ThreadPool
>
thread_pool
,
const
CompNode
&
comp_node
)
:
m_self_pointer
{
self_pointer
},
m_thread_pool
{
thread_pool
},
...
...
@@ -239,7 +240,7 @@ public:
return
m_thread_pool
?
m_thread_pool
->
nr_threads
()
:
1
_z
;
}
ThreadPool
*
get_thread_pool
()
{
return
m_thread_pool
;
}
ThreadPool
*
get_thread_pool
()
{
return
m_thread_pool
.
get
()
;
}
};
using
CompNodeBaseImpl
=
CpuCompNode
::
CompNodeBaseImpl
;
...
...
@@ -404,14 +405,14 @@ public:
//! implementation of InplaceCPUDispatcher
class
InplaceCPUDispatcher
final
:
public
CPUDispatcher
{
std
::
atomic_size_t
m_nr_task
{
0
};
ThreadPool
*
m_thread_pool
=
nullptr
;
std
::
shared_ptr
<
ThreadPool
>
m_thread_pool
=
nullptr
;
//! InplaceCPUDispatcher may used by both type of compnodes, so
//! m_comp_node's type should be base class.
CompNodeBaseImpl
*
const
m_comp_node
;
public:
InplaceCPUDispatcher
(
CompNodeBaseImpl
*
comp_node
,
ThreadPool
*
thread_pool
=
nullptr
)
std
::
shared_ptr
<
ThreadPool
>
thread_pool
=
nullptr
)
:
m_thread_pool
(
thread_pool
),
m_comp_node
(
comp_node
)
{}
void
dispatch
(
Task
&&
task
)
override
{
...
...
@@ -558,7 +559,7 @@ CompNodeDefaultImpl* CompNodeDefaultImpl::sm_default_cpu_comp_node_ptr =
//! ==================== CompNodeRecorderImpl ======================
class
CpuCompNode
::
CompNodeRecorderImpl
final
:
public
CompNodeBaseImpl
{
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
std
::
unique
_ptr
<
ThreadPool
>
m_thread_pool
;
std
::
shared
_ptr
<
ThreadPool
>
m_thread_pool
;
std
::
shared_ptr
<
WorkerQueue
>
m_worker_queue
;
//! used during comp node seq rec
...
...
@@ -629,7 +630,7 @@ public:
m_worker_queue
(
worker_queue
)
{
auto
cn
=
make_comp_node_from_impl
(
this
);
if
(
locator
.
type
==
DeviceType
::
MULTITHREAD
)
{
m_thread_pool
=
std
::
unique
_ptr
<
ThreadPool
>
(
m_thread_pool
=
std
::
shared
_ptr
<
ThreadPool
>
(
new
ThreadPool
(
static_cast
<
size_t
>
(
locator
.
nr_threads
)));
mgb_assert
(
m_thread_pool
,
"ThradPool create failed"
);
}
...
...
@@ -645,10 +646,10 @@ public:
}
else
if
(
locator
.
type
==
DeviceType
::
MULTITHREAD
)
{
if
(
locator
.
device
==
Locator
::
DEVICE_MULTITHREAD_DEFAULT
)
{
m_env
.
init_cpu
({
std
::
make_shared
<
InplaceCPUDispatcher
>
(
this
,
m_thread_pool
.
get
()
)},
this
,
m_thread_pool
)},
cn
);
}
else
{
m_worker_queue
->
attach_thread_pool
(
m_thread_pool
.
get
()
);
m_worker_queue
->
attach_thread_pool
(
m_thread_pool
);
m_env
.
init_cpu
({
std
::
make_shared
<
WorkerQueue
::
DispatcherImpl
>
(
m_worker_queue
,
this
)},
cn
);
...
...
@@ -807,7 +808,7 @@ public:
std
::
unique_ptr
<
CompNodeSeqRecorder
>
create_seq_recorder
(
cg
::
ComputingGraph
*
)
override
{
return
std
::
make_unique
<
SeqRecorderImpl
>
(
&
sm_cur_recorder
,
m_thread_pool
.
get
()
,
this
);
m_thread_pool
,
this
);
}
SeqRecorderImpl
*
cur_recorder
()
const
override
{
return
sm_cur_recorder
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录