From a81085221a1120709c24263d67b13a90726235ad Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 14 Apr 2021 18:15:12 +0800 Subject: [PATCH] refactor(imperative): refactor enum param type caster GitOrigin-RevId: 1aae07f143b8d1c0176a41de790fee8d6b2f1a25 --- imperative/python/src/graph_rt.cpp | 15 +- imperative/python/src/ops.cpp | 237 +++++++++--------- imperative/python/src/ops.h | 21 ++ imperative/tablegen/CMakeLists.txt | 3 +- imperative/tablegen/autogen.cpp | 11 +- imperative/tablegen/targets/macros.cpp | 56 +++++ imperative/tablegen/targets/macros.h | 19 ++ .../tablegen/targets/python_c_extension.cpp | 82 +++--- 8 files changed, 281 insertions(+), 163 deletions(-) create mode 100644 imperative/tablegen/targets/macros.cpp create mode 100644 imperative/tablegen/targets/macros.h diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 7ea3b9b19..15768f7ec 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -21,6 +21,7 @@ #include "./helper.h" #include "megbrain/plugin/profiler.h" #include "./common.h" +#include "./ops.h" #include "megbrain/gopt/inference.h" @@ -265,18 +266,8 @@ void init_graph_rt(py::module m) { }); m.def("modify_opr_algo_strategy_inplace", [](const VarNodeArray& dest_vars, - const std::string& strategy) { - _AlgoStrategy stg; - const std::unordered_map> m{ - {"HEURISTIC", [&]() { stg = _AlgoStrategy::HEURISTIC; }}, - {"PROFILE", [&]() { stg = _AlgoStrategy::PROFILE; }}, - {"REPRODUCIBLE", [&]() { stg = _AlgoStrategy::REPRODUCIBLE; }}, - {"OPTIMIZED", [&]() { stg = _AlgoStrategy::OPTIMIZED; }}, - }; - auto it = m.find(strategy); - mgb_assert(it != m.end(), "Invalid strategy string!"); - it->second(); - mgb::gopt::modify_opr_algo_strategy_inplace(dest_vars, stg); + const _AlgoStrategy& strategy) { + mgb::gopt::modify_opr_algo_strategy_inplace(dest_vars, strategy); }); m.def("get_info_for_strip", [](const std::vector& dest_vars) { diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index f34975fcc..392092edc 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -73,29 +73,6 @@ PyTypeObject PyOpType(name); } \ } while (0) -template -struct pyobj_convert_generic { - static T from(PyObject* obj) { - // TODO: remove this guard which is used for pybind11 implicit conversion - py::detail::loader_life_support guard{}; - return py::cast(py::handle(obj)); - } - template>>> - static PyObject* to(U&& t) { - return py::cast(std::forward(t)).release().ptr(); - } -}; - -template -struct EnumTrait; - -template -struct EnumTrait>> { - static constexpr bool is_bit_combined = false; - static constexpr std::underlying_type_t max = 0; -}; - template PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { PyObject* obj = type->tp_alloc(type, 0); @@ -115,7 +92,7 @@ void py_dealloc_generic(PyObject* obj) { template PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) { auto& op = reinterpret_cast(obj)->inst(); - return pyobj_convert_generic::to(op.*attr); + return py::cast(op.*attr).release().ptr(); } #define py_get_generic(name, attr) \ py_get_generic_impl().attr), &name::attr> @@ -128,7 +105,9 @@ int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { } auto& op = reinterpret_cast(obj)->inst(); try { - op.*attr = pyobj_convert_generic::from(value); + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + op.*attr = py::cast(py::handle(value)); } CATCH_ALL(-1) return 0; } @@ -148,8 +127,8 @@ PyTypeObject PyOpType(OpDef); std::unordered_map PyOp(OpDef)::ctype2pytype; PyObject* py_get_scope(PyObject* obj, void* /* closure */) { - return pyobj_convert_generic::to( - reinterpret_cast(obj)->op->scope()); + return py::cast( + reinterpret_cast(obj)->op->scope()).release().ptr(); } int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) { @@ -159,7 +138,7 @@ int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) { } try { reinterpret_cast(obj)->op - ->set_scope(pyobj_convert_generic::from(value)); + ->set_scope(py::cast(py::handle(value))); } CATCH_ALL(-1) return 0; } @@ -183,24 +162,29 @@ PyObject* PyOp(OpDef)::tp_richcompare(PyObject *self, PyObject *other, int op) { Py_RETURN_NOTIMPLEMENTED; } +template +struct EnumTrait; + +#define PyEnumHead \ + static_assert(std::is_enum_v); \ + PyObject_HEAD \ + T value; \ + constexpr static const char *name = EnumTrait::name; \ + static PyTypeObject type; \ + static const char* members[]; \ + static std::unordered_map mem2value; \ + static PyObject* pyobj_insts[]; + template struct EnumWrapper { - static_assert(std::is_enum_v); - PyObject_HEAD - T value; - static const char* name; - static PyTypeObject type; - static std::unordered_map type2str; - static std::unordered_map str2type; - EnumWrapper() = default; - EnumWrapper(T v): value(v) {} - EnumWrapper(std::string&& str): EnumWrapper(str2type.at(normalize_enum(str))) {} + PyEnumHead std::string to_string() const { - return type2str.at(value); + return members[static_cast(value)]; } static PyObject* py_repr(PyObject* self) { - return pyobj_convert_generic::to( - std::string(name) + "." + reinterpret_cast(self)->to_string()); + return py::cast( + std::string(name) + "." + reinterpret_cast(self)->to_string()) + .release().ptr(); } static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) { T lhs = reinterpret_cast(self)->value, @@ -210,59 +194,52 @@ struct EnumWrapper { } Py_RETURN_NOTIMPLEMENTED; } -}; - -template -struct pyobj_convert_generic> && - !EnumTrait::is_bit_combined>> { - using Wrapper = EnumWrapper; - static T from(PyObject* obj) { - if (PyObject_TypeCheck(obj, &Wrapper::type)) { - return reinterpret_cast(obj)->value; + static bool load(py::handle src, T& value) { + PyObject* obj = src.ptr(); + if (PyObject_TypeCheck(obj, &type)) { + value = reinterpret_cast(obj)->value; + return true; + } + if (py::isinstance(src)) { + auto&& iter = mem2value.find( + normalize_enum(py::cast(src))); + if (iter != mem2value.end()) { + value = iter->second; + return true; + } else { + return false; + } } - // try as string - // TODO: type checkcd - return Wrapper(pyobj_convert_generic::from(obj)).value; + return false; } - static PyObject* to(T t) { - PyTypeObject* pytype = &Wrapper::type; - PyObject* obj = pytype->tp_alloc(pytype, 0); - reinterpret_cast(obj)->value = t; + static PyObject* cast(const T& value) { + auto v = static_cast>(value); + mgb_assert(v <= EnumTrait::max); + PyObject* obj = pyobj_insts[v]; + Py_INCREF(obj); return obj; } }; template struct BitCombinedEnumWrapper { - static_assert(std::is_enum_v); - PyObject_HEAD - T value; - static const char* name; - static PyTypeObject type; - static std::unordered_map type2str; - static std::unordered_map str2type; + PyEnumHead static PyNumberMethods number_methods; - BitCombinedEnumWrapper() = default; - BitCombinedEnumWrapper(T v): value(v) {} - BitCombinedEnumWrapper(std::string&& str) - : BitCombinedEnumWrapper(str2type.at(normalize_enum(str))) {} std::string to_string() const { - if (static_cast(value) == 0) { + uint32_t value_int = static_cast(value); + if (value_int == 0) { return "None"; } else { - auto ret = std::string(); + std::string ret; bool first = true; for (uint32_t i = 0; i < 32; i++) { - uint32_t value_int = static_cast(value); - auto it = type2str.find(static_cast((1 << i) & value_int)); - if (it != type2str.end()) { + if (value_int >> i & 1) { if (!first) { ret += " + "; } else { first = false; } - ret += (std::string(name) + "." + it->second); + ret += (std::string(name) + "." + members[i]); } } return ret; @@ -280,17 +257,20 @@ struct BitCombinedEnumWrapper { return nullptr; } T value; - try { - value = pyobj_convert_generic::from(input); - } CATCH_ALL(nullptr); - PyObject* obj = type->tp_alloc(type, 0); - reinterpret_cast(obj)->value = value; - return obj; + if (load(input, value)) { + return cast(value); + } else { + PyErr_SetString(PyExc_RuntimeError, + mgb::ssprintf("Cannot convert type %s to type %s\n", + input->ob_type->tp_name, name).c_str()); + return nullptr; + } } } static PyObject* py_repr(PyObject* self) { - return pyobj_convert_generic::to( - reinterpret_cast(self)->to_string()); + return py::cast( + reinterpret_cast(self)->to_string()) + .release().ptr(); } static PyObject* py_or(PyObject* self, PyObject* other) { if(!(self->ob_type == other->ob_type)){ @@ -298,12 +278,9 @@ struct BitCombinedEnumWrapper { PyExc_RuntimeError, "Operand in or operator must be the same type."); } - PyObject* obj = type.tp_alloc(&type, 0); T lhs = reinterpret_cast(self)->value, rhs = reinterpret_cast(other)->value; - reinterpret_cast(obj)->value = static_cast( - static_cast(lhs) | static_cast(rhs)); - return obj; + return cast(lhs | rhs); } static PyObject* py_and(PyObject* self, PyObject* other) { if (!(self->ob_type == other->ob_type)) { @@ -311,12 +288,9 @@ struct BitCombinedEnumWrapper { PyExc_RuntimeError, "Operand in and operator must be the same type."); } - PyObject* obj = type.tp_alloc(&type, 0); T lhs = reinterpret_cast(self)->value, rhs = reinterpret_cast(other)->value; - reinterpret_cast(obj)->value = static_cast( - static_cast(lhs) & static_cast(rhs)); - return obj; + return cast(lhs & rhs); } static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) { T lhs = reinterpret_cast(self)->value, @@ -326,32 +300,45 @@ struct BitCombinedEnumWrapper { } Py_RETURN_NOTIMPLEMENTED; } -}; - -template -struct pyobj_convert_generic> && - EnumTrait::is_bit_combined>> { - using Wrapper = BitCombinedEnumWrapper; - static T from(PyObject* obj) { - if (PyObject_TypeCheck(obj, &Wrapper::type)) { - return reinterpret_cast(obj)->value; - } else if(PyLong_Check(obj)) { - auto value = pyobj_convert_generic>::from(obj); - mgb_throw_if(value > EnumTrait::max, mgb::MegBrainError, - "out of range, cannot convert %zu to %s", - static_cast(value), Wrapper::name); - return static_cast(value); + static bool load(py::handle src, T& value) { + PyObject* obj = src.ptr(); + if (PyObject_TypeCheck(obj, &type)) { + value = reinterpret_cast(obj)->value; + return true; + } + if (py::isinstance(src)) { + auto&& iter = mem2value.find( + normalize_enum(py::cast(src))); + if (iter != mem2value.end()) { + value = iter->second; + return true; + } else { + return false; + } + } + if (py::isinstance(obj)) { + auto v = py::cast>(src); + if(v > EnumTrait::max) { + return false; + } + value = static_cast(v); + return true; } - // try as string - // TODO: type checkcd - return Wrapper(pyobj_convert_generic::from(obj)).value; + return false; } - static PyObject* to(T t) { - PyTypeObject* pytype = &Wrapper::type; - PyObject* obj = pytype->tp_alloc(pytype, 0); - reinterpret_cast(obj)->value = t; - return obj; + static PyObject* cast(const T& value) { + auto v = static_cast>(value); + mgb_assert(v <= EnumTrait::max); + if ((!v) || (v & (v - 1))) { + PyTypeObject* pytype = &type; + PyObject* obj = pytype->tp_alloc(pytype, 0); + reinterpret_cast(obj)->value = value; + return obj; + } else { + PyObject* obj = pyobj_insts[__builtin_ctz(v)]; + Py_INCREF(obj); + return obj; + } } }; @@ -443,7 +430,6 @@ void _init_py_op_base(py::module m) { #include "opdef.cpy.inl" #undef CATCH_ALL - } // anonymous namespace namespace PYBIND11_NAMESPACE { @@ -478,6 +464,25 @@ handle type_caster::cast(const OpDef& op, return_value_policy, handle) { reinterpret_cast(obj)->op = const_cast(op).shared_from_this(); return py::handle(obj); } + +#define ENUM_CASTER_IMPL(T) \ +bool type_caster::load(handle src, bool) { \ + return EnumWrapper::load(src, value); \ +} \ +handle type_caster::cast(const T& value, return_value_policy, handle) { \ + return EnumWrapper::cast(value); \ +} +FOR_EACH_ENUM_PARAM(ENUM_CASTER_IMPL) + +#define BIT_COMBINED_ENUM_CASTER_IMPL(T) \ +bool type_caster::load(handle src, bool) { \ + return BitCombinedEnumWrapper::load(src, value); \ +} \ +handle type_caster::cast(const T& value, return_value_policy, handle) { \ + return BitCombinedEnumWrapper::cast(value); \ +} +FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL) + } // detail } // PYBIND11_NAMESPACE diff --git a/imperative/python/src/ops.h b/imperative/python/src/ops.h index 1beba75b3..f230c8b55 100644 --- a/imperative/python/src/ops.h +++ b/imperative/python/src/ops.h @@ -12,5 +12,26 @@ #pragma once #include "./helper.h" +#include "./enum_macro.h" + +#include "megdnn/opr_param_defs.h" +#include "megbrain/opr/param_defs.h" + +namespace PYBIND11_NAMESPACE { +namespace detail { + +#define ENUM_CASTER_DEF(name) \ +template<> struct type_caster { \ + PYBIND11_TYPE_CASTER(name, _(#name)); \ +public: \ + bool load(handle src, bool); \ + static handle cast(const name& v, return_value_policy, handle); \ +}; + +FOR_EACH_ENUM_PARAM(ENUM_CASTER_DEF) +FOR_EACH_BIT_COMBINED_ENUM_PARAM(ENUM_CASTER_DEF) + +} // detail +} // PYBIND11_NAMESPACE void init_ops(pybind11::module m); diff --git a/imperative/tablegen/CMakeLists.txt b/imperative/tablegen/CMakeLists.txt index 1a5466ef9..7b4a18027 100644 --- a/imperative/tablegen/CMakeLists.txt +++ b/imperative/tablegen/CMakeLists.txt @@ -12,5 +12,6 @@ tablegen(MGB opdef.h.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-header") tablegen(MGB opdef.cpp.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-body") tablegen(MGB opdef.py.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-binding") tablegen(MGB opdef.cpy.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-c-extension") -add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl opdef.cpy.inl param_defs_tblgen) +tablegen(MGB enum_macro.h ${MGE_IR_INCLUDE_DIRS} "--gen-enum-list-macro") +add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl opdef.cpy.inl enum_macro.h param_defs_tblgen) set(MGB_OPDEF_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE) diff --git a/imperative/tablegen/autogen.cpp b/imperative/tablegen/autogen.cpp index 83a861111..27ea91a27 100644 --- a/imperative/tablegen/autogen.cpp +++ b/imperative/tablegen/autogen.cpp @@ -12,6 +12,7 @@ #include "./targets/cpp_class.h" #include "./targets/pybind11.h" #include "./targets/python_c_extension.h" +#include "./targets/macros.h" using llvm::raw_ostream; using llvm::RecordKeeper; @@ -21,7 +22,8 @@ enum ActionType { CppHeader, CppBody, Pybind, - CPython + CPython, + EnumListMacro }; // NOLINTNEXTLINE @@ -34,7 +36,9 @@ llvm::cl::opt action( clEnumValN(Pybind, "gen-python-binding", "Generate pybind11 python bindings"), clEnumValN(CPython, "gen-python-c-extension", - "Generate python c extensions"))); + "Generate python c extensions"), + clEnumValN(EnumListMacro, "gen-enum-list-macro", + "Generate enum param list macro"))); using namespace mlir::tblgen; @@ -53,5 +57,8 @@ int main(int argc, char **argv) { if (action == ActionType::CPython) { return TableGenMain(argv[0], &gen_op_def_python_c_extension); } + if (action == ActionType::EnumListMacro) { + return TableGenMain(argv[0], &gen_enum_param_list_macro); + } return -1; } diff --git a/imperative/tablegen/targets/macros.cpp b/imperative/tablegen/targets/macros.cpp new file mode 100644 index 000000000..9df4256b1 --- /dev/null +++ b/imperative/tablegen/targets/macros.cpp @@ -0,0 +1,56 @@ +/** + * \file imperative/tablegen/targets/macros.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./cpp_class.h" +#include "../emitter.h" + +namespace mlir::tblgen { +bool gen_enum_param_list_macro(raw_ostream &os, llvm::RecordKeeper &keeper) { + std::vector> enums; + std::vector> bit_enums; + Environment env; + foreach_operator(keeper, [&](MgbOp& op) { + for (auto&& i : op.getAttributes()) { + if (auto attr = llvm::dyn_cast(&i.attr)) { + auto insert = [&](const MgbEnumAttr& attr) { + auto&& item = std::make_pair( + attr.getParentNamespace(), attr.getEnumName()); + if (env.enumAlias.emplace( + attr.getBaseRecord()->getID(), std::move(item)).second) { + if (attr.getEnumCombinedFlag()) { + bit_enums.emplace_back(item); + } else { + enums.emplace_back(item); + } + } + }; + if (auto alias = llvm::dyn_cast(attr)) { + auto&& aliasBase = alias->getAliasBase(); + insert(llvm::cast(aliasBase)); + } else { + insert(*attr); + } + } + } + }); + os << "#define FOR_EACH_ENUM_PARAM(cb)"; + for (auto && i : enums) { + os << formatv(" \\\n cb({0}::{1});", i.first, i.second); + } + os << "\n"; + os << "#define FOR_EACH_BIT_COMBINED_ENUM_PARAM(cb)"; + for (auto && i : bit_enums) { + os << formatv(" \\\n cb({0}::{1});", i.first, i.second); + } + os << "\n"; + return false; +} +} // namespace mlir::tblgen diff --git a/imperative/tablegen/targets/macros.h b/imperative/tablegen/targets/macros.h new file mode 100644 index 000000000..be599a4e1 --- /dev/null +++ b/imperative/tablegen/targets/macros.h @@ -0,0 +1,19 @@ +/** + * \file imperative/tablegen/targets/macros.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "../helper.h" + +namespace mlir::tblgen { + +bool gen_enum_param_list_macro(raw_ostream &os, llvm::RecordKeeper &keeper); + +} // namespace mlir::tblgen diff --git a/imperative/tablegen/targets/python_c_extension.cpp b/imperative/tablegen/targets/python_c_extension.cpp index 130962e0e..a51f2ceb3 100644 --- a/imperative/tablegen/targets/python_c_extension.cpp +++ b/imperative/tablegen/targets/python_c_extension.cpp @@ -60,6 +60,7 @@ public: Initproc emit(); protected: + void emit_trait(); void emit_tpl_spl(); Initproc emit_initproc(); @@ -69,50 +70,63 @@ protected: }; Initproc EnumAttrEmitter::emit() { + emit_trait(); emit_tpl_spl(); return emit_initproc(); } +void EnumAttrEmitter::emit_trait() { + if (!firstOccur) return; + + auto enumMax = [&] { + if (attr->getEnumCombinedFlag()) { + return formatv("(1llu << {0}) - 1", attr->getEnumMembers().size()); + } else { + return formatv("{0} - 1", attr->getEnumMembers().size()); + } + }; + os << tgfmt(R"( +template<> struct EnumTrait<$opClass::$enumClass> { + static constexpr const char *name = "$opClass.$enumClass"; + static constexpr std::underlying_type_t<$opClass::$enumClass> max = $0; +}; +)", &ctx, enumMax()); +} + void EnumAttrEmitter::emit_tpl_spl() { if (!firstOccur) return; os << tgfmt( - "template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type={};\n", + "template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type = {};\n", &ctx); + auto quote = [&](auto&& i) -> std::string { + return formatv("\"{0}\"", i); + }; + os << tgfmt(R"( +template<> const char* +$enumTpl<$opClass::$enumClass>::members[] = {$0}; +)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", ")); + + auto mem2value = [&](auto&& i) -> std::string { + return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i); + }; + os << tgfmt(R"( +template<> std::unordered_map +$enumTpl<$opClass::$enumClass>::mem2value = {$0}; +)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), mem2value), ", ")); + os << tgfmt( - "template<> const char* $enumTpl<$opClass::$enumClass>::name = " - "\"$opClass.$enumClass\";\n", - &ctx); + "template<> PyObject* " + "$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n", + &ctx, attr->getEnumMembers().size()); if (attr->getEnumCombinedFlag()) { os << tgfmt( "template<> PyNumberMethods " - "$enumTpl<$opClass::$enumClass>::number_methods={};\n", + "$enumTpl<$opClass::$enumClass>::number_methods = {};\n", &ctx); - os << tgfmt(R"( -template<> struct EnumTrait<$opClass::$enumClass> { - static constexpr bool is_bit_combined = true; - static constexpr std::underlying_type_t<$opClass::$enumClass> max = (1llu << $0) - 1; -}; -)", &ctx, attr->getEnumMembers().size()); } - - auto str2type = [&](auto&& i) -> std::string { - return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i); - }; - os << tgfmt(R"( -template<> std::unordered_map -$enumTpl<$opClass::$enumClass>::str2type = {$0}; -)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), str2type), ", ")); - - auto type2str = [&](auto&& i) -> std::string { - return tgfmt("{$opClass::$enumClass::$0, normalize_enum(\"$0\")}", &ctx, i); - }; - os << tgfmt(R"( -template<> std::unordered_map<$opClass::$enumClass, std::string> -$enumTpl<$opClass::$enumClass>::type2str = {$0}; -)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), type2str), ", ")); } Initproc EnumAttrEmitter::emit_initproc() { @@ -150,14 +164,16 @@ void $0(PyTypeObject& py_type) { os << " mgb_assert(PyType_Ready(&e_type) >= 0);\n"; - for (auto&& i : attr->getEnumMembers()) { + auto&& members = attr->getEnumMembers(); + for (size_t idx = 0; idx < members.size(); ++ idx) { os << tgfmt(R"({ PyObject* inst = e_type.tp_alloc(&e_type, 0); reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0; mgb_assert(PyDict_SetItemString(e_type.tp_dict, "$0", inst) >= 0); - PyType_Modified(&e_type); -})", &ctx, i); + $enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst; +})", &ctx, members[idx], idx); } + os << " PyType_Modified(&e_type);\n"; } os << tgfmt(R"( @@ -225,8 +241,10 @@ void OpDefEmitter::emit_py_init() { initBody += tgfmt(R"( if ($0) { try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; reinterpret_cast(self)->inst().$0 = - pyobj_convert_generic::from($0); + py::cast(py::handle($0)); } CATCH_ALL(-1) } )", &ctx, attr.name); @@ -236,7 +254,7 @@ void OpDefEmitter::emit_py_init() { if (scope) { try { reinterpret_cast(self)->op - ->set_scope(pyobj_convert_generic::from(scope)); + ->set_scope(py::cast(py::handle(scope))); } CATCH_ALL(-1) } )", &ctx); -- GitLab