提交 46988d45 编写于 作者: M Megvii Engine Team

refactor(mgb): code refactor of fast run

GitOrigin-RevId: 2c4b8e06bb3c4b4cb0228ee28988c2371455b1b0
上级 a1e38342
......@@ -66,7 +66,7 @@ class AlgoChooser {
public:
using FixedTensorLayouts = std::array<TensorLayout, arity>;
class ExeContext {
class AlgoChooserHelper {
FixedTensorLayouts m_layouts;
Opr* m_megdnn_opr;
std::string m_param;
......@@ -76,22 +76,23 @@ public:
bool m_allow_weight_preprocess;
public:
ExeContext(const FixedTensorLayouts& layouts, Opr* megdnn_opr,
const std::string& param_str,
const cg::OperatorNodeBase* mgb_opr, const CompNode& cn,
const megdnn::param::ExecutionPolicy& execution_policy,
bool allow_weight_preprocess);
AlgoChooserHelper(
const FixedTensorLayouts& layouts, Opr* megdnn_opr,
const std::string& param_str,
const cg::OperatorNodeBase* mgb_opr, const CompNode& cn,
const megdnn::param::ExecutionPolicy& execution_policy,
bool allow_weight_preprocess);
Opr* megdnn_opr() const { return m_megdnn_opr; }
const cg::OperatorNodeBase* mgb_opr() const { return m_base_mgb_opr; }
const TensorLayout& inp_layout(size_t idx) const {
return m_layouts[idx];
}
cg::ComputingGraph* owner_graph() const {
return m_base_mgb_opr->owner_graph();
}
const cg::OperatorNodeBase* mgb_opr() const { return m_base_mgb_opr; }
const megdnn::param::ExecutionPolicy& execution_policy() const {
return m_execution_policy;
}
......@@ -109,17 +110,40 @@ public:
const FixedTensorLayouts& layouts() const { return m_layouts; }
//! construct algo chain by heuristic
ImplExecutionPolicy choose_by_heuristic(
ExecutionStrategy selected_strategy) const;
const ExecutionStrategy& selected_strategy) const;
//! get all candidate algos, and the one choose_by_heuristic() is
//! put first
std::vector<ImplAlgo> get_all_candidates() const;
//! construct algo chain by profiling
ImplExecutionPolicy choose_by_profile(
const ExecutionStrategy& selected_strategy,
bool enable_update) const;
//! get all profile algorithm from cache, return invalid if not exists
ImplAlgo get_profile_result_from_cache(
const ExecutionStrategy& selected_strategy) const;
/**
* \brief construct execution policy from cache or heuristic.
*
* \param selected_strategy select algo which matched this strategy
* \param[in,out] policy execution policy
* \param retrive_from_cache retrive algo from cache if set True, get
* from heuristic otherwise.
* \return true if contruct success and false when fail
*/
void construct_execution_policy(
const ExecutionStrategy& selected_strategy,
bool retrive_from_cache, ImplExecutionPolicy& policy) const;
//! get workspace size required for specific execution policy
size_t get_workspace_size_bytes(
const ImplExecutionPolicy& policy) const;
//! get all candidate algos, and the one choose_by_heuristic() is
//! put first
std::vector<ImplAlgo> get_all_candidates() const;
/*!
* \brief profile a single algorithm
*
......@@ -132,22 +156,8 @@ public:
Maybe<AlgoChooserProfileCache::ResultEntry> profile_single_algo(
const ImplExecutionPolicy& policy, double& timeout) const;
//! get all profile algorithm from cache, return invalid if not exists
ImplAlgo get_profile_result_from_cache(
ExecutionStrategy selected_strategy) const;
/**
* \brief construct execution policy from cache or heuristic.
*
* \param selected_strategy select algo which matched this strategy
* \param [out] policy execution policy
* \param retrive_from_cache retrive algo from cache if set True, get
* from heuristic otherwise.
* \note When contruction fail, the policy will be cleaned.
*/
void construct_execution_policy(ExecutionStrategy selected_strategy,
ImplExecutionPolicy& policy,
bool retrive_from_cache = true) const;
//! profile and save to cache
void profile(const ExecutionStrategy& selected_strategy) const;
/**
* \brief extract algo attribute from execution strategy and graph
......@@ -168,14 +178,7 @@ public:
private:
//! entrance for getting algorithm according to execution strategy
static ImplExecutionPolicy get_policy(ExeContext& ctx);
//! profile and save to cache
static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy);
static ImplExecutionPolicy choose_by_profile(
ExeContext& ctx, ExecutionStrategy selected_strategy,
bool enable_update = true);
static ImplExecutionPolicy get_policy(const AlgoChooserHelper& helper);
public:
/*!
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册