#include "./ops.h" #include "./helper.h" #include "./tensor.h" #include "megbrain/common.h" #include "megbrain/custom/data_adaptor.h" #include "megbrain/imperative.h" #include "megbrain/imperative/graph_builder.h" #include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/rng.h" #include "megbrain/imperative/ops/utility.h" #include #include namespace py = pybind11; using namespace mgb::imperative; namespace { auto normalize_enum(const std::string& in) { std::string ret; for (auto&& c : in) { ret += toupper(c); } return ret; } } // anonymous namespace #define CATCH_ALL(RETVAL) \ catch (py::error_already_set & e) { \ e.restore(); \ return RETVAL; \ } \ catch (py::builtin_exception & e) { \ e.set_error(); \ return RETVAL; \ } \ catch (std::exception & e) { \ PyErr_SetString(PyExc_RuntimeError, e.what()); \ return RETVAL; \ } namespace { #define PyOp(name) Py##name #define PyOpType(name) PyOp(name)::py_type #define PyOpDefBegin(name) \ struct PyOp(name) : PyOpDef { \ using Ty = name; \ Ty& inst() { return op->cast_final_safe(); } \ static PyTypeObject py_type; #define PyOpDefEnd(name) \ } \ ; \ PyTypeObject PyOpType(name); #define RETURN_RICHCOMPARE(val1, val2, op) \ do { \ switch (op) { \ case Py_EQ: \ if ((val1) == (val2)) \ Py_RETURN_TRUE; \ Py_RETURN_FALSE; \ case Py_NE: \ if ((val1) != (val2)) \ Py_RETURN_TRUE; \ Py_RETURN_FALSE; \ case Py_LT: \ if ((val1) < (val2)) \ Py_RETURN_TRUE; \ Py_RETURN_FALSE; \ case Py_GT: \ if ((val1) > (val2)) \ Py_RETURN_TRUE; \ Py_RETURN_FALSE; \ case Py_LE: \ if ((val1) <= (val2)) \ Py_RETURN_TRUE; \ Py_RETURN_FALSE; \ case Py_GE: \ if ((val1) >= (val2)) \ Py_RETURN_TRUE; \ Py_RETURN_FALSE; \ default: \ Py_FatalError("Unreachable C code path reached"); \ } \ } while (0) template PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { PyObject* obj = type->tp_alloc(type, 0); T* self = reinterpret_cast(obj); if (self != NULL) { self->op = T::Ty::make(); } return obj; } template struct serialization { static T load(py::object obj) { return py::cast(obj); } template < typename U, typename = std::enable_if_t>>> static py::object dump(U&& t) { return py::cast(std::forward(t)); } }; template void py_dealloc_generic(PyObject* obj) { reinterpret_cast(obj)->op.reset(); Py_TYPE(obj)->tp_free(obj); } template PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) { auto& op = reinterpret_cast(obj)->inst(); return py::cast(op.*attr).release().ptr(); } #define py_get_generic(name, attr) \ py_get_generic_impl().attr), &name::attr> template int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { if (value == NULL) { PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute"); return -1; } auto& op = reinterpret_cast(obj)->inst(); try { // TODO: remove this guard which is used for pybind11 implicit conversion py::detail::loader_life_support guard{}; op.*attr = py::cast(py::handle(value)); } CATCH_ALL(-1) return 0; } #define py_set_generic(name, attr) \ py_set_generic_impl().attr), &name::attr> struct PyOpDef { PyObject_HEAD std::shared_ptr op; static PyTypeObject py_type; static std::unordered_map ctype2pytype; static PyGetSetDef py_getsetters[]; static Py_hash_t tp_hash(PyObject* obj); static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op); static PyObject* py_repr(PyObject* self) { return py::cast(reinterpret_cast(self)->op->make_name()) .release() .ptr(); } }; PyTypeObject PyOpType(OpDef); std::unordered_map PyOp(OpDef)::ctype2pytype; PyObject* py_get_scope(PyObject* obj, void* /* closure */) { return py::cast(reinterpret_cast(obj)->op->scope()).release().ptr(); } int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) { if (value == NULL) { PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute"); return -1; } try { reinterpret_cast(obj)->op->set_scope( py::cast(py::handle(value))); } CATCH_ALL(-1) return 0; } PyGetSetDef PyOp(OpDef)::py_getsetters[] = { {const_cast("scope"), py_get_scope, py_set_scope, const_cast("scope"), NULL}, {NULL}}; Py_hash_t PyOp(OpDef)::tp_hash(PyObject* obj) { return static_cast(reinterpret_cast(obj)->op->hash()); } PyObject* PyOp(OpDef)::tp_richcompare(PyObject* self, PyObject* other, int op) { bool same = reinterpret_cast(self)->op->is_same( *reinterpret_cast(other)->op); if (op == Py_EQ || op == Py_NE) { RETURN_RICHCOMPARE(same, true, op); } Py_RETURN_NOTIMPLEMENTED; } template struct EnumTrait; #define PyEnumHead \ static_assert(std::is_enum_v); \ PyObject_HEAD T value; \ constexpr static const char* name = EnumTrait::name; \ static PyTypeObject* type; \ static const char* members[]; \ static std::unordered_map mem2value; \ static PyObject* pyobj_insts[]; template struct EnumWrapper { PyEnumHead std::string to_string() const { return members[static_cast(value)]; } static PyObject* py_repr(PyObject* self) { return py::cast( std::string(name) + "." + reinterpret_cast(self)->to_string()) .release() .ptr(); } static PyObject* py_dump(PyObject* self) { return py::cast(reinterpret_cast(self)->to_string()) .release() .ptr(); } static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) { if (op == Py_EQ || op == Py_NE) { T lhs, rhs; if (load(other, rhs) && load(self, lhs)) { RETURN_RICHCOMPARE(lhs, rhs, op); } else { RETURN_RICHCOMPARE(0, 1, op); } } Py_RETURN_NOTIMPLEMENTED; } static bool load(py::handle src, T& value) { PyObject* obj = src.ptr(); if (PyObject_TypeCheck(obj, type)) { value = reinterpret_cast(obj)->value; return true; } if (py::isinstance(src)) { auto&& iter = mem2value.find(normalize_enum(py::cast(src))); if (iter != mem2value.end()) { value = iter->second; return true; } else { return false; } } return false; } static PyObject* cast(const T& value) { auto v = static_cast>(value); mgb_assert(v <= EnumTrait::max); PyObject* obj = pyobj_insts[v]; Py_INCREF(obj); return obj; } }; template struct BitCombinedEnumWrapper { PyEnumHead std::string to_string() const { uint32_t value_int = static_cast(value); if (value_int == 0) { return "None"; } else { std::string ret; bool first = true; for (uint32_t i = 0; i < 32; i++) { if (value_int >> i & 1) { if (!first) { ret += " + "; } else { first = false; } ret += (std::string(name) + "." + members[i]); } } return ret; } } static PyObject* py_new_combined_enum( PyTypeObject* type, PyObject* args, PyObject*) { if (!PyTuple_Size(args)) { PyObject* obj = type->tp_alloc(type, 0); reinterpret_cast(obj)->value = T(); return obj; } else { PyObject* input; if (!PyArg_ParseTuple(args, "|O", &input)) { return nullptr; } T value; if (load(input, value)) { return cast(value); } else { PyErr_SetString( PyExc_RuntimeError, mgb::ssprintf( "Cannot convert type %s to type %s\n", input->ob_type->tp_name, name) .c_str()); return nullptr; } } } static PyObject* py_repr(PyObject* self) { return py::cast(reinterpret_cast(self)->to_string()) .release() .ptr(); } static PyObject* py_dump(PyObject* self) { std::vector result; auto value = reinterpret_cast(self)->value; uint32_t value_int = static_cast(value); for (uint32_t i = 0; i < 32; i++) { if (value_int >> i & 1) { result.push_back(members[i]); } } return py::tuple(py::cast(result)).release().ptr(); } static PyObject* py_or(PyObject* self, PyObject* other) { if (!(self->ob_type == other->ob_type)) { return PyErr_Format( PyExc_RuntimeError, "Operand in or operator must be the same type."); } T lhs = reinterpret_cast(self)->value, rhs = reinterpret_cast(other)->value; return cast(lhs | rhs); } static PyObject* py_and(PyObject* self, PyObject* other) { if (!(self->ob_type == other->ob_type)) { return PyErr_Format( PyExc_RuntimeError, "Operand in and operator must be the same type."); } T lhs = reinterpret_cast(self)->value, rhs = reinterpret_cast(other)->value; return cast(lhs & rhs); } static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) { if (op == Py_EQ || op == Py_NE) { T lhs, rhs; if (load(other, rhs) && load(self, lhs)) { RETURN_RICHCOMPARE(lhs, rhs, op); } else { RETURN_RICHCOMPARE(0, 1, op); } } Py_RETURN_NOTIMPLEMENTED; } static bool load(py::handle src, T& value) { PyObject* obj = src.ptr(); if (PyObject_TypeCheck(obj, type)) { value = reinterpret_cast(obj)->value; return true; } if (py::isinstance(src)) { auto&& iter = mem2value.find(normalize_enum(py::cast(src))); if (iter != mem2value.end()) { value = iter->second; return true; } else { return false; } } if (py::isinstance(src)) { auto params = py::cast>(src); bool first = true; for (auto s : params) { auto&& iter = mem2value.find(normalize_enum(s)); if (iter != mem2value.end()) { if (first) { value = iter->second; first = false; } else { value |= iter->second; } } else { return false; } } return true; } if (py::isinstance(obj)) { auto v = py::cast>(src); if (v > EnumTrait::max) { return false; } value = static_cast(v); return true; } return false; } static PyObject* cast(const T& value) { auto v = static_cast>(value); mgb_assert(v <= EnumTrait::max); if ((!v) || (v & (v - 1))) { PyObject* obj = type->tp_alloc(type, 0); reinterpret_cast(obj)->value = value; return obj; } else { PyObject* obj = pyobj_insts[__builtin_ctz(v)]; Py_INCREF(obj); return obj; } } }; template struct serialization>>> { static T load(py::object obj) { auto caster = pybind11::detail::type_caster(); if (caster.load(obj, true)) { return caster; } else { PyErr_SetString(PyExc_RuntimeError, "load faild \n"); return caster; } } static py::object dump(T t) { return py::cast(t).attr("dump")(); } }; void _init_py_op_def(py::module m) { using py_op = PyOp(OpDef); auto& py_type = PyOpType(OpDef); py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; py_type.tp_name = "megengine.core._imperative_rt.OpDef"; py_type.tp_basicsize = sizeof(PyOp(OpDef)); py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; py_type.tp_doc = "OpDef"; py_type.tp_base = &PyBaseObject_Type; py_type.tp_hash = PyOp(OpDef)::tp_hash; py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare; py_type.tp_getset = py_op::py_getsetters; py_type.tp_repr = py_op::py_repr; py_type.tp_dealloc = py_dealloc_generic; mgb_assert(PyType_Ready(&py_type) >= 0); m.add_object("OpDef", reinterpret_cast(&py_type)); } /*********** begin of hand-write opdefs **************/ struct PyOpBase : PyOpDef { static PyTypeObject py_type; static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) { auto* obj = type->tp_alloc(type, 0); if (obj) { auto* self = reinterpret_cast(obj); new (&self->op) decltype(self->op); } return obj; } }; PyTypeObject PyOpBase::py_type; void _init_py_op_base(py::module m) { using py_op = PyOpBase; auto& py_type = PyOpBase::py_type; py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase"; py_type.tp_basicsize = sizeof(py_op); py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; py_type.tp_doc = "PyOpBase"; py_type.tp_base = &PyOpType(OpDef); py_type.tp_dealloc = py_dealloc_generic; py_type.tp_new = py_op::tp_new; mgb_assert(PyType_Ready(&py_type) >= 0); m.add_object("PyOpBase", reinterpret_cast(&py_type)); } /*********** end of hand-write opdefs **************/ // auto generated opdefs #include "opdef.cpy.inl" #undef CATCH_ALL } // anonymous namespace namespace PYBIND11_NAMESPACE { namespace detail { bool type_caster::load(handle src, bool convert) { PyObject* obj = src.ptr(); if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) { return false; } value = reinterpret_cast(obj)->op; if (!value) { // opdef only defined in Python value = std::make_shared(reinterpret_borrow(src)); } return true; } handle type_caster::cast(const OpDef& op, return_value_policy, handle) { if (auto* pyop = op.try_cast_final()) { return object(pyop->obj).release(); } PyTypeObject* pytype; auto& c2p = PyOp(OpDef)::ctype2pytype; auto&& iter = c2p.find(op.dyn_typeinfo()); if (iter != c2p.end()) { // FIXME: should always meet this condition pytype = iter->second; } else { // which means unregistered op type, jsut make it as an opaque op type // currently, only OprAttr goes into this branch pytype = &PyOpType(OpDef); } PyObject* obj = pytype->tp_alloc(pytype, 0); mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef))); reinterpret_cast(obj)->op = const_cast(op).shared_from_this(); return py::handle(obj); } #define ENUM_CASTER_IMPL(T) \ bool type_caster::load(handle src, bool) { \ return EnumWrapper::load(src, value); \ } \ handle type_caster::cast(const T& value, return_value_policy, handle) { \ return EnumWrapper::cast(value); \ } FOR_EACH_ENUM_PARAM(ENUM_CASTER_IMPL) #define BIT_COMBINED_ENUM_CASTER_IMPL(T) \ bool type_caster::load(handle src, bool) { \ return BitCombinedEnumWrapper::load(src, value); \ } \ handle type_caster::cast(const T& value, return_value_policy, handle) { \ return BitCombinedEnumWrapper::cast(value); \ } FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL) } // namespace detail } // namespace PYBIND11_NAMESPACE void init_ops(py::module m) { _init_py_op_def(m); _init_py_op_base(m); INIT_ALL_OP(m) m.def("new_rng_handle", &rng::new_handle); m.def( "delete_rng_handle", [](size_t handle) { if (mgb::imperative::python::interpreter_for_py->check_available()) { mgb::imperative::python::interpreter_for_py->sync(); } mgb::CompNode::sync_all(); mgb::CompNode::foreach ([](mgb::CompNode cn) { auto err = cn.check_async_error(); mgb_assert(!err, "%s", err->what()); }); py_task_q.wait_all_task_finish(); rng::delete_handle(handle); }, py::call_guard()); m.def("set_global_rng_seed", [](uint64_t seed) -> void { mgb_assert( python::interpreter_for_py->check_available(), "set global random seed failed since imperative interpreter has been " "destroyed"); python::interpreter_for_py->sync(); mgb::CompNode::sync_all(); rng::set_global_rng_seed(seed); }); m.def("get_global_rng_seed", &rng::get_global_rng_seed); m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode); struct PySubgraphBuilder { explicit PySubgraphBuilder(std::string name) : name{name} {} std::string name; Subgraph graph; mgb::SmallVector output_grad_mask; Subgraph::var_t next_var = 1; std::shared_ptr key = nullptr; std::shared_ptr build() { if (key == nullptr) { key = std::make_shared(); } return SubgraphOp::make( name, std::make_shared(graph), output_grad_mask, key); } }; py::class_(m, "SubgraphBuilder") .def(py::init()) .def(py::init()) .def("input", [](PySubgraphBuilder& self) { mgb_assert(self.key == nullptr); auto var = self.next_var++; self.graph.inputs.push_back(var); return var; }) .def("apply", [](PySubgraphBuilder& self, std::shared_ptr op, Subgraph::vars_t inputs, size_t nr_outputs) { mgb_assert(self.key == nullptr); Subgraph::vars_t outputs; for (size_t i = 0; i < nr_outputs; ++i) { outputs.push_back(self.next_var++); } self.graph.exprs.push_back({op, inputs, outputs}); return outputs; }) .def("apply_const", [](PySubgraphBuilder& self, py::object value, mgb::DType dtype, mgb::CompNode cn) { mgb_assert(self.key == nullptr); auto var = self.next_var++; mgb::HostTensorND hvalue(cn); npy::np2tensor( value.cast().ptr(), npy::Meth::copy_into(&hvalue), dtype); self.graph.constants.push_back({var, Tensor::make(hvalue)}); return var; }) .def("outputs", [](PySubgraphBuilder& self, Subgraph::vars_t outputs) { mgb_assert(self.key == nullptr); self.graph.outputs = outputs; self.output_grad_mask.resize(outputs.size(), true); }) .def("outputs_has_grad", [](PySubgraphBuilder& self, mgb::SmallVector outputs_has_grad) { mgb_assert(self.key == nullptr); mgb_assert( self.graph.outputs.size() == self.output_grad_mask.size()); self.output_grad_mask = outputs_has_grad; }) .def("get", [](PySubgraphBuilder& self) { return (std::shared_ptr)self.build(); }) .def("compile", [](PySubgraphBuilder& self, int gopt_level) { return (std::shared_ptr)CompiledOp::make( self.build(), gopt_level); }) .def("jit_fuse", [](PySubgraphBuilder& self) { return (std::shared_ptr)CompiledOp::make( JITFusionOp::make(self.build())); }); m.def("set_jit_enabled", &JITFusionOp::set_enabled); bool jit_supported = false; #if MGB_JIT jit_supported = true; #endif m.attr("jit_supported") = jit_supported; auto custom = submodule(m, "_custom"); init_custom(custom); } #define CUSTOM_CASE_TO_PARSE_NON_LIST(dyn_type, static_type) \ case custom::ParamDynType::dyn_type: { \ param_val = py::handle(kv.second).cast(); \ break; \ } #define CUSTOM_CASE_TO_PARSE_LIST(dyn_type, static_type) \ case custom::ParamDynType::dyn_type: { \ auto pyvals = py::handle(kv.second).cast(); \ static_type vals; \ using basic_type = custom::get_vector_template_arg_type::type; \ for (auto& pyval : pyvals) { \ vals.push_back(py::handle(pyval).cast()); \ } \ param_val = vals; \ break; \ } PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) { #if MGB_CUSTOM_OP auto op_name = py::handle(args[0]).cast(); auto kwargs = py::handle(args[1]).cast(); std::shared_ptr opdef = CustomOpDefFactory::inst()->create_opdef(op_name); auto& custom_opdef = static_cast(*opdef); auto& param = custom_opdef.param(); for (auto&& kv : kwargs) { std::string param_name = py::handle(kv.first).cast(); std::string type_name = py::handle(kv.second).ptr()->ob_type->tp_name; if (!param.exist(param_name)) { mgb_log_warn( "op %s have no param named %s, ignore this param parsed from " "python", op_name.c_str(), param_name.c_str()); continue; } auto& param_val = param[param_name]; switch (param_val.type()) { CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST) CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST) CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST) CUSTOM_FOR_BOOL_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST) CUSTOM_FOR_STRING_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST) case custom::ParamDynType::Device: { param_val = to_custom_device(py::handle(kv.second).cast()); break; } default: { mgb_assert( false, "param dtype of %s:%s is invalid", op_name.c_str(), param_name.c_str()); } } } PyTypeObject* pytype; pytype = &PyOpType(OpDef); PyObject* obj = pytype->tp_alloc(pytype, 0); reinterpret_cast(obj)->op = opdef; return obj; #else mgb_assert( false, "Custom Op is disabled now, please build megengine with Custom Op open"); return nullptr; #endif } #undef CUSTOM_CASE_TO_PARSE_LIST #undef CUSTOM_CASE_TO_PARSE_NON_LIST py::list install_custom(const std::string& name, const std::string& path) { #if MGB_CUSTOM_OP py::list ret; const auto& ops_in_lib = custom::LibManager::inst()->install(name, path); for (const auto& op : ops_in_lib) { ret.append(op); } return ret; #else mgb_assert( false, "Custom Op is disabled now, please build megengine with Custom Op open"); py::list ret; return ret; #endif } bool uninstall_custom(const std::string& name) { #if MGB_CUSTOM_OP return custom::LibManager::inst()->uninstall(name); #else mgb_assert( false, "Custom Op is disabled now, please build megengine with Custom Op open"); return false; #endif } py::list get_custom_op_list(void) { #if MGB_CUSTOM_OP std::vector all_ops = CustomOpDefFactory::inst()->op_list(); py::list ret; for (auto& op : all_ops) { ret.append(op); } return ret; #else mgb_assert( false, "Custom Op is disabled now, please build megengine with Custom Op open"); py::list ret; return ret; #endif } #ifndef METH_FASTCALL PyObject* py35_make_custom_op(PyObject* self, PyObject* args) { auto* arr = &PyTuple_GET_ITEM(args, 0); auto size = PyTuple_GET_SIZE(args); return make_custom_op(self, arr, size); }; #endif void init_custom(pybind11::module m) { m.def("_install", &install_custom); m.def("_uninstall", &uninstall_custom); m.def("_get_custom_op_list", &get_custom_op_list); m.def("get_custom_op_abi_tag", [](void) -> int { int ret = 0; #ifdef _GLIBCXX_USE_CXX11_ABI ret = _GLIBCXX_USE_CXX11_ABI; #endif return ret; }); static PyMethodDef method_def = { #ifdef METH_FASTCALL "_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, "" #else "_make_custom_op", (PyCFunction)py35_make_custom_op, METH_VARARGS, "" #endif }; auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr); pybind11::setattr(m, method_def.ml_name, func); }