diff --git a/imperative/python/megengine/core/autodiff/builtin_op_utils.py b/imperative/python/megengine/core/autodiff/builtin_op_utils.py index 5071e2b34748e79c3ad27f8884d2762701399db7..8db31ce26f70cb7c8e1fd3d6f691f699c1d04591 100644 --- a/imperative/python/megengine/core/autodiff/builtin_op_utils.py +++ b/imperative/python/megengine/core/autodiff/builtin_op_utils.py @@ -48,7 +48,7 @@ def _(op: OpDef, inputs, outputs, input_requires_grad): isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD ): grad_fn = elemwise_add_grad_fn - elif isinstance(op, Reduce) and op.mode.name == "SUM": + elif isinstance(op, Reduce) and op.mode == Reduce.Mode.SUM: grad_fn = reduce_sum_grad_fn else: grad_fn = default_grad_fn diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index f7781495fb35ef9e165317ec079966c942f26b4f..d4ff8b10f12cafc48678929fd2081b84b3da922f 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -447,8 +447,8 @@ def _(op: OpDef, *args: VarNode): def _(op: BackwardGraph, *args: VarNode): assert args graph = args[0].graph - return op.interpret( - lambda op, args: apply(op, *args), graph._make_const_for_backward, args + return BackwardGraph.interpret( + op, lambda op, args: apply(op, *args), graph._make_const_for_backward, args ) diff --git a/imperative/python/src/helper.h b/imperative/python/src/helper.h index 05aaf922115e06591c464a556bbd3893f6e08fdb..ab29206fdd87a434e8fa6ee3ba0f33d627d38bf2 100644 --- a/imperative/python/src/helper.h +++ b/imperative/python/src/helper.h @@ -13,6 +13,7 @@ #include "megbrain/graph.h" #include "megbrain/utils/persistent_cache.h" +#include "megbrain/imperative/op_def.h" #include #include @@ -376,6 +377,32 @@ namespace detail { } }; + template<> struct type_caster { + protected: + std::shared_ptr value; + public: + static constexpr auto name = _("OpDef"); + + operator mgb::imperative::OpDef&() { return *value; } + operator const mgb::imperative::OpDef&() { return *value; } + operator std::shared_ptr&() { return value; } + operator std::shared_ptr&&() && { return std::move(value); } + + template using cast_op_type = T; + + bool load(handle src, bool convert); + + static handle cast(const mgb::imperative::OpDef& op, return_value_policy /* policy */, handle /* parent */); + + static handle cast(std::shared_ptr op, return_value_policy policy, handle parent) { + return cast(*op, policy, parent); + } + }; + + template <> struct type_caster> : + public type_caster { + template using cast_op_type = pybind11::detail::movable_cast_op_type; + }; } // detail } // PYBIND11_NAMESPACE diff --git a/imperative/python/src/imperative_rt.cpp b/imperative/python/src/imperative_rt.cpp index 812fa724324f3d40718600fea4f15fdc2b4e812a..8baa379f0a87c33d410089a3828317acc5439694 100644 --- a/imperative/python/src/imperative_rt.cpp +++ b/imperative/python/src/imperative_rt.cpp @@ -106,13 +106,4 @@ void init_imperative_rt(py::module m) { }); m.def("make_backward_graph", &make_backward_graph); - - py::class_>(m, "OpDef") - .def("ctype", [](const OpDef& opdef) { - return opdef.dyn_typeinfo()->name; - }) - .def("__eq__", [](const OpDef& lhs, const OpDef& rhs) { - return lhs.is_same(rhs); - }) - .def("__hash__", &OpDef::hash); } diff --git a/imperative/python/src/module.cpp b/imperative/python/src/module.cpp index a0710efc303e6844fe5834fbb05c3deca5fbf5da..5e9e559955699d8914447f4e955f5dc18f815178 100644 --- a/imperative/python/src/module.cpp +++ b/imperative/python/src/module.cpp @@ -63,6 +63,7 @@ PYBIND11_MODULE(MODULE_NAME, m) { from .utils import * from .imperative import * from .graph import * + from .ops import OpDef )", py::getattr(m, "__dict__")); diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index a12b953a0b0273e054edbb4489196936cddbb7d2..ba76bd2e4006bb83fd911bc3c8204a384d7e0c31 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -16,7 +16,11 @@ #include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/autogen.h" +#include +#include + namespace py = pybind11; +using namespace mgb::imperative; namespace { auto normalize_enum(const std::string& in) { @@ -28,20 +32,256 @@ auto normalize_enum(const std::string& in) { } } // anonymous namespace +namespace { +#define PyOp(name) Py##name +#define PyOpType(name) PyOp(name)::py_type + +#define PyOpDefBegin(name) \ +struct PyOp(name) : PyOpDef { \ + using Ty = name; \ + Ty& inst() { return op->cast_final_safe(); } \ + static PyTypeObject py_type; + +#define PyOpDefEnd(name) \ +}; \ +PyTypeObject PyOpType(name); + +#define RETURN_RICHCOMPARE(val1, val2, op) \ + do { \ + switch (op) { \ + case Py_EQ: if ((val1) == (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ + case Py_NE: if ((val1) != (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ + case Py_LT: if ((val1) < (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ + case Py_GT: if ((val1) > (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ + case Py_LE: if ((val1) <= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ + case Py_GE: if ((val1) >= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ + default: \ + Py_FatalError("Unreachable C code path reached"); \ + } \ + } while (0) + +template +struct pyobj_convert_generic { + static T from(PyObject* obj) { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + return py::cast(py::handle(obj)); + } + template>>> + static PyObject* to(U&& t) { + return py::cast(std::forward(t)).release().ptr(); + } +}; + +template +PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { + PyObject* obj = type->tp_alloc(type, 0); + T* self = reinterpret_cast(obj); + if (self != NULL) { + self->op = T::Ty::make(); + } + return obj; +} + +template +void py_dealloc_generic(PyObject* obj) { + reinterpret_cast(obj)->op.reset(); + Py_TYPE(obj)->tp_free(obj); +} + +template +PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) { + auto& op = reinterpret_cast(obj)->inst(); + return pyobj_convert_generic::to(op.*attr); +} +#define py_get_generic(name, attr) \ + py_get_generic_impl().attr), &name::attr> + +template +int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute"); + return -1; + } + auto& op = reinterpret_cast(obj)->inst(); + try { + op.*attr = pyobj_convert_generic::from(value); + return 0; + } catch(py::error_already_set& e) { + e.restore(); + } catch(py::builtin_exception& e) { + e.set_error(); + } catch(...) { + PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); + } + return -1; +} +#define py_set_generic(name, attr) \ + py_set_generic_impl().attr), &name::attr> + +struct PyOpDef { + PyObject_HEAD + std::shared_ptr op; + static PyTypeObject py_type; + static std::unordered_map ctype2pytype; + static Py_hash_t tp_hash(PyObject *obj); + static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op); +}; +PyTypeObject PyOpType(OpDef); +std::unordered_map PyOp(OpDef)::ctype2pytype; + +Py_hash_t PyOp(OpDef)::tp_hash(PyObject *obj) { + return static_cast( + reinterpret_cast(obj)->op->hash()); +} + +PyObject* PyOp(OpDef)::tp_richcompare(PyObject *self, PyObject *other, int op) { + bool same = reinterpret_cast(self)->op->is_same( + *reinterpret_cast(other)->op); + if (op == Py_EQ || op == Py_NE) { + RETURN_RICHCOMPARE(same, true, op); + } + Py_RETURN_NOTIMPLEMENTED; +} + +template +struct EnumWrapper { + static_assert(std::is_enum_v); + PyObject_HEAD + T value; + static const char* name; + static PyTypeObject type; + static std::unordered_map type2str; + static std::unordered_map str2type; + EnumWrapper() = default; + EnumWrapper(T v): value(v) {} + EnumWrapper(std::string&& str): EnumWrapper(str2type.at(normalize_enum(str))) {} + std::string to_string() const { + return type2str.at(value); + } + static PyObject* py_repr(PyObject* self) { + return pyobj_convert_generic::to( + std::string(name) + "." + reinterpret_cast(self)->to_string()); + } + static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) { + T lhs = reinterpret_cast(self)->value, + rhs = reinterpret_cast(other)->value; + if (op == Py_EQ || op == Py_NE) { + RETURN_RICHCOMPARE(lhs, rhs, op); + } + Py_RETURN_NOTIMPLEMENTED; + } +}; + +template +struct pyobj_convert_generic>>> { + using Wrapper = EnumWrapper; + static T from(PyObject* obj) { + if (PyObject_TypeCheck(obj, &Wrapper::type)) { + return reinterpret_cast(obj)->value; + } + // try as string + // TODO: type checkcd + return Wrapper(pyobj_convert_generic::from(obj)).value; + } + static PyObject* to(T t) { + PyTypeObject* pytype = &Wrapper::type; + PyObject* obj = pytype->tp_alloc(pytype, 0); + reinterpret_cast(obj)->value = t; + return obj; + } +}; + +void _init_py_op_def(py::module m) { + auto& py_type = PyOpType(OpDef); + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.OpDef"; + py_type.tp_basicsize = sizeof(PyOp(OpDef)); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "OpDef"; + py_type.tp_base = &PyBaseObject_Type; + py_type.tp_hash = PyOp(OpDef)::tp_hash; + py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare; + mgb_assert(PyType_Ready(&py_type) >= 0); + m.add_object("OpDef", reinterpret_cast(&py_type)); +} + +/*********** begin of hand-write opdefs **************/ + +PyOpDefBegin(BackwardGraph) // {{ +// }; +PyOpDefEnd(BackwardGraph) + +void _init_py_backward_graph(py::module m) { + using py_op = PyOp(BackwardGraph); + auto& py_type = PyOpType(BackwardGraph); + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.BackwardGraph"; + py_type.tp_basicsize = sizeof(PyOp(BackwardGraph)); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "BackwardGraph"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic; + py_type.tp_new = py_new_generic; + mgb_assert(PyType_Ready(&py_type) >= 0); + // FIXME: rewrite interpret function in cpython instead wrap directly by pybind11::cppfunction + auto interpret = py::cpp_function( + [](OpDef& self, py::object pyf, py::object pyc, + const mgb::SmallVector& inputs) { + auto f = [pyf](OpDef& op, const mgb::SmallVector& inputs) { + return py::cast>(pyf(op.shared_from_this(), inputs)); + }; + auto c = [pyc](const TensorPtr& tensor) { + return pyc(tensor->dev_tensor()); + }; + return self.cast_final_safe().graph().interpret(f, c, inputs); + }); + mgb_assert(PyDict_SetItemString( + py_type.tp_dict, "interpret", interpret.release().ptr()) >= 0); + PyType_Modified(&py_type); + m.add_object("BackwardGraph", reinterpret_cast(&py_type)); + mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second); +} + +/*********** end of hand-write opdefs **************/ + +// auto generated opdefs +#include "opdef.cpy.inl" + +} // anonymous namespace + +namespace PYBIND11_NAMESPACE { +namespace detail { +bool type_caster::load(handle src, bool convert) { + PyObject* obj = src.ptr(); + if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) { + return false; + } + value = reinterpret_cast(obj)->op; + return true; +} +handle type_caster::cast(const OpDef& op, return_value_policy, handle) { + PyTypeObject* pytype; + auto& c2p = PyOp(OpDef)::ctype2pytype; + auto&& iter = c2p.find(op.dyn_typeinfo()); + if (iter != c2p.end()) { // FIXME: should always meet this condition + pytype = iter->second; + } else { // which means unregistered op type, jsut make it as an opaque op type + // currently, only OprAttr goes into this branch + pytype = &PyOpType(OpDef); + } + PyObject* obj = pytype->tp_alloc(pytype, 0); + mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef))); + reinterpret_cast(obj)->op = const_cast(op).shared_from_this(); + return py::handle(obj); +} +} // detail +} // PYBIND11_NAMESPACE + void init_ops(py::module m) { - using namespace mgb::imperative; - - py::class_, OpDef>(m, "BackwardGraph") - .def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc, - const mgb::SmallVector& inputs) { - auto f = [pyf](OpDef& op, const mgb::SmallVector& inputs) { - return py::cast>(pyf(op.shared_from_this(), inputs)); - }; - auto c = [pyc](const TensorPtr& tensor) { - return pyc(tensor->dev_tensor()); - }; - return self.graph().interpret(f, c, inputs); - }); - - #include "opdef.py.inl" + _init_py_op_def(m); + _init_py_backward_graph(m); + INIT_ALL_OP(m) } diff --git a/imperative/src/impl/ops/cond_take.cpp b/imperative/src/impl/ops/cond_take.cpp index 3fa3643a389d3e51246c5546b5a5e29908aa4ab7..70f55ce8112f264a1a48a97258958fa701cfaf87 100644 --- a/imperative/src/impl/ops/cond_take.cpp +++ b/imperative/src/impl/ops/cond_take.cpp @@ -76,7 +76,7 @@ cg::OperatorNodeBase* apply_on_var_node( SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs) { - auto opr = def.cast_final_safe(); + auto&& opr = def.cast_final_safe(); mgb_assert(opr.same_type()); mgb_assert(inputs.size() == 2, "CondTake take 2 inputs, got %lu", inputs.size()); diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index 98d1c7b26b5382230c160736fc8686804dc24d22..21ed0f9259be93b0f205a3aed6ba7ea7c3c7f5d4 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -111,7 +111,7 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node( SmallVector param_pack_split_apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs) { - auto param = def.cast_final_safe(); + auto&& param = def.cast_final_safe(); mgb_assert(inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size()); auto&& inp = inputs[0]; auto&& shp = inp->layout(); diff --git a/imperative/src/include/megbrain/imperative/op_def.h b/imperative/src/include/megbrain/imperative/op_def.h index 5e122f9c80dee94ebc09007aaf7e2e18ced00b6d..5dd9b421eab5707ea3c4ccb23e22c674937246ca 100644 --- a/imperative/src/include/megbrain/imperative/op_def.h +++ b/imperative/src/include/megbrain/imperative/op_def.h @@ -27,6 +27,7 @@ struct BackwardGraphResult { }; class OpDef : public Hashable, + public NonCopyableObj, public std::enable_shared_from_this { mutable const OpTrait* m_trait = nullptr; public: @@ -64,7 +65,7 @@ template class OpDefImplBase : public OpDef { public: template - static std::shared_ptr make(Args&& ...args) { + static std::shared_ptr make(Args&& ...args) { return std::make_shared(std::forward(args)...); } }; diff --git a/imperative/tablegen/CMakeLists.txt b/imperative/tablegen/CMakeLists.txt index 5beb57952d8576ffcd5bd7c0f7bafd4d16339160..31d3c5e87d8644c450e1eafb8445481d98764d8f 100644 --- a/imperative/tablegen/CMakeLists.txt +++ b/imperative/tablegen/CMakeLists.txt @@ -10,5 +10,6 @@ set(LLVM_TARGET_DEFINITIONS ${MGE_IR_DIR}/ops.td) tablegen(MGB opdef.h.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-header") tablegen(MGB opdef.cpp.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-body") tablegen(MGB opdef.py.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-binding") -add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl param_defs_tblgen) +tablegen(MGB opdef.cpy.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-c-extension") +add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl opdef.cpy.inl param_defs_tblgen) set(MGB_OPDEF_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE) diff --git a/imperative/tablegen/autogen.cpp b/imperative/tablegen/autogen.cpp index f2dc3dc53eb6857126eea06bf2639901ad8d4405..5f75cf57769a513764bb571ea90f1daec048f5a0 100644 --- a/imperative/tablegen/autogen.cpp +++ b/imperative/tablegen/autogen.cpp @@ -11,7 +11,8 @@ enum ActionType { None, CppHeader, CppBody, - Pybind + Pybind, + CPython }; // NOLINTNEXTLINE @@ -22,7 +23,9 @@ llvm::cl::opt action( clEnumValN(CppBody, "gen-cpp-body", "Generate operator cpp body"), clEnumValN(Pybind, "gen-python-binding", - "Generate pybind11 python bindings"))); + "Generate pybind11 python bindings"), + clEnumValN(CPython, "gen-python-c-extension", + "Generate python c extensions"))); using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; @@ -196,7 +199,7 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { formatMethImpl("hash") ); os << formatv( - " auto op_ = def_.cast_final_safe<{0}>();\n" + " auto&& op_ = def_.cast_final_safe<{0}>();\n" " static_cast(op_);\n", className ); @@ -210,8 +213,8 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { formatMethImpl("is_same_st") ); os << formatv( - " auto a_ = lhs_.cast_final_safe<{0}>(),\n" - " b_ = rhs_.cast_final_safe<{0}>();\n" + " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n" + " &&b_ = rhs_.cast_final_safe<{0}>();\n" " static_cast(a_);\n" " static_cast(b_);\n", className @@ -237,15 +240,15 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { } } -struct PybindContext { - std::unordered_map enumAlias; +struct EnumContext { + std::unordered_map> enumAlias; }; -static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext& ctx) { - auto class_name = op.getCppClassName(); +static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { + auto className = op.getCppClassName(); os << formatv( "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", - class_name + className ); for (auto&& i : op.getMgbAttributes()) { if (auto attr = llvm::dyn_cast(&i.attr)) { @@ -263,17 +266,17 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext if (iter == enumAlias.end()) { os << formatv( "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", - class_name, attr->getEnumName() + className, attr->getEnumName() ); std::vector body; for (auto&& i: attr->getEnumMembers()) { os << formatv( "\n .value(\"{2}\", {0}::{1}::{2})", - class_name, attr->getEnumName(), i + className, attr->getEnumName(), i ); body.push_back(formatv( "if (str == \"{2}\") return {0}::{1}::{2};", - class_name, attr->getEnumName(), i + className, attr->getEnumName(), i )); } os << formatv( @@ -286,21 +289,21 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext ); os << formatv( "py::implicitly_convertible();\n\n", - class_name, attr->getEnumName() + className, attr->getEnumName() ); - enumAlias.emplace(enumID, formatv( - "{0}Inst.attr(\"{1}\")", class_name, attr->getEnumName() - )); + enumAlias.emplace(enumID, + std::make_pair(className, attr->getEnumName())); } else { os << formatv( - "{0}Inst.attr(\"{1}\") = {2};\n\n", - class_name, attr->getEnumName(), iter->second + "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", + className, attr->getEnumName(), + iter->second.first, iter->second.second ); } } } // generate op class binding - os << formatv("{0}Inst", class_name); + os << formatv("{0}Inst", className); bool hasDefaultCtor = op.getMgbAttributes().empty(); if (!hasDefaultCtor) { os << "\n .def(py::init<"; @@ -327,12 +330,184 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext for (auto &&i : op.getMgbAttributes()) { os << formatv( "\n .def_readwrite(\"{0}\", &{1}::{0})", - i.name, class_name + i.name, className ); } os << ";\n\n"; } +static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) { + auto className = op.getCppClassName(); + std::string body; + + // generate PyType for enum class member + for (auto&& i : op.getMgbAttributes()) { + if (auto attr = llvm::dyn_cast(&i.attr)) { + unsigned int enumID; + if (auto alias = llvm::dyn_cast(attr)) { + auto&& aliasBase = alias->getAliasBase(); + enumID = + llvm::cast(aliasBase) + .getBaseRecord()->getID(); + } else { + enumID = attr->getBaseRecord()->getID(); + } + auto&& enumAlias = ctx.enumAlias; + auto&& iter = enumAlias.find(enumID); + auto enumName = attr->getEnumName(); + body += "{\n"; + body += formatv( + "auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName + ); + if (iter == enumAlias.end()) { + os << formatv( + "template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n", + className, enumName); + os << formatv( + "template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n", + className, enumName); + std::vector pairStr; + for (auto&& i: attr->getEnumMembers()) { + pairStr.push_back(formatv( + "{{normalize_enum(\"{2}\"), {0}::{1}::{2}}", + className, enumName, i)); + } + os << formatv(R"( +template<> std::unordered_map +EnumWrapper<{0}::{1}>::str2type = {{ + {2} +}; +)", className, enumName, llvm::join(pairStr, ", ")); + pairStr.clear(); + for (auto&& i: attr->getEnumMembers()) { + pairStr.push_back(formatv( + "{{{0}::{1}::{2}, normalize_enum(\"{2}\")}", + className, enumName, i)); + } + os << formatv(R"( +template<> std::unordered_map<{0}::{1}, std::string> +EnumWrapper<{0}::{1}>::type2str = {{ + {2} +}; +)", className, enumName, llvm::join(pairStr, ", ")); + body += formatv(R"( + e_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; + e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}"; + e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>); + e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + e_type.tp_doc = "{0}.{1}"; + e_type.tp_base = &PyBaseObject_Type; + e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr; + e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare; + mgb_assert(PyType_Ready(&e_type) >= 0); +)", className, enumName); + for (auto&& i: attr->getEnumMembers()) { + body += formatv(R"({{ + PyObject* inst = e_type.tp_alloc(&e_type, 0); + reinterpret_cast*>(inst)->value = {0}::{1}::{2}; + mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0); +})", className, enumName, i); + } + enumAlias.emplace(enumID, std::make_pair(className, enumName)); + } + body += formatv(R"( + PyType_Modified(&e_type); + mgb_assert(PyDict_SetItemString( + py_type.tp_dict, "{0}", reinterpret_cast(&e_type)) >= 0); +)", enumName); + body += "}\n"; + } + } + + // generate getsetters + std::vector getsetters; + for (auto &&i : op.getMgbAttributes()) { + getsetters.push_back(formatv( + "{{\"{1}\", py_get_generic({0}, {1}), py_set_generic({0}, {1}), \"{1}\", NULL},", + className, i.name)); + } + + // generate tp_init + std::string initBody; + if (!op.getMgbAttributes().empty()) { + initBody += "static const char* kwlist[] = {"; + llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { + initBody += formatv("\"{0}\", ", attr.name); + }); + initBody += "NULL};\n"; + initBody += " PyObject "; + std::vector attrs; + llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { + attrs.push_back(formatv("*{0} = NULL", attr.name)); + }); + initBody += llvm::join(attrs, ", ") + ";\n"; + initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; + initBody += std::string(op.getMgbAttributes().size(), 'O'); + initBody += "\", const_cast(kwlist)"; + llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { + initBody += formatv(" ,&{0}", attr.name); + }); + initBody += "))\n"; + initBody += " return -1;\n"; + llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { + initBody += formatv(R"( + if ({1}) {{ + try {{ + reinterpret_cast(self)->inst().{1} = + pyobj_convert_generic::from({1}); + } catch(py::error_already_set& e) {{ + e.restore(); + return -1; + } catch(py::builtin_exception& e) {{ + e.set_error(); + return -1; + } catch(...) {{ + PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); + return -1; + } + } +)", className, attr.name); + }); + } + initBody += "\n return 0;"; + + os << formatv(R"( +PyOpDefBegin({0}) // {{ + static PyGetSetDef py_getsetters[]; + static int py_init(PyObject *self, PyObject *args, PyObject *kwds); +// }; +PyOpDefEnd({0}) +PyGetSetDef PyOp({0})::py_getsetters[] = {{ + {1} + {{NULL} /* Sentinel */ +}; +int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{ + {2} +} + +void _init_py_{0}(py::module m) {{ + using py_op = PyOp({0}); + auto& py_type = PyOpType({0}); + py_type = {{PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.{0}"; + py_type.tp_basicsize = sizeof(PyOp({0})); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "{0}"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic; + py_type.tp_new = py_new_generic; + py_type.tp_init = py_op::py_init; + py_type.tp_getset = py_op::py_getsetters; + mgb_assert(PyType_Ready(&py_type) >= 0); + {3} + PyType_Modified(&py_type); + m.add_object("{0}", reinterpret_cast(&py_type)); + mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second); +} +)", + op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body); +} + static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, std::function callback) { auto op_base_class = keeper.getClass("Op"); @@ -360,13 +535,26 @@ static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) { } static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) { - PybindContext ctx; + EnumContext ctx; using namespace std::placeholders; for_each_operator(os, keeper, std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx))); return false; } +static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) { + EnumContext ctx; + using namespace std::placeholders; + for_each_operator(os, keeper, + std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx))); + os << "#define INIT_ALL_OP(m)"; + for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) { + os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName()); + }); + os << "\n"; + return false; +} + int main(int argc, char **argv) { llvm::InitLLVM y(argc, argv); llvm::cl::ParseCommandLineOptions(argc, argv); @@ -379,5 +567,8 @@ int main(int argc, char **argv) { if (action == ActionType::Pybind) { return TableGenMain(argv[0], &gen_op_def_pybind11); } + if (action == ActionType::CPython) { + return TableGenMain(argv[0], &gen_op_def_python_c_extension); + } return -1; -} \ No newline at end of file +}