提交 a8108522 编写于 作者: M Megvii Engine Team

refactor(imperative): refactor enum param type caster

GitOrigin-RevId: 1aae07f143b8d1c0176a41de790fee8d6b2f1a25
上级 dcff115e
......@@ -21,6 +21,7 @@
#include "./helper.h"
#include "megbrain/plugin/profiler.h"
#include "./common.h"
#include "./ops.h"
#include "megbrain/gopt/inference.h"
......@@ -265,18 +266,8 @@ void init_graph_rt(py::module m) {
});
m.def("modify_opr_algo_strategy_inplace", [](const VarNodeArray& dest_vars,
const std::string& strategy) {
_AlgoStrategy stg;
const std::unordered_map<std::string, std::function<void()>> m{
{"HEURISTIC", [&]() { stg = _AlgoStrategy::HEURISTIC; }},
{"PROFILE", [&]() { stg = _AlgoStrategy::PROFILE; }},
{"REPRODUCIBLE", [&]() { stg = _AlgoStrategy::REPRODUCIBLE; }},
{"OPTIMIZED", [&]() { stg = _AlgoStrategy::OPTIMIZED; }},
};
auto it = m.find(strategy);
mgb_assert(it != m.end(), "Invalid strategy string!");
it->second();
mgb::gopt::modify_opr_algo_strategy_inplace(dest_vars, stg);
const _AlgoStrategy& strategy) {
mgb::gopt::modify_opr_algo_strategy_inplace(dest_vars, strategy);
});
m.def("get_info_for_strip", [](const std::vector<VarNode*>& dest_vars) {
......
......@@ -73,29 +73,6 @@ PyTypeObject PyOpType(name);
} \
} while (0)
template <typename T, typename SFINAE = void>
struct pyobj_convert_generic {
static T from(PyObject* obj) {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
return py::cast<T>(py::handle(obj));
}
template<typename U,
typename = std::enable_if_t<std::is_same_v<T, std::decay_t<U>>>>
static PyObject* to(U&& t) {
return py::cast(std::forward<U>(t)).release().ptr();
}
};
template<typename T, typename SFINAE=void>
struct EnumTrait;
template <typename T>
struct EnumTrait<T, std::enable_if_t<std::is_enum_v<T>>> {
static constexpr bool is_bit_combined = false;
static constexpr std::underlying_type_t<T> max = 0;
};
template <typename T>
PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) {
PyObject* obj = type->tp_alloc(type, 0);
......@@ -115,7 +92,7 @@ void py_dealloc_generic(PyObject* obj) {
template<typename T, typename U, U T::Ty::*attr>
PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
auto& op = reinterpret_cast<T*>(obj)->inst();
return pyobj_convert_generic<U>::to(op.*attr);
return py::cast(op.*attr).release().ptr();
}
#define py_get_generic(name, attr) \
py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
......@@ -128,7 +105,9 @@ int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
}
auto& op = reinterpret_cast<T*>(obj)->inst();
try {
op.*attr = pyobj_convert_generic<U>::from(value);
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
op.*attr = py::cast<U>(py::handle(value));
} CATCH_ALL(-1)
return 0;
}
......@@ -148,8 +127,8 @@ PyTypeObject PyOpType(OpDef);
std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype;
PyObject* py_get_scope(PyObject* obj, void* /* closure */) {
return pyobj_convert_generic<std::string>::to(
reinterpret_cast<PyOp(OpDef)*>(obj)->op->scope());
return py::cast(
reinterpret_cast<PyOp(OpDef)*>(obj)->op->scope()).release().ptr();
}
int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
......@@ -159,7 +138,7 @@ int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
}
try {
reinterpret_cast<PyOp(OpDef)*>(obj)->op
->set_scope(pyobj_convert_generic<std::string>::from(value));
->set_scope(py::cast<std::string>(py::handle(value)));
} CATCH_ALL(-1)
return 0;
}
......@@ -183,24 +162,29 @@ PyObject* PyOp(OpDef)::tp_richcompare(PyObject *self, PyObject *other, int op) {
Py_RETURN_NOTIMPLEMENTED;
}
template<typename T>
struct EnumTrait;
#define PyEnumHead \
static_assert(std::is_enum_v<T>); \
PyObject_HEAD \
T value; \
constexpr static const char *name = EnumTrait<T>::name; \
static PyTypeObject type; \
static const char* members[]; \
static std::unordered_map<std::string, T> mem2value; \
static PyObject* pyobj_insts[];
template<typename T>
struct EnumWrapper {
static_assert(std::is_enum_v<T>);
PyObject_HEAD
T value;
static const char* name;
static PyTypeObject type;
static std::unordered_map<T, std::string> type2str;
static std::unordered_map<std::string, T> str2type;
EnumWrapper() = default;
EnumWrapper(T v): value(v) {}
EnumWrapper(std::string&& str): EnumWrapper(str2type.at(normalize_enum(str))) {}
PyEnumHead
std::string to_string() const {
return type2str.at(value);
return members[static_cast<size_t>(value)];
}
static PyObject* py_repr(PyObject* self) {
return pyobj_convert_generic<std::string>::to(
std::string(name) + "." + reinterpret_cast<EnumWrapper*>(self)->to_string());
return py::cast(
std::string(name) + "." + reinterpret_cast<EnumWrapper*>(self)->to_string())
.release().ptr();
}
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) {
T lhs = reinterpret_cast<EnumWrapper*>(self)->value,
......@@ -210,59 +194,52 @@ struct EnumWrapper {
}
Py_RETURN_NOTIMPLEMENTED;
}
};
template <typename T>
struct pyobj_convert_generic<T,
std::enable_if_t<std::is_enum_v<std::decay_t<T>> &&
!EnumTrait<T>::is_bit_combined>> {
using Wrapper = EnumWrapper<T>;
static T from(PyObject* obj) {
if (PyObject_TypeCheck(obj, &Wrapper::type)) {
return reinterpret_cast<Wrapper*>(obj)->value;
}
// try as string
// TODO: type checkcd
return Wrapper(pyobj_convert_generic<std::string>::from(obj)).value;
}
static PyObject* to(T t) {
PyTypeObject* pytype = &Wrapper::type;
PyObject* obj = pytype->tp_alloc(pytype, 0);
reinterpret_cast<Wrapper*>(obj)->value = t;
static bool load(py::handle src, T& value) {
PyObject* obj = src.ptr();
if (PyObject_TypeCheck(obj, &type)) {
value = reinterpret_cast<EnumWrapper*>(obj)->value;
return true;
}
if (py::isinstance<py::str>(src)) {
auto&& iter = mem2value.find(
normalize_enum(py::cast<std::string>(src)));
if (iter != mem2value.end()) {
value = iter->second;
return true;
} else {
return false;
}
}
return false;
}
static PyObject* cast(const T& value) {
auto v = static_cast<std::underlying_type_t<T>>(value);
mgb_assert(v <= EnumTrait<T>::max);
PyObject* obj = pyobj_insts[v];
Py_INCREF(obj);
return obj;
}
};
template<typename T>
struct BitCombinedEnumWrapper {
static_assert(std::is_enum_v<T>);
PyObject_HEAD
T value;
static const char* name;
static PyTypeObject type;
static std::unordered_map<T, std::string> type2str;
static std::unordered_map<std::string, T> str2type;
PyEnumHead
static PyNumberMethods number_methods;
BitCombinedEnumWrapper() = default;
BitCombinedEnumWrapper(T v): value(v) {}
BitCombinedEnumWrapper(std::string&& str)
: BitCombinedEnumWrapper(str2type.at(normalize_enum(str))) {}
std::string to_string() const {
if (static_cast<uint32_t>(value) == 0) {
uint32_t value_int = static_cast<uint32_t>(value);
if (value_int == 0) {
return "None";
} else {
auto ret = std::string();
std::string ret;
bool first = true;
for (uint32_t i = 0; i < 32; i++) {
uint32_t value_int = static_cast<uint32_t>(value);
auto it = type2str.find(static_cast<T>((1 << i) & value_int));
if (it != type2str.end()) {
if (value_int >> i & 1) {
if (!first) {
ret += " + ";
} else {
first = false;
}
ret += (std::string(name) + "." + it->second);
ret += (std::string(name) + "." + members[i]);
}
}
return ret;
......@@ -280,17 +257,20 @@ struct BitCombinedEnumWrapper {
return nullptr;
}
T value;
try {
value = pyobj_convert_generic<T>::from(input);
} CATCH_ALL(nullptr);
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
return obj;
if (load(input, value)) {
return cast(value);
} else {
PyErr_SetString(PyExc_RuntimeError,
mgb::ssprintf("Cannot convert type %s to type %s\n",
input->ob_type->tp_name, name).c_str());
return nullptr;
}
}
}
static PyObject* py_repr(PyObject* self) {
return pyobj_convert_generic<std::string>::to(
reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string());
return py::cast(
reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string())
.release().ptr();
}
static PyObject* py_or(PyObject* self, PyObject* other) {
if(!(self->ob_type == other->ob_type)){
......@@ -298,12 +278,9 @@ struct BitCombinedEnumWrapper {
PyExc_RuntimeError,
"Operand in or operator must be the same type.");
}
PyObject* obj = type.tp_alloc(&type, 0);
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(
static_cast<uint32_t>(lhs) | static_cast<uint32_t>(rhs));
return obj;
return cast(lhs | rhs);
}
static PyObject* py_and(PyObject* self, PyObject* other) {
if (!(self->ob_type == other->ob_type)) {
......@@ -311,12 +288,9 @@ struct BitCombinedEnumWrapper {
PyExc_RuntimeError,
"Operand in and operator must be the same type.");
}
PyObject* obj = type.tp_alloc(&type, 0);
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(
static_cast<uint32_t>(lhs) & static_cast<uint32_t>(rhs));
return obj;
return cast(lhs & rhs);
}
static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
......@@ -326,33 +300,46 @@ struct BitCombinedEnumWrapper {
}
Py_RETURN_NOTIMPLEMENTED;
}
};
template <typename T>
struct pyobj_convert_generic<T,
std::enable_if_t<std::is_enum_v<std::decay_t<T>> &&
EnumTrait<T>::is_bit_combined>> {
using Wrapper = BitCombinedEnumWrapper<T>;
static T from(PyObject* obj) {
if (PyObject_TypeCheck(obj, &Wrapper::type)) {
return reinterpret_cast<Wrapper*>(obj)->value;
} else if(PyLong_Check(obj)) {
auto value = pyobj_convert_generic<std::underlying_type_t<T>>::from(obj);
mgb_throw_if(value > EnumTrait<T>::max, mgb::MegBrainError,
"out of range, cannot convert %zu to %s",
static_cast<uint32_t>(value), Wrapper::name);
return static_cast<T>(value);
}
// try as string
// TODO: type checkcd
return Wrapper(pyobj_convert_generic<std::string>::from(obj)).value;
}
static PyObject* to(T t) {
PyTypeObject* pytype = &Wrapper::type;
static bool load(py::handle src, T& value) {
PyObject* obj = src.ptr();
if (PyObject_TypeCheck(obj, &type)) {
value = reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value;
return true;
}
if (py::isinstance<py::str>(src)) {
auto&& iter = mem2value.find(
normalize_enum(py::cast<std::string>(src)));
if (iter != mem2value.end()) {
value = iter->second;
return true;
} else {
return false;
}
}
if (py::isinstance<py::int_>(obj)) {
auto v = py::cast<std::underlying_type_t<T>>(src);
if(v > EnumTrait<T>::max) {
return false;
}
value = static_cast<T>(v);
return true;
}
return false;
}
static PyObject* cast(const T& value) {
auto v = static_cast<std::underlying_type_t<T>>(value);
mgb_assert(v <= EnumTrait<T>::max);
if ((!v) || (v & (v - 1))) {
PyTypeObject* pytype = &type;
PyObject* obj = pytype->tp_alloc(pytype, 0);
reinterpret_cast<Wrapper*>(obj)->value = t;
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
return obj;
} else {
PyObject* obj = pyobj_insts[__builtin_ctz(v)];
Py_INCREF(obj);
return obj;
}
}
};
void _init_py_op_def(py::module m) {
......@@ -443,7 +430,6 @@ void _init_py_op_base(py::module m) {
#include "opdef.cpy.inl"
#undef CATCH_ALL
} // anonymous namespace
namespace PYBIND11_NAMESPACE {
......@@ -478,6 +464,25 @@ handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) {
reinterpret_cast<PyOp(OpDef)*>(obj)->op = const_cast<OpDef&>(op).shared_from_this();
return py::handle(obj);
}
#define ENUM_CASTER_IMPL(T) \
bool type_caster<T>::load(handle src, bool) { \
return EnumWrapper<T>::load(src, value); \
} \
handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
return EnumWrapper<T>::cast(value); \
}
FOR_EACH_ENUM_PARAM(ENUM_CASTER_IMPL)
#define BIT_COMBINED_ENUM_CASTER_IMPL(T) \
bool type_caster<T>::load(handle src, bool) { \
return BitCombinedEnumWrapper<T>::load(src, value); \
} \
handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
return BitCombinedEnumWrapper<T>::cast(value); \
}
FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL)
} // detail
} // PYBIND11_NAMESPACE
......
......@@ -12,5 +12,26 @@
#pragma once
#include "./helper.h"
#include "./enum_macro.h"
#include "megdnn/opr_param_defs.h"
#include "megbrain/opr/param_defs.h"
namespace PYBIND11_NAMESPACE {
namespace detail {
#define ENUM_CASTER_DEF(name) \
template<> struct type_caster<name> { \
PYBIND11_TYPE_CASTER(name, _(#name)); \
public: \
bool load(handle src, bool); \
static handle cast(const name& v, return_value_policy, handle); \
};
FOR_EACH_ENUM_PARAM(ENUM_CASTER_DEF)
FOR_EACH_BIT_COMBINED_ENUM_PARAM(ENUM_CASTER_DEF)
} // detail
} // PYBIND11_NAMESPACE
void init_ops(pybind11::module m);
......@@ -12,5 +12,6 @@ tablegen(MGB opdef.h.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-header")
tablegen(MGB opdef.cpp.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-body")
tablegen(MGB opdef.py.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-binding")
tablegen(MGB opdef.cpy.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-c-extension")
add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl opdef.cpy.inl param_defs_tblgen)
tablegen(MGB enum_macro.h ${MGE_IR_INCLUDE_DIRS} "--gen-enum-list-macro")
add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl opdef.cpy.inl enum_macro.h param_defs_tblgen)
set(MGB_OPDEF_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE)
......@@ -12,6 +12,7 @@
#include "./targets/cpp_class.h"
#include "./targets/pybind11.h"
#include "./targets/python_c_extension.h"
#include "./targets/macros.h"
using llvm::raw_ostream;
using llvm::RecordKeeper;
......@@ -21,7 +22,8 @@ enum ActionType {
CppHeader,
CppBody,
Pybind,
CPython
CPython,
EnumListMacro
};
// NOLINTNEXTLINE
......@@ -34,7 +36,9 @@ llvm::cl::opt<ActionType> action(
clEnumValN(Pybind, "gen-python-binding",
"Generate pybind11 python bindings"),
clEnumValN(CPython, "gen-python-c-extension",
"Generate python c extensions")));
"Generate python c extensions"),
clEnumValN(EnumListMacro, "gen-enum-list-macro",
"Generate enum param list macro")));
using namespace mlir::tblgen;
......@@ -53,5 +57,8 @@ int main(int argc, char **argv) {
if (action == ActionType::CPython) {
return TableGenMain(argv[0], &gen_op_def_python_c_extension);
}
if (action == ActionType::EnumListMacro) {
return TableGenMain(argv[0], &gen_enum_param_list_macro);
}
return -1;
}
/**
* \file imperative/tablegen/targets/macros.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./cpp_class.h"
#include "../emitter.h"
namespace mlir::tblgen {
bool gen_enum_param_list_macro(raw_ostream &os, llvm::RecordKeeper &keeper) {
std::vector<std::pair<std::string, std::string>> enums;
std::vector<std::pair<std::string, std::string>> bit_enums;
Environment env;
foreach_operator(keeper, [&](MgbOp& op) {
for (auto&& i : op.getAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
auto insert = [&](const MgbEnumAttr& attr) {
auto&& item = std::make_pair(
attr.getParentNamespace(), attr.getEnumName());
if (env.enumAlias.emplace(
attr.getBaseRecord()->getID(), std::move(item)).second) {
if (attr.getEnumCombinedFlag()) {
bit_enums.emplace_back(item);
} else {
enums.emplace_back(item);
}
}
};
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
auto&& aliasBase = alias->getAliasBase();
insert(llvm::cast<MgbEnumAttr>(aliasBase));
} else {
insert(*attr);
}
}
}
});
os << "#define FOR_EACH_ENUM_PARAM(cb)";
for (auto && i : enums) {
os << formatv(" \\\n cb({0}::{1});", i.first, i.second);
}
os << "\n";
os << "#define FOR_EACH_BIT_COMBINED_ENUM_PARAM(cb)";
for (auto && i : bit_enums) {
os << formatv(" \\\n cb({0}::{1});", i.first, i.second);
}
os << "\n";
return false;
}
} // namespace mlir::tblgen
/**
* \file imperative/tablegen/targets/macros.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "../helper.h"
namespace mlir::tblgen {
bool gen_enum_param_list_macro(raw_ostream &os, llvm::RecordKeeper &keeper);
} // namespace mlir::tblgen
......@@ -60,6 +60,7 @@ public:
Initproc emit();
protected:
void emit_trait();
void emit_tpl_spl();
Initproc emit_initproc();
......@@ -69,50 +70,63 @@ protected:
};
Initproc EnumAttrEmitter::emit() {
emit_trait();
emit_tpl_spl();
return emit_initproc();
}
void EnumAttrEmitter::emit_tpl_spl() {
void EnumAttrEmitter::emit_trait() {
if (!firstOccur) return;
os << tgfmt(
"template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type={};\n",
&ctx);
auto enumMax = [&] {
if (attr->getEnumCombinedFlag()) {
return formatv("(1llu << {0}) - 1", attr->getEnumMembers().size());
} else {
return formatv("{0} - 1", attr->getEnumMembers().size());
}
};
os << tgfmt(R"(
template<> struct EnumTrait<$opClass::$enumClass> {
static constexpr const char *name = "$opClass.$enumClass";
static constexpr std::underlying_type_t<$opClass::$enumClass> max = $0;
};
)", &ctx, enumMax());
}
os << tgfmt(
"template<> const char* $enumTpl<$opClass::$enumClass>::name = "
"\"$opClass.$enumClass\";\n",
&ctx);
void EnumAttrEmitter::emit_tpl_spl() {
if (!firstOccur) return;
if (attr->getEnumCombinedFlag()) {
os << tgfmt(
"template<> PyNumberMethods "
"$enumTpl<$opClass::$enumClass>::number_methods={};\n",
"template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type = {};\n",
&ctx);
auto quote = [&](auto&& i) -> std::string {
return formatv("\"{0}\"", i);
};
os << tgfmt(R"(
template<> struct EnumTrait<$opClass::$enumClass> {
static constexpr bool is_bit_combined = true;
static constexpr std::underlying_type_t<$opClass::$enumClass> max = (1llu << $0) - 1;
};
)", &ctx, attr->getEnumMembers().size());
}
template<> const char*
$enumTpl<$opClass::$enumClass>::members[] = {$0};
)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", "));
auto str2type = [&](auto&& i) -> std::string {
auto mem2value = [&](auto&& i) -> std::string {
return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i);
};
os << tgfmt(R"(
template<> std::unordered_map<std::string, $opClass::$enumClass>
$enumTpl<$opClass::$enumClass>::str2type = {$0};
)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), str2type), ", "));
$enumTpl<$opClass::$enumClass>::mem2value = {$0};
)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), mem2value), ", "));
auto type2str = [&](auto&& i) -> std::string {
return tgfmt("{$opClass::$enumClass::$0, normalize_enum(\"$0\")}", &ctx, i);
};
os << tgfmt(R"(
template<> std::unordered_map<$opClass::$enumClass, std::string>
$enumTpl<$opClass::$enumClass>::type2str = {$0};
)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), type2str), ", "));
os << tgfmt(
"template<> PyObject* "
"$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};\n",
&ctx, attr->getEnumMembers().size());
if (attr->getEnumCombinedFlag()) {
os << tgfmt(
"template<> PyNumberMethods "
"$enumTpl<$opClass::$enumClass>::number_methods = {};\n",
&ctx);
}
}
Initproc EnumAttrEmitter::emit_initproc() {
......@@ -150,14 +164,16 @@ void $0(PyTypeObject& py_type) {
os << " mgb_assert(PyType_Ready(&e_type) >= 0);\n";
for (auto&& i : attr->getEnumMembers()) {
auto&& members = attr->getEnumMembers();
for (size_t idx = 0; idx < members.size(); ++ idx) {
os << tgfmt(R"({
PyObject* inst = e_type.tp_alloc(&e_type, 0);
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "$0", inst) >= 0);
PyType_Modified(&e_type);
})", &ctx, i);
$enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
})", &ctx, members[idx], idx);
}
os << " PyType_Modified(&e_type);\n";
}
os << tgfmt(R"(
......@@ -225,8 +241,10 @@ void OpDefEmitter::emit_py_init() {
initBody += tgfmt(R"(
if ($0) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp($_self)*>(self)->inst().$0 =
pyobj_convert_generic<decltype($_self::$0)>::from($0);
py::cast<decltype($_self::$0)>(py::handle($0));
} CATCH_ALL(-1)
}
)", &ctx, attr.name);
......@@ -236,7 +254,7 @@ void OpDefEmitter::emit_py_init() {
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(pyobj_convert_generic<std::string>::from(scope));
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}
)", &ctx);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册