提交 7c09e41f 编写于 作者: M Megvii Engine Team

refactor(mgb): add circular dependency check

GitOrigin-RevId: 01fdb8684be2c594d9b8d9a57d28528cf5412dc6
上级 af42ce7e
......@@ -12,6 +12,8 @@
#include "megbrain/opr/search_policy/algo_chooser.h"
#include <limits>
#include <unordered_set>
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/search_policy/algo_chooser_helper.h"
#include "megbrain/opr/search_policy/profiler.h"
......@@ -22,6 +24,7 @@
//! TODO: here has to be know some megdnn::opr when there is produced midout.h
//! fix it if there is another graceful way.
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h"
#include "megdnn/oprs/base.h"
#include "midout.h"
......@@ -78,6 +81,58 @@ std::string format_fixlayouts(
return ret;
}
/**
* \brief Check if the sub opr list has circular dependence.
*/
class CircularDepsChecker {
struct SearchItemStorage {
std::string data_hold;
size_t hash = 0;
SearchItemStorage(const Algorithm::SearchItem& item) {
Algorithm::serialize_write_pod(item.opr_type, data_hold);
for (auto&& layout : item.layouts) {
data_hold += layout.serialize();
}
data_hold += item.param;
}
SearchItemStorage& init_hash() {
hash = XXHash64CT::hash(data_hold.data(), data_hold.size(),
20201225);
return *this;
}
bool operator==(const SearchItemStorage& rhs) const {
return data_hold == rhs.data_hold;
}
struct Hash {
size_t operator()(const SearchItemStorage& s) const {
return s.hash;
}
};
};
std::unordered_set<SearchItemStorage, SearchItemStorage::Hash> m_set;
public:
void put(const megdnn::Algorithm::SearchItem& key) {
SearchItemStorage key_storage(key);
key_storage.init_hash();
mgb_assert(m_set.find(key_storage) == m_set.end(),
"Circular dependency during flatten search space");
auto ret = m_set.insert(std::move(key_storage));
mgb_assert(ret.second);
}
void remove(const megdnn::Algorithm::SearchItem& key) {
SearchItemStorage key_storage(key);
key_storage.init_hash();
auto&& iter = m_set.find(key_storage);
mgb_assert(iter != m_set.end());
m_set.erase(iter);
}
};
///////////////// OprTypeTrait /////////////////////////////
template <megdnn::Algorithm::OprType>
struct OprFromOprTypeTrait;
......@@ -176,14 +231,26 @@ typename opr::AlgoChooser<Opr>::FixedTensorLayouts to_fixed_layouts(
return ret;
}
} // namespace
namespace mgb {
namespace opr {
/**
* flatten search space in postorder traversal
* The subopr search construct a search tree
*
* A
* / \
* B1B2 C
* / \
* D1D2D3 E
* We use postorder traverse the search tree.
* D1 -> D2 -> D3 -> E -> B1 -> B2 -> C -> A
*/
template <typename Opr>
std::vector<megdnn::Algorithm::SearchItem>
AlgoChooser<Opr>::flatten_search_space(const ExeContext& ctx) {
std::vector<megdnn::Algorithm::SearchItem> flatten_search_space(
const typename opr::AlgoChooser<Opr>::ExeContext& ctx,
CircularDepsChecker& checker) {
auto&& search_item = megdnn::Algorithm::SearchItem{
OprTypeFromOprTrait<Opr>::opr_type, ctx.param(),
to_layout_array<Opr>(ctx.layouts())};
checker.put(search_item);
std::vector<megdnn::Algorithm::SearchItem> ret;
for (auto algo_info : ctx.get_all_candidates()) {
megdnn::Algorithm* algo = ctx.get_algorithm_from_desc(algo_info.desc);
......@@ -193,23 +260,29 @@ AlgoChooser<Opr>::flatten_search_space(const ExeContext& ctx) {
ctx.megdnn_opr());
FOREACH_OPR_TYPE_DISPATCH(sub_items, {
auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(ctx.comp_node());
auto&& megdnn_opr =
opr::intl::create_megdnn_opr<_Opr>(ctx.comp_node());
megdnn_opr->param() =
Algorithm::deserialize_read_pod<typename _Opr::Param>(
_item.param);
typename AlgoChooser<_Opr>::ExeContext sub_ctx(
typename opr::AlgoChooser<_Opr>::ExeContext sub_ctx(
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(),
_item.param, ctx.mgb_opr(), ctx.comp_node(),
ctx.execution_policy(), ctx.allow_weight_preprocess());
auto space = AlgoChooser<_Opr>::flatten_search_space(sub_ctx);
auto space = flatten_search_space<_Opr>(sub_ctx, checker);
ret.insert(ret.end(), space.begin(), space.end());
});
}
ret.push_back({OprTypeFromOprTrait<Opr>::opr_type, ctx.param(),
to_layout_array<Opr>(ctx.layouts())});
ret.push_back(search_item);
checker.remove(search_item);
return ret;
}
} // namespace
namespace mgb {
namespace opr {
template <typename Opr>
void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) {
if (ctx.get_profile_result_from_cache(require_reproducible).valid())
......@@ -289,7 +362,9 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, bool require_reproducible,
}
if (enable_update) {
auto&& search_items = flatten_search_space(ctx);
CircularDepsChecker circular_deps_checker;
auto&& search_items =
flatten_search_space<Opr>(ctx, circular_deps_checker);
FOREACH_OPR_TYPE_DISPATCH(search_items, {
auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(ctx.comp_node());
megdnn_opr->param() =
......@@ -382,14 +457,12 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy(
AlgoChooser<megdnn::Opr>::get_policy(ExeContext& ctx); \
template void AlgoChooser<megdnn::Opr>::profile( \
ExeContext& ctx, bool require_reproducible); \
template std::vector<megdnn::Algorithm::SearchItem> \
AlgoChooser<megdnn::Opr>::flatten_search_space(const ExeContext& ctx); \
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
AlgoChooser<megdnn::Opr>::choose_by_profile( \
ExeContext& ctx, bool require_reproducible, bool enable_update); \
template size_t AlgoChooser<megdnn::Opr>::setup_algo( \
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \
const MGBOpr* mgb_opr, bool allow_weight_preprocess);
const MGBOpr* mgb_opr, bool allow_weight_preprocess); \
MGB_FOREACH_FASTRUN_OPR(INST)
......
......@@ -159,21 +159,6 @@ private:
bool require_reproducible,
bool enable_update = true);
/**
* flatten search space in postorder traversal
* The subopr search construct a search tree
*
* A
* / \
* B1B2 C
* / \
* D1D2D3 E
* We use postorder traverse the search tree.
* D1 -> D2 -> D3 -> E -> B1 -> B2 -> C -> A
*/
static std::vector<megdnn::Algorithm::SearchItem> flatten_search_space(
const ExeContext& ctx);
public:
/*!
* \brief setup algorithm and return workspace size
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册