提交 282dfc62 编写于 作者: M Megvii Engine Team

refactor(imperative): alloc enum type class on heap

GitOrigin-RevId: d2b2acea229df68151f04ce17c1e73621dd7fb60
上级 1e6ef377
...@@ -170,7 +170,7 @@ struct EnumTrait; ...@@ -170,7 +170,7 @@ struct EnumTrait;
PyObject_HEAD \ PyObject_HEAD \
T value; \ T value; \
constexpr static const char *name = EnumTrait<T>::name; \ constexpr static const char *name = EnumTrait<T>::name; \
static PyTypeObject type; \ static PyTypeObject* type; \
static const char* members[]; \ static const char* members[]; \
static std::unordered_map<std::string, T> mem2value; \ static std::unordered_map<std::string, T> mem2value; \
static PyObject* pyobj_insts[]; static PyObject* pyobj_insts[];
...@@ -196,7 +196,7 @@ struct EnumWrapper { ...@@ -196,7 +196,7 @@ struct EnumWrapper {
} }
static bool load(py::handle src, T& value) { static bool load(py::handle src, T& value) {
PyObject* obj = src.ptr(); PyObject* obj = src.ptr();
if (PyObject_TypeCheck(obj, &type)) { if (PyObject_TypeCheck(obj, type)) {
value = reinterpret_cast<EnumWrapper*>(obj)->value; value = reinterpret_cast<EnumWrapper*>(obj)->value;
return true; return true;
} }
...@@ -224,7 +224,6 @@ struct EnumWrapper { ...@@ -224,7 +224,6 @@ struct EnumWrapper {
template<typename T> template<typename T>
struct BitCombinedEnumWrapper { struct BitCombinedEnumWrapper {
PyEnumHead PyEnumHead
static PyNumberMethods number_methods;
std::string to_string() const { std::string to_string() const {
uint32_t value_int = static_cast<uint32_t>(value); uint32_t value_int = static_cast<uint32_t>(value);
if (value_int == 0) { if (value_int == 0) {
...@@ -302,7 +301,7 @@ struct BitCombinedEnumWrapper { ...@@ -302,7 +301,7 @@ struct BitCombinedEnumWrapper {
} }
static bool load(py::handle src, T& value) { static bool load(py::handle src, T& value) {
PyObject* obj = src.ptr(); PyObject* obj = src.ptr();
if (PyObject_TypeCheck(obj, &type)) { if (PyObject_TypeCheck(obj, type)) {
value = reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value; value = reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value;
return true; return true;
} }
...@@ -330,8 +329,7 @@ struct BitCombinedEnumWrapper { ...@@ -330,8 +329,7 @@ struct BitCombinedEnumWrapper {
auto v = static_cast<std::underlying_type_t<T>>(value); auto v = static_cast<std::underlying_type_t<T>>(value);
mgb_assert(v <= EnumTrait<T>::max); mgb_assert(v <= EnumTrait<T>::max);
if ((!v) || (v & (v - 1))) { if ((!v) || (v & (v - 1))) {
PyTypeObject* pytype = &type; PyObject* obj = type->tp_alloc(type, 0);
PyObject* obj = pytype->tp_alloc(pytype, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value; reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
return obj; return obj;
} else { } else {
......
...@@ -69,3 +69,16 @@ def test_raw_tensor(): ...@@ -69,3 +69,16 @@ def test_raw_tensor():
np.testing.assert_allclose(x * x, yy.numpy()) np.testing.assert_allclose(x * x, yy.numpy())
(yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx) (yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx)
np.testing.assert_allclose(x * x, yy.numpy()) np.testing.assert_allclose(x * x, yy.numpy())
def test_opdef_path():
from megengine.core.ops.builtin import Elemwise
assert Elemwise.__module__ == "megengine.core._imperative_rt.ops"
assert Elemwise.__name__ == "Elemwise"
assert Elemwise.__qualname__ == "Elemwise"
Mode = Elemwise.Mode
assert Mode.__module__ == "megengine.core._imperative_rt.ops"
assert Mode.__name__ == "Mode"
assert Mode.__qualname__ == "Elemwise.Mode"
...@@ -97,7 +97,7 @@ void EnumAttrEmitter::emit_tpl_spl() { ...@@ -97,7 +97,7 @@ void EnumAttrEmitter::emit_tpl_spl() {
if (!firstOccur) return; if (!firstOccur) return;
os << tgfmt( os << tgfmt(
"template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type = {};\n", "template<> PyTypeObject* $enumTpl<$opClass::$enumClass>::type = nullptr;\n",
&ctx); &ctx);
auto quote = [&](auto&& i) -> std::string { auto quote = [&](auto&& i) -> std::string {
...@@ -120,13 +120,6 @@ $enumTpl<$opClass::$enumClass>::mem2value = {$0}; ...@@ -120,13 +120,6 @@ $enumTpl<$opClass::$enumClass>::mem2value = {$0};
"template<> PyObject* " "template<> PyObject* "
"$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n", "$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n",
&ctx, attr->getEnumMembers().size()); &ctx, attr->getEnumMembers().size());
if (attr->getEnumCombinedFlag()) {
os << tgfmt(
"template<> PyNumberMethods "
"$enumTpl<$opClass::$enumClass>::number_methods = {};\n",
&ctx);
}
} }
Initproc EnumAttrEmitter::emit_initproc() { Initproc EnumAttrEmitter::emit_initproc() {
...@@ -140,45 +133,70 @@ void $0(PyTypeObject& py_type) { ...@@ -140,45 +133,70 @@ void $0(PyTypeObject& py_type) {
if (firstOccur) { if (firstOccur) {
os << tgfmt(R"( os << tgfmt(R"(
e_type = {PyVarObject_HEAD_INIT(NULL, 0)}; static PyType_Slot slots[] = {
e_type.tp_name = "megengine.core._imperative_rt.ops.$opClass.$enumClass"; {Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr},
e_type.tp_basicsize = sizeof($enumTpl<$opClass::$enumClass>); {Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare},
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
e_type.tp_doc = "$opClass.$enumClass";
e_type.tp_base = &PyBaseObject_Type;
e_type.tp_repr = $enumTpl<$opClass::$enumClass>::py_repr;
e_type.tp_richcompare = $enumTpl<$opClass::$enumClass>::tp_richcompare;
)", &ctx); )", &ctx);
if (attr->getEnumCombinedFlag()) { if (attr->getEnumCombinedFlag()) {
// only bit combined enum could new instance because bitwise operation, // only bit combined enum could new instance because bitwise operation,
// others should always use singleton // others should always use singleton
os << tgfmt(R"( os << tgfmt(R"(
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum; {Py_tp_new, (void*)$enumTpl<$opClass::$enumClass>::py_new_combined_enum},
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods; {Py_nb_or, (void*)$enumTpl<$opClass::$enumClass>::py_or},
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or; {Py_nb_and, (void*)$enumTpl<$opClass::$enumClass>::py_and},
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and;
e_type.tp_as_number = &number_method;
)", &ctx); )", &ctx);
} }
os << R"(
{0, NULL}
};)";
os << tgfmt(R"(
static PyType_Spec spec = {
// name
"megengine.core._imperative_rt.ops.$opClass.$enumClass",
// basicsize
sizeof($enumTpl<$opClass::$enumClass>),
// itemsize
0,
// flags
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
// slots
slots
};)", &ctx);
os << tgfmt(R"(
e_type = reinterpret_cast<PyTypeObject*>(PyType_FromSpec(&spec));
)", &ctx);
os << " mgb_assert(PyType_Ready(&e_type) >= 0);\n"; for (auto&& i : {
std::pair<std::string, std::string>{"__name__", tgfmt("$enumClass", &ctx)},
{"__module__", "megengine.core._imperative_rt.ops"},
{"__qualname__", tgfmt("$opClass.$enumClass", &ctx)}}) {
os << formatv(R"(
mgb_assert(
e_type->tp_setattro(
reinterpret_cast<PyObject*>(e_type),
py::cast("{0}").release().ptr(),
py::cast("{1}").release().ptr()) >= 0);
)", i.first, i.second);
}
auto&& members = attr->getEnumMembers(); auto&& members = attr->getEnumMembers();
for (size_t idx = 0; idx < members.size(); ++ idx) { for (size_t idx = 0; idx < members.size(); ++ idx) {
os << tgfmt(R"({ os << tgfmt(R"({
PyObject* inst = e_type.tp_alloc(&e_type, 0); PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0; reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "$0", inst) >= 0); mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
$enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst; $enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
})", &ctx, members[idx], idx); })", &ctx, members[idx], idx);
} }
os << " PyType_Modified(&e_type);\n";
} }
os << tgfmt(R"( os << tgfmt(R"(
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString( mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(&e_type)) >= 0); py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(e_type)) >= 0);
)", &ctx); )", &ctx);
os << "}\n"; os << "}\n";
return initproc; return initproc;
......
...@@ -11,6 +11,7 @@ endif() ...@@ -11,6 +11,7 @@ endif()
# TODO: turn python binding into a static/object library # TODO: turn python binding into a static/object library
add_executable(imperative_test ${SOURCES} ${SRCS}) add_executable(imperative_test ${SOURCES} ${SRCS})
add_dependencies(imperative_test mgb_opdef)
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR}) target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR})
# Python binding # Python binding
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册