提交 567586a0 编写于 作者: M Megvii Engine Team

feat(mgb): fastrun algo profile deduplication

GitOrigin-RevId: 0d1bed781d889e6e105268ea09dc5c96cd4df013
上级 8110bb21
......@@ -627,7 +627,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::choose_by_profile(
}
template <typename Opr>
typename AlgoChooser<Opr>::ImplAlgoDesc
std::pair<typename AlgoChooser<Opr>::ImplAlgoDesc, Maybe<AlgoChooserProfileCache::Result>>
AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
const ExecutionStrategy& selected_strategy) const {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_profile_result_from_cache")))
......@@ -639,11 +639,11 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
&origin_param, sizeof(origin_param)};
auto&& rst = cache.get(cache_key);
if (!rst.valid())
return {};
return {{}, rst};
auto&& prof = rst.val();
if (prof.empty())
return {};
return {{}, rst};
auto target_attr = extract_algo_attribute(selected_strategy);
bool skip_by_negative = false;
......@@ -657,7 +657,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
if (contain_attr_all_positive) {
if (!contain_attr_any_negative) {
Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo);
return algo_desc;
return {algo_desc, rst};
} else {
skip_by_negative = true;
}
......@@ -695,7 +695,7 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy(
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("construct_execution_policy")))
if (!policy.algo.valid()) {
if (retrive_from_cache) {
policy.algo = get_profile_result_from_cache(selected_strategy);
policy.algo = get_profile_result_from_cache(selected_strategy).first;
if (!policy.algo.valid()) {
if (allow_log) {
auto target_attr =
......@@ -886,7 +886,8 @@ template <typename Opr>
void AlgoChooser<Opr>::AlgoChooserHelper::profile(
const ExecutionStrategy& selected_strategy) const {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile")))
if (get_profile_result_from_cache(selected_strategy).valid())
auto&& rst = get_profile_result_from_cache(selected_strategy);
if (rst.first.valid())
return;
AlgoChooserProfileCache::Result prof_rst;
......@@ -898,7 +899,20 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile(
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit);
RealTimer timer;
std::unordered_set<std::string> rst_algos;
if (rst.second.valid()) {
std::transform(rst.second.val().begin(), rst.second.val().end(),
std::inserter(rst_algos, rst_algos.end()),
[](const AlgoChooserProfileCache::ResultEntry& result) {
return result.algo;
});
}
for (auto algo : get_all_candidates()) {
std::string desc;
serialize_write_pod(algo.desc, desc);
if (rst_algos.find(desc) != rst_algos.end()) {
continue;
}
Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst;
ImplExecutionPolicy policy;
......@@ -960,6 +974,9 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile(
Algorithm::attribute_str(target_attr.second).c_str(),
workspace_limit);
mgb_assert(!prof_rst.empty(), "%s", msg.c_str());
if (rst.second.valid())
prof_rst.insert(prof_rst.end(), rst.second.val().begin(),
rst.second.val().end());
FixedTensorLayouts incache_layouts = m_incache_layouts;
typename Opr::Param origin_param = m_dnn_opr->param();
......@@ -1058,7 +1075,8 @@ AlgoChooser<Opr>::AlgoChooserHelper::extract_algo_attribute(
AlgoChooser<megdnn::Opr>::AlgoChooserHelper::choose_by_profile( \
const ExecutionStrategy& select_strategy, bool enable_update) \
const; \
template typename AlgoChooser<megdnn::Opr>::ImplAlgoDesc \
template std::pair<typename AlgoChooser<megdnn::Opr>::ImplAlgoDesc, \
Maybe<AlgoChooserProfileCache::Result>> \
AlgoChooser<megdnn::Opr>::AlgoChooserHelper:: \
get_profile_result_from_cache( \
const ExecutionStrategy& select_strategy) const; \
......
......@@ -131,7 +131,8 @@ public:
bool enable_update) const;
//! get all profile algorithm from cache, return invalid if not exists
ImplAlgoDesc get_profile_result_from_cache(
std::pair<ImplAlgoDesc, Maybe<AlgoChooserProfileCache::Result>>
get_profile_result_from_cache(
const ExecutionStrategy& selected_strategy) const;
/**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册