diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 06100741e50698827c3e9df4e9333c98b38ec9c9..8eb2c5a9b726406c743e177dbbc7ee5ae1ac4974 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -73,7 +73,7 @@ PyTypeObject PyOpType(name); } \ } while (0) -template +template struct pyobj_convert_generic { static T from(PyObject* obj) { // TODO: remove this guard which is used for pybind11 implicit conversion @@ -87,7 +87,12 @@ struct pyobj_convert_generic { } }; -template +template +struct EnumTrait { + static constexpr bool is_bit_combined = false; +}; + +template PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { PyObject* obj = type->tp_alloc(type, 0); T* self = reinterpret_cast(obj); @@ -203,9 +208,10 @@ struct EnumWrapper { } }; -template +template struct pyobj_convert_generic>>> { + std::enable_if_t> && + !EnumTrait::is_bit_combined>> { using Wrapper = EnumWrapper; static T from(PyObject* obj) { if (PyObject_TypeCheck(obj, &Wrapper::type)) { @@ -223,6 +229,115 @@ struct pyobj_convert_generic +struct BitCombinedEnumWrapper { + static_assert(std::is_enum_v); + PyObject_HEAD + T value; + static const char* name; + static PyTypeObject type; + static std::unordered_map type2str; + static std::unordered_map str2type; + static PyNumberMethods number_methods; + BitCombinedEnumWrapper() = default; + BitCombinedEnumWrapper(T v): value(v) {} + BitCombinedEnumWrapper(std::string&& str) + : BitCombinedEnumWrapper(str2type.at(normalize_enum(str))) {} + std::string to_string() const { + if (static_cast(value) == 0) { + return "None"; + } else { + auto ret = std::string(); + bool first = true; + for (uint32_t i = 0; i < 32; i++) { + uint32_t value_int = static_cast(value); + auto it = type2str.find(static_cast((1 << i) & value_int)); + if (it != type2str.end()) { + if (!first) { + ret += " + "; + } else { + first = false; + } + ret += (std::string(name) + "." + it->second); + } + } + return ret; + } + } + static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) { + PyObject* obj = type->tp_alloc(type, 0); + reinterpret_cast(obj)->value = static_cast(1); + return obj; + } + static int py_init(PyObject* self, PyObject* args, PyObject*) { + int input = 1; + if (PyArg_ParseTuple(args, "|i", &input)){ + reinterpret_cast(self)->value = + static_cast(input); + } + return 0; + } + static PyObject* py_repr(PyObject* self) { + return pyobj_convert_generic::to( + reinterpret_cast(self)->to_string()); + } + static PyObject* py_or(PyObject* self, PyObject* other) { + if(!(self->ob_type == other->ob_type)){ + return PyErr_Format( + PyExc_RuntimeError, + "Operand in or operator must be the same type."); + } + PyObject* obj = type.tp_alloc(&type, 0); + T lhs = reinterpret_cast(self)->value, + rhs = reinterpret_cast(other)->value; + reinterpret_cast(obj)->value = static_cast( + static_cast(lhs) | static_cast(rhs)); + return obj; + } + static PyObject* py_and(PyObject* self, PyObject* other) { + if (!(self->ob_type == other->ob_type)) { + return PyErr_Format( + PyExc_RuntimeError, + "Operand in and operator must be the same type."); + } + PyObject* obj = type.tp_alloc(&type, 0); + T lhs = reinterpret_cast(self)->value, + rhs = reinterpret_cast(other)->value; + reinterpret_cast(obj)->value = static_cast( + static_cast(lhs) & static_cast(rhs)); + return obj; + } + static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) { + T lhs = reinterpret_cast(self)->value, + rhs = reinterpret_cast(other)->value; + if (op == Py_EQ || op == Py_NE) { + RETURN_RICHCOMPARE(lhs, rhs, op); + } + Py_RETURN_NOTIMPLEMENTED; + } +}; + +template +struct pyobj_convert_generic> && + EnumTrait::is_bit_combined>> { + using Wrapper = BitCombinedEnumWrapper; + static T from(PyObject* obj) { + if (PyObject_TypeCheck(obj, &Wrapper::type)) { + return reinterpret_cast(obj)->value; + } + // try as string + // TODO: type checkcd + return Wrapper(pyobj_convert_generic::from(obj)).value; + } + static PyObject* to(T t) { + PyTypeObject* pytype = &Wrapper::type; + PyObject* obj = pytype->tp_alloc(pytype, 0); + reinterpret_cast(obj)->value = t; + return obj; + } +}; + void _init_py_op_def(py::module m) { using py_op = PyOp(OpDef); auto& py_type = PyOpType(OpDef); diff --git a/imperative/tablegen/autogen.cpp b/imperative/tablegen/autogen.cpp index 1e00f8f3d7901610d5519065a59f08897c3d12da..44b3dabfb81b8d46e8c812e42ad2073c83c8c544 100644 --- a/imperative/tablegen/autogen.cpp +++ b/imperative/tablegen/autogen.cpp @@ -408,61 +408,58 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& os << ";\n\n"; } -static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { - auto className = op.getCppClassName(); +static std::string gen_op_def_python_c_extension_enum( + raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, + llvm::StringRef className) { std::string body; - - // generate PyType for enum class member - for (auto&& i : op.getMgbAttributes()) { - if (auto attr = llvm::dyn_cast(&i.attr)) { - unsigned int enumID; - if (auto alias = llvm::dyn_cast(attr)) { - auto&& aliasBase = alias->getAliasBase(); - enumID = - llvm::cast(aliasBase) - .getBaseRecord()->getID(); - } else { - enumID = attr->getBaseRecord()->getID(); - } - auto&& enumAlias = ctx.enumAlias; - auto&& iter = enumAlias.find(enumID); - auto enumName = attr->getEnumName(); - body += "{\n"; - body += formatv( - "auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName - ); - if (iter == enumAlias.end()) { - os << formatv( - "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", - className, enumName); - os << formatv( - "template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n", - className, enumName); - std::vector pairStr; - for (auto&& i: attr->getEnumMembers()) { - pairStr.push_back(formatv( - "{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", - className, enumName, i)); - } - os << formatv(R"( + unsigned int enumID; + if (auto alias = llvm::dyn_cast(attr)) { + auto&& aliasBase = alias->getAliasBase(); + enumID = llvm::cast(aliasBase).getBaseRecord()->getID(); + } else { + enumID = attr->getBaseRecord()->getID(); + } + auto&& enumAlias = ctx.enumAlias; + auto&& iter = enumAlias.find(enumID); + auto enumName = attr->getEnumName(); + body += "{\n"; + body += formatv("auto& e_type = EnumWrapper<{0}::{1}>::type;", className, + enumName); + if (iter == enumAlias.end()) { + os << formatv( + "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", + className, enumName); + os << formatv( + "template<> const char* EnumWrapper<{0}::{1}>::name = " + "\"{0}.{1}\";\n", + className, enumName); + std::vector pairStr; + for (auto&& i : attr->getEnumMembers()) { + pairStr.push_back( + formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", + className, enumName, i)); + } + os << formatv(R"( template<> std::unordered_map EnumWrapper<{0}::{1}>::str2type = {{ {2} }; -)", className, enumName, llvm::join(pairStr, ", ")); - pairStr.clear(); - for (auto&& i: attr->getEnumMembers()) { - pairStr.push_back(formatv( - "{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", - className, enumName, i)); - } - os << formatv(R"( +)", + className, enumName, llvm::join(pairStr, ", ")); + pairStr.clear(); + for (auto&& i : attr->getEnumMembers()) { + pairStr.push_back( + formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", + className, enumName, i)); + } + os << formatv(R"( template<> std::unordered_map<{0}::{1}, std::string> EnumWrapper<{0}::{1}>::type2str = {{ {2} }; -)", className, enumName, llvm::join(pairStr, ", ")); - body += formatv(R"( +)", + className, enumName, llvm::join(pairStr, ", ")); + body += formatv(R"( e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); @@ -472,22 +469,140 @@ EnumWrapper<{0}::{1}>::type2str = {{ e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; mgb_assert(PyType_Ready(&e_type) >= 0); -)", className, enumName); - for (auto&& i: attr->getEnumMembers()) { - body += formatv(R"({{ +)", + className, enumName); + for (auto&& i : attr->getEnumMembers()) { + body += formatv(R"({{ PyObject* inst = e_type.tp_alloc(&e_type, 0); reinterpret_cast*>(inst)->value = {0}::{1}::{2}; mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); -})", className, enumName, i); - } - enumAlias.emplace(enumID, std::make_pair(className, enumName)); - } - body += formatv(R"( +})", + className, enumName, i); + } + enumAlias.emplace(enumID, std::make_pair(className, enumName)); + } + body += formatv(R"( + PyType_Modified(&e_type); + mgb_assert(PyDict_SetItemString( + py_type.tp_dict, "{0}", reinterpret_cast(&e_type)) >= 0); +)", + enumName); + body += "}\n"; + return body; +} + +static std::string gen_op_def_python_c_extension_bit_combined_enum( + raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr, + llvm::StringRef className) { + std::string body; + unsigned int enumID; + if (auto alias = llvm::dyn_cast(attr)) { + auto&& aliasBase = alias->getAliasBase(); + enumID = llvm::cast(aliasBase).getBaseRecord()->getID(); + } else { + enumID = attr->getBaseRecord()->getID(); + } + auto&& enumAlias = ctx.enumAlias; + auto&& iter = enumAlias.find(enumID); + auto enumName = attr->getEnumName(); + body += "{\n"; + body += formatv("auto& e_type = BitCombinedEnumWrapper<{0}::{1}>::type;", + className, enumName); + if (iter == enumAlias.end()) { + os << formatv( + "template<> PyTypeObject " + "BitCombinedEnumWrapper<{0}::{1}>::type={{};\n", + className, enumName); + os << formatv( + "template<> PyNumberMethods " + "BitCombinedEnumWrapper<{0}::{1}>::number_methods={{};\n", + className, enumName); + os << formatv( + "template<> const char* BitCombinedEnumWrapper<{0}::{1}>::name " + "= \"{0}.{1}\";\n", + className, enumName); + os << formatv( + "template<> struct EnumTrait<{0}::{1}> {{ static constexpr " + "bool is_bit_combined = true;};\n", + className, enumName); + std::vector pairStr; + for (auto&& i : attr->getEnumMembers()) { + pairStr.push_back( + formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", + className, enumName, i)); + } + os << formatv(R"( +template<> std::unordered_map +BitCombinedEnumWrapper<{0}::{1}>::str2type = {{ + {2} +}; +)", + className, enumName, llvm::join(pairStr, ", ")); + pairStr.clear(); + for (auto&& i : attr->getEnumMembers()) { + pairStr.push_back( + formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", + className, enumName, i)); + } + os << formatv(R"( +template<> std::unordered_map<{0}::{1}, std::string> +BitCombinedEnumWrapper<{0}::{1}>::type2str = {{ + {2} +}; +)", + className, enumName, llvm::join(pairStr, ", ")); + body += formatv(R"( + e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; + e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; + e_type.tp_basicsize = sizeof(BitCombinedEnumWrapper<{0}::{1}>); + e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + e_type.tp_doc = "{0}.{1}"; + e_type.tp_base = &PyBaseObject_Type; + e_type.tp_new = BitCombinedEnumWrapper<{0}::{1}>::py_new_combined_enum; + e_type.tp_init = BitCombinedEnumWrapper<{0}::{1}>::py_init; + e_type.tp_repr = BitCombinedEnumWrapper<{0}::{1}>::py_repr; + e_type.tp_richcompare = BitCombinedEnumWrapper<{0}::{1}>::tp_richcompare; + auto& number_method = BitCombinedEnumWrapper<{0}::{1}>::number_methods; + number_method.nb_or = BitCombinedEnumWrapper<{0}::{1}>::py_or; + number_method.nb_and = BitCombinedEnumWrapper<{0}::{1}>::py_and; + e_type.tp_as_number = &number_method; + mgb_assert(PyType_Ready(&e_type) >= 0); +)", + className, enumName); + for (auto&& i : attr->getEnumMembers()) { + body += formatv(R"({{ + PyObject* inst = e_type.tp_alloc(&e_type, 0); + reinterpret_cast*>(inst)->value = {0}::{1}::{2}; + mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); +})", + className, enumName, i); + } + enumAlias.emplace(enumID, std::make_pair(className, enumName)); + } + body += formatv(R"( PyType_Modified(&e_type); mgb_assert(PyDict_SetItemString( py_type.tp_dict, "{0}", reinterpret_cast(&e_type)) >= 0); -)", enumName); - body += "}\n"; +)", + enumName); + body += "}\n"; + return body; +} + +static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { + auto className = op.getCppClassName(); + std::string body; + + // generate PyType for enum class member + for (auto&& i : op.getMgbAttributes()) { + if (auto attr = llvm::dyn_cast(&i.attr)) { + if (attr->getEnumCombinedFlag()) { + body += gen_op_def_python_c_extension_bit_combined_enum( + os, ctx, attr, className); + } else { + body += gen_op_def_python_c_extension_enum(os, ctx, attr, + className); + } } } diff --git a/sdk/load-and-run/src/mgblar.cpp b/sdk/load-and-run/src/mgblar.cpp index f25a0e37497e8c6a728210abd84b7e4ed15c0a69..e51ffa6a25b7977c6b52cac54ed60278c425bbc1 100644 --- a/sdk/load-and-run/src/mgblar.cpp +++ b/sdk/load-and-run/src/mgblar.cpp @@ -141,15 +141,13 @@ R"__usage__( )__usage__" #if MGB_ENABLE_FASTRUN R"__usage__( - --fast-run - This param will be deperated later, please replace with param --full-profile. - --full-profile - Enable full-profile mode. Operators with multiple algorithms would be profiled + --full-run + Enable full-run mode. Operators with multiple algorithms would be profiled on the real device with actual input shapes, all algorithms will be profiled include naive algorithms. See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. - --fast-profile - Enable fast-profile mode. Operators with multiple algorithms would be profiled + --fast-run + Enable fast-run mode. Operators with multiple algorithms would be profiled on the real device with actual input shapes, this mode will only profile the well optimized algorithms to get the profile result fast. See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. @@ -519,8 +517,8 @@ struct Args { bool disable_assert_throw = false; bool share_param_mem = false; #if MGB_ENABLE_FASTRUN - bool use_full_profile = false; - bool use_fast_profile = false; + bool use_full_run = false; + bool use_fast_run = false; #endif bool reproducible = false; std::string fast_run_cache_path; @@ -704,13 +702,13 @@ void run_test_st(Args &env) { using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; S strategy = S::HEURISTIC; #if MGB_ENABLE_FASTRUN - if (env.use_full_profile) { + if (env.use_full_run) { if (env.reproducible) { strategy = S::PROFILE | S::REPRODUCIBLE; } else { strategy = S::PROFILE; } - } else if (env.use_fast_profile) { + } else if (env.use_fast_run) { strategy = S::PROFILE | S::OPTMIZED; } else if (env.reproducible) { strategy = S::HEURISTIC | S::REPRODUCIBLE; @@ -740,12 +738,12 @@ void run_test_st(Args &env) { std::make_shared(buf.get(), flen)); #if MGB_ENABLE_FASTRUN } else { - mgb_assert(env.use_full_profile || env.use_fast_profile, - "fast-run or fast-profile should be enabled"); + mgb_assert(env.use_full_run || env.use_fast_run, + "fast-run or fast-run should be enabled"); PersistentCache::set_impl( std::make_shared()); } - if (!env.use_full_profile && !env.use_fast_profile) + if (!env.use_full_run && !env.use_fast_run) #endif mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); } @@ -1326,18 +1324,11 @@ Args Args::from_argv(int argc, char **argv) { } #if MGB_ENABLE_FASTRUN if (!strcmp(argv[i], "--fast-run")) { - mgb_log_warn( - "--fast-run param will be deperated later, please replace " - "with --full-profile or --fast-profile."); - ret.use_full_profile = true; + ret.use_fast_run = true; continue; } - if (!strcmp(argv[i], "--full-profile")) { - ret.use_full_profile = true; - continue; - } - if (!strcmp(argv[i], "--fast-profile")) { - ret.use_fast_profile = true; + if (!strcmp(argv[i], "--full-run")) { + ret.use_full_run = true; continue; } #endif diff --git a/src/core/include/megbrain/common.h b/src/core/include/megbrain/common.h index 085ff414481db259a15034f02a69af4120810e04..cb2781e302bf4d405e7cdb12e1a3a0dcf9644f53 100644 --- a/src/core/include/megbrain/common.h +++ b/src/core/include/megbrain/common.h @@ -12,7 +12,6 @@ #pragma once #include "megbrain_build_config.h" -#include "megbrain/opr/param_defs.h" #include "megdnn/basic_types.h" #include @@ -250,10 +249,4 @@ inline constexpr std::size_t operator"" _z(unsigned long long n) { } // namespace mgb -namespace megdnn { -namespace param { -MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy) -} -} // namespace megdnn - // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/core/include/megbrain/graph/operator_node.h b/src/core/include/megbrain/graph/operator_node.h index 27c597416dd6270e28b4577c64fe1635e761d264..021e255f99f4e619668fff80f679603a326ccc90 100644 --- a/src/core/include/megbrain/graph/operator_node.h +++ b/src/core/include/megbrain/graph/operator_node.h @@ -18,6 +18,7 @@ #include "megbrain/utils/hashable.h" #include "megbrain/utils/thin/hash_table.h" #include "megbrain/utils/small_vector.h" +#include "megbrain/opr/param_defs.h" #include @@ -1043,4 +1044,10 @@ MGB_DEFINE_CLS_WITH_SUPER(_name final, _base ,##__VA_ARGS__) \ } // namespace cg } // namespace mgb +namespace megdnn { +namespace param { +MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy) +} +} // namespace megdnn + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index 22c49155782363a23ca9dbd48c68a7ab8da62770..9f712c4040913b84fe0af16e3f4786356dd91972 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -278,6 +278,19 @@ std::vector flatten_search_space( return ret; } +//! Test whether the algo attribute of a algo match the require +//! algo_strategy +static bool algo_attribute_match_strategy(AlgoAttribute attribute, + ExecutionStrategy selected_strategy) { + bool ret = true; + if (selected_strategy & ExecutionStrategy::OPTMIZED) { + ret &= (!static_cast(AlgoAttribute::NAIVE & attribute)); + } else if (selected_strategy & ExecutionStrategy::REPRODUCIBLE) { + ret &= static_cast(AlgoAttribute::REPRODUCIBLE & attribute); + } + return ret; +} + } // namespace namespace mgb { @@ -285,8 +298,8 @@ namespace opr { template void AlgoChooser::profile(ExeContext& ctx, - ExecutionStrategy select_strategy) { - if (ctx.get_profile_result_from_cache(select_strategy).valid()) + ExecutionStrategy selected_strategy) { + if (ctx.get_profile_result_from_cache(selected_strategy).valid()) return; AlgoChooserProfileCache::Result prof_rst; @@ -306,9 +319,19 @@ void AlgoChooser::profile(ExeContext& ctx, algo.name.c_str(), str_on_inp_shape.c_str()); ImplExecutionPolicy policy; policy.algo = algo.desc; - ctx.construct_execution_policy(select_strategy, policy); - if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) + ctx.construct_execution_policy(selected_strategy, policy); + if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) { continue; + } + auto algo_attribute = ctx.megdnn_opr() + ->get_algorithm_from_desc(policy.algo) + ->attribute(); + if (!algo_attribute_match_strategy(algo_attribute, selected_strategy)) { + mgb_log_debug( + "skip algo %s, which is not match the profile strategy.", + algo.name.c_str()); + continue; + } timer.reset(); MGB_TRY { cur_rst = ctx.profile_single_algo(policy, cur_timeout); } @@ -356,7 +379,7 @@ void AlgoChooser::profile(ExeContext& ctx, template typename AlgoChooser::ImplExecutionPolicy AlgoChooser::choose_by_profile(ExeContext& ctx, - ExecutionStrategy select_strategy, + ExecutionStrategy selected_strategy, bool enable_update) { MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) if (ctx.owner_graph()->options().no_profiling_on_shape_change) { @@ -378,11 +401,11 @@ AlgoChooser::choose_by_profile(ExeContext& ctx, to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), _item.param, ctx.mgb_opr(), ctx.comp_node(), ctx.execution_policy(), ctx.allow_weight_preprocess()); - AlgoChooser<_Opr>::profile(sub_ctx, select_strategy); + AlgoChooser<_Opr>::profile(sub_ctx, selected_strategy); }); } typename AlgoChooser::ImplExecutionPolicy policy; - ctx.construct_execution_policy(select_strategy, policy); + ctx.construct_execution_policy(selected_strategy, policy); return policy; MIDOUT_E } @@ -440,7 +463,8 @@ typename AlgoChooser::ImplExecutionPolicy AlgoChooser::get_policy( if (!policy.algo.valid()) policy = ctx.choose_by_heuristic(opr_strategy); return policy; - } else if ((opr_strategy & ExecutionStrategy::HEURISTIC)) { + } else if (!static_cast(opr_strategy) || + (opr_strategy & ExecutionStrategy::HEURISTIC)) { return ctx.choose_by_heuristic(opr_strategy); } #if MGB_ENABLE_FASTRUN @@ -449,7 +473,7 @@ typename AlgoChooser::ImplExecutionPolicy AlgoChooser::get_policy( } #endif else { - mgb_throw(GraphError, "bad convolution ExecutionPolicy strategy"); + mgb_throw(GraphError, "bad ExecutionPolicy strategy"); } } @@ -495,7 +519,7 @@ AlgoChooser::ExeContext::ExeContext( template typename AlgoChooser::ImplAlgo AlgoChooser::ExeContext::get_profile_result_from_cache( - ExecutionStrategy select_strategy) const { + ExecutionStrategy selected_strategy) const { MIDOUT_B(Opr, midout_iv(MGB_HASH_STR( "AlgoChooser::ExeContext::get_profile_result_from_cache"))) @@ -519,7 +543,7 @@ AlgoChooser::ExeContext::get_profile_result_from_cache( if (prof.empty()) return {}; for (auto&& i : prof) { - if (!(select_strategy & ExecutionStrategy::REPRODUCIBLE) || + if (!(selected_strategy & ExecutionStrategy::REPRODUCIBLE) || static_cast(i.attribute) & AlgoAttribute::REPRODUCIBLE) { auto iter = algo_map.find(i.algo); @@ -550,7 +574,7 @@ AlgoChooser::ExeContext::get_profile_result_from_cache( template typename AlgoChooser::ImplExecutionPolicy AlgoChooser::ExeContext::choose_by_heuristic( - ExecutionStrategy select_strategy) const { + ExecutionStrategy selected_strategy) const { if (m_execution_policy.workspace_limit != std::numeric_limits::max()) { @@ -558,7 +582,7 @@ AlgoChooser::ExeContext::choose_by_heuristic( "workspace_limit should not be setted if choose algo by " "heuristic"); } - bool reproducible = static_cast(select_strategy & + bool reproducible = static_cast(selected_strategy & ExecutionStrategy::REPRODUCIBLE); auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( owner_graph(), m_cn, m_execution_policy.workspace_limit); @@ -582,7 +606,7 @@ AlgoChooser::ExeContext::choose_by_heuristic( _item.param, m_base_mgb_opr, m_cn, m_execution_policy, m_allow_weight_preprocess); policy.sub_policy.push_back( - sub_ctx.choose_by_heuristic(select_strategy)); + sub_ctx.choose_by_heuristic(selected_strategy)); }); return policy; @@ -613,15 +637,15 @@ AlgoChooser::ExeContext::get_all_candidates() const { template void AlgoChooser::ExeContext::construct_execution_policy( - ExecutionStrategy select_strategy, + ExecutionStrategy selected_strategy, typename AlgoChooser::ImplExecutionPolicy& policy, bool retrive_from_cache) const { - bool reproducible = static_cast(select_strategy & + bool reproducible = static_cast(selected_strategy & ExecutionStrategy::REPRODUCIBLE); if (!policy.algo.valid()) { if (retrive_from_cache) { policy.algo = - get_profile_result_from_cache(select_strategy).desc; + get_profile_result_from_cache(selected_strategy).desc; } else { auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( owner_graph(), m_cn, m_execution_policy.workspace_limit); @@ -651,7 +675,7 @@ void AlgoChooser::ExeContext::construct_execution_policy( _item.param, m_base_mgb_opr, m_cn, m_execution_policy, m_allow_weight_preprocess); policy.sub_policy.push_back({}); - sub_ctx.construct_execution_policy(select_strategy, + sub_ctx.construct_execution_policy(selected_strategy, policy.sub_policy.back(), retrive_from_cache); }); diff --git a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h index a9af2081373906423ec28d8914873692d1579a82..bb193e18af7efb04785d8ef0bcd50de5135931cf 100644 --- a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h +++ b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h @@ -110,7 +110,7 @@ public: const FixedTensorLayouts& layouts() const { return m_layouts; } ImplExecutionPolicy choose_by_heuristic( - ExecutionStrategy select_strategy) const; + ExecutionStrategy selected_strategy) const; //! get all candidate algos, and the one choose_by_heuristic() is //! put first @@ -134,17 +134,17 @@ public: //! get all profile algorithm from cache, return invalid if not exists ImplAlgo get_profile_result_from_cache( - ExecutionStrategy select_strategy) const; + ExecutionStrategy selected_strategy) const; /** * \brief construct execution policy from cache or heuristic. * - * \param select_strategy select algo which matched this strategy + * \param selected_strategy select algo which matched this strategy * \param policy execution policy * \param retrive_from_cache retrive algo from cache if set True, get * from heuristic otherwise. */ - void construct_execution_policy(ExecutionStrategy select_strategy, + void construct_execution_policy(ExecutionStrategy selected_strategy, ImplExecutionPolicy& policy, bool retrive_from_cache = true) const; @@ -161,10 +161,10 @@ private: //! profile and save to cache - static void profile(ExeContext& ctx, ExecutionStrategy select_strategy); + static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy); static ImplExecutionPolicy choose_by_profile( - ExeContext& ctx, ExecutionStrategy select_strategy, + ExeContext& ctx, ExecutionStrategy selected_strategy, bool enable_update = true); public: