提交 4b141f8d 编写于 作者: M Megvii Engine Team

fix(mgb): add usable-depend-on-shape attr

GitOrigin-RevId: 3a14fa6b6f61c30999a9cd7f20f90dccc52e1377
上级 15b647ae
......@@ -122,6 +122,11 @@ public:
* these algorithms to speed up fastrun.
* */
NAIVE = 1 << 1,
/**
* \brief whether the algo is usable once shape changed.
* */
USABLE_DEPEND_ON_SHAPE = 1 << 2,
};
/**
......
......@@ -35,7 +35,8 @@ public:
class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; }
bool usable(const KernSizeParam&) const override;
......@@ -146,7 +147,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; }
bool usable(const KernSizeParam&) const override;
......@@ -220,7 +222,8 @@ public:
class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
......@@ -235,7 +238,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override {
return "AARCH64_INT8X8X16_MK4_16X12X4";
......@@ -253,7 +257,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override {
return "AARCH64_INT8X8X16_MK4_K8X8X8";
......@@ -271,7 +276,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; }
bool usable(const KernSizeParam&) const override;
......@@ -330,7 +336,8 @@ public:
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; }
bool usable(const KernSizeParam&) const override;
......
......@@ -34,7 +34,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; }
bool usable(const KernSizeParam&) const override;
......@@ -50,7 +51,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; }
bool usable(const KernSizeParam&) const override;
......@@ -67,7 +69,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; }
bool usable(const KernSizeParam&) const override;
......@@ -102,7 +105,8 @@ public:
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; }
bool usable(const KernSizeParam&) const override;
......
......@@ -35,7 +35,8 @@ public:
class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; }
bool usable(const KernSizeParam&) const override;
......@@ -224,7 +225,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; }
bool usable(const KernSizeParam&) const override;
......@@ -266,7 +268,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; }
bool usable(const KernSizeParam&) const override;
......
......@@ -18,7 +18,8 @@ using namespace megdnn;
#define FOREACH_ALGO_ATTRIBUTE(cb) \
cb(DEFAULT) \
cb(REPRODUCIBLE) \
cb(NAIVE)
cb(NAIVE) \
cb(USABLE_DEPEND_ON_SHAPE)
namespace {
inline const char* attr_str(const AlgoAttribute& attr) {
......
......@@ -184,7 +184,8 @@ public:
const char* name() const override { return "CHANNEL_WISE_SMALL"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
};
......
......@@ -89,7 +89,8 @@ public:
void exec(const ExecArgs& args) const override;
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override {
......@@ -108,7 +109,8 @@ public:
void exec(const ExecArgs& args) const override;
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override {
......
......@@ -114,7 +114,8 @@ public:
void exec(const ExecArgs& args) const override;
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
};
......@@ -231,7 +232,8 @@ public:
const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override;
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K)
......
......@@ -100,7 +100,8 @@ public:
const char* name() const override { return "BLAS"; }
void exec(const ExecArgs& args) const override;
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS)
};
......
......@@ -135,7 +135,8 @@ public:
class MatrixMulImpl::AlgoF32MK8_8x8 : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "X86_F32MK8_8X8"; }
bool usable(const KernSizeParam&) const override;
......
......@@ -276,21 +276,6 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space(
return ret;
}
//! return pair<positive_attr, negative_attr>
std::pair<AlgoAttribute, AlgoAttribute>
extract_algo_attribute_from_execution_strategy(
const ExecutionStrategy& strategy) {
std::pair<AlgoAttribute, AlgoAttribute> ret =
std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
if (strategy & ExecutionStrategy::REPRODUCIBLE) {
ret.first |= AlgoAttribute::REPRODUCIBLE;
}
if (strategy & ExecutionStrategy::OPTIMIZED) {
ret.second |= AlgoAttribute::NAIVE;
}
return ret;
}
} // namespace
namespace mgb {
......@@ -303,9 +288,9 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
return;
AlgoChooserProfileCache::Result prof_rst;
auto target_attr =
extract_algo_attribute_from_execution_strategy(selected_strategy);
std::string layouts_str = format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out);
auto target_attr = ctx.extract_algo_attribute(selected_strategy);
std::string layouts_str =
format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out);
double cur_timeout = 0;
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
......@@ -558,16 +543,15 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
if (prof.empty())
return {};
auto attr_from_strategy =
extract_algo_attribute_from_execution_strategy(selected_strategy);
auto target_attr = extract_algo_attribute(selected_strategy);
for (auto&& i : prof) {
auto attr_of_algo =
static_cast<megdnn::Algorithm::Attribute>(i.attribute);
bool contain_attr_all_positive =
(attr_from_strategy.first ==
(attr_of_algo & attr_from_strategy.first));
(target_attr.first ==
(attr_of_algo & target_attr.first));
bool contain_attr_any_negative =
static_cast<bool>(attr_of_algo & attr_from_strategy.second);
static_cast<bool>(attr_of_algo & target_attr.second);
if (contain_attr_all_positive && !contain_attr_any_negative) {
auto iter = algo_map.find(i.algo);
mgb_assert(iter != algo_map.end(),
......@@ -586,8 +570,8 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
mgb_log_error(
"algos read from cache could not satisfy attribute with %s and "
"without %s",
Algorithm::attribute_str(attr_from_strategy.first).c_str(),
Algorithm::attribute_str(attr_from_strategy.second).c_str());
Algorithm::attribute_str(target_attr.first).c_str(),
Algorithm::attribute_str(target_attr.second).c_str());
mgb_trap();
MIDOUT_E
......@@ -606,8 +590,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
}
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit);
auto attr =
extract_algo_attribute_from_execution_strategy(selected_strategy);
auto attr = extract_algo_attribute(selected_strategy);
ImplExecutionPolicy policy;
policy.algo =
APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
......@@ -668,9 +651,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
if (retrive_from_cache) {
policy.algo = get_profile_result_from_cache(selected_strategy).desc;
if (!policy.algo.valid()) {
auto target_attr =
extract_algo_attribute_from_execution_strategy(
selected_strategy);
auto target_attr = extract_algo_attribute(selected_strategy);
std::string layouts_str =
format_fixlayouts<Opr>(m_layouts, arity_in, arity_out);
std::string msg = ssprintf(
......@@ -692,8 +673,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit);
auto attr = extract_algo_attribute_from_execution_strategy(
selected_strategy);
auto attr = extract_algo_attribute(selected_strategy);
policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
args..., workspace_limit, attr.first,
attr.second),
......@@ -837,6 +817,24 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const {
return result;
}
template <typename Opr>
std::pair<AlgoAttribute, AlgoAttribute>
AlgoChooser<Opr>::ExeContext::extract_algo_attribute(
const ExecutionStrategy& strategy) const {
std::pair<AlgoAttribute, AlgoAttribute> ret =
std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
//! from strategy
if (strategy & ExecutionStrategy::REPRODUCIBLE) {
ret.first |= AlgoAttribute::REPRODUCIBLE;
}
if (strategy & ExecutionStrategy::OPTMIZED) {
ret.second |= AlgoAttribute::NAIVE;
}
return ret;
}
#define INST(Opr) \
template AlgoChooser<megdnn::Opr>::ExeContext::ExeContext( \
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \
......@@ -865,7 +863,10 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const {
AlgoChooser<megdnn::Opr>::ExeContext::profile_single_algo( \
const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& \
policy, \
double& timeout) const;
double& timeout) const; \
template std::pair<AlgoAttribute, AlgoAttribute> \
AlgoChooser<megdnn::Opr>::ExeContext::extract_algo_attribute( \
const ExecutionStrategy& strategy) const;
MGB_FOREACH_FASTRUN_OPR(INST)
......
......@@ -149,6 +149,16 @@ public:
ImplExecutionPolicy& policy,
bool retrive_from_cache = true) const;
/**
* \brief extract algo attribute from execution strategy and graph
* option.
*
* \param strategy select algo which matched this strategy
* \return pair<positive_attr, negative_attr>
*/
std::pair<AlgoAttribute, AlgoAttribute> extract_algo_attribute(
const ExecutionStrategy& strategy) const;
private:
Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册