提交 577366a2 编写于 作者: M Megvii Engine Team

feat(mge/imperative): basic impl of python c extension for opdef

fix(imperative): fix refcount management on cpython opdef

refactor(mge/imperative): fix compilation for python3.6

GitOrigin-RevId: 332a516895fbba528fca4c1b31de8e674bbca47d
上级 9d928e7f
......@@ -48,7 +48,7 @@ def _(op: OpDef, inputs, outputs, input_requires_grad):
isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD
):
grad_fn = elemwise_add_grad_fn
elif isinstance(op, Reduce) and op.mode.name == "SUM":
elif isinstance(op, Reduce) and op.mode == Reduce.Mode.SUM:
grad_fn = reduce_sum_grad_fn
else:
grad_fn = default_grad_fn
......
......@@ -447,8 +447,8 @@ def _(op: OpDef, *args: VarNode):
def _(op: BackwardGraph, *args: VarNode):
assert args
graph = args[0].graph
return op.interpret(
lambda op, args: apply(op, *args), graph._make_const_for_backward, args
return BackwardGraph.interpret(
op, lambda op, args: apply(op, *args), graph._make_const_for_backward, args
)
......
......@@ -13,6 +13,7 @@
#include "megbrain/graph.h"
#include "megbrain/utils/persistent_cache.h"
#include "megbrain/imperative/op_def.h"
#include <Python.h>
#include <string>
......@@ -376,6 +377,32 @@ namespace detail {
}
};
template<> struct type_caster<mgb::imperative::OpDef> {
protected:
std::shared_ptr<mgb::imperative::OpDef> value;
public:
static constexpr auto name = _("OpDef");
operator mgb::imperative::OpDef&() { return *value; }
operator const mgb::imperative::OpDef&() { return *value; }
operator std::shared_ptr<mgb::imperative::OpDef>&() { return value; }
operator std::shared_ptr<mgb::imperative::OpDef>&&() && { return std::move(value); }
template <typename T> using cast_op_type = T;
bool load(handle src, bool convert);
static handle cast(const mgb::imperative::OpDef& op, return_value_policy /* policy */, handle /* parent */);
static handle cast(std::shared_ptr<mgb::imperative::OpDef> op, return_value_policy policy, handle parent) {
return cast(*op, policy, parent);
}
};
template <> struct type_caster<std::shared_ptr<mgb::imperative::OpDef>> :
public type_caster<mgb::imperative::OpDef> {
template <typename T> using cast_op_type = pybind11::detail::movable_cast_op_type<T>;
};
} // detail
} // PYBIND11_NAMESPACE
......
......@@ -106,13 +106,4 @@ void init_imperative_rt(py::module m) {
});
m.def("make_backward_graph", &make_backward_graph);
py::class_<OpDef, std::shared_ptr<OpDef>>(m, "OpDef")
.def("ctype", [](const OpDef& opdef) {
return opdef.dyn_typeinfo()->name;
})
.def("__eq__", [](const OpDef& lhs, const OpDef& rhs) {
return lhs.is_same(rhs);
})
.def("__hash__", &OpDef::hash);
}
......@@ -63,6 +63,7 @@ PYBIND11_MODULE(MODULE_NAME, m) {
from .utils import *
from .imperative import *
from .graph import *
from .ops import OpDef
)",
py::getattr(m, "__dict__"));
......
......@@ -16,7 +16,11 @@
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/autogen.h"
#include <Python.h>
#include <unordered_map>
namespace py = pybind11;
using namespace mgb::imperative;
namespace {
auto normalize_enum(const std::string& in) {
......@@ -28,11 +32,203 @@ auto normalize_enum(const std::string& in) {
}
} // anonymous namespace
void init_ops(py::module m) {
using namespace mgb::imperative;
namespace {
#define PyOp(name) Py##name
#define PyOpType(name) PyOp(name)::py_type
#define PyOpDefBegin(name) \
struct PyOp(name) : PyOpDef { \
using Ty = name; \
Ty& inst() { return op->cast_final_safe<Ty>(); } \
static PyTypeObject py_type;
#define PyOpDefEnd(name) \
}; \
PyTypeObject PyOpType(name);
#define RETURN_RICHCOMPARE(val1, val2, op) \
do { \
switch (op) { \
case Py_EQ: if ((val1) == (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
case Py_NE: if ((val1) != (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
case Py_LT: if ((val1) < (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
case Py_GT: if ((val1) > (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
case Py_LE: if ((val1) <= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
case Py_GE: if ((val1) >= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \
default: \
Py_FatalError("Unreachable C code path reached"); \
} \
} while (0)
template<typename T, typename SFINAE=void>
struct pyobj_convert_generic {
static T from(PyObject* obj) {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
return py::cast<T>(py::handle(obj));
}
template<typename U,
typename = std::enable_if_t<std::is_same_v<T, std::decay_t<U>>>>
static PyObject* to(U&& t) {
return py::cast(std::forward<U>(t)).release().ptr();
}
};
template<typename T>
PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) {
PyObject* obj = type->tp_alloc(type, 0);
T* self = reinterpret_cast<T*>(obj);
if (self != NULL) {
self->op = T::Ty::make();
}
return obj;
}
template<typename T>
void py_dealloc_generic(PyObject* obj) {
reinterpret_cast<T*>(obj)->op.reset();
Py_TYPE(obj)->tp_free(obj);
}
template<typename T, typename U, U T::Ty::*attr>
PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
auto& op = reinterpret_cast<T*>(obj)->inst();
return pyobj_convert_generic<U>::to(op.*attr);
}
#define py_get_generic(name, attr) \
py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
template<typename T, typename U, U T::Ty::*attr>
int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
if (value == NULL) {
PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
return -1;
}
auto& op = reinterpret_cast<T*>(obj)->inst();
try {
op.*attr = pyobj_convert_generic<U>::from(value);
return 0;
} catch(py::error_already_set& e) {
e.restore();
} catch(py::builtin_exception& e) {
e.set_error();
} catch(...) {
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
}
return -1;
}
#define py_set_generic(name, attr) \
py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
struct PyOpDef {
PyObject_HEAD
std::shared_ptr<OpDef> op;
static PyTypeObject py_type;
static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype;
static Py_hash_t tp_hash(PyObject *obj);
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op);
};
PyTypeObject PyOpType(OpDef);
std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype;
Py_hash_t PyOp(OpDef)::tp_hash(PyObject *obj) {
return static_cast<Py_hash_t>(
reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash());
}
PyObject* PyOp(OpDef)::tp_richcompare(PyObject *self, PyObject *other, int op) {
bool same = reinterpret_cast<PyOp(OpDef)*>(self)->op->is_same(
*reinterpret_cast<PyOp(OpDef)*>(other)->op);
if (op == Py_EQ || op == Py_NE) {
RETURN_RICHCOMPARE(same, true, op);
}
Py_RETURN_NOTIMPLEMENTED;
}
template<typename T>
struct EnumWrapper {
static_assert(std::is_enum_v<T>);
PyObject_HEAD
T value;
static const char* name;
static PyTypeObject type;
static std::unordered_map<T, std::string> type2str;
static std::unordered_map<std::string, T> str2type;
EnumWrapper() = default;
EnumWrapper(T v): value(v) {}
EnumWrapper(std::string&& str): EnumWrapper(str2type.at(normalize_enum(str))) {}
std::string to_string() const {
return type2str.at(value);
}
static PyObject* py_repr(PyObject* self) {
return pyobj_convert_generic<std::string>::to(
std::string(name) + "." + reinterpret_cast<EnumWrapper*>(self)->to_string());
}
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) {
T lhs = reinterpret_cast<EnumWrapper*>(self)->value,
rhs = reinterpret_cast<EnumWrapper*>(other)->value;
if (op == Py_EQ || op == Py_NE) {
RETURN_RICHCOMPARE(lhs, rhs, op);
}
Py_RETURN_NOTIMPLEMENTED;
}
};
template<typename T>
struct pyobj_convert_generic<T,
std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> {
using Wrapper = EnumWrapper<T>;
static T from(PyObject* obj) {
if (PyObject_TypeCheck(obj, &Wrapper::type)) {
return reinterpret_cast<Wrapper*>(obj)->value;
}
// try as string
// TODO: type checkcd
return Wrapper(pyobj_convert_generic<std::string>::from(obj)).value;
}
static PyObject* to(T t) {
PyTypeObject* pytype = &Wrapper::type;
PyObject* obj = pytype->tp_alloc(pytype, 0);
reinterpret_cast<Wrapper*>(obj)->value = t;
return obj;
}
};
void _init_py_op_def(py::module m) {
auto& py_type = PyOpType(OpDef);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.OpDef";
py_type.tp_basicsize = sizeof(PyOp(OpDef));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "OpDef";
py_type.tp_base = &PyBaseObject_Type;
py_type.tp_hash = PyOp(OpDef)::tp_hash;
py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare;
mgb_assert(PyType_Ready(&py_type) >= 0);
m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type));
}
/*********** begin of hand-write opdefs **************/
py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph")
.def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc,
PyOpDefBegin(BackwardGraph) // {{
// };
PyOpDefEnd(BackwardGraph)
void _init_py_backward_graph(py::module m) {
using py_op = PyOp(BackwardGraph);
auto& py_type = PyOpType(BackwardGraph);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.BackwardGraph";
py_type.tp_basicsize = sizeof(PyOp(BackwardGraph));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "BackwardGraph";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
mgb_assert(PyType_Ready(&py_type) >= 0);
// FIXME: rewrite interpret function in cpython instead wrap directly by pybind11::cppfunction
auto interpret = py::cpp_function(
[](OpDef& self, py::object pyf, py::object pyc,
const mgb::SmallVector<py::object>& inputs) {
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) {
return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs));
......@@ -40,8 +236,52 @@ void init_ops(py::module m) {
auto c = [pyc](const TensorPtr& tensor) {
return pyc(tensor->dev_tensor());
};
return self.graph().interpret<py::object>(f, c, inputs);
return self.cast_final_safe<BackwardGraph>().graph().interpret<py::object>(f, c, inputs);
});
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "interpret", interpret.release().ptr()) >= 0);
PyType_Modified(&py_type);
m.add_object("BackwardGraph", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second);
}
/*********** end of hand-write opdefs **************/
// auto generated opdefs
#include "opdef.cpy.inl"
} // anonymous namespace
#include "opdef.py.inl"
namespace PYBIND11_NAMESPACE {
namespace detail {
bool type_caster<OpDef>::load(handle src, bool convert) {
PyObject* obj = src.ptr();
if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) {
return false;
}
value = reinterpret_cast<PyOp(OpDef)*>(obj)->op;
return true;
}
handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) {
PyTypeObject* pytype;
auto& c2p = PyOp(OpDef)::ctype2pytype;
auto&& iter = c2p.find(op.dyn_typeinfo());
if (iter != c2p.end()) { // FIXME: should always meet this condition
pytype = iter->second;
} else { // which means unregistered op type, jsut make it as an opaque op type
// currently, only OprAttr goes into this branch
pytype = &PyOpType(OpDef);
}
PyObject* obj = pytype->tp_alloc(pytype, 0);
mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef)));
reinterpret_cast<PyOp(OpDef)*>(obj)->op = const_cast<OpDef&>(op).shared_from_this();
return py::handle(obj);
}
} // detail
} // PYBIND11_NAMESPACE
void init_ops(py::module m) {
_init_py_op_def(m);
_init_py_backward_graph(m);
INIT_ALL_OP(m)
}
......@@ -76,7 +76,7 @@ cg::OperatorNodeBase* apply_on_var_node(
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto opr = def.cast_final_safe<CondTake>();
auto&& opr = def.cast_final_safe<CondTake>();
mgb_assert(opr.same_type<CondTake>());
mgb_assert(inputs.size() == 2, "CondTake take 2 inputs, got %lu",
inputs.size());
......
......@@ -111,7 +111,7 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node(
SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto param = def.cast_final_safe<ParamPackSplit>();
auto&& param = def.cast_final_safe<ParamPackSplit>();
mgb_assert(inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size());
auto&& inp = inputs[0];
auto&& shp = inp->layout();
......
......@@ -27,6 +27,7 @@ struct BackwardGraphResult {
};
class OpDef : public Hashable,
public NonCopyableObj,
public std::enable_shared_from_this<OpDef> {
mutable const OpTrait* m_trait = nullptr;
public:
......@@ -64,7 +65,7 @@ template<typename T>
class OpDefImplBase : public OpDef {
public:
template<typename ...Args>
static std::shared_ptr<OpDef> make(Args&& ...args) {
static std::shared_ptr<T> make(Args&& ...args) {
return std::make_shared<T>(std::forward<Args>(args)...);
}
};
......
......@@ -10,5 +10,6 @@ set(LLVM_TARGET_DEFINITIONS ${MGE_IR_DIR}/ops.td)
tablegen(MGB opdef.h.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-header")
tablegen(MGB opdef.cpp.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-body")
tablegen(MGB opdef.py.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-binding")
add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl param_defs_tblgen)
tablegen(MGB opdef.cpy.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-c-extension")
add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl opdef.cpy.inl param_defs_tblgen)
set(MGB_OPDEF_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE)
......@@ -11,7 +11,8 @@ enum ActionType {
None,
CppHeader,
CppBody,
Pybind
Pybind,
CPython
};
// NOLINTNEXTLINE
......@@ -22,7 +23,9 @@ llvm::cl::opt<ActionType> action(
clEnumValN(CppBody, "gen-cpp-body",
"Generate operator cpp body"),
clEnumValN(Pybind, "gen-python-binding",
"Generate pybind11 python bindings")));
"Generate pybind11 python bindings"),
clEnumValN(CPython, "gen-python-c-extension",
"Generate python c extensions")));
using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase;
using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin;
......@@ -196,7 +199,7 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
formatMethImpl("hash")
);
os << formatv(
" auto op_ = def_.cast_final_safe<{0}>();\n"
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
className
);
......@@ -210,8 +213,8 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
formatMethImpl("is_same_st")
);
os << formatv(
" auto a_ = lhs_.cast_final_safe<{0}>(),\n"
" b_ = rhs_.cast_final_safe<{0}>();\n"
" auto &&a_ = lhs_.cast_final_safe<{0}>(),\n"
" &&b_ = rhs_.cast_final_safe<{0}>();\n"
" static_cast<void>(a_);\n"
" static_cast<void>(b_);\n",
className
......@@ -237,15 +240,15 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
}
}
struct PybindContext {
std::unordered_map<unsigned int, std::string> enumAlias;
struct EnumContext {
std::unordered_map<unsigned int, std::pair<llvm::StringRef, llvm::StringRef>> enumAlias;
};
static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext& ctx) {
auto class_name = op.getCppClassName();
static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
auto className = op.getCppClassName();
os << formatv(
"py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
class_name
className
);
for (auto&& i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
......@@ -263,17 +266,17 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext
if (iter == enumAlias.end()) {
os << formatv(
"py::enum_<{0}::{1}>({0}Inst, \"{1}\")",
class_name, attr->getEnumName()
className, attr->getEnumName()
);
std::vector<std::string> body;
for (auto&& i: attr->getEnumMembers()) {
os << formatv(
"\n .value(\"{2}\", {0}::{1}::{2})",
class_name, attr->getEnumName(), i
className, attr->getEnumName(), i
);
body.push_back(formatv(
"if (str == \"{2}\") return {0}::{1}::{2};",
class_name, attr->getEnumName(), i
className, attr->getEnumName(), i
));
}
os << formatv(
......@@ -286,21 +289,21 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext
);
os << formatv(
"py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
class_name, attr->getEnumName()
className, attr->getEnumName()
);
enumAlias.emplace(enumID, formatv(
"{0}Inst.attr(\"{1}\")", class_name, attr->getEnumName()
));
enumAlias.emplace(enumID,
std::make_pair(className, attr->getEnumName()));
} else {
os << formatv(
"{0}Inst.attr(\"{1}\") = {2};\n\n",
class_name, attr->getEnumName(), iter->second
"{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n",
className, attr->getEnumName(),
iter->second.first, iter->second.second
);
}
}
}
// generate op class binding
os << formatv("{0}Inst", class_name);
os << formatv("{0}Inst", className);
bool hasDefaultCtor = op.getMgbAttributes().empty();
if (!hasDefaultCtor) {
os << "\n .def(py::init<";
......@@ -327,12 +330,184 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext
for (auto &&i : op.getMgbAttributes()) {
os << formatv(
"\n .def_readwrite(\"{0}\", &{1}::{0})",
i.name, class_name
i.name, className
);
}
os << ";\n\n";
}
static void gen_op_def_python_c_extension_single(raw_ostream &os, MgbOp& op, EnumContext& ctx) {
auto className = op.getCppClassName();
std::string body;
// generate PyType for enum class member
for (auto&& i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
unsigned int enumID;
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
auto&& aliasBase = alias->getAliasBase();
enumID =
llvm::cast<MgbEnumAttr>(aliasBase)
.getBaseRecord()->getID();
} else {
enumID = attr->getBaseRecord()->getID();
}
auto&& enumAlias = ctx.enumAlias;
auto&& iter = enumAlias.find(enumID);
auto enumName = attr->getEnumName();
body += "{\n";
body += formatv(
"auto& e_type = EnumWrapper<{0}::{1}>::type;", className, enumName
);
if (iter == enumAlias.end()) {
os << formatv(
"template<> PyTypeObject EnumWrapper<{0}::{1}>::type={{};\n",
className, enumName);
os << formatv(
"template<> const char* EnumWrapper<{0}::{1}>::name = \"{0}.{1}\";\n",
className, enumName);
std::vector<std::string> pairStr;
for (auto&& i: attr->getEnumMembers()) {
pairStr.push_back(formatv(
"{{normalize_enum(\"{2}\"), {0}::{1}::{2}}",
className, enumName, i));
}
os << formatv(R"(
template<> std::unordered_map<std::string, {0}::{1}>
EnumWrapper<{0}::{1}>::str2type = {{
{2}
};
)", className, enumName, llvm::join(pairStr, ", "));
pairStr.clear();
for (auto&& i: attr->getEnumMembers()) {
pairStr.push_back(formatv(
"{{{0}::{1}::{2}, normalize_enum(\"{2}\")}",
className, enumName, i));
}
os << formatv(R"(
template<> std::unordered_map<{0}::{1}, std::string>
EnumWrapper<{0}::{1}>::type2str = {{
{2}
};
)", className, enumName, llvm::join(pairStr, ", "));
body += formatv(R"(
e_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
e_type.tp_name = "megengine.core._imperative_rt.ops.{0}.{1}";
e_type.tp_basicsize = sizeof(EnumWrapper<{0}::{1}>);
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
e_type.tp_doc = "{0}.{1}";
e_type.tp_base = &PyBaseObject_Type;
e_type.tp_repr = EnumWrapper<{0}::{1}>::py_repr;
e_type.tp_richcompare = EnumWrapper<{0}::{1}>::tp_richcompare;
mgb_assert(PyType_Ready(&e_type) >= 0);
)", className, enumName);
for (auto&& i: attr->getEnumMembers()) {
body += formatv(R"({{
PyObject* inst = e_type.tp_alloc(&e_type, 0);
reinterpret_cast<EnumWrapper<{0}::{1}>*>(inst)->value = {0}::{1}::{2};
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "{2}", inst) >= 0);
})", className, enumName, i);
}
enumAlias.emplace(enumID, std::make_pair(className, enumName));
}
body += formatv(R"(
PyType_Modified(&e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "{0}", reinterpret_cast<PyObject*>(&e_type)) >= 0);
)", enumName);
body += "}\n";
}
}
// generate getsetters
std::vector<std::string> getsetters;
for (auto &&i : op.getMgbAttributes()) {
getsetters.push_back(formatv(
"{{\"{1}\", py_get_generic({0}, {1}), py_set_generic({0}, {1}), \"{1}\", NULL},",
className, i.name));
}
// generate tp_init
std::string initBody;
if (!op.getMgbAttributes().empty()) {
initBody += "static const char* kwlist[] = {";
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv("\"{0}\", ", attr.name);
});
initBody += "NULL};\n";
initBody += " PyObject ";
std::vector<std::string> attrs;
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
attrs.push_back(formatv("*{0} = NULL", attr.name));
});
initBody += llvm::join(attrs, ", ") + ";\n";
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
initBody += std::string(op.getMgbAttributes().size(), 'O');
initBody += "\", const_cast<char**>(kwlist)";
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv(" ,&{0}", attr.name);
});
initBody += "))\n";
initBody += " return -1;\n";
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv(R"(
if ({1}) {{
try {{
reinterpret_cast<PyOp({0})*>(self)->inst().{1} =
pyobj_convert_generic<decltype({0}::{1})>::from({1});
} catch(py::error_already_set& e) {{
e.restore();
return -1;
} catch(py::builtin_exception& e) {{
e.set_error();
return -1;
} catch(...) {{
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
return -1;
}
}
)", className, attr.name);
});
}
initBody += "\n return 0;";
os << formatv(R"(
PyOpDefBegin({0}) // {{
static PyGetSetDef py_getsetters[];
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
// };
PyOpDefEnd({0})
PyGetSetDef PyOp({0})::py_getsetters[] = {{
{1}
{{NULL} /* Sentinel */
};
int PyOp({0})::py_init(PyObject *self, PyObject *args, PyObject *kwds) {{
{2}
}
void _init_py_{0}(py::module m) {{
using py_op = PyOp({0});
auto& py_type = PyOpType({0});
py_type = {{PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.{0}";
py_type.tp_basicsize = sizeof(PyOp({0}));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "{0}";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0);
{3}
PyType_Modified(&py_type);
m.add_object("{0}", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace({0}::typeinfo(), &py_type).second);
}
)",
op.getCppClassName(), llvm::join(getsetters, "\n "), initBody, body);
}
static void for_each_operator(raw_ostream &os, RecordKeeper &keeper,
std::function<void(raw_ostream&, MgbOp&)> callback) {
auto op_base_class = keeper.getClass("Op");
......@@ -360,13 +535,26 @@ static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) {
}
static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) {
PybindContext ctx;
EnumContext ctx;
using namespace std::placeholders;
for_each_operator(os, keeper,
std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx)));
return false;
}
static bool gen_op_def_python_c_extension(raw_ostream &os, RecordKeeper &keeper) {
EnumContext ctx;
using namespace std::placeholders;
for_each_operator(os, keeper,
std::bind(gen_op_def_python_c_extension_single, _1, _2, std::ref(ctx)));
os << "#define INIT_ALL_OP(m)";
for_each_operator(os, keeper, [&](raw_ostream& os, MgbOp& op) {
os << formatv(" \\\n _init_py_{0}(m);", op.getCppClassName());
});
os << "\n";
return false;
}
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv);
......@@ -379,5 +567,8 @@ int main(int argc, char **argv) {
if (action == ActionType::Pybind) {
return TableGenMain(argv[0], &gen_op_def_pybind11);
}
if (action == ActionType::CPython) {
return TableGenMain(argv[0], &gen_op_def_python_c_extension);
}
return -1;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册