/** * \file imperative/python/src/ops.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "./ops.h" #include "megbrain/imperative.h" #include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/autogen.h" #include #include namespace py = pybind11; using namespace mgb::imperative; namespace { auto normalize_enum(const std::string& in) { std::string ret; for (auto&& c : in) { ret += toupper(c); } return ret; } } // anonymous namespace #define CATCH_ALL(RETVAL) \ catch(py::error_already_set& e) { \ e.restore(); \ return RETVAL; \ } catch(py::builtin_exception& e) { \ e.set_error(); \ return RETVAL; \ } catch(std::exception& e) { \ PyErr_SetString(PyExc_RuntimeError, e.what()); \ return RETVAL; \ } \ namespace { #define PyOp(name) Py##name #define PyOpType(name) PyOp(name)::py_type #define PyOpDefBegin(name) \ struct PyOp(name) : PyOpDef { \ using Ty = name; \ Ty& inst() { return op->cast_final_safe(); } \ static PyTypeObject py_type; #define PyOpDefEnd(name) \ }; \ PyTypeObject PyOpType(name); #define RETURN_RICHCOMPARE(val1, val2, op) \ do { \ switch (op) { \ case Py_EQ: if ((val1) == (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ case Py_NE: if ((val1) != (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ case Py_LT: if ((val1) < (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ case Py_GT: if ((val1) > (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ case Py_LE: if ((val1) <= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ case Py_GE: if ((val1) >= (val2)) Py_RETURN_TRUE; Py_RETURN_FALSE; \ default: \ Py_FatalError("Unreachable C code path reached"); \ } \ } while (0) template PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { PyObject* obj = type->tp_alloc(type, 0); T* self = reinterpret_cast(obj); if (self != NULL) { self->op = T::Ty::make(); } return obj; } template void py_dealloc_generic(PyObject* obj) { reinterpret_cast(obj)->op.reset(); Py_TYPE(obj)->tp_free(obj); } template PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) { auto& op = reinterpret_cast(obj)->inst(); return py::cast(op.*attr).release().ptr(); } #define py_get_generic(name, attr) \ py_get_generic_impl().attr), &name::attr> template int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { if (value == NULL) { PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute"); return -1; } auto& op = reinterpret_cast(obj)->inst(); try { // TODO: remove this guard which is used for pybind11 implicit conversion py::detail::loader_life_support guard{}; op.*attr = py::cast(py::handle(value)); } CATCH_ALL(-1) return 0; } #define py_set_generic(name, attr) \ py_set_generic_impl().attr), &name::attr> struct PyOpDef { PyObject_HEAD std::shared_ptr op; static PyTypeObject py_type; static std::unordered_map ctype2pytype; static PyGetSetDef py_getsetters[]; static Py_hash_t tp_hash(PyObject *obj); static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op); }; PyTypeObject PyOpType(OpDef); std::unordered_map PyOp(OpDef)::ctype2pytype; PyObject* py_get_scope(PyObject* obj, void* /* closure */) { return py::cast( reinterpret_cast(obj)->op->scope()).release().ptr(); } int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) { if (value == NULL) { PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute"); return -1; } try { reinterpret_cast(obj)->op ->set_scope(py::cast(py::handle(value))); } CATCH_ALL(-1) return 0; } PyGetSetDef PyOp(OpDef)::py_getsetters[] = { {const_cast("scope"), py_get_scope, py_set_scope, "scope", NULL}, {NULL} }; Py_hash_t PyOp(OpDef)::tp_hash(PyObject *obj) { return static_cast( reinterpret_cast(obj)->op->hash()); } PyObject* PyOp(OpDef)::tp_richcompare(PyObject *self, PyObject *other, int op) { bool same = reinterpret_cast(self)->op->is_same( *reinterpret_cast(other)->op); if (op == Py_EQ || op == Py_NE) { RETURN_RICHCOMPARE(same, true, op); } Py_RETURN_NOTIMPLEMENTED; } template struct EnumTrait; #define PyEnumHead \ static_assert(std::is_enum_v); \ PyObject_HEAD \ T value; \ constexpr static const char *name = EnumTrait::name; \ static PyTypeObject type; \ static const char* members[]; \ static std::unordered_map mem2value; \ static PyObject* pyobj_insts[]; template struct EnumWrapper { PyEnumHead std::string to_string() const { return members[static_cast(value)]; } static PyObject* py_repr(PyObject* self) { return py::cast( std::string(name) + "." + reinterpret_cast(self)->to_string()) .release().ptr(); } static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) { T lhs = reinterpret_cast(self)->value, rhs = reinterpret_cast(other)->value; if (op == Py_EQ || op == Py_NE) { RETURN_RICHCOMPARE(lhs, rhs, op); } Py_RETURN_NOTIMPLEMENTED; } static bool load(py::handle src, T& value) { PyObject* obj = src.ptr(); if (PyObject_TypeCheck(obj, &type)) { value = reinterpret_cast(obj)->value; return true; } if (py::isinstance(src)) { auto&& iter = mem2value.find( normalize_enum(py::cast(src))); if (iter != mem2value.end()) { value = iter->second; return true; } else { return false; } } return false; } static PyObject* cast(const T& value) { auto v = static_cast>(value); mgb_assert(v <= EnumTrait::max); PyObject* obj = pyobj_insts[v]; Py_INCREF(obj); return obj; } }; template struct BitCombinedEnumWrapper { PyEnumHead static PyNumberMethods number_methods; std::string to_string() const { uint32_t value_int = static_cast(value); if (value_int == 0) { return "None"; } else { std::string ret; bool first = true; for (uint32_t i = 0; i < 32; i++) { if (value_int >> i & 1) { if (!first) { ret += " + "; } else { first = false; } ret += (std::string(name) + "." + members[i]); } } return ret; } } static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject* args, PyObject*) { if (!PyTuple_Size(args)) { PyObject* obj = type->tp_alloc(type, 0); reinterpret_cast(obj)->value = T(); return obj; } else { PyObject* input; if (!PyArg_ParseTuple(args, "|O", &input)) { return nullptr; } T value; if (load(input, value)) { return cast(value); } else { PyErr_SetString(PyExc_RuntimeError, mgb::ssprintf("Cannot convert type %s to type %s\n", input->ob_type->tp_name, name).c_str()); return nullptr; } } } static PyObject* py_repr(PyObject* self) { return py::cast( reinterpret_cast(self)->to_string()) .release().ptr(); } static PyObject* py_or(PyObject* self, PyObject* other) { if(!(self->ob_type == other->ob_type)){ return PyErr_Format( PyExc_RuntimeError, "Operand in or operator must be the same type."); } T lhs = reinterpret_cast(self)->value, rhs = reinterpret_cast(other)->value; return cast(lhs | rhs); } static PyObject* py_and(PyObject* self, PyObject* other) { if (!(self->ob_type == other->ob_type)) { return PyErr_Format( PyExc_RuntimeError, "Operand in and operator must be the same type."); } T lhs = reinterpret_cast(self)->value, rhs = reinterpret_cast(other)->value; return cast(lhs & rhs); } static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) { T lhs = reinterpret_cast(self)->value, rhs = reinterpret_cast(other)->value; if (op == Py_EQ || op == Py_NE) { RETURN_RICHCOMPARE(lhs, rhs, op); } Py_RETURN_NOTIMPLEMENTED; } static bool load(py::handle src, T& value) { PyObject* obj = src.ptr(); if (PyObject_TypeCheck(obj, &type)) { value = reinterpret_cast(obj)->value; return true; } if (py::isinstance(src)) { auto&& iter = mem2value.find( normalize_enum(py::cast(src))); if (iter != mem2value.end()) { value = iter->second; return true; } else { return false; } } if (py::isinstance(obj)) { auto v = py::cast>(src); if(v > EnumTrait::max) { return false; } value = static_cast(v); return true; } return false; } static PyObject* cast(const T& value) { auto v = static_cast>(value); mgb_assert(v <= EnumTrait::max); if ((!v) || (v & (v - 1))) { PyTypeObject* pytype = &type; PyObject* obj = pytype->tp_alloc(pytype, 0); reinterpret_cast(obj)->value = value; return obj; } else { PyObject* obj = pyobj_insts[__builtin_ctz(v)]; Py_INCREF(obj); return obj; } } }; void _init_py_op_def(py::module m) { using py_op = PyOp(OpDef); auto& py_type = PyOpType(OpDef); py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; py_type.tp_name = "megengine.core._imperative_rt.OpDef"; py_type.tp_basicsize = sizeof(PyOp(OpDef)); py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; py_type.tp_doc = "OpDef"; py_type.tp_base = &PyBaseObject_Type; py_type.tp_hash = PyOp(OpDef)::tp_hash; py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare; py_type.tp_getset = py_op::py_getsetters; mgb_assert(PyType_Ready(&py_type) >= 0); m.add_object("OpDef", reinterpret_cast(&py_type)); } /*********** begin of hand-write opdefs **************/ PyOpDefBegin(BackwardGraph) // {{ // }; PyOpDefEnd(BackwardGraph) void _init_py_backward_graph(py::module m) { using py_op = PyOp(BackwardGraph); auto& py_type = PyOpType(BackwardGraph); py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; py_type.tp_name = "megengine.core._imperative_rt.ops.BackwardGraph"; py_type.tp_basicsize = sizeof(PyOp(BackwardGraph)); py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; py_type.tp_doc = "BackwardGraph"; py_type.tp_base = &PyOpType(OpDef); py_type.tp_dealloc = py_dealloc_generic; py_type.tp_new = py_new_generic; mgb_assert(PyType_Ready(&py_type) >= 0); // FIXME: rewrite interpret function in cpython instead wrap directly by pybind11::cppfunction auto interpret = py::cpp_function( [](OpDef& self, py::object pyf, py::object pyc, const mgb::SmallVector& inputs) { auto f = [pyf](OpDef& op, const mgb::SmallVector& inputs) { return py::cast>(pyf(op.shared_from_this(), inputs)); }; auto c = [pyc](const TensorPtr& tensor) { return pyc(tensor->dev_tensor()); }; return self.cast_final_safe().graph().interpret(f, c, inputs); }); mgb_assert(PyDict_SetItemString( py_type.tp_dict, "interpret", interpret.release().ptr()) >= 0); PyType_Modified(&py_type); m.add_object("BackwardGraph", reinterpret_cast(&py_type)); mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second); } struct PyOpBase : PyOpDef { static PyTypeObject py_type; static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) { auto* obj = type->tp_alloc(type, 0); if (obj) { auto* self = reinterpret_cast(obj); new(&self->op) decltype(self->op); } return obj; } }; PyTypeObject PyOpBase::py_type; void _init_py_op_base(py::module m) { using py_op = PyOpBase; auto& py_type = PyOpBase::py_type; py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase"; py_type.tp_basicsize = sizeof(py_op); py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; py_type.tp_doc = "PyOpBase"; py_type.tp_base = &PyOpType(OpDef); py_type.tp_dealloc = py_dealloc_generic; py_type.tp_new = py_op::tp_new; mgb_assert(PyType_Ready(&py_type) >= 0); m.add_object("PyOpBase", reinterpret_cast(&py_type)); } /*********** end of hand-write opdefs **************/ // auto generated opdefs #include "opdef.cpy.inl" #undef CATCH_ALL } // anonymous namespace namespace PYBIND11_NAMESPACE { namespace detail { bool type_caster::load(handle src, bool convert) { PyObject* obj = src.ptr(); if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) { return false; } value = reinterpret_cast(obj)->op; if (!value) { // opdef only defined in Python value = std::make_shared(reinterpret_borrow(src)); } return true; } handle type_caster::cast(const OpDef& op, return_value_policy, handle) { if (auto* pyop = op.try_cast_final()) { return object(pyop->obj).release(); } PyTypeObject* pytype; auto& c2p = PyOp(OpDef)::ctype2pytype; auto&& iter = c2p.find(op.dyn_typeinfo()); if (iter != c2p.end()) { // FIXME: should always meet this condition pytype = iter->second; } else { // which means unregistered op type, jsut make it as an opaque op type // currently, only OprAttr goes into this branch pytype = &PyOpType(OpDef); } PyObject* obj = pytype->tp_alloc(pytype, 0); mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef))); reinterpret_cast(obj)->op = const_cast(op).shared_from_this(); return py::handle(obj); } #define ENUM_CASTER_IMPL(T) \ bool type_caster::load(handle src, bool) { \ return EnumWrapper::load(src, value); \ } \ handle type_caster::cast(const T& value, return_value_policy, handle) { \ return EnumWrapper::cast(value); \ } FOR_EACH_ENUM_PARAM(ENUM_CASTER_IMPL) #define BIT_COMBINED_ENUM_CASTER_IMPL(T) \ bool type_caster::load(handle src, bool) { \ return BitCombinedEnumWrapper::load(src, value); \ } \ handle type_caster::cast(const T& value, return_value_policy, handle) { \ return BitCombinedEnumWrapper::cast(value); \ } FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL) } // detail } // PYBIND11_NAMESPACE void init_ops(py::module m) { _init_py_op_def(m); _init_py_backward_graph(m); _init_py_op_base(m); INIT_ALL_OP(m) }