From 282dfc62326405fd576c3fe182de133c2b44792b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 20 Apr 2021 16:05:05 +0800 Subject: [PATCH] refactor(imperative): alloc enum type class on heap GitOrigin-RevId: d2b2acea229df68151f04ce17c1e73621dd7fb60 --- imperative/python/src/ops.cpp | 10 ++- .../test/unit/core/test_imperative_rt.py | 13 ++++ .../tablegen/targets/python_c_extension.cpp | 70 ++++++++++++------- imperative/test/CMakeLists.txt | 1 + 4 files changed, 62 insertions(+), 32 deletions(-) diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 392092edc..1be1253d0 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -170,7 +170,7 @@ struct EnumTrait; PyObject_HEAD \ T value; \ constexpr static const char *name = EnumTrait::name; \ - static PyTypeObject type; \ + static PyTypeObject* type; \ static const char* members[]; \ static std::unordered_map mem2value; \ static PyObject* pyobj_insts[]; @@ -196,7 +196,7 @@ struct EnumWrapper { } static bool load(py::handle src, T& value) { PyObject* obj = src.ptr(); - if (PyObject_TypeCheck(obj, &type)) { + if (PyObject_TypeCheck(obj, type)) { value = reinterpret_cast(obj)->value; return true; } @@ -224,7 +224,6 @@ struct EnumWrapper { template struct BitCombinedEnumWrapper { PyEnumHead - static PyNumberMethods number_methods; std::string to_string() const { uint32_t value_int = static_cast(value); if (value_int == 0) { @@ -302,7 +301,7 @@ struct BitCombinedEnumWrapper { } static bool load(py::handle src, T& value) { PyObject* obj = src.ptr(); - if (PyObject_TypeCheck(obj, &type)) { + if (PyObject_TypeCheck(obj, type)) { value = reinterpret_cast(obj)->value; return true; } @@ -330,8 +329,7 @@ struct BitCombinedEnumWrapper { auto v = static_cast>(value); mgb_assert(v <= EnumTrait::max); if ((!v) || (v & (v - 1))) { - PyTypeObject* pytype = &type; - PyObject* obj = pytype->tp_alloc(pytype, 0); + PyObject* obj = type->tp_alloc(type, 0); reinterpret_cast(obj)->value = value; return obj; } else { diff --git a/imperative/python/test/unit/core/test_imperative_rt.py b/imperative/python/test/unit/core/test_imperative_rt.py index 4934eecf1..04b08307f 100644 --- a/imperative/python/test/unit/core/test_imperative_rt.py +++ b/imperative/python/test/unit/core/test_imperative_rt.py @@ -69,3 +69,16 @@ def test_raw_tensor(): np.testing.assert_allclose(x * x, yy.numpy()) (yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx) np.testing.assert_allclose(x * x, yy.numpy()) + + +def test_opdef_path(): + from megengine.core.ops.builtin import Elemwise + + assert Elemwise.__module__ == "megengine.core._imperative_rt.ops" + assert Elemwise.__name__ == "Elemwise" + assert Elemwise.__qualname__ == "Elemwise" + + Mode = Elemwise.Mode + assert Mode.__module__ == "megengine.core._imperative_rt.ops" + assert Mode.__name__ == "Mode" + assert Mode.__qualname__ == "Elemwise.Mode" diff --git a/imperative/tablegen/targets/python_c_extension.cpp b/imperative/tablegen/targets/python_c_extension.cpp index a51f2ceb3..c17506e13 100644 --- a/imperative/tablegen/targets/python_c_extension.cpp +++ b/imperative/tablegen/targets/python_c_extension.cpp @@ -97,7 +97,7 @@ void EnumAttrEmitter::emit_tpl_spl() { if (!firstOccur) return; os << tgfmt( - "template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type = {};\n", + "template<> PyTypeObject* $enumTpl<$opClass::$enumClass>::type = nullptr;\n", &ctx); auto quote = [&](auto&& i) -> std::string { @@ -120,13 +120,6 @@ $enumTpl<$opClass::$enumClass>::mem2value = {$0}; "template<> PyObject* " "$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n", &ctx, attr->getEnumMembers().size()); - - if (attr->getEnumCombinedFlag()) { - os << tgfmt( - "template<> PyNumberMethods " - "$enumTpl<$opClass::$enumClass>::number_methods = {};\n", - &ctx); - } } Initproc EnumAttrEmitter::emit_initproc() { @@ -140,45 +133,70 @@ void $0(PyTypeObject& py_type) { if (firstOccur) { os << tgfmt(R"( - e_type = {PyVarObject_HEAD_INIT(NULL, 0)}; - e_type.tp_name = "megengine.core._imperative_rt.ops.$opClass.$enumClass"; - e_type.tp_basicsize = sizeof($enumTpl<$opClass::$enumClass>); - e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; - e_type.tp_doc = "$opClass.$enumClass"; - e_type.tp_base = &PyBaseObject_Type; - e_type.tp_repr = $enumTpl<$opClass::$enumClass>::py_repr; - e_type.tp_richcompare = $enumTpl<$opClass::$enumClass>::tp_richcompare; + static PyType_Slot slots[] = { + {Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr}, + {Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare}, )", &ctx); if (attr->getEnumCombinedFlag()) { // only bit combined enum could new instance because bitwise operation, // others should always use singleton os << tgfmt(R"( - e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum; - auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods; - number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or; - number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and; - e_type.tp_as_number = &number_method; + {Py_tp_new, (void*)$enumTpl<$opClass::$enumClass>::py_new_combined_enum}, + {Py_nb_or, (void*)$enumTpl<$opClass::$enumClass>::py_or}, + {Py_nb_and, (void*)$enumTpl<$opClass::$enumClass>::py_and}, )", &ctx); } + os << R"( + {0, NULL} + };)"; + +os << tgfmt(R"( + static PyType_Spec spec = { + // name + "megengine.core._imperative_rt.ops.$opClass.$enumClass", + // basicsize + sizeof($enumTpl<$opClass::$enumClass>), + // itemsize + 0, + // flags + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, + // slots + slots + };)", &ctx); + + os << tgfmt(R"( + e_type = reinterpret_cast(PyType_FromSpec(&spec)); +)", &ctx); - os << " mgb_assert(PyType_Ready(&e_type) >= 0);\n"; + for (auto&& i : { + std::pair{"__name__", tgfmt("$enumClass", &ctx)}, + {"__module__", "megengine.core._imperative_rt.ops"}, + {"__qualname__", tgfmt("$opClass.$enumClass", &ctx)}}) { + os << formatv(R"( + mgb_assert( + e_type->tp_setattro( + reinterpret_cast(e_type), + py::cast("{0}").release().ptr(), + py::cast("{1}").release().ptr()) >= 0); +)", i.first, i.second); + } auto&& members = attr->getEnumMembers(); for (size_t idx = 0; idx < members.size(); ++ idx) { os << tgfmt(R"({ - PyObject* inst = e_type.tp_alloc(&e_type, 0); + PyObject* inst = e_type->tp_alloc(e_type, 0); reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0; - mgb_assert(PyDict_SetItemString(e_type.tp_dict, "$0", inst) >= 0); + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0); $enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst; })", &ctx, members[idx], idx); } - os << " PyType_Modified(&e_type);\n"; } os << tgfmt(R"( + Py_INCREF(e_type); mgb_assert(PyDict_SetItemString( - py_type.tp_dict, "$enumClass", reinterpret_cast(&e_type)) >= 0); + py_type.tp_dict, "$enumClass", reinterpret_cast(e_type)) >= 0); )", &ctx); os << "}\n"; return initproc; diff --git a/imperative/test/CMakeLists.txt b/imperative/test/CMakeLists.txt index 68d0dcecb..3ecd79363 100644 --- a/imperative/test/CMakeLists.txt +++ b/imperative/test/CMakeLists.txt @@ -11,6 +11,7 @@ endif() # TODO: turn python binding into a static/object library add_executable(imperative_test ${SOURCES} ${SRCS}) +add_dependencies(imperative_test mgb_opdef) target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR}) # Python binding -- GitLab