Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7c09e41f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
7c09e41f
编写于
1月 29, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mgb): add circular dependency check
GitOrigin-RevId: 01fdb8684be2c594d9b8d9a57d28528cf5412dc6
上级
af42ce7e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
89 addition
and
31 deletion
+89
-31
src/opr/impl/search_policy/algo_chooser.cpp
src/opr/impl/search_policy/algo_chooser.cpp
+89
-16
src/opr/include/megbrain/opr/search_policy/algo_chooser.h
src/opr/include/megbrain/opr/search_policy/algo_chooser.h
+0
-15
未找到文件。
src/opr/impl/search_policy/algo_chooser.cpp
浏览文件 @
7c09e41f
...
...
@@ -12,6 +12,8 @@
#include "megbrain/opr/search_policy/algo_chooser.h"
#include <limits>
#include <unordered_set>
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/search_policy/algo_chooser_helper.h"
#include "megbrain/opr/search_policy/profiler.h"
...
...
@@ -22,6 +24,7 @@
//! TODO: here has to be know some megdnn::opr when there is produced midout.h
//! fix it if there is another graceful way.
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h"
#include "megdnn/oprs/base.h"
#include "midout.h"
...
...
@@ -78,6 +81,58 @@ std::string format_fixlayouts(
return
ret
;
}
/**
* \brief Check if the sub opr list has circular dependence.
*/
class
CircularDepsChecker
{
struct
SearchItemStorage
{
std
::
string
data_hold
;
size_t
hash
=
0
;
SearchItemStorage
(
const
Algorithm
::
SearchItem
&
item
)
{
Algorithm
::
serialize_write_pod
(
item
.
opr_type
,
data_hold
);
for
(
auto
&&
layout
:
item
.
layouts
)
{
data_hold
+=
layout
.
serialize
();
}
data_hold
+=
item
.
param
;
}
SearchItemStorage
&
init_hash
()
{
hash
=
XXHash64CT
::
hash
(
data_hold
.
data
(),
data_hold
.
size
(),
20201225
);
return
*
this
;
}
bool
operator
==
(
const
SearchItemStorage
&
rhs
)
const
{
return
data_hold
==
rhs
.
data_hold
;
}
struct
Hash
{
size_t
operator
()(
const
SearchItemStorage
&
s
)
const
{
return
s
.
hash
;
}
};
};
std
::
unordered_set
<
SearchItemStorage
,
SearchItemStorage
::
Hash
>
m_set
;
public:
void
put
(
const
megdnn
::
Algorithm
::
SearchItem
&
key
)
{
SearchItemStorage
key_storage
(
key
);
key_storage
.
init_hash
();
mgb_assert
(
m_set
.
find
(
key_storage
)
==
m_set
.
end
(),
"Circular dependency during flatten search space"
);
auto
ret
=
m_set
.
insert
(
std
::
move
(
key_storage
));
mgb_assert
(
ret
.
second
);
}
void
remove
(
const
megdnn
::
Algorithm
::
SearchItem
&
key
)
{
SearchItemStorage
key_storage
(
key
);
key_storage
.
init_hash
();
auto
&&
iter
=
m_set
.
find
(
key_storage
);
mgb_assert
(
iter
!=
m_set
.
end
());
m_set
.
erase
(
iter
);
}
};
///////////////// OprTypeTrait /////////////////////////////
template
<
megdnn
::
Algorithm
::
OprType
>
struct
OprFromOprTypeTrait
;
...
...
@@ -176,14 +231,26 @@ typename opr::AlgoChooser<Opr>::FixedTensorLayouts to_fixed_layouts(
return
ret
;
}
}
// namespace
namespace
mgb
{
namespace
opr
{
/**
* flatten search space in postorder traversal
* The subopr search construct a search tree
*
* A
* / \
* B1B2 C
* / \
* D1D2D3 E
* We use postorder traverse the search tree.
* D1 -> D2 -> D3 -> E -> B1 -> B2 -> C -> A
*/
template
<
typename
Opr
>
std
::
vector
<
megdnn
::
Algorithm
::
SearchItem
>
AlgoChooser
<
Opr
>::
flatten_search_space
(
const
ExeContext
&
ctx
)
{
std
::
vector
<
megdnn
::
Algorithm
::
SearchItem
>
flatten_search_space
(
const
typename
opr
::
AlgoChooser
<
Opr
>::
ExeContext
&
ctx
,
CircularDepsChecker
&
checker
)
{
auto
&&
search_item
=
megdnn
::
Algorithm
::
SearchItem
{
OprTypeFromOprTrait
<
Opr
>::
opr_type
,
ctx
.
param
(),
to_layout_array
<
Opr
>
(
ctx
.
layouts
())};
checker
.
put
(
search_item
);
std
::
vector
<
megdnn
::
Algorithm
::
SearchItem
>
ret
;
for
(
auto
algo_info
:
ctx
.
get_all_candidates
())
{
megdnn
::
Algorithm
*
algo
=
ctx
.
get_algorithm_from_desc
(
algo_info
.
desc
);
...
...
@@ -193,23 +260,29 @@ AlgoChooser<Opr>::flatten_search_space(const ExeContext& ctx) {
ctx
.
megdnn_opr
());
FOREACH_OPR_TYPE_DISPATCH
(
sub_items
,
{
auto
&&
megdnn_opr
=
intl
::
create_megdnn_opr
<
_Opr
>
(
ctx
.
comp_node
());
auto
&&
megdnn_opr
=
opr
::
intl
::
create_megdnn_opr
<
_Opr
>
(
ctx
.
comp_node
());
megdnn_opr
->
param
()
=
Algorithm
::
deserialize_read_pod
<
typename
_Opr
::
Param
>
(
_item
.
param
);
typename
AlgoChooser
<
_Opr
>::
ExeContext
sub_ctx
(
typename
opr
::
AlgoChooser
<
_Opr
>::
ExeContext
sub_ctx
(
to_fixed_layouts
<
_Opr
>
(
_item
.
layouts
),
megdnn_opr
.
get
(),
_item
.
param
,
ctx
.
mgb_opr
(),
ctx
.
comp_node
(),
ctx
.
execution_policy
(),
ctx
.
allow_weight_preprocess
());
auto
space
=
AlgoChooser
<
_Opr
>::
flatten_search_space
(
sub_ctx
);
auto
space
=
flatten_search_space
<
_Opr
>
(
sub_ctx
,
checker
);
ret
.
insert
(
ret
.
end
(),
space
.
begin
(),
space
.
end
());
});
}
ret
.
push_back
(
{
OprTypeFromOprTrait
<
Opr
>::
opr_type
,
ctx
.
param
(),
to_layout_array
<
Opr
>
(
ctx
.
layouts
())}
);
ret
.
push_back
(
search_item
);
checker
.
remove
(
search_item
);
return
ret
;
}
}
// namespace
namespace
mgb
{
namespace
opr
{
template
<
typename
Opr
>
void
AlgoChooser
<
Opr
>::
profile
(
ExeContext
&
ctx
,
bool
require_reproducible
)
{
if
(
ctx
.
get_profile_result_from_cache
(
require_reproducible
).
valid
())
...
...
@@ -289,7 +362,9 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, bool require_reproducible,
}
if
(
enable_update
)
{
auto
&&
search_items
=
flatten_search_space
(
ctx
);
CircularDepsChecker
circular_deps_checker
;
auto
&&
search_items
=
flatten_search_space
<
Opr
>
(
ctx
,
circular_deps_checker
);
FOREACH_OPR_TYPE_DISPATCH
(
search_items
,
{
auto
&&
megdnn_opr
=
intl
::
create_megdnn_opr
<
_Opr
>
(
ctx
.
comp_node
());
megdnn_opr
->
param
()
=
...
...
@@ -382,14 +457,12 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy(
AlgoChooser<megdnn::Opr>::get_policy(ExeContext& ctx); \
template void AlgoChooser<megdnn::Opr>::profile( \
ExeContext& ctx, bool require_reproducible); \
template std::vector<megdnn::Algorithm::SearchItem> \
AlgoChooser<megdnn::Opr>::flatten_search_space(const ExeContext& ctx); \
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
AlgoChooser<megdnn::Opr>::choose_by_profile( \
ExeContext& ctx, bool require_reproducible, bool enable_update); \
template size_t AlgoChooser<megdnn::Opr>::setup_algo( \
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \
const MGBOpr* mgb_opr, bool allow_weight_preprocess);
const MGBOpr* mgb_opr, bool allow_weight_preprocess);
\
MGB_FOREACH_FASTRUN_OPR
(
INST
)
...
...
src/opr/include/megbrain/opr/search_policy/algo_chooser.h
浏览文件 @
7c09e41f
...
...
@@ -159,21 +159,6 @@ private:
bool
require_reproducible
,
bool
enable_update
=
true
);
/**
* flatten search space in postorder traversal
* The subopr search construct a search tree
*
* A
* / \
* B1B2 C
* / \
* D1D2D3 E
* We use postorder traverse the search tree.
* D1 -> D2 -> D3 -> E -> B1 -> B2 -> C -> A
*/
static
std
::
vector
<
megdnn
::
Algorithm
::
SearchItem
>
flatten_search_space
(
const
ExeContext
&
ctx
);
public:
/*!
* \brief setup algorithm and return workspace size
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录