Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
46988d45
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4704
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
46988d45
编写于
4月 20, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mgb): code refactor of fast run
GitOrigin-RevId: 2c4b8e06bb3c4b4cb0228ee28988c2371455b1b0
上级
a1e38342
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
369 addition
and
364 deletion
+369
-364
src/opr/impl/search_policy/algo_chooser.cpp
src/opr/impl/search_policy/algo_chooser.cpp
+330
-328
src/opr/include/megbrain/opr/search_policy/algo_chooser.h
src/opr/include/megbrain/opr/search_policy/algo_chooser.h
+39
-36
未找到文件。
src/opr/impl/search_policy/algo_chooser.cpp
浏览文件 @
46988d45
...
...
@@ -243,31 +243,33 @@ typename opr::AlgoChooser<Opr>::FixedTensorLayouts to_fixed_layouts(
*/
template
<
typename
Opr
>
std
::
vector
<
megdnn
::
Algorithm
::
SearchItem
>
flatten_search_space
(
const
typename
opr
::
AlgoChooser
<
Opr
>::
ExeContext
&
ctx
,
const
typename
opr
::
AlgoChooser
<
Opr
>::
AlgoChooserHelper
&
helper
,
CircularDepsChecker
&
checker
)
{
auto
&&
search_item
=
megdnn
::
Algorithm
::
SearchItem
{
OprTypeFromOprTrait
<
Opr
>::
opr_type
,
ctx
.
param
(),
to_layout_array
<
Opr
>
(
ctx
.
layouts
())};
OprTypeFromOprTrait
<
Opr
>::
opr_type
,
helper
.
param
(),
to_layout_array
<
Opr
>
(
helper
.
layouts
())};
checker
.
put
(
search_item
);
std
::
vector
<
megdnn
::
Algorithm
::
SearchItem
>
ret
;
for
(
auto
algo_info
:
ctx
.
get_all_candidates
())
{
megdnn
::
Algorithm
*
algo
=
ctx
.
get_algorithm_from_desc
(
algo_info
.
desc
);
for
(
auto
algo_info
:
helper
.
get_all_candidates
())
{
megdnn
::
Algorithm
*
algo
=
helper
.
get_algorithm_from_desc
(
algo_info
.
desc
);
mgb_assert
(
algo
,
"Unknown algo description"
);
std
::
vector
<
megdnn
::
Algorithm
::
SearchItem
>&&
sub_items
=
algo
->
get_subopr_list
(
to_layout_array
<
Opr
>
(
ctx
.
layouts
()),
ctx
.
megdnn_opr
());
algo
->
get_subopr_list
(
to_layout_array
<
Opr
>
(
helper
.
layouts
()),
helper
.
megdnn_opr
());
FOREACH_OPR_TYPE_DISPATCH
(
sub_items
,
{
auto
&&
megdnn_opr
=
opr
::
intl
::
create_megdnn_opr
<
_Opr
>
(
ctx
.
comp_node
());
opr
::
intl
::
create_megdnn_opr
<
_Opr
>
(
helper
.
comp_node
());
megdnn_opr
->
param
()
=
Algorithm
::
deserialize_read_pod
<
typename
_Opr
::
Param
>
(
_item
.
param
);
typename
opr
::
AlgoChooser
<
_Opr
>::
ExeContext
sub_ctx
(
typename
opr
::
AlgoChooser
<
_Opr
>::
AlgoChooserHelper
sub_helper
(
to_fixed_layouts
<
_Opr
>
(
_item
.
layouts
),
megdnn_opr
.
get
(),
_item
.
param
,
ctx
.
mgb_opr
(),
ctx
.
comp_node
(),
ctx
.
execution_policy
(),
ctx
.
allow_weight_preprocess
());
auto
space
=
flatten_search_space
<
_Opr
>
(
sub_ctx
,
checker
);
_item
.
param
,
helper
.
mgb_opr
(),
helper
.
comp_node
(),
helper
.
execution_policy
(),
helper
.
allow_weight_preprocess
());
auto
space
=
flatten_search_space
<
_Opr
>
(
sub_helper
,
checker
);
ret
.
insert
(
ret
.
end
(),
space
.
begin
(),
space
.
end
());
});
}
...
...
@@ -280,255 +282,113 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space(
namespace
mgb
{
namespace
opr
{
///////////////////////////// AlgoChooserHelper //////////////////////////
template
<
typename
Opr
>
void
AlgoChooser
<
Opr
>::
profile
(
ExeContext
&
ctx
,
ExecutionStrategy
selected_strategy
)
{
if
(
ctx
.
get_profile_result_from_cache
(
selected_strategy
).
valid
())
return
;
AlgoChooserProfileCache
::
Result
prof_rst
;
auto
target_attr
=
ctx
.
extract_algo_attribute
(
selected_strategy
);
std
::
string
layouts_str
=
format_fixlayouts
<
Opr
>
(
ctx
.
layouts
(),
arity_in
,
arity_out
);
double
cur_timeout
=
0
;
AlgoChooser
<
Opr
>::
AlgoChooserHelper
::
AlgoChooserHelper
(
const
FixedTensorLayouts
&
layouts
,
Opr
*
megdnn_opr
,
const
std
::
string
&
param_str
,
const
cg
::
OperatorNodeBase
*
mgb_opr
,
const
CompNode
&
cn
,
const
megdnn
::
param
::
ExecutionPolicy
&
execution_policy
,
bool
allow_weight_preprocess
)
:
m_layouts
{
layouts
},
m_megdnn_opr
{
megdnn_opr
},
m_param
{
param_str
},
m_base_mgb_opr
{
mgb_opr
},
m_cn
{
cn
},
m_execution_policy
{
execution_policy
},
m_allow_weight_preprocess
{
allow_weight_preprocess
}
{
mgb_assert
(
m_layouts
.
size
()
==
layouts
.
size
());
static_assert
(
std
::
tuple_size
<
FixedTensorLayouts
>::
value
==
3
||
std
::
tuple_size
<
FixedTensorLayouts
>::
value
==
5
||
std
::
tuple_size
<
FixedTensorLayouts
>::
value
==
8
,
"Convolution AlgoChooser assumes arity = 3 , 5 or 8 (for "
"deformable conv)"
);
}
template
<
typename
Opr
>
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
AlgoChooser
<
Opr
>::
AlgoChooserHelper
::
choose_by_heuristic
(
const
ExecutionStrategy
&
selected_strategy
)
const
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"choose_by_heuristic"
)))
ImplExecutionPolicy
policy
;
auto
workspace_limit
=
WorkspaceLimitGetter
::
get_workspace_limit
(
ctx
.
owner_graph
(),
ctx
.
comp_node
(),
ctx
.
execution_policy
().
workspace_limit
);
RealTimer
timer
;
for
(
auto
algo
:
ctx
.
get_all_candidates
())
{
Maybe
<
AlgoChooserProfileCache
::
ResultEntry
>
cur_rst
;
ImplExecutionPolicy
policy
;
policy
.
algo
=
algo
.
desc
;
//! check negative attribute : skip negative attribute
auto
palgo
=
ctx
.
megdnn_opr
()
->
get_algorithm_from_desc
(
policy
.
algo
);
if
(
palgo
->
contain_attribute_any
(
target_attr
.
second
))
{
mgb_log_debug
(
"skip algo %s with attribute(%s), which is not match the "
"profile strategy required contain attribute(%s) and not "
"contain attribute(%s)."
,
algo
.
desc
.
name
.
c_str
(),
Algorithm
::
attribute_str
(
palgo
->
attribute
()).
c_str
(),
Algorithm
::
attribute_str
(
target_attr
.
first
).
c_str
(),
Algorithm
::
attribute_str
(
target_attr
.
second
).
c_str
());
continue
;
}
//! check workspace limit
ctx
.
construct_execution_policy
(
selected_strategy
,
policy
);
if
(
ctx
.
get_workspace_size_bytes
(
policy
)
>=
workspace_limit
)
{
continue
;
}
owner_graph
(),
m_cn
,
m_execution_policy
.
workspace_limit
);
auto
attr
=
extract_algo_attribute
(
selected_strategy
);
policy
.
algo
=
APPLY
(
m_megdnn_opr
->
get_algorithm_info_heuristic
(
args
...,
workspace_limit
,
attr
.
first
,
attr
.
second
),
m_layouts
)
.
desc
;
std
::
string
msg
=
ssprintf
(
"profiling %s algorithm %s %s"
,
ctx
.
mgb_opr
()
->
dyn_typeinfo
()
->
name
,
algo
.
desc
.
name
.
c_str
(),
layouts_str
.
c_str
());
timer
.
reset
();
MGB_TRY
{
cur_rst
=
ctx
.
profile_single_algo
(
policy
,
cur_timeout
);
}
MGB_CATCH
(
std
::
exception
&
exc
,
{
mgb_log_warn
(
"caught exception during %s: %s"
,
msg
.
c_str
(),
exc
.
what
());
continue
;
})
MGB_CATCH
(...,
{
mgb_log_warn
(
"caught exception during %s"
,
msg
.
c_str
());
continue
;
})
if
(
!
cur_rst
.
valid
())
{
mgb_log_warn
(
"timeout when %s; timeout setting: %.3fsec"
,
msg
.
c_str
(),
cur_timeout
);
continue
;
}
if
(
!
cur_timeout
)
{
cur_timeout
=
timer
.
get_secs
()
+
TIMEOUT_TOLERANCE
;
}
else
{
cur_timeout
=
std
::
min
(
cur_timeout
,
timer
.
get_secs
()
+
TIMEOUT_TOLERANCE
);
}
auto
&&
rst
=
cur_rst
.
val
();
mgb_log_debug
(
"%s: workspace: %zu; time: %.3gsec"
,
msg
.
c_str
(),
rst
.
workspace
,
rst
.
time
);
prof_rst
.
push_back
(
rst
);
}
std
::
string
msg
=
ssprintf
(
"no usable %s algorithm %s without attribute(%s) or could not meet "
"workspace limite requirement(%zu)"
,
ctx
.
mgb_opr
()
->
dyn_typeinfo
()
->
name
,
layouts_str
.
c_str
(),
Algorithm
::
attribute_str
(
target_attr
.
second
).
c_str
(),
workspace_limit
);
mgb_assert
(
!
prof_rst
.
empty
(),
"%s"
,
msg
.
c_str
());
Algorithm
*
algo
=
m_megdnn_opr
->
get_algorithm_from_desc
(
policy
.
algo
);
mgb_assert
(
algo
,
"Unknown algo description"
);
std
::
vector
<
Algorithm
::
SearchItem
>&&
sub_items
=
algo
->
get_subopr_list
(
to_layout_array
<
Opr
>
(
m_layouts
),
m_megdnn_opr
);
FixedTensorLayouts
origin_layouts
=
ctx
.
layouts
();
typename
Opr
::
Param
origin_param
=
ctx
.
megdnn_opr
()
->
param
();
AlgoChooserProfileCache
::
Key
cache_key
{
origin_layouts
.
data
(),
origin_layouts
.
size
(),
&
origin_param
,
sizeof
(
origin_param
)};
FOREACH_OPR_TYPE_DISPATCH
(
sub_items
,
{
auto
&&
megdnn_opr
=
intl
::
create_megdnn_opr
<
_Opr
>
(
m_cn
);
megdnn_opr
->
param
()
=
Algorithm
::
deserialize_read_pod
<
typename
_Opr
::
Param
>
(
_item
.
param
);
typename
AlgoChooser
<
_Opr
>::
AlgoChooserHelper
sub_helper
(
to_fixed_layouts
<
_Opr
>
(
_item
.
layouts
),
megdnn_opr
.
get
(),
_item
.
param
,
m_base_mgb_opr
,
m_cn
,
m_execution_policy
,
m_allow_weight_preprocess
);
policy
.
sub_policy
.
push_back
(
sub_helper
.
choose_by_heuristic
(
selected_strategy
));
});
AlgoChooserProfileCache
cache
(
ctx
.
comp_node
(),
profile_name
(
ctx
.
megdnn_opr
()).
c_str
());
cache
.
put
(
cache_key
,
prof_rst
);
return
policy
;
MIDOUT_E
}
template
<
typename
Opr
>
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
AlgoChooser
<
Opr
>::
choose_by_profile
(
ExeContext
&
ctx
,
ExecutionStrategy
selected_strategy
,
bool
enable_update
)
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"AlgoChooser::choose_by_profile"
)))
if
(
ctx
.
owner_graph
()
->
options
().
no_profiling_on_shape_change
)
{
auto
policy
=
ctx
.
megdnn_opr
()
->
execution_policy
();
if
(
policy
.
algo
.
valid
()){
AlgoChooser
<
Opr
>::
AlgoChooserHelper
::
choose_by_profile
(
const
ExecutionStrategy
&
selected_strategy
,
bool
enable_update
)
const
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"choose_by_profile"
)))
if
(
owner_graph
()
->
options
().
no_profiling_on_shape_change
)
{
auto
policy
=
m_megdnn_opr
->
execution_policy
();
if
(
policy
.
algo
.
valid
())
{
return
policy
;
}
if
(
!
algo_usable_on_shape_change
<
Opr
>
())
{
mgb_log_warn
(
"choose algo by heuristic, which may cause performance "
"regression."
);
return
c
tx
.
c
hoose_by_heuristic
(
selected_strategy
);
return
choose_by_heuristic
(
selected_strategy
);
}
}
if
(
enable_update
)
{
CircularDepsChecker
circular_deps_checker
;
auto
&&
search_items
=
flatten_search_space
<
Opr
>
(
ctx
,
circular_deps_checker
);
flatten_search_space
<
Opr
>
(
*
this
,
circular_deps_checker
);
FOREACH_OPR_TYPE_DISPATCH
(
search_items
,
{
auto
&&
megdnn_opr
=
intl
::
create_megdnn_opr
<
_Opr
>
(
ctx
.
comp_node
()
);
auto
&&
megdnn_opr
=
intl
::
create_megdnn_opr
<
_Opr
>
(
m_cn
);
megdnn_opr
->
param
()
=
Algorithm
::
deserialize_read_pod
<
typename
_Opr
::
Param
>
(
_item
.
param
);
typename
AlgoChooser
<
_Opr
>::
ExeContext
sub_ctx
(
typename
AlgoChooser
<
_Opr
>::
AlgoChooserHelper
sub_helper
(
to_fixed_layouts
<
_Opr
>
(
_item
.
layouts
),
megdnn_opr
.
get
(),
_item
.
param
,
ctx
.
mgb_opr
(),
ctx
.
comp_node
()
,
ctx
.
execution_policy
(),
ctx
.
allow_weight_preprocess
()
);
AlgoChooser
<
_Opr
>::
profile
(
sub_ctx
,
selected_strategy
);
_item
.
param
,
m_base_mgb_opr
,
m_cn
,
m_execution_policy
,
m_allow_weight_preprocess
);
sub_helper
.
profile
(
selected_strategy
);
});
}
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
policy
;
c
tx
.
construct_execution_policy
(
selected_strategy
,
policy
);
c
onstruct_execution_policy
(
selected_strategy
,
true
,
policy
);
return
policy
;
MIDOUT_E
}
template
<
typename
Opr
>
size_t
AlgoChooser
<
Opr
>::
setup_algo
(
const
FixedTensorLayouts
&
layouts
,
Opr
*
megdnn_opr
,
const
MGBOpr
*
mgb_opr
,
bool
allow_weight_preprocess
)
{
if
(
WorkspaceLimitGetter
::
is_prealloc_run
(
mgb_opr
->
owner_graph
()))
{
return
0
;
}
std
::
string
param_str
;
Algorithm
::
serialize_write_pod
(
megdnn_opr
->
param
(),
param_str
);
ExeContext
ctx
(
layouts
,
megdnn_opr
,
param_str
,
mgb_opr
,
mgb_opr
->
comp_node
(),
mgb_opr
->
execution_policy
(),
allow_weight_preprocess
);
ImplExecutionPolicy
policy
;
if
(
auto
algo_choose_hook
=
mgb_opr
->
algo_chooser
())
{
policy
=
algo_choose_hook
(
mgb_opr
);
ctx
.
construct_execution_policy
((
ExecutionStrategy
::
HEURISTIC
|
ExecutionStrategy
::
REPRODUCIBLE
),
policy
,
false
);
}
if
(
!
policy
.
algo
.
valid
())
{
policy
=
get_policy
(
ctx
);
}
size_t
workspace
=
ctx
.
get_workspace_size_bytes
(
policy
);
std
::
string
ret
;
ret
.
append
(
mgb_opr
->
dyn_typeinfo
()
->
name
);
ret
+=
format_fixlayouts
<
Opr
>
(
layouts
,
arity_in
,
arity_out
);
Algorithm
*
palgo
=
megdnn_opr
->
get_algorithm_from_desc
(
policy
.
algo
);
mgb_assert
(
palgo
,
"Unknown algo description"
);
ret
.
append
(
"): algo="
+
std
::
string
(
palgo
->
name
()));
ret
.
append
(
ssprintf
(
" workspace=%.2fMiB attirbute(%s)"
,
workspace
/
(
1024
*
1024.0
),
Algorithm
::
attribute_str
(
palgo
->
attribute
()).
c_str
()));
mgb_log_debug
(
"%s"
,
ret
.
c_str
());
megdnn_opr
->
execution_policy
()
=
policy
;
return
workspace
;
}
template
<
typename
Opr
>
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
AlgoChooser
<
Opr
>::
get_policy
(
ExeContext
&
ctx
)
{
MGB_MARK_USED_VAR
(
TIMEOUT_TOLERANCE
);
auto
opr_strategy
=
ctx
.
execution_policy
().
strategy
;
if
((
opr_strategy
&
ExecutionStrategy
::
HEURISTIC
)
&&
(
opr_strategy
&
ExecutionStrategy
::
PROFILE
))
{
ImplExecutionPolicy
policy
=
choose_by_profile
(
ctx
,
opr_strategy
,
false
);
if
(
!
policy
.
algo
.
valid
())
policy
=
ctx
.
choose_by_heuristic
(
opr_strategy
);
return
policy
;
}
else
if
(
!
static_cast
<
int
>
(
opr_strategy
)
||
(
opr_strategy
&
ExecutionStrategy
::
HEURISTIC
))
{
return
ctx
.
choose_by_heuristic
(
opr_strategy
);
}
#if MGB_ENABLE_FASTRUN
else
if
(
opr_strategy
&
ExecutionStrategy
::
PROFILE
)
{
return
choose_by_profile
(
ctx
,
opr_strategy
);
}
#endif
else
{
mgb_throw
(
GraphError
,
"bad ExecutionPolicy strategy"
);
}
}
#define INST(Opr) \
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
AlgoChooser<megdnn::Opr>::get_policy(ExeContext& ctx); \
template void AlgoChooser<megdnn::Opr>::profile(ExeContext& ctx, \
ExecutionStrategy); \
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
AlgoChooser<megdnn::Opr>::choose_by_profile( \
ExeContext& ctx, ExecutionStrategy, bool enable_update); \
template size_t AlgoChooser<megdnn::Opr>::setup_algo( \
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \
const MGBOpr* mgb_opr, bool allow_weight_preprocess);
MGB_FOREACH_FASTRUN_OPR
(
INST
)
#undef INST
//////////////////////////////// ExeContext /////////////////////////////
template
<
typename
Opr
>
AlgoChooser
<
Opr
>::
ExeContext
::
ExeContext
(
const
FixedTensorLayouts
&
layouts
,
Opr
*
megdnn_opr
,
const
std
::
string
&
param_str
,
const
cg
::
OperatorNodeBase
*
mgb_opr
,
const
CompNode
&
cn
,
const
megdnn
::
param
::
ExecutionPolicy
&
execution_policy
,
bool
allow_weight_preprocess
)
:
m_layouts
{
layouts
},
m_megdnn_opr
{
megdnn_opr
},
m_param
{
param_str
},
m_base_mgb_opr
{
mgb_opr
},
m_cn
{
cn
},
m_execution_policy
{
execution_policy
},
m_allow_weight_preprocess
{
allow_weight_preprocess
}
{
mgb_assert
(
m_layouts
.
size
()
==
layouts
.
size
());
static_assert
(
std
::
tuple_size
<
FixedTensorLayouts
>::
value
==
3
||
std
::
tuple_size
<
FixedTensorLayouts
>::
value
==
5
||
std
::
tuple_size
<
FixedTensorLayouts
>::
value
==
8
,
"Convolution AlgoChooser assumes arity = 3 , 5 or 8 (for "
"deformable conv)"
);
}
template
<
typename
Opr
>
typename
AlgoChooser
<
Opr
>::
ImplAlgo
AlgoChooser
<
Opr
>::
ExeContext
::
get_profile_result_from_cache
(
ExecutionStrategy
selected_strategy
)
const
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"AlgoChooser::ExeContext::get_profile_result_from_cache"
)))
AlgoChooserProfileCache
cache
(
m_cn
,
profile_name
(
m_megdnn_opr
).
c_str
());
AlgoChooser
<
Opr
>::
AlgoChooserHelper
::
get_profile_result_from_cache
(
const
ExecutionStrategy
&
selected_strategy
)
const
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"get_profile_result_from_cache"
)))
AlgoChooserProfileCache
cache
(
m_cn
,
profile_name
(
m_megdnn_opr
).
c_str
());
typename
Opr
::
Param
origin_param
=
m_megdnn_opr
->
param
();
AlgoChooserProfileCache
::
Key
cache_key
{
m_layouts
.
data
(),
m_layouts
.
size
(),
...
...
@@ -538,23 +398,22 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
return
{};
auto
&&
prof
=
rst
.
val
();
if
(
prof
.
empty
())
return
{};
std
::
unordered_map
<
std
::
string
,
ImplAlgo
>
algo_map
;
for
(
auto
i
:
get_all_candidates
())
{
auto
ins
=
algo_map
.
emplace
(
i
.
desc
.
name
.
c_str
(),
i
);
mgb_assert
(
ins
.
second
,
"duplicated algo name: %s"
,
i
.
desc
.
name
.
c_str
());
}
if
(
prof
.
empty
())
return
{};
auto
target_attr
=
extract_algo_attribute
(
selected_strategy
);
bool
skip_by_negative
=
false
;
for
(
auto
&&
i
:
prof
)
{
auto
attr_of_algo
=
static_cast
<
megdnn
::
Algorithm
::
Attribute
>
(
i
.
attribute
);
bool
contain_attr_all_positive
=
(
target_attr
.
first
==
(
attr_of_algo
&
target_attr
.
first
));
(
target_attr
.
first
==
(
attr_of_algo
&
target_attr
.
first
));
bool
contain_attr_any_negative
=
static_cast
<
bool
>
(
attr_of_algo
&
target_attr
.
second
);
if
(
contain_attr_all_positive
)
{
...
...
@@ -578,13 +437,14 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
if
(
skip_by_negative
)
{
mgb_log_error
(
"No usable algo. Only navie algos are available, but negative "
"stategy is %s."
,
"No usable algo. There are available algos match positive "
"strategy(%s), but filtered by negative stategy(%s)."
,
Algorithm
::
attribute_str
(
target_attr
.
first
).
c_str
(),
Algorithm
::
attribute_str
(
target_attr
.
second
).
c_str
());
}
else
{
mgb_log_error
(
"No usable algo. algos read from cache could not satisfy "
"
attribute with %s
"
,
"
positive strategy(%s)
"
,
Algorithm
::
attribute_str
(
target_attr
.
first
).
c_str
());
}
...
...
@@ -593,75 +453,10 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
}
template
<
typename
Opr
>
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
AlgoChooser
<
Opr
>::
ExeContext
::
choose_by_heuristic
(
ExecutionStrategy
selected_strategy
)
const
{
if
(
m_execution_policy
.
workspace_limit
!=
std
::
numeric_limits
<
decltype
(
m_execution_policy
.
workspace_limit
)
>::
max
())
{
mgb_log_warn
(
"workspace_limit should not be setted if choose algo by "
"heuristic"
);
}
auto
workspace_limit
=
WorkspaceLimitGetter
::
get_workspace_limit
(
owner_graph
(),
m_cn
,
m_execution_policy
.
workspace_limit
);
auto
attr
=
extract_algo_attribute
(
selected_strategy
);
ImplExecutionPolicy
policy
;
policy
.
algo
=
APPLY
(
m_megdnn_opr
->
get_algorithm_info_heuristic
(
args
...,
workspace_limit
,
attr
.
first
,
attr
.
second
),
m_layouts
)
.
desc
;
Algorithm
*
algo
=
m_megdnn_opr
->
get_algorithm_from_desc
(
policy
.
algo
);
mgb_assert
(
algo
,
"Unknown algo description"
);
std
::
vector
<
Algorithm
::
SearchItem
>&&
sub_items
=
algo
->
get_subopr_list
(
to_layout_array
<
Opr
>
(
m_layouts
),
m_megdnn_opr
);
FOREACH_OPR_TYPE_DISPATCH
(
sub_items
,
{
auto
&&
megdnn_opr
=
intl
::
create_megdnn_opr
<
_Opr
>
(
m_cn
);
megdnn_opr
->
param
()
=
Algorithm
::
deserialize_read_pod
<
typename
_Opr
::
Param
>
(
_item
.
param
);
typename
AlgoChooser
<
_Opr
>::
ExeContext
sub_ctx
(
to_fixed_layouts
<
_Opr
>
(
_item
.
layouts
),
megdnn_opr
.
get
(),
_item
.
param
,
m_base_mgb_opr
,
m_cn
,
m_execution_policy
,
m_allow_weight_preprocess
);
policy
.
sub_policy
.
push_back
(
sub_ctx
.
choose_by_heuristic
(
selected_strategy
));
});
return
policy
;
}
template
<
typename
Opr
>
std
::
vector
<
typename
AlgoChooser
<
Opr
>::
ImplAlgo
>
AlgoChooser
<
Opr
>::
ExeContext
::
get_all_candidates
()
const
{
auto
heu
=
choose_by_heuristic
(
ExecutionStrategy
::
HEURISTIC
);
auto
&&
ret
=
APPLY
(
m_megdnn_opr
->
get_all_algorithms_info
(
args
...),
m_layouts
);
bool
found
=
false
;
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
if
(
ret
[
i
].
desc
==
heu
.
algo
)
{
found
=
true
;
std
::
swap
(
ret
[
i
],
ret
[
0
]);
break
;
}
}
Algorithm
*
palgo
=
m_megdnn_opr
->
get_algorithm_from_desc
(
heu
.
algo
);
mgb_assert
(
palgo
,
"Unknown algo description"
);
mgb_assert
(
found
,
"algo %s got by heuristic not found in "
"candidate list"
,
palgo
->
name
());
return
std
::
move
(
ret
);
}
template
<
typename
Opr
>
void
AlgoChooser
<
Opr
>::
ExeContext
::
construct_execution_policy
(
ExecutionStrategy
selected_strategy
,
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
&
policy
,
bool
retrive_from_cache
)
const
{
void
AlgoChooser
<
Opr
>::
AlgoChooserHelper
::
construct_execution_policy
(
const
ExecutionStrategy
&
selected_strategy
,
bool
retrive_from_cache
,
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
&
policy
)
const
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"construct_execution_policy"
)))
if
(
!
policy
.
algo
.
valid
())
{
if
(
retrive_from_cache
)
{
policy
.
algo
=
get_profile_result_from_cache
(
selected_strategy
).
desc
;
...
...
@@ -712,26 +507,28 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
megdnn_opr
->
param
()
=
Algorithm
::
deserialize_read_pod
<
typename
_Opr
::
Param
>
(
_item
.
param
);
typename
AlgoChooser
<
_Opr
>::
ExeContext
sub_ctx
(
typename
AlgoChooser
<
_Opr
>::
AlgoChooserHelper
sub_helper
(
to_fixed_layouts
<
_Opr
>
(
_item
.
layouts
),
megdnn_opr
.
get
(),
_item
.
param
,
m_base_mgb_opr
,
m_cn
,
m_execution_policy
,
m_allow_weight_preprocess
);
policy
.
sub_policy
.
push_back
({});
sub_
ctx
.
construct_execution_policy
(
selected_strategy
,
policy
.
sub_policy
.
back
()
,
retrive_from_cache
);
sub_
helper
.
construct_execution_policy
(
selected_strategy
,
retrive_from_cache
,
policy
.
sub_policy
.
back
()
);
if
(
!
policy
.
sub_policy
.
back
().
algo
.
valid
())
{
// means sub_
ctx
.construct_execution_policy fails. clean up
// means sub_
helper
.construct_execution_policy fails. clean up
// policy.algo and return
policy
=
{};
return
;
}
});
MIDOUT_E
}
template
<
typename
Opr
>
size_t
AlgoChooser
<
Opr
>::
ExeContext
::
get_workspace_size_bytes
(
size_t
AlgoChooser
<
Opr
>::
AlgoChooserHelper
::
get_workspace_size_bytes
(
const
ImplExecutionPolicy
&
policy
)
const
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"get_workspace_size_bytes"
)))
m_megdnn_opr
->
execution_policy
()
=
policy
;
size_t
result
;
if_constexpr
<
opr_supports_preprocess
<
Opr
>
()
>
(
...
...
@@ -752,12 +549,40 @@ size_t AlgoChooser<Opr>::ExeContext::get_workspace_size_bytes(
m_layouts
);
});
return
result
;
MIDOUT_E
}
template
<
typename
Opr
>
std
::
vector
<
typename
AlgoChooser
<
Opr
>::
ImplAlgo
>
AlgoChooser
<
Opr
>::
AlgoChooserHelper
::
get_all_candidates
()
const
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"get_all_candidates"
)))
auto
heu
=
choose_by_heuristic
(
m_execution_policy
.
strategy
);
auto
&&
ret
=
APPLY
(
m_megdnn_opr
->
get_all_algorithms_info
(
args
...),
m_layouts
);
bool
found
=
false
;
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
if
(
ret
[
i
].
desc
==
heu
.
algo
)
{
found
=
true
;
std
::
swap
(
ret
[
i
],
ret
[
0
]);
break
;
}
}
Algorithm
*
palgo
=
m_megdnn_opr
->
get_algorithm_from_desc
(
heu
.
algo
);
mgb_assert
(
palgo
,
"Unknown algo description"
);
mgb_assert
(
found
,
"algo %s got by heuristic not found in "
"candidate list"
,
palgo
->
name
());
return
std
::
move
(
ret
);
MIDOUT_E
}
template
<
typename
Opr
>
Maybe
<
AlgoChooserProfileCache
::
ResultEntry
>
AlgoChooser
<
Opr
>::
ExeContext
::
profile_single_algo
(
AlgoChooser
<
Opr
>::
AlgoChooserHelper
::
profile_single_algo
(
const
ImplExecutionPolicy
&
policy
,
double
&
timeout
)
const
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"profile_single_algo"
)))
typename
TimedProfiler
<
Opr
>::
Param
param
;
// force check copy size <= dest len-1 from gcc8 for safe
param
.
execution_policy
=
...
...
@@ -791,14 +616,103 @@ AlgoChooser<Opr>::ExeContext::profile_single_algo(
if
(
!
rst
.
valid
())
return
None
;
return
AlgoChooserProfileCache
::
ResultEntry
{
palgo
->
name
(),
static_cast
<
uint32_t
>
(
palgo
->
attribute
()),
palgo
->
name
(),
static_cast
<
uint32_t
>
(
palgo
->
attribute
()),
rst
.
val
().
time
,
param
.
workspace
};
MIDOUT_E
}
template
<
typename
Opr
>
void
AlgoChooser
<
Opr
>::
AlgoChooserHelper
::
profile
(
const
ExecutionStrategy
&
selected_strategy
)
const
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"profile"
)))
if
(
get_profile_result_from_cache
(
selected_strategy
).
valid
())
return
;
AlgoChooserProfileCache
::
Result
prof_rst
;
auto
target_attr
=
extract_algo_attribute
(
selected_strategy
);
std
::
string
layouts_str
=
format_fixlayouts
<
Opr
>
(
m_layouts
,
arity_in
,
arity_out
);
double
cur_timeout
=
0
;
auto
workspace_limit
=
WorkspaceLimitGetter
::
get_workspace_limit
(
owner_graph
(),
m_cn
,
m_execution_policy
.
workspace_limit
);
RealTimer
timer
;
for
(
auto
algo
:
get_all_candidates
())
{
Maybe
<
AlgoChooserProfileCache
::
ResultEntry
>
cur_rst
;
ImplExecutionPolicy
policy
;
policy
.
algo
=
algo
.
desc
;
//! check negative attribute : skip negative attribute
auto
palgo
=
m_megdnn_opr
->
get_algorithm_from_desc
(
policy
.
algo
);
if
(
palgo
->
contain_attribute_any
(
target_attr
.
second
))
{
mgb_log_debug
(
"skip algo %s, which matches the profile strategy required "
"'not contain attribute(%s).'"
,
algo
.
desc
.
name
.
c_str
(),
Algorithm
::
attribute_str
(
target_attr
.
second
).
c_str
());
continue
;
}
//! check workspace limit
construct_execution_policy
(
selected_strategy
,
true
,
policy
);
if
(
get_workspace_size_bytes
(
policy
)
>=
workspace_limit
)
{
continue
;
}
std
::
string
msg
=
ssprintf
(
"profiling %s algorithm %s %s"
,
m_base_mgb_opr
->
dyn_typeinfo
()
->
name
,
algo
.
desc
.
name
.
c_str
(),
layouts_str
.
c_str
());
timer
.
reset
();
MGB_TRY
{
cur_rst
=
profile_single_algo
(
policy
,
cur_timeout
);
}
MGB_CATCH
(
std
::
exception
&
exc
,
{
mgb_log_warn
(
"caught exception during %s: %s"
,
msg
.
c_str
(),
exc
.
what
());
continue
;
})
MGB_CATCH
(...,
{
mgb_log_warn
(
"caught exception during %s"
,
msg
.
c_str
());
continue
;
})
if
(
!
cur_rst
.
valid
())
{
mgb_log_warn
(
"timeout when %s; timeout setting: %.3fsec"
,
msg
.
c_str
(),
cur_timeout
);
continue
;
}
if
(
!
cur_timeout
)
{
cur_timeout
=
timer
.
get_secs
()
+
TIMEOUT_TOLERANCE
;
}
else
{
cur_timeout
=
std
::
min
(
cur_timeout
,
timer
.
get_secs
()
+
TIMEOUT_TOLERANCE
);
}
auto
&&
rst
=
cur_rst
.
val
();
mgb_log_debug
(
"%s: workspace: %zu; time: %.3gsec"
,
msg
.
c_str
(),
rst
.
workspace
,
rst
.
time
);
prof_rst
.
push_back
(
rst
);
}
std
::
string
msg
=
ssprintf
(
"no usable %s algorithm %s without attribute(%s) or could not meet "
"workspace limite requirement(%zu)"
,
m_base_mgb_opr
->
dyn_typeinfo
()
->
name
,
layouts_str
.
c_str
(),
Algorithm
::
attribute_str
(
target_attr
.
second
).
c_str
(),
workspace_limit
);
mgb_assert
(
!
prof_rst
.
empty
(),
"%s"
,
msg
.
c_str
());
FixedTensorLayouts
origin_layouts
=
m_layouts
;
typename
Opr
::
Param
origin_param
=
m_megdnn_opr
->
param
();
AlgoChooserProfileCache
::
Key
cache_key
{
origin_layouts
.
data
(),
origin_layouts
.
size
(),
&
origin_param
,
sizeof
(
origin_param
)};
AlgoChooserProfileCache
cache
(
m_cn
,
profile_name
(
m_megdnn_opr
).
c_str
());
cache
.
put
(
cache_key
,
prof_rst
);
MIDOUT_E
}
template
<
typename
Opr
>
Maybe
<
PreprocessFilter
<
Opr
>>
AlgoChooser
<
Opr
>::
ExeContext
::
construct_fake_preprocess_filter
()
const
{
AlgoChooser
<
Opr
>::
AlgoChooserHelper
::
construct_fake_preprocess_filter
()
const
{
MIDOUT_B
(
Opr
,
midout_iv
(
MGB_HASH_STR
(
"construct_fake_preprocess_filter"
)))
Maybe
<
PreprocessFilter
<
Opr
>>
result
=
None
;
if_constexpr
<
opr_supports_preprocess
<
Opr
>
()
>
([
&
](
auto
_
)
{
if
(
!
m_allow_weight_preprocess
)
...
...
@@ -830,11 +744,12 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const {
}
});
return
result
;
MIDOUT_E
}
template
<
typename
Opr
>
std
::
pair
<
AlgoAttribute
,
AlgoAttribute
>
AlgoChooser
<
Opr
>::
ExeContext
::
extract_algo_attribute
(
AlgoChooser
<
Opr
>::
AlgoChooserHelper
::
extract_algo_attribute
(
const
ExecutionStrategy
&
strategy
)
const
{
std
::
pair
<
AlgoAttribute
,
AlgoAttribute
>
ret
=
std
::
make_pair
(
AlgoAttribute
::
DEFAULT
,
AlgoAttribute
::
DEFAULT
);
...
...
@@ -851,41 +766,128 @@ AlgoChooser<Opr>::ExeContext::extract_algo_attribute(
}
#define INST(Opr) \
template AlgoChooser<megdnn::Opr>::
ExeContext::ExeContext(
\
template AlgoChooser<megdnn::Opr>::
AlgoChooserHelper::AlgoChooserHelper(
\
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \
const std::string& param_str, const cg::OperatorNodeBase* mgb_opr, \
const CompNode& cn, \
const megdnn::param::ExecutionPolicy& execution_policy, \
bool allow_weight_preprocess); \
template typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
AlgoChooser<megdnn::Opr>::ExeContext::choose_by_heuristic( \
ExecutionStrategy select_strategy) const; \
AlgoChooser<megdnn::Opr>::AlgoChooserHelper::choose_by_heuristic( \
const ExecutionStrategy& select_strategy) const; \
template typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
AlgoChooser<megdnn::Opr>::AlgoChooserHelper::choose_by_profile( \
const ExecutionStrategy& select_strategy, bool enable_update) \
const; \
template typename AlgoChooser<megdnn::Opr>::ImplAlgo \
AlgoChooser<megdnn::Opr>::ExeContext::get_profile_result_from_cache( \
ExecutionStrategy select_strategy) const; \
template std::vector<typename AlgoChooser<megdnn::Opr>::ImplAlgo> \
AlgoChooser<megdnn::Opr>::ExeContext::get_all_candidates() const; \
AlgoChooser<megdnn::Opr>::AlgoChooserHelper:: \
get_profile_result_from_cache( \
const ExecutionStrategy& select_strategy) const; \
template void \
AlgoChooser<megdnn::Opr>::AlgoChooserHelper::construct_execution_policy( \
const ExecutionStrategy& select_strategy, bool retrive_from_cache, \
typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& policy) \
const; \
template size_t \
AlgoChooser<megdnn::Opr>::
ExeContext::get_workspace_size_bytes(
\
AlgoChooser<megdnn::Opr>::
AlgoChooserHelper::get_workspace_size_bytes(
\
const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& \
policy) const; \
template void \
AlgoChooser<megdnn::Opr>::ExeContext::construct_execution_policy( \
ExecutionStrategy select_strategy, \
typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& policy, \
bool retrive_from_cache) const; \
template std::vector<typename AlgoChooser<megdnn::Opr>::ImplAlgo> \
AlgoChooser<megdnn::Opr>::AlgoChooserHelper::get_all_candidates() const; \
template Maybe<AlgoChooserProfileCache::ResultEntry> \
AlgoChooser<megdnn::Opr>::
ExeContext::profile_single_algo(
\
AlgoChooser<megdnn::Opr>::
AlgoChooserHelper::profile_single_algo(
\
const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& \
policy, \
double& timeout) const; \
template std::pair<AlgoAttribute, AlgoAttribute> \
AlgoChooser<megdnn::Opr>::ExeContext::extract_algo_attribute( \
const ExecutionStrategy& strategy) const;
AlgoChooser<megdnn::Opr>::AlgoChooserHelper::extract_algo_attribute( \
const ExecutionStrategy& strategy) const; \
template void AlgoChooser<megdnn::Opr>::AlgoChooserHelper::profile( \
const ExecutionStrategy& selected_strategy) const;
MGB_FOREACH_FASTRUN_OPR
(
INST
)
#undef INST
//////////////////////////////// AlgoChoose /////////////////////////////
template
<
typename
Opr
>
typename
AlgoChooser
<
Opr
>::
ImplExecutionPolicy
AlgoChooser
<
Opr
>::
get_policy
(
const
AlgoChooserHelper
&
helper
)
{
auto
opr_strategy
=
helper
.
execution_policy
().
strategy
;
if
(
opr_strategy
&
ExecutionStrategy
::
HEURISTIC
)
{
if
(
opr_strategy
&
ExecutionStrategy
::
PROFILE
)
{
//! this strategy will choose from cache first, then choost by
//! heuristic if fail.
ImplExecutionPolicy
policy
=
helper
.
choose_by_profile
(
opr_strategy
,
false
);
if
(
!
policy
.
algo
.
valid
())
{
policy
=
helper
.
choose_by_heuristic
(
opr_strategy
);
}
return
policy
;
}
else
{
return
helper
.
choose_by_heuristic
(
opr_strategy
);
}
}
#if MGB_ENABLE_FASTRUN
else
if
(
opr_strategy
&
ExecutionStrategy
::
PROFILE
)
{
return
helper
.
choose_by_profile
(
opr_strategy
,
true
);
}
#endif
else
{
mgb_throw
(
GraphError
,
"bad ExecutionPolicy strategy"
);
}
}
template
<
typename
Opr
>
size_t
AlgoChooser
<
Opr
>::
setup_algo
(
const
FixedTensorLayouts
&
layouts
,
Opr
*
megdnn_opr
,
const
MGBOpr
*
mgb_opr
,
bool
allow_weight_preprocess
)
{
if
(
WorkspaceLimitGetter
::
is_prealloc_run
(
mgb_opr
->
owner_graph
()))
{
return
0
;
}
std
::
string
param_str
;
Algorithm
::
serialize_write_pod
(
megdnn_opr
->
param
(),
param_str
);
AlgoChooserHelper
helper
(
layouts
,
megdnn_opr
,
param_str
,
mgb_opr
,
mgb_opr
->
comp_node
(),
mgb_opr
->
execution_policy
(),
allow_weight_preprocess
);
ImplExecutionPolicy
policy
;
if
(
auto
algo_choose_hook
=
mgb_opr
->
algo_chooser
())
{
policy
=
algo_choose_hook
(
mgb_opr
);
auto
strategy
=
ExecutionStrategy
::
HEURISTIC
|
ExecutionStrategy
::
REPRODUCIBLE
;
helper
.
construct_execution_policy
(
strategy
,
false
,
policy
);
}
if
(
!
policy
.
algo
.
valid
())
{
policy
=
get_policy
(
helper
);
}
size_t
workspace
=
helper
.
get_workspace_size_bytes
(
policy
);
std
::
string
ret
;
ret
.
append
(
mgb_opr
->
dyn_typeinfo
()
->
name
);
ret
+=
format_fixlayouts
<
Opr
>
(
layouts
,
arity_in
,
arity_out
);
Algorithm
*
palgo
=
megdnn_opr
->
get_algorithm_from_desc
(
policy
.
algo
);
mgb_assert
(
palgo
,
"Unknown algo description"
);
ret
.
append
(
"): algo="
+
std
::
string
(
palgo
->
name
()));
ret
.
append
(
ssprintf
(
" workspace=%.2fMiB attirbute=%d"
,
workspace
/
(
1024
*
1024.0
),
static_cast
<
uint32_t
>
(
palgo
->
attribute
())));
mgb_log_debug
(
"%s"
,
ret
.
c_str
());
megdnn_opr
->
execution_policy
()
=
policy
;
return
workspace
;
}
#define INST(Opr) \
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
AlgoChooser<megdnn::Opr>::get_policy(const AlgoChooserHelper& proxy); \
template size_t AlgoChooser<megdnn::Opr>::setup_algo( \
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \
const MGBOpr* mgb_opr, bool allow_weight_preprocess);
MGB_FOREACH_FASTRUN_OPR
(
INST
)
#undef INST
}
// namespace opr
}
// namespace mgb
...
...
src/opr/include/megbrain/opr/search_policy/algo_chooser.h
浏览文件 @
46988d45
...
...
@@ -66,7 +66,7 @@ class AlgoChooser {
public:
using
FixedTensorLayouts
=
std
::
array
<
TensorLayout
,
arity
>
;
class
ExeContext
{
class
AlgoChooserHelper
{
FixedTensorLayouts
m_layouts
;
Opr
*
m_megdnn_opr
;
std
::
string
m_param
;
...
...
@@ -76,22 +76,23 @@ public:
bool
m_allow_weight_preprocess
;
public:
ExeContext
(
const
FixedTensorLayouts
&
layouts
,
Opr
*
megdnn_opr
,
const
std
::
string
&
param_str
,
const
cg
::
OperatorNodeBase
*
mgb_opr
,
const
CompNode
&
cn
,
const
megdnn
::
param
::
ExecutionPolicy
&
execution_policy
,
bool
allow_weight_preprocess
);
AlgoChooserHelper
(
const
FixedTensorLayouts
&
layouts
,
Opr
*
megdnn_opr
,
const
std
::
string
&
param_str
,
const
cg
::
OperatorNodeBase
*
mgb_opr
,
const
CompNode
&
cn
,
const
megdnn
::
param
::
ExecutionPolicy
&
execution_policy
,
bool
allow_weight_preprocess
);
Opr
*
megdnn_opr
()
const
{
return
m_megdnn_opr
;
}
const
cg
::
OperatorNodeBase
*
mgb_opr
()
const
{
return
m_base_mgb_opr
;
}
const
TensorLayout
&
inp_layout
(
size_t
idx
)
const
{
return
m_layouts
[
idx
];
}
cg
::
ComputingGraph
*
owner_graph
()
const
{
return
m_base_mgb_opr
->
owner_graph
();
}
const
cg
::
OperatorNodeBase
*
mgb_opr
()
const
{
return
m_base_mgb_opr
;
}
const
megdnn
::
param
::
ExecutionPolicy
&
execution_policy
()
const
{
return
m_execution_policy
;
}
...
...
@@ -109,17 +110,40 @@ public:
const
FixedTensorLayouts
&
layouts
()
const
{
return
m_layouts
;
}
//! construct algo chain by heuristic
ImplExecutionPolicy
choose_by_heuristic
(
ExecutionStrategy
selected_strategy
)
const
;
const
ExecutionStrategy
&
selected_strategy
)
const
;
//! get all candidate algos, and the one choose_by_heuristic() is
//! put first
std
::
vector
<
ImplAlgo
>
get_all_candidates
()
const
;
//! construct algo chain by profiling
ImplExecutionPolicy
choose_by_profile
(
const
ExecutionStrategy
&
selected_strategy
,
bool
enable_update
)
const
;
//! get all profile algorithm from cache, return invalid if not exists
ImplAlgo
get_profile_result_from_cache
(
const
ExecutionStrategy
&
selected_strategy
)
const
;
/**
* \brief construct execution policy from cache or heuristic.
*
* \param selected_strategy select algo which matched this strategy
* \param[in,out] policy execution policy
* \param retrive_from_cache retrive algo from cache if set True, get
* from heuristic otherwise.
* \return true if contruct success and false when fail
*/
void
construct_execution_policy
(
const
ExecutionStrategy
&
selected_strategy
,
bool
retrive_from_cache
,
ImplExecutionPolicy
&
policy
)
const
;
//! get workspace size required for specific execution policy
size_t
get_workspace_size_bytes
(
const
ImplExecutionPolicy
&
policy
)
const
;
//! get all candidate algos, and the one choose_by_heuristic() is
//! put first
std
::
vector
<
ImplAlgo
>
get_all_candidates
()
const
;
/*!
* \brief profile a single algorithm
*
...
...
@@ -132,22 +156,8 @@ public:
Maybe
<
AlgoChooserProfileCache
::
ResultEntry
>
profile_single_algo
(
const
ImplExecutionPolicy
&
policy
,
double
&
timeout
)
const
;
//! get all profile algorithm from cache, return invalid if not exists
ImplAlgo
get_profile_result_from_cache
(
ExecutionStrategy
selected_strategy
)
const
;
/**
* \brief construct execution policy from cache or heuristic.
*
* \param selected_strategy select algo which matched this strategy
* \param [out] policy execution policy
* \param retrive_from_cache retrive algo from cache if set True, get
* from heuristic otherwise.
* \note When contruction fail, the policy will be cleaned.
*/
void
construct_execution_policy
(
ExecutionStrategy
selected_strategy
,
ImplExecutionPolicy
&
policy
,
bool
retrive_from_cache
=
true
)
const
;
//! profile and save to cache
void
profile
(
const
ExecutionStrategy
&
selected_strategy
)
const
;
/**
* \brief extract algo attribute from execution strategy and graph
...
...
@@ -168,14 +178,7 @@ public:
private:
//! entrance for getting algorithm according to execution strategy
static
ImplExecutionPolicy
get_policy
(
ExeContext
&
ctx
);
//! profile and save to cache
static
void
profile
(
ExeContext
&
ctx
,
ExecutionStrategy
selected_strategy
);
static
ImplExecutionPolicy
choose_by_profile
(
ExeContext
&
ctx
,
ExecutionStrategy
selected_strategy
,
bool
enable_update
=
true
);
static
ImplExecutionPolicy
get_policy
(
const
AlgoChooserHelper
&
helper
);
public:
/*!
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录