提交 e19b9af1 编写于 作者: M Megvii Engine Team

feat(imperative): add bit combined enum to python C extension

GitOrigin-RevId: 92307dd2ca077ea5606657f7cb7b321fd0dc8129
上级 a3ea1f15
...@@ -73,7 +73,7 @@ PyTypeObject PyOpType(name); ...@@ -73,7 +73,7 @@ PyTypeObject PyOpType(name);
} \ } \
} while (0) } while (0)
template<typename T, typename SFINAE=void> template <typename T, typename SFINAE = void>
struct pyobj_convert_generic { struct pyobj_convert_generic {
static T from(PyObject* obj) { static T from(PyObject* obj) {
// TODO: remove this guard which is used for pybind11 implicit conversion // TODO: remove this guard which is used for pybind11 implicit conversion
...@@ -87,7 +87,12 @@ struct pyobj_convert_generic { ...@@ -87,7 +87,12 @@ struct pyobj_convert_generic {
} }
}; };
template<typename T> template <typename T>
struct EnumTrait {
static constexpr bool is_bit_combined = false;
};
template <typename T>
PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) {
PyObject* obj = type->tp_alloc(type, 0); PyObject* obj = type->tp_alloc(type, 0);
T* self = reinterpret_cast<T*>(obj); T* self = reinterpret_cast<T*>(obj);
...@@ -203,9 +208,10 @@ struct EnumWrapper { ...@@ -203,9 +208,10 @@ struct EnumWrapper {
} }
}; };
template<typename T> template <typename T>
struct pyobj_convert_generic<T, struct pyobj_convert_generic<T,
std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> { std::enable_if_t<std::is_enum_v<std::decay_t<T>> &&
!EnumTrait<T>::is_bit_combined>> {
using Wrapper = EnumWrapper<T>; using Wrapper = EnumWrapper<T>;
static T from(PyObject* obj) { static T from(PyObject* obj) {
if (PyObject_TypeCheck(obj, &Wrapper::type)) { if (PyObject_TypeCheck(obj, &Wrapper::type)) {
...@@ -223,6 +229,115 @@ struct pyobj_convert_generic<T, ...@@ -223,6 +229,115 @@ struct pyobj_convert_generic<T,
} }
}; };
template<typename T>
struct BitCombinedEnumWrapper {
static_assert(std::is_enum_v<T>);
PyObject_HEAD
T value;
static const char* name;
static PyTypeObject type;
static std::unordered_map<T, std::string> type2str;
static std::unordered_map<std::string, T> str2type;
static PyNumberMethods number_methods;
BitCombinedEnumWrapper() = default;
BitCombinedEnumWrapper(T v): value(v) {}
BitCombinedEnumWrapper(std::string&& str)
: BitCombinedEnumWrapper(str2type.at(normalize_enum(str))) {}
std::string to_string() const {
if (static_cast<uint32_t>(value) == 0) {
return "None";
} else {
auto ret = std::string();
bool first = true;
for (uint32_t i = 0; i < 32; i++) {
uint32_t value_int = static_cast<uint32_t>(value);
auto it = type2str.find(static_cast<T>((1 << i) & value_int));
if (it != type2str.end()) {
if (!first) {
ret += " + ";
} else {
first = false;
}
ret += (std::string(name) + "." + it->second);
}
}
return ret;
}
}
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) {
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(1);
return obj;
}
static int py_init(PyObject* self, PyObject* args, PyObject*) {
int input = 1;
if (PyArg_ParseTuple(args, "|i", &input)){
reinterpret_cast<BitCombinedEnumWrapper*>(self)->value =
static_cast<T>(input);
}
return 0;
}
static PyObject* py_repr(PyObject* self) {
return pyobj_convert_generic<std::string>::to(
reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string());
}
static PyObject* py_or(PyObject* self, PyObject* other) {
if(!(self->ob_type == other->ob_type)){
return PyErr_Format(
PyExc_RuntimeError,
"Operand in or operator must be the same type.");
}
PyObject* obj = type.tp_alloc(&type, 0);
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(
static_cast<uint32_t>(lhs) | static_cast<uint32_t>(rhs));
return obj;
}
static PyObject* py_and(PyObject* self, PyObject* other) {
if (!(self->ob_type == other->ob_type)) {
return PyErr_Format(
PyExc_RuntimeError,
"Operand in and operator must be the same type.");
}
PyObject* obj = type.tp_alloc(&type, 0);
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(
static_cast<uint32_t>(lhs) & static_cast<uint32_t>(rhs));
return obj;
}
static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
if (op == Py_EQ || op == Py_NE) {
RETURN_RICHCOMPARE(lhs, rhs, op);
}
Py_RETURN_NOTIMPLEMENTED;
}
};
template <typename T>
struct pyobj_convert_generic<T,
std::enable_if_t<std::is_enum_v<std::decay_t<T>> &&
EnumTrait<T>::is_bit_combined>> {
using Wrapper = BitCombinedEnumWrapper<T>;
static T from(PyObject* obj) {
if (PyObject_TypeCheck(obj, &Wrapper::type)) {
return reinterpret_cast<Wrapper*>(obj)->value;
}
// try as string
// TODO: type checkcd
return Wrapper(pyobj_convert_generic<std::string>::from(obj)).value;
}
static PyObject* to(T t) {
PyTypeObject* pytype = &Wrapper::type;
PyObject* obj = pytype->tp_alloc(pytype, 0);
reinterpret_cast<Wrapper*>(obj)->value = t;
return obj;
}
};
void _init_py_op_def(py::module m) { void _init_py_op_def(py::module m) {
using py_op = PyOp(OpDef); using py_op = PyOp(OpDef);
auto& py_type = PyOpType(OpDef); auto& py_type = PyOpType(OpDef);
......
...@@ -408,19 +408,14 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ...@@ -408,19 +408,14 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext&
os << ";\n\n"; os << ";\n\n";
} }
static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { static std::string gen_op_def_python_c_extension_enum(
auto className = op.getCppClassName(); raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr,
llvm::StringRef className) {
std::string body; std::string body;
// generate PyType for enum class member
for (auto&& i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
unsigned int enumID; unsigned int enumID;
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) { if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
auto&& aliasBase = alias->getAliasBase(); auto&& aliasBase = alias->getAliasBase();
enumID = enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
llvm::cast<MgbEnumAttr>(aliasBase)
.getBaseRecord()->getID();
} else { } else {
enumID = attr->getBaseRecord()->getID(); enumID = attr->getBaseRecord()->getID();
} }
...@@ -428,20 +423,20 @@ static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, Enu ...@@ -428,20 +423,20 @@ static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, Enu
auto&& iter = enumAlias.find(enumID); auto&& iter = enumAlias.find(enumID);
auto enumName = attr->getEnumName(); auto enumName = attr->getEnumName();
body += "{\n"; body += "{\n";
body += formatv( body += formatv("auto& e_type = EnumWrapper<{0}::{1}>::type;", className,
"auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName enumName);
);
if (iter == enumAlias.end()) { if (iter == enumAlias.end()) {
os << formatv( os << formatv(
"template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n",
className, enumName); className, enumName);
os << formatv( os << formatv(
"template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n", "template<> const char* EnumWrapper<{0}::{1}>::name = "
"\"{0}.{1}\";\n",
className, enumName); className, enumName);
std::vector<std::string> pairStr; std::vector<std::string> pairStr;
for (auto&& i: attr->getEnumMembers()) { for (auto&& i : attr->getEnumMembers()) {
pairStr.push_back(formatv( pairStr.push_back(
"{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}",
className, enumName, i)); className, enumName, i));
} }
os << formatv(R"( os << formatv(R"(
...@@ -449,11 +444,12 @@ template<> std::unordered_map<std::string, {0}::{1}> ...@@ -449,11 +444,12 @@ template<> std::unordered_map<std::string, {0}::{1}>
EnumWrapper<{0}::{1}>::str2type = {{ EnumWrapper<{0}::{1}>::str2type = {{
{2} {2}
}; };
)", className, enumName, llvm::join(pairStr, ", ")); )",
className, enumName, llvm::join(pairStr, ", "));
pairStr.clear(); pairStr.clear();
for (auto&& i: attr->getEnumMembers()) { for (auto&& i : attr->getEnumMembers()) {
pairStr.push_back(formatv( pairStr.push_back(
"{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}",
className, enumName, i)); className, enumName, i));
} }
os << formatv(R"( os << formatv(R"(
...@@ -461,7 +457,8 @@ template<> std::unordered_map<{0}::{1}, std::string> ...@@ -461,7 +457,8 @@ template<> std::unordered_map<{0}::{1}, std::string>
EnumWrapper<{0}::{1}>::type2str = {{ EnumWrapper<{0}::{1}>::type2str = {{
{2} {2}
}; };
)", className, enumName, llvm::join(pairStr, ", ")); )",
className, enumName, llvm::join(pairStr, ", "));
body += formatv(R"( body += formatv(R"(
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; e_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}";
...@@ -472,13 +469,113 @@ EnumWrapper<{0}::{1}>::type2str = {{ ...@@ -472,13 +469,113 @@ EnumWrapper<{0}::{1}>::type2str = {{
e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr;
e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare;
mgb_assert(PyType_Ready(&e_type) >= 0); mgb_assert(PyType_Ready(&e_type) >= 0);
)", className, enumName); )",
for (auto&& i: attr->getEnumMembers()) { className, enumName);
for (auto&& i : attr->getEnumMembers()) {
body += formatv(R"({{ body += formatv(R"({{
PyObject* inst = e_type.tp_alloc(&e_type, 0); PyObject* inst = e_type.tp_alloc(&e_type, 0);
reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2}; reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2};
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0);
})", className, enumName, i); })",
className, enumName, i);
}
enumAlias.emplace(enumID, std::make_pair(className, enumName));
}
body += formatv(R"(
PyType_Modified(&e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0);
)",
enumName);
body += "}\n";
return body;
}
static std::string gen_op_def_python_c_extension_bit_combined_enum(
raw_ostream& os, EnumContext& ctx, MgbEnumAttr* attr,
llvm::StringRef className) {
std::string body;
unsigned int enumID;
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
auto&& aliasBase = alias->getAliasBase();
enumID = llvm::cast<MgbEnumAttr>(aliasBase).getBaseRecord()->getID();
} else {
enumID = attr->getBaseRecord()->getID();
}
auto&& enumAlias = ctx.enumAlias;
auto&& iter = enumAlias.find(enumID);
auto enumName = attr->getEnumName();
body += "{\n";
body += formatv("auto& e_type = BitCombinedEnumWrapper<{0}::{1}>::type;",
className, enumName);
if (iter == enumAlias.end()) {
os << formatv(
"template<> PyTypeObject "
"BitCombinedEnumWrapper<{0}::{1}>::type={{};\n",
className, enumName);
os << formatv(
"template<> PyNumberMethods "
"BitCombinedEnumWrapper<{0}::{1}>::number_methods={{};\n",
className, enumName);
os << formatv(
"template<> const char* BitCombinedEnumWrapper<{0}::{1}>::name "
"= \"{0}.{1}\";\n",
className, enumName);
os << formatv(
"template<> struct EnumTrait<{0}::{1}> {{ static constexpr "
"bool is_bit_combined = true;};\n",
className, enumName);
std::vector<std::string> pairStr;
for (auto&& i : attr->getEnumMembers()) {
pairStr.push_back(
formatv("{{normalize_enum(\"{2}\"), {0}::{1}::{2}}",
className, enumName, i));
}
os << formatv(R"(
template<> std::unordered_map<std::string, {0}::{1}>
BitCombinedEnumWrapper<{0}::{1}>::str2type = {{
{2}
};
)",
className, enumName, llvm::join(pairStr, ", "));
pairStr.clear();
for (auto&& i : attr->getEnumMembers()) {
pairStr.push_back(
formatv("{{{0}::{1}::{2}, normalize_enum(\"{2}\")}",
className, enumName, i));
}
os << formatv(R"(
template<> std::unordered_map<{0}::{1}, std::string>
BitCombinedEnumWrapper<{0}::{1}>::type2str = {{
{2}
};
)",
className, enumName, llvm::join(pairStr, ", "));
body += formatv(R"(
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}";
e_type.tp_basicsize = sizeof(BitCombinedEnumWrapper<{0}::{1}>);
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
e_type.tp_doc = "{0}.{1}";
e_type.tp_base = &PyBaseObject_Type;
e_type.tp_new = BitCombinedEnumWrapper<{0}::{1}>::py_new_combined_enum;
e_type.tp_init = BitCombinedEnumWrapper<{0}::{1}>::py_init;
e_type.tp_repr = BitCombinedEnumWrapper<{0}::{1}>::py_repr;
e_type.tp_richcompare = BitCombinedEnumWrapper<{0}::{1}>::tp_richcompare;
auto& number_method = BitCombinedEnumWrapper<{0}::{1}>::number_methods;
number_method.nb_or = BitCombinedEnumWrapper<{0}::{1}>::py_or;
number_method.nb_and = BitCombinedEnumWrapper<{0}::{1}>::py_and;
e_type.tp_as_number = &number_method;
mgb_assert(PyType_Ready(&e_type) >= 0);
)",
className, enumName);
for (auto&& i : attr->getEnumMembers()) {
body += formatv(R"({{
PyObject* inst = e_type.tp_alloc(&e_type, 0);
reinterpret_cast<BitCombinedEnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2};
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0);
})",
className, enumName, i);
} }
enumAlias.emplace(enumID, std::make_pair(className, enumName)); enumAlias.emplace(enumID, std::make_pair(className, enumName));
} }
...@@ -486,8 +583,26 @@ EnumWrapper<{0}::{1}>::type2str = {{ ...@@ -486,8 +583,26 @@ EnumWrapper<{0}::{1}>::type2str = {{
PyType_Modified(&e_type); PyType_Modified(&e_type);
mgb_assert(PyDict_SetItemString( mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0); py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0);
)", enumName); )",
enumName);
body += "}\n"; body += "}\n";
return body;
}
static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
auto className = op.getCppClassName();
std::string body;
// generate PyType for enum class member
for (auto&& i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
if (attr->getEnumCombinedFlag()) {
body += gen_op_def_python_c_extension_bit_combined_enum(
os, ctx, attr, className);
} else {
body += gen_op_def_python_c_extension_enum(os, ctx, attr,
className);
}
} }
} }
......
...@@ -141,15 +141,13 @@ R"__usage__( ...@@ -141,15 +141,13 @@ R"__usage__(
)__usage__" )__usage__"
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
R"__usage__( R"__usage__(
--fast-run --full-run
This param will be deperated later, please replace with param --full-profile. Enable full-run mode. Operators with multiple algorithms would be profiled
--full-profile
Enable full-profile mode. Operators with multiple algorithms would be profiled
on the real device with actual input shapes, all algorithms will be profiled on the real device with actual input shapes, all algorithms will be profiled
include naive algorithms. include naive algorithms.
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details.
--fast-profile --fast-run
Enable fast-profile mode. Operators with multiple algorithms would be profiled Enable fast-run mode. Operators with multiple algorithms would be profiled
on the real device with actual input shapes, this mode will only profile the on the real device with actual input shapes, this mode will only profile the
well optimized algorithms to get the profile result fast. well optimized algorithms to get the profile result fast.
See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details. See `mgb::gopt::enable_opr_algo_profiling_inplace` for more details.
...@@ -519,8 +517,8 @@ struct Args { ...@@ -519,8 +517,8 @@ struct Args {
bool disable_assert_throw = false; bool disable_assert_throw = false;
bool share_param_mem = false; bool share_param_mem = false;
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
bool use_full_profile = false; bool use_full_run = false;
bool use_fast_profile = false; bool use_fast_run = false;
#endif #endif
bool reproducible = false; bool reproducible = false;
std::string fast_run_cache_path; std::string fast_run_cache_path;
...@@ -704,13 +702,13 @@ void run_test_st(Args &env) { ...@@ -704,13 +702,13 @@ void run_test_st(Args &env) {
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
S strategy = S::HEURISTIC; S strategy = S::HEURISTIC;
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
if (env.use_full_profile) { if (env.use_full_run) {
if (env.reproducible) { if (env.reproducible) {
strategy = S::PROFILE | S::REPRODUCIBLE; strategy = S::PROFILE | S::REPRODUCIBLE;
} else { } else {
strategy = S::PROFILE; strategy = S::PROFILE;
} }
} else if (env.use_fast_profile) { } else if (env.use_fast_run) {
strategy = S::PROFILE | S::OPTMIZED; strategy = S::PROFILE | S::OPTMIZED;
} else if (env.reproducible) { } else if (env.reproducible) {
strategy = S::HEURISTIC | S::REPRODUCIBLE; strategy = S::HEURISTIC | S::REPRODUCIBLE;
...@@ -740,12 +738,12 @@ void run_test_st(Args &env) { ...@@ -740,12 +738,12 @@ void run_test_st(Args &env) {
std::make_shared<InFilePersistentCache>(buf.get(), flen)); std::make_shared<InFilePersistentCache>(buf.get(), flen));
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
} else { } else {
mgb_assert(env.use_full_profile || env.use_fast_profile, mgb_assert(env.use_full_run || env.use_fast_run,
"fast-run or fast-profile should be enabled"); "fast-run or fast-run should be enabled");
PersistentCache::set_impl( PersistentCache::set_impl(
std::make_shared<InFilePersistentCache>()); std::make_shared<InFilePersistentCache>());
} }
if (!env.use_full_profile && !env.use_fast_profile) if (!env.use_full_run && !env.use_fast_run)
#endif #endif
mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); mgb::gopt::enable_opr_use_profiling_cache_inplace(vars);
} }
...@@ -1326,18 +1324,11 @@ Args Args::from_argv(int argc, char **argv) { ...@@ -1326,18 +1324,11 @@ Args Args::from_argv(int argc, char **argv) {
} }
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
if (!strcmp(argv[i], "--fast-run")) { if (!strcmp(argv[i], "--fast-run")) {
mgb_log_warn( ret.use_fast_run = true;
"--fast-run param will be deperated later, please replace "
"with --full-profile or --fast-profile.");
ret.use_full_profile = true;
continue;
}
if (!strcmp(argv[i], "--full-profile")) {
ret.use_full_profile = true;
continue; continue;
} }
if (!strcmp(argv[i], "--fast-profile")) { if (!strcmp(argv[i], "--full-run")) {
ret.use_fast_profile = true; ret.use_full_run = true;
continue; continue;
} }
#endif #endif
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#pragma once #pragma once
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#include "megbrain/opr/param_defs.h"
#include "megdnn/basic_types.h" #include "megdnn/basic_types.h"
#include <memory> #include <memory>
...@@ -250,10 +249,4 @@ inline constexpr std::size_t operator"" _z(unsigned long long n) { ...@@ -250,10 +249,4 @@ inline constexpr std::size_t operator"" _z(unsigned long long n) {
} // namespace mgb } // namespace mgb
namespace megdnn {
namespace param {
MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy)
}
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "megbrain/utils/hashable.h" #include "megbrain/utils/hashable.h"
#include "megbrain/utils/thin/hash_table.h" #include "megbrain/utils/thin/hash_table.h"
#include "megbrain/utils/small_vector.h" #include "megbrain/utils/small_vector.h"
#include "megbrain/opr/param_defs.h"
#include <type_traits> #include <type_traits>
...@@ -1043,4 +1044,10 @@ MGB_DEFINE_CLS_WITH_SUPER(_name final, _base ,##__VA_ARGS__) \ ...@@ -1043,4 +1044,10 @@ MGB_DEFINE_CLS_WITH_SUPER(_name final, _base ,##__VA_ARGS__) \
} // namespace cg } // namespace cg
} // namespace mgb } // namespace mgb
namespace megdnn {
namespace param {
MGB_DEF_ENUM_CLASS_BIT_OPR(ExecutionPolicy::Strategy)
}
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -278,6 +278,19 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( ...@@ -278,6 +278,19 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space(
return ret; return ret;
} }
//! Test whether the algo attribute of a algo match the require
//! algo_strategy
static bool algo_attribute_match_strategy(AlgoAttribute attribute,
ExecutionStrategy selected_strategy) {
bool ret = true;
if (selected_strategy & ExecutionStrategy::OPTMIZED) {
ret &= (!static_cast<bool>(AlgoAttribute::NAIVE & attribute));
} else if (selected_strategy & ExecutionStrategy::REPRODUCIBLE) {
ret &= static_cast<bool>(AlgoAttribute::REPRODUCIBLE & attribute);
}
return ret;
}
} // namespace } // namespace
namespace mgb { namespace mgb {
...@@ -285,8 +298,8 @@ namespace opr { ...@@ -285,8 +298,8 @@ namespace opr {
template <typename Opr> template <typename Opr>
void AlgoChooser<Opr>::profile(ExeContext& ctx, void AlgoChooser<Opr>::profile(ExeContext& ctx,
ExecutionStrategy select_strategy) { ExecutionStrategy selected_strategy) {
if (ctx.get_profile_result_from_cache(select_strategy).valid()) if (ctx.get_profile_result_from_cache(selected_strategy).valid())
return; return;
AlgoChooserProfileCache::Result prof_rst; AlgoChooserProfileCache::Result prof_rst;
...@@ -306,9 +319,19 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, ...@@ -306,9 +319,19 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
algo.name.c_str(), str_on_inp_shape.c_str()); algo.name.c_str(), str_on_inp_shape.c_str());
ImplExecutionPolicy policy; ImplExecutionPolicy policy;
policy.algo = algo.desc; policy.algo = algo.desc;
ctx.construct_execution_policy(select_strategy, policy); ctx.construct_execution_policy(selected_strategy, policy);
if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) {
continue; continue;
}
auto algo_attribute = ctx.megdnn_opr()
->get_algorithm_from_desc(policy.algo)
->attribute();
if (!algo_attribute_match_strategy(algo_attribute, selected_strategy)) {
mgb_log_debug(
"skip algo %s, which is not match the profile strategy.",
algo.name.c_str());
continue;
}
timer.reset(); timer.reset();
MGB_TRY { cur_rst = ctx.profile_single_algo(policy, cur_timeout); } MGB_TRY { cur_rst = ctx.profile_single_algo(policy, cur_timeout); }
...@@ -356,7 +379,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, ...@@ -356,7 +379,7 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
template <typename Opr> template <typename Opr>
typename AlgoChooser<Opr>::ImplExecutionPolicy typename AlgoChooser<Opr>::ImplExecutionPolicy
AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx,
ExecutionStrategy select_strategy, ExecutionStrategy selected_strategy,
bool enable_update) { bool enable_update) {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile")))
if (ctx.owner_graph()->options().no_profiling_on_shape_change) { if (ctx.owner_graph()->options().no_profiling_on_shape_change) {
...@@ -378,11 +401,11 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, ...@@ -378,11 +401,11 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx,
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(),
_item.param, ctx.mgb_opr(), ctx.comp_node(), _item.param, ctx.mgb_opr(), ctx.comp_node(),
ctx.execution_policy(), ctx.allow_weight_preprocess()); ctx.execution_policy(), ctx.allow_weight_preprocess());
AlgoChooser<_Opr>::profile(sub_ctx, select_strategy); AlgoChooser<_Opr>::profile(sub_ctx, selected_strategy);
}); });
} }
typename AlgoChooser<Opr>::ImplExecutionPolicy policy; typename AlgoChooser<Opr>::ImplExecutionPolicy policy;
ctx.construct_execution_policy(select_strategy, policy); ctx.construct_execution_policy(selected_strategy, policy);
return policy; return policy;
MIDOUT_E MIDOUT_E
} }
...@@ -440,7 +463,8 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( ...@@ -440,7 +463,8 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy(
if (!policy.algo.valid()) if (!policy.algo.valid())
policy = ctx.choose_by_heuristic(opr_strategy); policy = ctx.choose_by_heuristic(opr_strategy);
return policy; return policy;
} else if ((opr_strategy & ExecutionStrategy::HEURISTIC)) { } else if (!static_cast<int>(opr_strategy) ||
(opr_strategy & ExecutionStrategy::HEURISTIC)) {
return ctx.choose_by_heuristic(opr_strategy); return ctx.choose_by_heuristic(opr_strategy);
} }
#if MGB_ENABLE_FASTRUN #if MGB_ENABLE_FASTRUN
...@@ -449,7 +473,7 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( ...@@ -449,7 +473,7 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy(
} }
#endif #endif
else { else {
mgb_throw(GraphError, "bad convolution ExecutionPolicy strategy"); mgb_throw(GraphError, "bad ExecutionPolicy strategy");
} }
} }
...@@ -495,7 +519,7 @@ AlgoChooser<Opr>::ExeContext::ExeContext( ...@@ -495,7 +519,7 @@ AlgoChooser<Opr>::ExeContext::ExeContext(
template <typename Opr> template <typename Opr>
typename AlgoChooser<Opr>::ImplAlgo typename AlgoChooser<Opr>::ImplAlgo
AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
ExecutionStrategy select_strategy) const { ExecutionStrategy selected_strategy) const {
MIDOUT_B(Opr, MIDOUT_B(Opr,
midout_iv(MGB_HASH_STR( midout_iv(MGB_HASH_STR(
"AlgoChooser::ExeContext::get_profile_result_from_cache"))) "AlgoChooser::ExeContext::get_profile_result_from_cache")))
...@@ -519,7 +543,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( ...@@ -519,7 +543,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
if (prof.empty()) if (prof.empty())
return {}; return {};
for (auto&& i : prof) { for (auto&& i : prof) {
if (!(select_strategy & ExecutionStrategy::REPRODUCIBLE) || if (!(selected_strategy & ExecutionStrategy::REPRODUCIBLE) ||
static_cast<AlgoAttribute>(i.attribute) & static_cast<AlgoAttribute>(i.attribute) &
AlgoAttribute::REPRODUCIBLE) { AlgoAttribute::REPRODUCIBLE) {
auto iter = algo_map.find(i.algo); auto iter = algo_map.find(i.algo);
...@@ -550,7 +574,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( ...@@ -550,7 +574,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
template <typename Opr> template <typename Opr>
typename AlgoChooser<Opr>::ImplExecutionPolicy typename AlgoChooser<Opr>::ImplExecutionPolicy
AlgoChooser<Opr>::ExeContext::choose_by_heuristic( AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
ExecutionStrategy select_strategy) const { ExecutionStrategy selected_strategy) const {
if (m_execution_policy.workspace_limit != if (m_execution_policy.workspace_limit !=
std::numeric_limits<decltype( std::numeric_limits<decltype(
m_execution_policy.workspace_limit)>::max()) { m_execution_policy.workspace_limit)>::max()) {
...@@ -558,7 +582,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( ...@@ -558,7 +582,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
"workspace_limit should not be setted if choose algo by " "workspace_limit should not be setted if choose algo by "
"heuristic"); "heuristic");
} }
bool reproducible = static_cast<bool>(select_strategy & bool reproducible = static_cast<bool>(selected_strategy &
ExecutionStrategy::REPRODUCIBLE); ExecutionStrategy::REPRODUCIBLE);
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit); owner_graph(), m_cn, m_execution_policy.workspace_limit);
...@@ -582,7 +606,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( ...@@ -582,7 +606,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, _item.param, m_base_mgb_opr, m_cn, m_execution_policy,
m_allow_weight_preprocess); m_allow_weight_preprocess);
policy.sub_policy.push_back( policy.sub_policy.push_back(
sub_ctx.choose_by_heuristic(select_strategy)); sub_ctx.choose_by_heuristic(selected_strategy));
}); });
return policy; return policy;
...@@ -613,15 +637,15 @@ AlgoChooser<Opr>::ExeContext::get_all_candidates() const { ...@@ -613,15 +637,15 @@ AlgoChooser<Opr>::ExeContext::get_all_candidates() const {
template <typename Opr> template <typename Opr>
void AlgoChooser<Opr>::ExeContext::construct_execution_policy( void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
ExecutionStrategy select_strategy, ExecutionStrategy selected_strategy,
typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, typename AlgoChooser<Opr>::ImplExecutionPolicy& policy,
bool retrive_from_cache) const { bool retrive_from_cache) const {
bool reproducible = static_cast<bool>(select_strategy & bool reproducible = static_cast<bool>(selected_strategy &
ExecutionStrategy::REPRODUCIBLE); ExecutionStrategy::REPRODUCIBLE);
if (!policy.algo.valid()) { if (!policy.algo.valid()) {
if (retrive_from_cache) { if (retrive_from_cache) {
policy.algo = policy.algo =
get_profile_result_from_cache(select_strategy).desc; get_profile_result_from_cache(selected_strategy).desc;
} else { } else {
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit); owner_graph(), m_cn, m_execution_policy.workspace_limit);
...@@ -651,7 +675,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( ...@@ -651,7 +675,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, _item.param, m_base_mgb_opr, m_cn, m_execution_policy,
m_allow_weight_preprocess); m_allow_weight_preprocess);
policy.sub_policy.push_back({}); policy.sub_policy.push_back({});
sub_ctx.construct_execution_policy(select_strategy, sub_ctx.construct_execution_policy(selected_strategy,
policy.sub_policy.back(), policy.sub_policy.back(),
retrive_from_cache); retrive_from_cache);
}); });
......
...@@ -110,7 +110,7 @@ public: ...@@ -110,7 +110,7 @@ public:
const FixedTensorLayouts& layouts() const { return m_layouts; } const FixedTensorLayouts& layouts() const { return m_layouts; }
ImplExecutionPolicy choose_by_heuristic( ImplExecutionPolicy choose_by_heuristic(
ExecutionStrategy select_strategy) const; ExecutionStrategy selected_strategy) const;
//! get all candidate algos, and the one choose_by_heuristic() is //! get all candidate algos, and the one choose_by_heuristic() is
//! put first //! put first
...@@ -134,17 +134,17 @@ public: ...@@ -134,17 +134,17 @@ public:
//! get all profile algorithm from cache, return invalid if not exists //! get all profile algorithm from cache, return invalid if not exists
ImplAlgo get_profile_result_from_cache( ImplAlgo get_profile_result_from_cache(
ExecutionStrategy select_strategy) const; ExecutionStrategy selected_strategy) const;
/** /**
* \brief construct execution policy from cache or heuristic. * \brief construct execution policy from cache or heuristic.
* *
* \param select_strategy select algo which matched this strategy * \param selected_strategy select algo which matched this strategy
* \param policy execution policy * \param policy execution policy
* \param retrive_from_cache retrive algo from cache if set True, get * \param retrive_from_cache retrive algo from cache if set True, get
* from heuristic otherwise. * from heuristic otherwise.
*/ */
void construct_execution_policy(ExecutionStrategy select_strategy, void construct_execution_policy(ExecutionStrategy selected_strategy,
ImplExecutionPolicy& policy, ImplExecutionPolicy& policy,
bool retrive_from_cache = true) const; bool retrive_from_cache = true) const;
...@@ -161,10 +161,10 @@ private: ...@@ -161,10 +161,10 @@ private:
//! profile and save to cache //! profile and save to cache
static void profile(ExeContext& ctx, ExecutionStrategy select_strategy); static void profile(ExeContext& ctx, ExecutionStrategy selected_strategy);
static ImplExecutionPolicy choose_by_profile( static ImplExecutionPolicy choose_by_profile(
ExeContext& ctx, ExecutionStrategy select_strategy, ExeContext& ctx, ExecutionStrategy selected_strategy,
bool enable_update = true); bool enable_update = true);
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册