提交 5d637d07 编写于 作者: M Megvii Engine Team

refactor(mgb): code refactor of fast run

GitOrigin-RevId: 2c4b8e06bb3c4b4cb0228ee28988c2371455b1b0
上级 f6bd4f59
...@@ -66,7 +66,7 @@ class AlgoChooser { ...@@ -66,7 +66,7 @@ class AlgoChooser {
public: public:
using FixedTensorLayouts = std::array<TensorLayout, arity>; using FixedTensorLayouts = std::array<TensorLayout, arity>;
class ExeContext { class AlgoChooserHelper {
FixedTensorLayouts m_layouts; FixedTensorLayouts m_layouts;
Opr* m_megdnn_opr; Opr* m_megdnn_opr;
std::string m_param; std::string m_param;
...@@ -76,22 +76,23 @@ public: ...@@ -76,22 +76,23 @@ public:
bool m_allow_weight_preprocess; bool m_allow_weight_preprocess;
public: public:
ExeContext(const FixedTensorLayouts& layouts, Opr* megdnn_opr, AlgoChooserHelper(
const std::string& param_str, const FixedTensorLayouts& layouts, Opr* megdnn_opr,
const cg::OperatorNodeBase* mgb_opr, const CompNode& cn, const std::string& param_str,
const megdnn::param::ExecutionPolicy& execution_policy, const cg::OperatorNodeBase* mgb_opr, const CompNode& cn,
bool allow_weight_preprocess); const megdnn::param::ExecutionPolicy& execution_policy,
bool allow_weight_preprocess);
Opr* megdnn_opr() const { return m_megdnn_opr; } Opr* megdnn_opr() const { return m_megdnn_opr; }
const cg::OperatorNodeBase* mgb_opr() const { return m_base_mgb_opr; }
const TensorLayout& inp_layout(size_t idx) const { const TensorLayout& inp_layout(size_t idx) const {
return m_layouts[idx]; return m_layouts[idx];
} }
cg::ComputingGraph* owner_graph() const { cg::ComputingGraph* owner_graph() const {
return m_base_mgb_opr->owner_graph(); return m_base_mgb_opr->owner_graph();
} }
const cg::OperatorNodeBase* mgb_opr() const { return m_base_mgb_opr; }
const megdnn::param::ExecutionPolicy& execution_policy() const { const megdnn::param::ExecutionPolicy& execution_policy() const {
return m_execution_policy; return m_execution_policy;
} }
...@@ -109,17 +110,40 @@ public: ...@@ -109,17 +110,40 @@ public:
const FixedTensorLayouts& layouts() const { return m_layouts; } const FixedTensorLayouts& layouts() const { return m_layouts; }
//! construct algo chain by heuristic
ImplExecutionPolicy choose_by_heuristic( ImplExecutionPolicy choose_by_heuristic(
ExecutionStrategy selected_strategy) const; const ExecutionStrategy& selected_strategy) const;
//! get all candidate algos, and the one choose_by_heuristic() is //! construct algo chain by profiling
//! put first ImplExecutionPolicy choose_by_profile(
std::vector<ImplAlgo> get_all_candidates() const; const ExecutionStrategy& selected_strategy,
bool enable_update) const;
//! get all profile algorithm from cache, return invalid if not exists
ImplAlgo get_profile_result_from_cache(
const ExecutionStrategy& selected_strategy) const;
/**
* \brief construct execution policy from cache or heuristic.
*
* \param selected_strategy select algo which matched this strategy
* \param[in,out] policy execution policy
* \param retrive_from_cache retrive algo from cache if set True, get
* from heuristic otherwise.
* \return true if contruct success and false when fail
*/
void construct_execution_policy(
const ExecutionStrategy& selected_strategy,
bool retrive_from_cache, ImplExecutionPolicy& policy) const;
//! get workspace size required for specific execution policy //! get workspace size required for specific execution policy
size_t get_workspace_size_bytes( size_t get_workspace_size_bytes(
const ImplExecutionPolicy& policy) const; const ImplExecutionPolicy& policy) const;
//! get all candidate algos, and the one choose_by_heuristic() is
//! put first
std::vector<ImplAlgo> get_all_candidates() const;
/*! /*!
* \brief profile a single algorithm * \brief profile a single algorithm
* *
...@@ -132,22 +156,8 @@ public: ...@@ -132,22 +156,8 @@ public:
Maybe<AlgoChooserProfileCache::ResultEntry> profile_single_algo( Maybe<AlgoChooserProfileCache::ResultEntry> profile_single_algo(
const ImplExecutionPolicy& policy, double& timeout) const; const ImplExecutionPolicy& policy, double& timeout) const;
//! get all profile algorithm from cache, return invalid if not exists //! profile and save to cache
ImplAlgo get_profile_result_from_cache( void profile(const ExecutionStrategy& selected_strategy) const;
ExecutionStrategy selected_strategy) const;
/**
* \brief construct execution policy from cache or heuristic.
*
* \param selected_strategy select algo which matched this strategy
* \param [out] policy execution policy
* \param retrive_from_cache retrive algo from cache if set True, get
* from heuristic otherwise.
* \note When contruction fail, the policy will be cleaned.
*/
void construct_execution_policy(ExecutionStrategy selected_strategy,
ImplExecutionPolicy& policy,
bool retrive_from_cache = true) const;
/** /**
* \brief extract algo attribute from execution strategy and graph * \brief extract algo attribute from execution strategy and graph
...@@ -168,14 +178,7 @@ public: ...@@ -168,14 +178,7 @@ public:
private: private:
//! entrance for getting algorithm according to execution strategy //! entrance for getting algorithm according to execution strategy
static ImplExecutionPolicy get_policy(ExeContext& ctx); static ImplExecutionPolicy get_policy(const AlgoChooserHelper& helper);
//! profile and save to cache
static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy);
static ImplExecutionPolicy choose_by_profile(
ExeContext& ctx, ExecutionStrategy selected_strategy,
bool enable_update = true);
public: public:
/*! /*!
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册