Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
bcf69d8f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
bcf69d8f
编写于
11月 04, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(imperative): correctly apply sqrt sampling for dtr
GitOrigin-RevId: dabd36551765af1d2646789ae9ed57d8eac4a936
上级
48100781
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
38 addition
and
10 deletion
+38
-10
imperative/src/impl/interpreter/interpreter_impl.cpp
imperative/src/impl/interpreter/interpreter_impl.cpp
+34
-9
imperative/src/impl/interpreter/interpreter_impl.h
imperative/src/impl/interpreter/interpreter_impl.h
+1
-1
imperative/src/impl/interpreter/tensor_info.h
imperative/src/impl/interpreter/tensor_info.h
+3
-0
未找到文件。
imperative/src/impl/interpreter/interpreter_impl.cpp
浏览文件 @
bcf69d8f
...
...
@@ -646,6 +646,10 @@ void ChannelImpl::release_tensor(TensorInfo* dest) {
MGB_RECORD_EVENT
(
TensorReleaseEvent
,
dest
->
id
);
MGB_LOCK_GUARD
(
m_mutex
);
dest
->
ptr
.
reset
();
auto
&
state
=
get_worker_state
();
if
(
dest
->
size_exceeds_thd
(
state
.
options
.
dtr_evictee_minimum_size
))
{
m_dtr
.
erase_candidate
(
dest
);
}
}
void
ChannelImpl
::
regenerate
(
TensorInfo
*
dest
)
{
...
...
@@ -891,8 +895,7 @@ bool ChannelImpl::auto_evict(size_t force_num) {
force_num
>
0
)
{
MGB_RECORD_EVENT
(
AutoEvictEvent
);
sample_on_device
(
m_dtr
.
comp_node
,
false
);
auto
best
=
m_dtr
.
find_best_tensor
(
state
.
options
.
enable_dtr_sqrt_sampling
&&
!
force_num
);
auto
best
=
m_dtr
.
find_best_tensor
(
state
.
options
.
enable_dtr_sqrt_sampling
);
if
(
!
best
)
{
MGB_RECORD_EVENT
(
AutoEvictFinishEvent
);
break
;
...
...
@@ -1300,7 +1303,6 @@ void ChannelImpl::CommandBuffer::enqueue(CommandData cmd) {
if
(
std
::
get_if
<
Del
>
(
&
cmd
)
&&
fuse_del
(
std
::
get
<
Del
>
(
cmd
)))
{
return
;
}
// mgb_log_debug("%s Enqueued", to_string(cmd).c_str());
m_commands
.
push_back
(
{
Profiler
::
next_id
(),
std
::
move
(
cmd
),
state
.
stack_manager
.
dump
()});
auto
flush_pos
=
flush_pos_for
(
m_commands
.
back
());
...
...
@@ -1365,7 +1367,6 @@ bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
if
(
apply_iter
==
end
||
find_last_usage
(
dest
,
{
apply_iter
+
1
,
end
})
!=
end
)
{
return
false
;
}
// mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str());
std
::
get
<
ApplyOp
>
(
apply_iter
->
data
).
dels
.
push_back
(
dest
);
return
true
;
}
...
...
@@ -1538,16 +1539,26 @@ double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
TensorInfo
*
ChannelImpl
::
DynamicSublinear
::
find_best_tensor
(
bool
enable_dtr_sqrt_sampling
=
false
)
{
if
(
candidates
.
empty
())
return
nullptr
;
double
min_msps
=
-
1
;
TensorInfo
*
best
=
nullptr
;
size_t
sz
=
1
;
if
(
enable_dtr_sqrt_sampling
)
{
while
(
sz
*
sz
<=
candidates
.
size
())
sz
++
;
sz
--
;
}
else
{
sz
=
candidates
.
size
();
}
for
(
auto
i
:
candidates
)
{
size_t
ti
=
rand
()
%
sz
;
for
(
size_t
vi
=
0
;
vi
<
sz
;
vi
++
)
{
if
(
!
enable_dtr_sqrt_sampling
)
{
ti
=
vi
;
}
auto
i
=
candidates
[
ti
];
if
(
i
->
producer
&&
i
->
ptr
&&
!
i
->
pinned
&&
i
->
evict_type
==
EvictType
::
NONE
)
{
double
neighbor_cost
=
estimate_neighbor_cost
(
i
);
size_t
begin_ptr
=
...
...
@@ -1562,8 +1573,11 @@ TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(
best
=
i
;
}
}
if
(
--
sz
==
0
)
break
;
if
(
enable_dtr_sqrt_sampling
)
{
ti
+=
rand
()
%
sz
;
if
(
ti
>
candidates
.
size
())
break
;
}
}
return
best
;
}
...
...
@@ -1590,14 +1604,25 @@ std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(
}
void
ChannelImpl
::
DynamicSublinear
::
insert_candidate
(
TensorInfo
*
ptr
)
{
candidates
.
insert
(
ptr
);
// tensor to be inserted must be brand new
mgb_assert
(
ptr
->
cand_index
==
UINT_MAX
,
"got wrong candidate index : %lu"
,
ptr
->
cand_index
);
ptr
->
cand_index
=
candidates
.
size
();
candidates
.
push_back
(
ptr
);
if
(
!
comp_node
.
valid
())
{
comp_node
=
ptr
->
ptr
->
comp_node
();
}
}
void
ChannelImpl
::
DynamicSublinear
::
erase_candidate
(
TensorInfo
*
ptr
)
{
candidates
.
erase
(
ptr
);
// some tensors may be erased already, so just skip them
if
(
ptr
->
cand_index
!=
UINT_MAX
)
{
std
::
swap
(
candidates
[
ptr
->
cand_index
],
candidates
.
back
());
candidates
[
ptr
->
cand_index
]
->
cand_index
=
ptr
->
cand_index
;
candidates
.
pop_back
();
ptr
->
cand_index
=
UINT_MAX
;
}
}
void
ChannelImpl
::
DynamicSublinear
::
update_used_time
(
TensorInfo
*
ptr
)
{
...
...
imperative/src/impl/interpreter/interpreter_impl.h
浏览文件 @
bcf69d8f
...
...
@@ -335,7 +335,7 @@ private:
CompNode
comp_node
;
//! store all tensors that may be evicted
std
::
unordered_set
<
TensorInfo
*>
candidates
;
SmallVector
<
TensorInfo
*>
candidates
;
bool
is_bad_op
(
std
::
string
op_name
)
{
return
std
::
find
(
op_blacklist
.
begin
(),
op_blacklist
.
end
(),
op_name
)
!=
...
...
imperative/src/impl/interpreter/tensor_info.h
浏览文件 @
bcf69d8f
...
...
@@ -170,6 +170,9 @@ struct TensorInfo {
bool
size_exceeds_thd
(
size_t
thd
)
{
return
memory
>
thd
;
}
SmallVector
<
ComputePath
*>
users
;
// UINT_MAX as a magic default value
size_t
cand_index
=
UINT_MAX
;
};
}
// namespace interpreter::intl
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录