提交 e19b9af1 编写于 作者: M Megvii Engine Team

feat(imperative): add bit combined enum to python C extension

GitOrigin-RevId: 92307dd2ca077ea5606657f7cb7b321fd0dc8129
上级 a3ea1f15
......@@ -73,7 +73,7 @@ PyTypeObject PyOpType(name);
} \
} while (0)
template<typename T, typename SFINAE=void>
template <typename T, typename SFINAE = void>
struct pyobj_convert_generic {
static T from(PyObject* obj) {
// TODO: remove this guard which is used for pybind11 implicit conversion
......@@ -87,7 +87,12 @@ struct pyobj_convert_generic {
}
};
template<typename T>
template <typename T>
struct EnumTrait {
static constexpr bool is_bit_combined = false;
};
template <typename T>
PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) {
PyObject* obj = type->tp_alloc(type, 0);
T* self = reinterpret_cast<T*>(obj);
......@@ -203,9 +208,10 @@ struct EnumWrapper {
}
};
template<typename T>
template <typename T>
struct pyobj_convert_generic<T,
std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> {
std::enable_if_t<std::is_enum_v<std::decay_t<T>> &&
!EnumTrait<T>::is_bit_combined>> {
using Wrapper = EnumWrapper<T>;
static T from(PyObject* obj) {
if (PyObject_TypeCheck(obj, &Wrapper::type)) {
......@@ -223,6 +229,115 @@ struct pyobj_convert_generic<T,
}
};
template<typename T>
struct BitCombinedEnumWrapper {
static_assert(std::is_enum_v<T>);
PyObject_HEAD
T value;
static const char* name;
static PyTypeObject type;
static std::unordered_map<T, std::string> type2str;
static std::unordered_map<std::string, T> str2type;
static PyNumberMethods number_methods;
BitCombinedEnumWrapper() = default;
BitCombinedEnumWrapper(T v): value(v) {}
BitCombinedEnumWrapper(std::string&& str)
: BitCombinedEnumWrapper(str2type.at(normalize_enum(str))) {}
std::string to_string() const {
if (static_cast<uint32_t>(value) == 0) {
return "None";
} else {
auto ret = std::string();
bool first = true;
for (uint32_t i = 0; i < 32; i++) {
uint32_t value_int = static_cast<uint32_t>(value);
auto it = type2str.find(static_cast<T>((1 << i) & value_int));
if (it != type2str.end()) {
if (!first) {
ret += " + ";
} else {
first = false;
}
ret += (std::string(name) + "." + it->second);
}
}
return ret;
}
}
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) {
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(1);
return obj;
}
static int py_init(PyObject* self, PyObject* args, PyObject*) {
int input = 1;
if (PyArg_ParseTuple(args, "|i", &input)){
reinterpret_cast<BitCombinedEnumWrapper*>(self)->value =
static_cast<T>(input);
}
return 0;
}
static PyObject* py_repr(PyObject* self) {
return pyobj_convert_generic<std::string>::to(
reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string());
}
static PyObject* py_or(PyObject* self, PyObject* other) {
if(!(self->ob_type == other->ob_type)){
return PyErr_Format(
PyExc_RuntimeError,
"Operand in or operator must be the same type.");
}
PyObject* obj = type.tp_alloc(&type, 0);
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(
static_cast<uint32_t>(lhs) | static_cast<uint32_t>(rhs));
return obj;
}
static PyObject* py_and(PyObject* self, PyObject* other) {
if (!(self->ob_type == other->ob_type)) {
return PyErr_Format(
PyExc_RuntimeError,
"Operand in and operator must be the same type.");
}
PyObject* obj = type.tp_alloc(&type, 0);
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(
static_cast<uint32_t>(lhs) & static_cast<uint32_t>(rhs));
return obj;
}
static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
if (op == Py_EQ || op == Py_NE) {
RETURN_RICHCOMPARE(lhs, rhs, op);
}
Py_RETURN_NOTIMPLEMENTED;
}
};
template <typename T>
struct pyobj_convert_generic<T,
std::enable_if_t<std::is_enum_v<std::decay_t<T>> &&
EnumTrait<T>::is_bit_combined>> {
using Wrapper = BitCombinedEnumWrapper<T>;
static T from(PyObject* obj) {
if (PyObject_TypeCheck(obj, &Wrapper::type)) {
return reinterpret_cast<Wrapper*>(obj)->value;
}
// try as string
// TODO: type checkcd
return Wrapper(pyobj_convert_generic<std::string>::from(obj)).value;
}
static PyObject* to(T t) {
PyTypeObject* pytype = &Wrapper::type;
PyObject* obj = pytype->tp_alloc(pytype, 0);
reinterpret_cast<Wrapper*>(obj)->value = t;
return obj;
}
};
void _init_py_op_def(py::module m) {
using py_op = PyOp(OpDef);
auto& py_type = PyOpType(OpDef);
......
......@@ -408,19 +408,14 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext&
os << ";\n\n";
}
static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
auto className = op.getCppClassName();
static std::string gen_op_def_python_c_extension_enum(
raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr,
llvm::StringRef className) {
std::string body;
// generate PyType for enum class member
for (auto&& i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
unsigned int enumID;
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
auto&& aliasBase = alias->getAliasBase();
enumID =
llvm::cast<MgbEnumAttr>(aliasBase)
.getBaseRecord()->getID();
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
} else {
enumID = attr->getBaseRecord()->getID();
}
......@@ -428,20 +423,20 @@ static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, Enu
auto&& iter = enumAlias.find(enumID);
auto enumName = attr->getEnumName();
body += "{\n";
body += formatv(
"auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName
);
body += formatv("auto& e_type = EnumWrapper<{0}::{1}>::type;", className,
enumName);
if (iter == enumAlias.end()) {
os << formatv(
"template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n",
className, enumName);
os << formatv(
"template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n",
"template<> const char* EnumWrapper<{0}::{1}>::name = "
"\"{0}.{1}\";\n",
className, enumName);
std::vector<std::string> pairStr;
for (auto&& i: attr->getEnumMembers()) {
pairStr.push_back(formatv(
"{{normalize_enum(\"{2}\"), {0}::{1}::{2}}",
for (auto&& i : attr->getEnumMembers()) {
pairStr.push_back(
formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}",
className, enumName, i));
}
os << formatv(R"(
......@@ -449,11 +444,12 @@ template<> std::unordered_map<std::string, {0}::{1}>
EnumWrapper<{0}::{1}>::str2type = {{
{2}
};
)", className, enumName, llvm::join(pairStr, ", "));
)",
className, enumName, llvm::join(pairStr, ", "));
pairStr.clear();
for (auto&& i: attr->getEnumMembers()) {
pairStr.push_back(formatv(
"{{{0}::{1}::{2}, normalize_enum(\"{2}\")}",
for (auto&& i : attr->getEnumMembers()) {
pairStr.push_back(
formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}",
className, enumName, i));
}
os << formatv(R"(
......@@ -461,7 +457,8 @@ template<> std::unordered_map<{0}::{1}, std::string>
EnumWrapper<{0}::{1}>::type2str = {{
{2}
};
)", className, enumName, llvm::join(pairStr, ", "));
)",
className, enumName, llvm::join(pairStr, ", "));
body += formatv(R"(
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}";
......@@ -472,13 +469,113 @@ EnumWrapper<{0}::{1}>::type2str = {{
e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr;
e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare;
mgb_assert(PyType_Ready(&e_type) >= 0);
)", className, enumName);
for (auto&& i: attr->getEnumMembers()) {
)",
className, enumName);
for (auto&& i : attr->getEnumMembers()) {
body += formatv(R"({{
PyObject* inst = e_type.tp_alloc(&e_type, 0);
reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2};
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0);
})", className, enumName, i);
})",
className, enumName, i);
}
enumAlias.emplace(enumID, std::make_pair(className, enumName));
}
body += formatv(R"(
PyType_Modified(&e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0);
)",
enumName);
body += "}\n";
return body;
}
static std::string gen_op_def_python_c_extension_bit_combined_enum(
raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr,
llvm::StringRef className) {
std::string body;
unsigned int enumID;
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
auto&& aliasBase = alias->getAliasBase();
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
} else {
enumID = attr->getBaseRecord()->getID();
}
auto&& enumAlias = ctx.enumAlias;
auto&& iter = enumAlias.find(enumID);
auto enumName = attr->getEnumName();
body += "{\n";
body += formatv("auto& e_type = BitCombinedEnumWrapper<{0}::{1}>::type;",
className, enumName);
if (iter == enumAlias.end()) {
os << formatv(
"template<> PyTypeObject "
"BitCombinedEnumWrapper<{0}::{1}>::type={{};\n",
className, enumName);
os << formatv(
"template<> PyNumberMethods "
"BitCombinedEnumWrapper<{0}::{1}>::number_methods={{};\n",
className, enumName);
os << formatv(
"template<> const char* BitCombinedEnumWrapper<{0}::{1}>::name "
"= \"{0}.{1}\";\n",
className, enumName);
os << formatv(
"template<> struct EnumTrait<{0}::{1}> {{ static constexpr "
"bool is_bit_combined = true;};\n",
className, enumName);
std::vector<std::string> pairStr;
for (auto&& i : attr->getEnumMembers()) {
pairStr.push_back(
formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}",
className, enumName, i));
}
os << formatv(R"(
template<> std::unordered_map<std::string, {0}::{1}>
BitCombinedEnumWrapper<{0}::{1}>::str2type = {{
{2}
};
)",
className, enumName, llvm::join(pairStr, ", "));
pairStr.clear();
for (auto&& i : attr->getEnumMembers()) {
pairStr.push_back(
formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}",
className, enumName, i));
}
os << formatv(R"(
template<> std::unordered_map<{0}::{1}, std::string>
BitCombinedEnumWrapper<{0}::{1}>::type2str = {{
{2}
};
)",
className, enumName, llvm::join(pairStr, ", "));
body += formatv(R"(
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}";
e_type.tp_basicsize = sizeof(BitCombinedEnumWrapper<{0}::{1}>);
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
e_type.tp_doc = "{0}.{1}";
e_type.tp_base = &PyBaseObject_Type;
e_type.tp_new = BitCombinedEnumWrapper<{0}::{1}>::py_new_combined_enum;
e_type.tp_init = BitCombinedEnumWrapper<{0}::{1}>::py_init;
e_type.tp_repr = BitCombinedEnumWrapper<{0}::{1}>::py_repr;
e_type.tp_richcompare = BitCombinedEnumWrapper<{0}::{1}>::tp_richcompare;
auto& number_method = BitCombinedEnumWrapper<{0}::{1}>::number_methods;
number_method.nb_or = BitCombinedEnumWrapper<{0}::{1}>::py_or;
number_method.nb_and = BitCombinedEnumWrapper<{0}::{1}>::py_and;
e_type.tp_as_number = &number_method;
mgb_assert(PyType_Ready(&e_type) >= 0);
)",
className, enumName);
for (auto&& i : attr->getEnumMembers()) {
body += formatv(R"({{
PyObject* inst = e_type.tp_alloc(&e_type, 0);
reinterpret_cast<BitCombinedEnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2};
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0);
})",
className, enumName, i);
}
enumAlias.emplace(enumID, std::make_pair(className, enumName));
}
......@@ -486,8 +583,26 @@ EnumWrapper<{0}::{1}>::type2str = {{
PyType_Modified(&e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0);
)", enumName);
)",
enumName);
body += "}\n";
return body;
}
static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
auto className = op.getCppClassName();
std::string body;
// generate PyType for enum class member
for (auto&& i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
if (attr->getEnumCombinedFlag()) {
body += gen_op_def_python_c_extension_bit_combined_enum(
os, ctx, attr, className);
} else {
body += gen_op_def_python_c_extension_enum(os, ctx, attr,
className);
}
}
}
......
......@@ -141,15 +141,13 @@ R"__usage__(
)__usage__"
#if MGB_ENABLE_FASTRUN
R"__usage__(
--fast-run
This param will be deperated later, please replace with param --full-profile.
--full-profile
Enable full-profile mode. Operators with multiple algorithms would be profiled
--full-run
Enable full-run mode. Operators with multiple algorithms would be profiled
on the real device with actual input shapes, all algorithms will be profiled
include naive algorithms.
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details.
--fast-profile
Enable fast-profile mode. Operators with multiple algorithms would be profiled
--fast-run
Enable fast-run mode. Operators with multiple algorithms would be profiled
on the real device with actual input shapes, this mode will only profile the
well optimized algorithms to get the profile result fast.
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details.
......@@ -519,8 +517,8 @@ struct Args {
bool disable_assert_throw = false;
bool share_param_mem = false;
#if MGB_ENABLE_FASTRUN
bool use_full_profile = false;
bool use_fast_profile = false;
bool use_full_run = false;
bool use_fast_run = false;
#endif
bool reproducible = false;
std::string fast_run_cache_path;
......@@ -704,13 +702,13 @@ void run_test_st(Args &env) {
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
S strategy = S::HEURISTIC;
#if MGB_ENABLE_FASTRUN
if (env.use_full_profile) {
if (env.use_full_run) {
if (env.reproducible) {
strategy = S::PROFILE | S::REPRODUCIBLE;
} else {
strategy = S::PROFILE;
}
} else if (env.use_fast_profile) {
} else if (env.use_fast_run) {
strategy = S::PROFILE | S::OPTMIZED;
} else if (env.reproducible) {
strategy = S::HEURISTIC | S::REPRODUCIBLE;
......@@ -740,12 +738,12 @@ void run_test_st(Args &env) {
std::make_shared<InFilePersistentCache>(buf.get(), flen));
#if MGB_ENABLE_FASTRUN
} else {
mgb_assert(env.use_full_profile || env.use_fast_profile,
"fast-run or fast-profile should be enabled");
mgb_assert(env.use_full_run || env.use_fast_run,
"fast-run or fast-run should be enabled");
PersistentCache::set_impl(
std::make_shared<InFilePersistentCache>());
}
if (!env.use_full_profile && !env.use_fast_profile)
if (!env.use_full_run && !env.use_fast_run)
#endif
mgb::gopt::enable_opr_use_profiling_cache_inplace(vars);
}
......@@ -1326,18 +1324,11 @@ Args Args::from_argv(int argc, char **argv) {
}
#if MGB_ENABLE_FASTRUN
if (!strcmp(argv[i], "--fast-run")) {
mgb_log_warn(
"--fast-run param will be deperated later, please replace "
"with --full-profile or --fast-profile.");
ret.use_full_profile = true;
continue;
}
if (!strcmp(argv[i], "--full-profile")) {
ret.use_full_profile = true;
ret.use_fast_run = true;
continue;
}
if (!strcmp(argv[i], "--fast-profile")) {
ret.use_fast_profile = true;
if (!strcmp(argv[i], "--full-run")) {
ret.use_full_run = true;
continue;
}
#endif
......
......@@ -12,7 +12,6 @@
#pragma once
#include "megbrain_build_config.h"
#include "megbrain/opr/param_defs.h"
#include "megdnn/basic_types.h"
#include <memory>
......@@ -250,10 +249,4 @@ inline constexpr std::size_t operator"" _z(unsigned long long n) {
} // namespace mgb
namespace megdnn {
namespace param {
MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy)
}
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -18,6 +18,7 @@
#include "megbrain/utils/hashable.h"
#include "megbrain/utils/thin/hash_table.h"
#include "megbrain/utils/small_vector.h"
#include "megbrain/opr/param_defs.h"
#include <type_traits>
......@@ -1043,4 +1044,10 @@ MGB_DEFINE_CLS_WITH_SUPER(_name final, _base ,##__VA_ARGS__) \
} // namespace cg
} // namespace mgb
namespace megdnn {
namespace param {
MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy)
}
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -278,6 +278,19 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space(
return ret;
}
//! Test whether the algo attribute of a algo match the require
//! algo_strategy
static bool algo_attribute_match_strategy(AlgoAttribute attribute,
ExecutionStrategy selected_strategy) {
bool ret = true;
if (selected_strategy & ExecutionStrategy::OPTMIZED) {
ret &= (!static_cast<bool>(AlgoAttribute::NAIVE & attribute));
} else if (selected_strategy & ExecutionStrategy::REPRODUCIBLE) {
ret &= static_cast<bool>(AlgoAttribute::REPRODUCIBLE & attribute);
}
return ret;
}
} // namespace
namespace mgb {
......@@ -285,8 +298,8 @@ namespace opr {
template <typename Opr>
void AlgoChooser<Opr>::profile(ExeContext& ctx,
ExecutionStrategy select_strategy) {
if (ctx.get_profile_result_from_cache(select_strategy).valid())
ExecutionStrategy selected_strategy) {
if (ctx.get_profile_result_from_cache(selected_strategy).valid())
return;
AlgoChooserProfileCache::Result prof_rst;
......@@ -306,9 +319,19 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
algo.name.c_str(), str_on_inp_shape.c_str());
ImplExecutionPolicy policy;
policy.algo = algo.desc;
ctx.construct_execution_policy(select_strategy, policy);
if (ctx.get_workspace_size_bytes(policy) >= workspace_limit)
ctx.construct_execution_policy(selected_strategy, policy);
if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) {
continue;
}
auto algo_attribute = ctx.megdnn_opr()
->get_algorithm_from_desc(policy.algo)
->attribute();
if (!algo_attribute_match_strategy(algo_attribute, selected_strategy)) {
mgb_log_debug(
"skip algo %s, which is not match the profile strategy.",
algo.name.c_str());
continue;
}
timer.reset();
MGB_TRY { cur_rst = ctx.profile_single_algo(policy, cur_timeout); }
......@@ -356,7 +379,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
template <typename Opr>
typename AlgoChooser<Opr>::ImplExecutionPolicy
AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx,
ExecutionStrategy select_strategy,
ExecutionStrategy selected_strategy,
bool enable_update) {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile")))
if (ctx.owner_graph()->options().no_profiling_on_shape_change) {
......@@ -378,11 +401,11 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx,
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(),
_item.param, ctx.mgb_opr(), ctx.comp_node(),
ctx.execution_policy(), ctx.allow_weight_preprocess());
AlgoChooser<_Opr>::profile(sub_ctx, select_strategy);
AlgoChooser<_Opr>::profile(sub_ctx, selected_strategy);
});
}
typename AlgoChooser<Opr>::ImplExecutionPolicy policy;
ctx.construct_execution_policy(select_strategy, policy);
ctx.construct_execution_policy(selected_strategy, policy);
return policy;
MIDOUT_E
}
......@@ -440,7 +463,8 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy(
if (!policy.algo.valid())
policy = ctx.choose_by_heuristic(opr_strategy);
return policy;
} else if ((opr_strategy & ExecutionStrategy::HEURISTIC)) {
} else if (!static_cast<int>(opr_strategy) ||
(opr_strategy & ExecutionStrategy::HEURISTIC)) {
return ctx.choose_by_heuristic(opr_strategy);
}
#if MGB_ENABLE_FASTRUN
......@@ -449,7 +473,7 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy(
}
#endif
else {
mgb_throw(GraphError, "bad convolution ExecutionPolicy strategy");
mgb_throw(GraphError, "bad ExecutionPolicy strategy");
}
}
......@@ -495,7 +519,7 @@ AlgoChooser<Opr>::ExeContext::ExeContext(
template <typename Opr>
typename AlgoChooser<Opr>::ImplAlgo
AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
ExecutionStrategy select_strategy) const {
ExecutionStrategy selected_strategy) const {
MIDOUT_B(Opr,
midout_iv(MGB_HASH_STR(
"AlgoChooser::ExeContext::get_profile_result_from_cache")))
......@@ -519,7 +543,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
if (prof.empty())
return {};
for (auto&& i : prof) {
if (!(select_strategy & ExecutionStrategy::REPRODUCIBLE) ||
if (!(selected_strategy & ExecutionStrategy::REPRODUCIBLE) ||
static_cast<AlgoAttribute>(i.attribute) &
AlgoAttribute::REPRODUCIBLE) {
auto iter = algo_map.find(i.algo);
......@@ -550,7 +574,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
template <typename Opr>
typename AlgoChooser<Opr>::ImplExecutionPolicy
AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
ExecutionStrategy select_strategy) const {
ExecutionStrategy selected_strategy) const {
if (m_execution_policy.workspace_limit !=
std::numeric_limits<decltype(
m_execution_policy.workspace_limit)>::max()) {
......@@ -558,7 +582,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
"workspace_limit should not be setted if choose algo by "
"heuristic");
}
bool reproducible = static_cast<bool>(select_strategy &
bool reproducible = static_cast<bool>(selected_strategy &
ExecutionStrategy::REPRODUCIBLE);
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit);
......@@ -582,7 +606,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
_item.param, m_base_mgb_opr, m_cn, m_execution_policy,
m_allow_weight_preprocess);
policy.sub_policy.push_back(
sub_ctx.choose_by_heuristic(select_strategy));
sub_ctx.choose_by_heuristic(selected_strategy));
});
return policy;
......@@ -613,15 +637,15 @@ AlgoChooser<Opr>::ExeContext::get_all_candidates() const {
template <typename Opr>
void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
ExecutionStrategy select_strategy,
ExecutionStrategy selected_strategy,
typename AlgoChooser<Opr>::ImplExecutionPolicy& policy,
bool retrive_from_cache) const {
bool reproducible = static_cast<bool>(select_strategy &
bool reproducible = static_cast<bool>(selected_strategy &
ExecutionStrategy::REPRODUCIBLE);
if (!policy.algo.valid()) {
if (retrive_from_cache) {
policy.algo =
get_profile_result_from_cache(select_strategy).desc;
get_profile_result_from_cache(selected_strategy).desc;
} else {
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit);
......@@ -651,7 +675,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
_item.param, m_base_mgb_opr, m_cn, m_execution_policy,
m_allow_weight_preprocess);
policy.sub_policy.push_back({});
sub_ctx.construct_execution_policy(select_strategy,
sub_ctx.construct_execution_policy(selected_strategy,
policy.sub_policy.back(),
retrive_from_cache);
});
......
......@@ -110,7 +110,7 @@ public:
const FixedTensorLayouts& layouts() const { return m_layouts; }
ImplExecutionPolicy choose_by_heuristic(
ExecutionStrategy select_strategy) const;
ExecutionStrategy selected_strategy) const;
//! get all candidate algos, and the one choose_by_heuristic() is
//! put first
......@@ -134,17 +134,17 @@ public:
//! get all profile algorithm from cache, return invalid if not exists
ImplAlgo get_profile_result_from_cache(
ExecutionStrategy select_strategy) const;
ExecutionStrategy selected_strategy) const;
/**
* \brief construct execution policy from cache or heuristic.
*
* \param select_strategy select algo which matched this strategy
* \param selected_strategy select algo which matched this strategy
* \param policy execution policy
* \param retrive_from_cache retrive algo from cache if set True, get
* from heuristic otherwise.
*/
void construct_execution_policy(ExecutionStrategy select_strategy,
void construct_execution_policy(ExecutionStrategy selected_strategy,
ImplExecutionPolicy& policy,
bool retrive_from_cache = true) const;
......@@ -161,10 +161,10 @@ private:
//! profile and save to cache
static void profile(ExeContext& ctx, ExecutionStrategy select_strategy);
static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy);
static ImplExecutionPolicy choose_by_profile(
ExeContext& ctx, ExecutionStrategy select_strategy,
ExeContext& ctx, ExecutionStrategy selected_strategy,
bool enable_update = true);
public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册