提交 5a1f9134 编写于 作者: M Megvii Engine Team

chore(imperative): remove unnecessary function template

GitOrigin-RevId: 8dd2f8c308061fd510a6a82f09c94f2214e6f4e4
上级 2de2222e
...@@ -33,6 +33,18 @@ auto normalize_enum(const std::string& in) { ...@@ -33,6 +33,18 @@ auto normalize_enum(const std::string& in) {
} }
} // anonymous namespace } // anonymous namespace
#define CATCH_ALL(RETVAL) \
catch(py::error_already_set& e) { \
e.restore(); \
return RETVAL; \
} catch(py::builtin_exception& e) { \
e.set_error(); \
return RETVAL; \
} catch(std::exception& e) { \
PyErr_SetString(PyExc_RuntimeError, e.what()); \
return RETVAL; \
} \
namespace { namespace {
#define PyOp(name) Py##name #define PyOp(name) Py##name
#define PyOpType(name) PyOp(name)::py_type #define PyOpType(name) PyOp(name)::py_type
...@@ -99,14 +111,6 @@ PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) { ...@@ -99,14 +111,6 @@ PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
#define py_get_generic(name, attr) \ #define py_get_generic(name, attr) \
py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
template<typename T>
PyObject* py_get_scope_impl(PyObject* obj, void* /* closure */) {
// T: PyOpXXX inst(): return XXX in opdef.h.inl
auto& op = reinterpret_cast<T*>(obj)->inst();
return pyobj_convert_generic<std::string>::to(op.scope());
}
#define py_get_scope(class) py_get_scope_impl<PyOp(class)>
template<typename T, typename U, U T::Ty::*attr> template<typename T, typename U, U T::Ty::*attr>
int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
if (value == NULL) { if (value == NULL) {
...@@ -116,51 +120,46 @@ int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { ...@@ -116,51 +120,46 @@ int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
auto& op = reinterpret_cast<T*>(obj)->inst(); auto& op = reinterpret_cast<T*>(obj)->inst();
try { try {
op.*attr = pyobj_convert_generic<U>::from(value); op.*attr = pyobj_convert_generic<U>::from(value);
} CATCH_ALL(-1)
return 0; return 0;
} catch(py::error_already_set& e) {
e.restore();
} catch(py::builtin_exception& e) {
e.set_error();
} catch(...) {
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
}
return -1;
} }
#define py_set_generic(name, attr) \ #define py_set_generic(name, attr) \
py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
template<typename T>
int py_set_scope_impl(PyObject* obj, PyObject* value, void* /* closure */) {
if (value == NULL) {
PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
return -1;
}
auto& op = reinterpret_cast<T*>(obj)->inst();
try {
op.set_scope(pyobj_convert_generic<std::string>::from(value));
return 0;
} catch(py::error_already_set& e) {
e.restore();
} catch(py::builtin_exception& e) {
e.set_error();
} catch(...) {
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
}
return -1;
}
#define py_set_scope(class) py_set_scope_impl<PyOp(class)>
struct PyOpDef { struct PyOpDef {
PyObject_HEAD PyObject_HEAD
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
static PyTypeObject py_type; static PyTypeObject py_type;
static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype; static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype;
static PyGetSetDef py_getsetters[];
static Py_hash_t tp_hash(PyObject *obj); static Py_hash_t tp_hash(PyObject *obj);
static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op); static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op);
}; };
PyTypeObject PyOpType(OpDef); PyTypeObject PyOpType(OpDef);
std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype; std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype;
PyObject* py_get_scope(PyObject* obj, void* /* closure */) {
return pyobj_convert_generic<std::string>::to(
reinterpret_cast<PyOp(OpDef)*>(obj)->op->scope());
}
int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
if (value == NULL) {
PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
return -1;
}
try {
reinterpret_cast<PyOp(OpDef)*>(obj)->op
->set_scope(pyobj_convert_generic<std::string>::from(value));
} CATCH_ALL(-1)
return 0;
}
PyGetSetDef PyOp(OpDef)::py_getsetters[] = {
{const_cast<char*>("scope"), py_get_scope, py_set_scope, "scope", NULL},
{NULL}
};
Py_hash_t PyOp(OpDef)::tp_hash(PyObject *obj) { Py_hash_t PyOp(OpDef)::tp_hash(PyObject *obj) {
return static_cast<Py_hash_t>( return static_cast<Py_hash_t>(
reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash()); reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash());
...@@ -225,6 +224,7 @@ struct pyobj_convert_generic<T, ...@@ -225,6 +224,7 @@ struct pyobj_convert_generic<T,
}; };
void _init_py_op_def(py::module m) { void _init_py_op_def(py::module m) {
using py_op = PyOp(OpDef);
auto& py_type = PyOpType(OpDef); auto& py_type = PyOpType(OpDef);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.OpDef"; py_type.tp_name = "megengine.core._imperative_rt.OpDef";
...@@ -234,6 +234,7 @@ void _init_py_op_def(py::module m) { ...@@ -234,6 +234,7 @@ void _init_py_op_def(py::module m) {
py_type.tp_base = &PyBaseObject_Type; py_type.tp_base = &PyBaseObject_Type;
py_type.tp_hash = PyOp(OpDef)::tp_hash; py_type.tp_hash = PyOp(OpDef)::tp_hash;
py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare; py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0); mgb_assert(PyType_Ready(&py_type) >= 0);
m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type)); m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type));
} }
...@@ -309,6 +310,8 @@ void _init_py_op_base(py::module m) { ...@@ -309,6 +310,8 @@ void _init_py_op_base(py::module m) {
// auto generated opdefs // auto generated opdefs
#include "opdef.cpy.inl" #include "opdef.cpy.inl"
#undef CATCH_ALL
} // anonymous namespace } // anonymous namespace
namespace PYBIND11_NAMESPACE { namespace PYBIND11_NAMESPACE {
......
...@@ -485,52 +485,44 @@ EnumWrapper<{0}::{1}>::type2str = {{ ...@@ -485,52 +485,44 @@ EnumWrapper<{0}::{1}>::type2str = {{
className, i.name)); className, i.name));
} }
getsetters.push_back(formatv(
"{{\"scope\", py_get_scope({0}), py_set_scope({0}), \"scope\", NULL},",
className));
// generate tp_init // generate tp_init
std::string initBody; std::string initBody;
if (!op.getMgbAttributes().empty()) { if (!op.getMgbAttributes().empty()) {
initBody += "static const char* kwlist[] = {"; initBody += "static const char* kwlist[] = {";
std::vector<llvm::StringRef> attr_name_list;
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv("\"{0}\", ", attr.name); attr_name_list.push_back(attr.name);
});
attr_name_list.push_back("scope");
llvm::for_each(attr_name_list, [&](auto&& attr) {
initBody += formatv("\"{0}\", ", attr);
}); });
initBody += "\"scope\", ";
initBody += "NULL};\n"; initBody += "NULL};\n";
initBody += " PyObject "; initBody += " PyObject ";
std::vector<std::string> attrs; std::vector<std::string> attr_init;
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { llvm::for_each(attr_name_list, [&](auto&& attr) {
attrs.push_back(formatv("*{0} = NULL", attr.name)); attr_init.push_back(formatv("*{0} = NULL", attr));
}); });
initBody += llvm::join(attrs, ", ") + ";\n"; initBody += llvm::join(attr_init, ", ") + ";\n";
initBody += " PyObject *scope = NULL;\n";
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
// an extra slot created for name // an extra slot created for name
initBody += std::string(op.getMgbAttributes().size() + 1, 'O'); initBody += std::string(attr_name_list.size(), 'O');
initBody += "\", const_cast<char**>(kwlist)"; initBody += "\", const_cast<char**>(kwlist)";
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { llvm::for_each(attr_name_list, [&](auto&& attr) {
initBody += formatv(", &{0}", attr.name); initBody += formatv(", &{0}", attr);
}); });
initBody += ", &scope";
initBody += "))\n"; initBody += "))\n";
initBody += " return -1;\n"; initBody += " return -1;\n";
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv(R"( initBody += formatv(R"(
if ({1}) {{ if ({1}) {{
try {{ try {{
reinterpret_cast<PyOp({0})*>(self)->inst().{1} = reinterpret_cast<PyOp({0})*>(self)->inst().{1} =
pyobj_convert_generic<decltype({0}::{1})>::from({1}); pyobj_convert_generic<decltype({0}::{1})>::from({1});
} catch(py::error_already_set& e) {{ } CATCH_ALL(-1)
e.restore();
return -1;
} catch(py::builtin_exception& e) {{
e.set_error();
return -1;
} catch(...) {{
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
return -1;
}
} }
)", className, attr.name); )", className, attr.name);
}); });
...@@ -538,18 +530,9 @@ EnumWrapper<{0}::{1}>::type2str = {{ ...@@ -538,18 +530,9 @@ EnumWrapper<{0}::{1}>::type2str = {{
initBody += formatv(R"( initBody += formatv(R"(
if (scope) {{ if (scope) {{
try {{ try {{
reinterpret_cast<PyOp({0})*>(self)->inst().set_scope( reinterpret_cast<PyOp(OpDef)*>(self)->op
pyobj_convert_generic<std::string>::from(scope)); ->set_scope(pyobj_convert_generic<std::string>::from(scope));
} catch(py::error_already_set& e) {{ } CATCH_ALL(-1)
e.restore();
return -1;
} catch(py::builtin_exception& e) {{
e.set_error();
return -1;
} catch(...) {{
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
return -1;
}
} }
)", className); )", className);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册