Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
221ec38a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
221ec38a
编写于
3月 18, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative): reduce profiler overhead
GitOrigin-RevId: dded9d9391a49883f1c1035c1aa10f8174867367
上级
1c01128f
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
89 addition
and
47 deletion
+89
-47
imperative/src/impl/interpreter/interpreter_impl.cpp
imperative/src/impl/interpreter/interpreter_impl.cpp
+88
-45
imperative/src/include/megbrain/imperative/profiler.h
imperative/src/include/megbrain/imperative/profiler.h
+1
-2
未找到文件。
imperative/src/impl/interpreter/interpreter_impl.cpp
浏览文件 @
221ec38a
...
...
@@ -52,7 +52,9 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
info
->
desc
.
layout
=
data
.
layout
();
info
->
desc
.
comp_node
=
data
.
comp_node
();
info
->
ptr
=
Tensor
::
make
(
data
);
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
TensorProduceEvent
>
(
info
->
id
,
info
->
desc
.
layout
,
info
->
desc
.
comp_node
);
}
return
info
;
}
...
...
@@ -147,8 +149,9 @@ void ChannelImpl::dispatch_default_cpu(
return
tid
;
};
OpEvent
event_data
=
{
++
m_last_id
,
op
,
tinfo_to_tid
(
input_infos
),
{}};
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
HostOpExecuteEvent
>
(
event_data
);
}
OpDef
::
apply_on_device_tensornd
(
*
op
,
input_tensornds
,
&
output_tensornds
);
...
...
@@ -169,8 +172,9 @@ void ChannelImpl::dispatch_default_cpu(
}
event_data
.
outputs
=
tinfo_to_tid
(
output_infos
);
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
HostOpFinishEvent
>
(
event_data
);
}
}
void
ChannelImpl
::
dispatch_kernel
(
...
...
@@ -267,13 +271,17 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
m_waitee
=
info
;
regenerate
(
info
);
m_buffer
.
enqueue
(
GetValue
{
info
});
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
TensorWaitPropEvent
>
(
info
->
id
,
TensorInfo
::
HostValue
);
}
m_cv
.
wait
(
lock
,
[
&
]()
{
check_worker_exc_unsafe
();
tensor_ptr
=
info
->
ptr
;
return
value_fetched
();
});
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
TensorWaitPropFinishEvent
>
(
info
->
id
,
TensorInfo
::
HostValue
);
}
m_waitee
=
nullptr
;
}
return
tensor_ptr
->
get_value
();
...
...
@@ -290,12 +298,16 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
mgb_assert
(
!
m_waitee
);
m_waitee
=
info
;
m_buffer
.
flush
();
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
TensorWaitPropEvent
>
(
info
->
id
,
TensorInfo
::
Shape
);
}
m_cv
.
wait
(
lock
,
[
&
]()
{
check_worker_exc_unsafe
();
return
static_cast
<
bool
>
(
info
->
ptr
);
});
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
TensorWaitPropFinishEvent
>
(
info
->
id
,
TensorInfo
::
Shape
);
}
m_waitee
=
nullptr
;
TensorShape
ret
=
info
->
ptr
->
layout
();
mgb_assert
(
ret
.
ndim
!=
0
);
...
...
@@ -306,7 +318,9 @@ DType ChannelImpl::get_dtype(Handle handle) {
mgb_assert
(
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
TensorGetPropEvent
>
(
info
->
id
,
TensorInfo
::
DType
);
}
auto
ret
=
info
->
desc
.
layout
.
dtype
;
mgb_assert
(
ret
.
valid
());
return
ret
;
...
...
@@ -316,7 +330,9 @@ CompNode ChannelImpl::get_device(Handle handle) {
mgb_assert
(
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
TensorGetPropEvent
>
(
info
->
id
,
TensorInfo
::
Device
);
}
auto
ret
=
info
->
desc
.
comp_node
;
mgb_assert
(
ret
.
valid
());
return
ret
;
...
...
@@ -331,22 +347,30 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
m_waitee
=
info
;
regenerate
(
info
);
m_buffer
.
flush
();
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
TensorWaitPropEvent
>
(
info
->
id
,
TensorInfo
::
DevValue
);
}
m_cv
.
wait
(
lock
,
[
&
]()
{
check_worker_exc_unsafe
();
return
static_cast
<
bool
>
(
info
->
ptr
);
});
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
TensorWaitPropFinishEvent
>
(
info
->
id
,
TensorInfo
::
DevValue
);
}
m_waitee
=
nullptr
;
return
info
->
ptr
->
dev_tensor
();
}
void
ChannelImpl
::
sync
()
{
m_buffer
.
flush
();
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
SyncStartEvent
>
();
}
m_worker
.
wait_all_task_finish
();
CompNode
::
sync_all
();
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
SyncFinishEvent
>
();
}
MGB_LOCK_GUARD
(
m_mutex
);
check_worker_exc_unsafe
();
}
...
...
@@ -369,13 +393,17 @@ TensorInfo* ChannelImpl::alloc() {
auto
info
=
m_pool
.
alloc
();
m_valid_handle
.
insert
(
info
);
info
->
id
=
m_last_id
++
;
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
TensorDeclareEvent
>
(
info
->
id
);
}
return
info
;
}
void
ChannelImpl
::
free
(
TensorInfo
*
ptr
)
{
MGB_LOCK_GUARD
(
m_mutex
);
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
TensorEraseEvent
>
(
ptr
->
id
);
}
m_pool
.
free
(
ptr
);
}
...
...
@@ -389,7 +417,9 @@ ChannelImpl::~ChannelImpl() {
void
ChannelImpl
::
produce_tensor
(
TensorInfo
*
dest
,
TensorPtr
ptr
)
{
MGB_LOCK_GUARD
(
m_mutex
);
if
(
m_worker_state
.
profiler
->
is_profiling
())
{
m_worker_state
.
profiler
->
record_host
<
TensorProduceEvent
>
(
dest
->
id
,
ptr
->
layout
(),
ptr
->
comp_node
());
}
dest
->
value_fetched
=
ptr
->
value_fetched
();
// update tensor desc for static infer
dest
->
desc
.
layout
=
ptr
->
layout
();
...
...
@@ -471,13 +501,17 @@ void ChannelImpl::sync_device_scope(CompNode device) {
}
void
ChannelImpl
::
process_one_task
(
IdentifiedCommand
&
icmd
)
{
if
(
m_worker_state
.
profiler
->
is_profiling
())
{
m_worker_state
.
profiler
->
record_host
<
CommandExecuteEvent
>
(
icmd
);
}
bool
finished
=
false
;
auto
do_finish_command
=
[
&
]{
if
(
finished
)
{
return
;
}
if
(
m_worker_state
.
profiler
->
is_profiling
())
{
m_worker_state
.
profiler
->
record_host
<
CommandFinishEvent
>
(
icmd
);
}
finished
=
true
;
};
//TODO: remove std::visit for support osx 10.12
...
...
@@ -498,6 +532,8 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
tensor_inputs
.
push_back
(
i
->
ptr
);
}
// Begin profiling operator
OpEvent
event_data
;
if
(
m_worker_state
.
profiler
->
is_profiling
())
{
auto
tinfo_to_tid
=
[
&
](
SmallVector
<
TensorInfo
*>
tinfo
)
{
SmallVector
<
uint64_t
>
tid
;
for
(
auto
*
ptinfo
:
tinfo
)
{
...
...
@@ -505,7 +541,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
}
return
tid
;
};
OpEvent
event_data
=
{
apply_id
,
cmd
.
op
,
tinfo_to_tid
(
cmd
.
inputs
),
tinfo_to_tid
(
cmd
.
outputs
)};
event_data
=
{
apply_id
,
cmd
.
op
,
tinfo_to_tid
(
cmd
.
inputs
),
tinfo_to_tid
(
cmd
.
outputs
)};
// Collecting devices
for
(
auto
i
:
cmd
.
inputs
)
{
devices
.
push_back
(
i
->
desc
.
comp_node
);
...
...
@@ -514,6 +550,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
devices
.
push_back
(
i
->
desc
.
comp_node
);
}
devices
.
erase
(
std
::
unique
(
devices
.
begin
(),
devices
.
end
()),
devices
.
end
());
}
// Fused by command buffer. @see: CommandBuffer::fuse_del
// Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del.
// Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused.
...
...
@@ -643,7 +680,7 @@ void ChannelImpl::CommandBuffer::enqueue(Command cmd) {
if
(
std
::
get_if
<
Del
>
(
&
cmd
)
&&
fuse_del
(
std
::
get
<
Del
>
(
cmd
)))
{
return
;
}
mgb_log_debug
(
"%s Enqueued"
,
to_string
(
cmd
).
c_str
());
//
mgb_log_debug("%s Enqueued", to_string(cmd).c_str());
m_commands
.
push_back
(
std
::
move
(
cmd
));
auto
flush_pos
=
flush_pos_for
(
m_commands
.
back
());
flush
(
flush_pos
);
...
...
@@ -655,9 +692,11 @@ void ChannelImpl::CommandBuffer::flush() {
void
ChannelImpl
::
CommandBuffer
::
flush
(
Handle
pos
)
{
for
(
auto
iter
=
m_commands
.
begin
();
iter
!=
pos
;
++
iter
)
{
mgb_log_debug
(
"%s Flushed"
,
to_string
(
*
iter
).
c_str
());
//
mgb_log_debug("%s Flushed", to_string(*iter).c_str());
IdentifiedCommand
icmd
{
++
m_owner
->
m_last_id
,
std
::
move
(
*
iter
)};
if
(
m_owner
->
m_channel_state
.
profiler
->
is_profiling
())
{
m_owner
->
m_channel_state
.
profiler
->
record_host
<
CommandEnqueueEvent
>
(
icmd
);
}
m_owner
->
m_worker
.
add_task
(
std
::
move
(
icmd
));
}
m_commands
.
erase
(
m_commands
.
begin
(),
pos
);
...
...
@@ -705,7 +744,7 @@ bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
if
(
apply_iter
==
end
||
find_last_usage
(
dest
,
{
apply_iter
+
1
,
end
})
!=
end
)
{
return
false
;
}
mgb_log_debug
(
"%s Fused"
,
to_string
(
Command
{
cmd
}).
c_str
());
//
mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str());
std
::
get
<
ApplyOp
>
(
*
apply_iter
).
dels
.
push_back
(
dest
);
return
true
;
}
...
...
@@ -771,16 +810,20 @@ void ChannelImpl::stop_profile(std::string basename, std::string format) {
}
void
ChannelImpl
::
push_scope
(
std
::
string
name
)
{
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
m_channel_state
.
profiler
->
record_host
<
ChannelBeginScope
>
(
name
);
m_channel_state
.
scopes
.
push_back
(
name
);
m_buffer
.
enqueue
(
PushScope
{
name
});
}
}
void
ChannelImpl
::
pop_scope
(
std
::
string
name
)
{
if
(
m_channel_state
.
profiler
->
is_profiling
())
{
mgb_assert
((
!
m_channel_state
.
scopes
.
empty
())
&&
m_channel_state
.
scopes
.
back
()
==
name
,
"scope name mismatch"
);
m_channel_state
.
scopes
.
pop_back
();
m_channel_state
.
profiler
->
record_host
<
ChannelEndScope
>
(
name
);
m_buffer
.
enqueue
(
PopScope
{
name
});
}
}
void
ChannelImpl
::
assert_in_channel
()
{
...
...
imperative/src/include/megbrain/imperative/profiler.h
浏览文件 @
221ec38a
...
...
@@ -163,7 +163,6 @@ public:
}
// unsafe
bool
is_profiling
()
{
MGB_LOCK_GUARD
(
m_lock
);
return
m_status
==
Profiling
;
}
void
start
(
Mask
mask
)
{
...
...
@@ -188,7 +187,7 @@ public:
protected:
std
::
vector
<
Record
>
m_record_list
;
Mask
m_event_mask
;
Status
m_status
=
NotStarted
;
std
::
atomic
<
Status
>
m_status
=
NotStarted
;
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录