提交 f902ba24 编写于 作者: M Megvii Engine Team

docs(megbrain): add notes for fastrun

GitOrigin-RevId: b59f7f205d98e127c6dcaaaedfab556cdf2dba21
上级 d968942f
......@@ -565,6 +565,7 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::AlgoChooserHelp
choose_by_profile(
const ExecutionStrategy& selected_strategy, bool enable_update) const {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("choose_by_profile")))
// no_profiling_on_shape_change is usually false, no interface to change it easily
if (m_desc.no_profiling_on_shape_change) {
auto policy = m_dnn_opr->execution_policy();
if (policy.algo.valid()) {
......@@ -579,6 +580,8 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::AlgoChooserHelp
}
}
// if update enabled, do profiling and update cache
// enable_update = false only when using HEURISRIC_PROFILE strategy
typename AlgoChooser<Opr>::ImplExecutionPolicy tmp_policy;
bool retrive_from_cache = true;
bool allow_log = false;
......@@ -604,6 +607,8 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::AlgoChooserHelp
});
}
// try to retrive algorithm from fastrun cache, this time it's guaranteed to get
// result, retrive_from_cache = true, allow_log = true
typename AlgoChooser<Opr>::ImplExecutionPolicy policy;
construct_execution_policy(selected_strategy, policy);
return policy;
......@@ -623,13 +628,16 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
m_incache_layouts.data(), m_incache_layouts.size(), &origin_param,
sizeof(origin_param)};
auto&& rst = cache.get(cache_key);
// failed to find a cache entry, return
if (!rst.valid())
return {{}, rst};
// found a cache entry(it's a vector of Result), but it's empty
auto&& prof = rst.val();
if (prof.empty())
return {{}, rst};
// found non-empty cache result, filter it by workspace limit and attribute
size_t workspace_limit =
m_desc.get_workspace_limit(m_cn, m_execution_policy.workspace_limit);
auto target_attr = extract_algo_attribute(selected_strategy);
......@@ -644,6 +652,8 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
if (contain_attr_all_positive) {
if (!contain_attr_any_negative) {
if (i.workspace <= workspace_limit) {
// found a well-suited algothrim with good workspace limit and
// correct attribute
Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo);
return {algo_desc, rst};
}
......@@ -654,9 +664,11 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
}
}
// failed to find an algorithm that satisfies the actual workspace limit
if (skip_by_workspace)
return {};
// failed to find an algorithm that satisfies the actual attribute
std::string layouts_str = AlgoChooser::format_fixlayouts(m_fastrun_layouts);
if (skip_by_negative) {
mgb_log_error(
......@@ -685,9 +697,12 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy(
typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, bool retrive_from_cache,
bool allow_log) const {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("construct_execution_policy")))
// policy.algo is always invalid when called from choose_by_profile
// policy.algo will be valid when called from profile
if (!policy.algo.valid()) {
if (retrive_from_cache) {
policy.algo = get_profile_result_from_cache(selected_strategy).first;
// nothing is found even with profiling
if (!policy.algo.valid()) {
if (allow_log) {
auto target_attr = extract_algo_attribute(selected_strategy);
......@@ -710,6 +725,8 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy(
return;
}
} else {
// retrive_from_cache = false happens when using algo choose hook in
// megbrain graph return heuristic algorithm in this case
auto workspace_limit = m_desc.get_workspace_limit(
m_cn, m_execution_policy.workspace_limit);
......@@ -727,11 +744,13 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy(
}
}
// construct current algorithm
Algorithm* algo = m_dnn_opr->get_algorithm_from_desc(policy.algo);
mgb_assert(algo, "Unknown algo description");
std::vector<Algorithm::SearchItem>&& sub_items =
algo->get_subopr_list(to_layout_array<Opr>(m_fastrun_layouts), m_dnn_opr);
// construct sub oprs' algorithm
FOREACH_OPR_TYPE_DISPATCH(sub_items, {
auto&& megdnn_opr = opr::intl::create_megdnn_opr<_Opr>(m_cn);
megdnn_opr->param() =
......@@ -790,6 +809,8 @@ std::vector<typename AlgoChooser<Opr>::ImplAlgo> AlgoChooser<
auto heu = choose_by_heuristic(m_execution_policy.strategy);
auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info(args...), m_fastrun_layouts);
bool found = false;
// make heuristic algorithm always the first in all candidate alrogrithms
// so profiling step will always run heuristic algorithm first
for (size_t i = 0; i < ret.size(); ++i) {
if (ret[i].desc == heu.algo) {
found = true;
......@@ -798,6 +819,7 @@ std::vector<typename AlgoChooser<Opr>::ImplAlgo> AlgoChooser<
}
}
// make sure heuristic algorithm is valid
Algorithm* palgo = m_dnn_opr->get_algorithm_from_desc(heu.algo);
mgb_assert(palgo, "Unknown algo description");
mgb_assert(
......@@ -813,6 +835,7 @@ template <typename Opr>
Maybe<AlgoChooserProfileCache::ResultEntry> AlgoChooser<Opr>::AlgoChooserHelper::
profile_single_algo(const ImplExecutionPolicy& policy, double& timeout) const {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile_single_algo")))
// fill TimedProfiler<Opr>::param and run actual timed profiler
typename TimedProfiler<Opr>::Param param;
// force check copy size <= dest len-1 from gcc8 for safe
param.execution_policy =
......@@ -867,7 +890,11 @@ template <typename Opr>
void AlgoChooser<Opr>::AlgoChooserHelper::profile(
const ExecutionStrategy& selected_strategy) const {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile")))
// some sub oprs have beed profiled before
// sub oprs won't be checked at the beginning of choose_by_profile
auto&& rst = get_profile_result_from_cache(selected_strategy);
// rst.first.valid means there exists valid algorithms for current opr, just return
// otherwise need to profile
if (rst.first.valid())
return;
AlgoChooserProfileCache::Result prof_rst;
......@@ -957,6 +984,7 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile(
Algorithm::attribute_str(target_attr.second).c_str(), workspace_limit);
mgb_assert(!prof_rst.empty(), "%s", msg.c_str());
// append some previous profiled results
if (rst.second.valid())
prof_rst.insert(
prof_rst.end(), rst.second.val().begin(), rst.second.val().end());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册