From 9279104b111ef1bb924af3218d9276f3504882bd Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 20 Aug 2021 19:11:48 +0800 Subject: [PATCH] feat(mge): add opdef serialization and apply_module_trace GitOrigin-RevId: 5b45bded1de8e1fb36447d4469423ef68ff627e8 --- .../experimental/traced_module/__init__.py | 7 ++ .../traced_module/serialization.py | 34 ++++++++ imperative/python/src/module_trace.cpp | 41 +++++++++ imperative/python/src/module_trace.h | 20 +++++ imperative/python/src/ops.cpp | 78 +++++++++++++++++ imperative/python/src/pyext17.h | 9 +- imperative/python/src/tensor.cpp | 39 ++++++++- imperative/python/src/tensor.h | 13 ++- .../test/unit/core/test_serialization.py | 27 ++++++ .../tablegen/targets/python_c_extension.cpp | 83 ++++++++++++++++++- 10 files changed, 337 insertions(+), 14 deletions(-) create mode 100644 imperative/python/megengine/experimental/traced_module/__init__.py create mode 100644 imperative/python/megengine/experimental/traced_module/serialization.py create mode 100644 imperative/python/src/module_trace.cpp create mode 100644 imperative/python/src/module_trace.h diff --git a/imperative/python/megengine/experimental/traced_module/__init__.py b/imperative/python/megengine/experimental/traced_module/__init__.py new file mode 100644 index 000000000..f92f1aa01 --- /dev/null +++ b/imperative/python/megengine/experimental/traced_module/__init__.py @@ -0,0 +1,7 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/imperative/python/megengine/experimental/traced_module/serialization.py b/imperative/python/megengine/experimental/traced_module/serialization.py new file mode 100644 index 000000000..ec9596216 --- /dev/null +++ b/imperative/python/megengine/experimental/traced_module/serialization.py @@ -0,0 +1,34 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from typing import Dict + +from ...core._imperative_rt import OpDef +from ...core.ops import builtin +from ...version import __version__ + +OPDEF_PARAM_LOADER = {} + + +def get_opdef_state(obj: OpDef) -> Dict: + state = obj.__getstate__() + state["type"] = type(obj) + state["version"] = __version__ + return state + + +def load_opdef_from_state(state: Dict) -> OpDef: + assert "type" in state and issubclass(state["type"], OpDef) + assert "version" in state + opdef_type = state.pop("type") + if opdef_type in OPDEF_PARAM_LOADER: + loader = OPDEF_PARAM_LOADER[opdef_type] + state = loader(state) + state.pop("version") + opdef_obj = opdef_type() + opdef_obj.__setstate__(state) + return opdef_obj diff --git a/imperative/python/src/module_trace.cpp b/imperative/python/src/module_trace.cpp new file mode 100644 index 000000000..44cb72768 --- /dev/null +++ b/imperative/python/src/module_trace.cpp @@ -0,0 +1,41 @@ +/** + * \file imperative/python/src/module_trace.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "./module_trace.h" +#include "./helper.h" // include op pybind11 caster + +namespace py = pybind11; + +namespace mgb::imperative::python { + +apply_result_t apply_module_trace(ApplyContext& ctx) { + apply_result_t outputs; + + auto args = py::tuple(ctx.nargs + 1); + args[0] = py::cast(ctx.op); + for (size_t i = 0; i < ctx.nargs; i++) { + args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this()); + } + auto pyout = PyObject_Call(cpp_apply_module_trace, args.ptr(), nullptr); + if (!pyout) throw py::error_already_set(); + auto ret = py::reinterpret_steal(pyout); + + // assumption: python function always returns PyList + auto tup = py::reinterpret_borrow(ret); + for (auto i = 0; i < tup.size(); i++) { + auto tw = TensorWrapper::try_cast(tup[i].ptr()); + outputs.emplace_back(tw->m_tensor); + } + return outputs; +} + +} // namespace mgb::imperative::python diff --git a/imperative/python/src/module_trace.h b/imperative/python/src/module_trace.h new file mode 100644 index 000000000..d99228abc --- /dev/null +++ b/imperative/python/src/module_trace.h @@ -0,0 +1,20 @@ +/** + * \file imperative/python/src/module_trace.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "./tensor.h" + +namespace mgb::imperative::python { + +apply_result_t apply_module_trace(ApplyContext& ctx); + +} // namespace mgb::imperative::python diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index d821894f9..5cbefadc5 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -88,6 +88,19 @@ PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) { return obj; } +template +struct serialization { + static T load(py::object obj) { + return py::cast(obj); + } + template>>> + static py::object dump(U&& t) { + return py::cast(std::forward(t)); + } +}; + + template void py_dealloc_generic(PyObject* obj) { reinterpret_cast(obj)->op.reset(); @@ -127,6 +140,13 @@ struct PyOpDef { static PyGetSetDef py_getsetters[]; static Py_hash_t tp_hash(PyObject *obj); static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op); + static PyObject* py_repr(PyObject* self) { + return py::cast( + reinterpret_cast(self)->op->make_name()) + .release() + .ptr(); + } + }; PyTypeObject PyOpType(OpDef); std::unordered_map PyOp(OpDef)::ctype2pytype; @@ -191,6 +211,13 @@ struct EnumWrapper { std::string(name) + "." + reinterpret_cast(self)->to_string()) .release().ptr(); } + + static PyObject* py_dump(PyObject* self) { + return py::cast(reinterpret_cast(self)->to_string()) + .release() + .ptr(); + } + static PyObject* tp_richcompare(PyObject *self, PyObject *other, int op) { if (op == Py_EQ || op == Py_NE) { T lhs, rhs; @@ -279,6 +306,19 @@ struct BitCombinedEnumWrapper { reinterpret_cast(self)->to_string()) .release().ptr(); } + + static PyObject* py_dump(PyObject* self) { + std::vector result; + auto value = reinterpret_cast(self)->value; + uint32_t value_int = static_cast(value); + for (uint32_t i = 0; i < 32; i++) { + if (value_int >> i & 1) { + result.push_back(members[i]); + } + } + return py::tuple(py::cast(result)).release().ptr(); + } + static PyObject* py_or(PyObject* self, PyObject* other) { if(!(self->ob_type == other->ob_type)){ return PyErr_Format( @@ -326,6 +366,24 @@ struct BitCombinedEnumWrapper { return false; } } + if (py::isinstance(src)) { + auto params = py::cast>(src); + bool first = true; + for (auto s : params){ + auto&& iter = mem2value.find(normalize_enum(s)); + if (iter != mem2value.end()) { + if (first) { + value = iter->second; + first = false; + } else { + value |= iter->second; + } + } else { + return false; + } + } + return true; + } if (py::isinstance(obj)) { auto v = py::cast>(src); if(v > EnumTrait::max) { @@ -351,6 +409,25 @@ struct BitCombinedEnumWrapper { } }; +template +struct serialization>>> { + static T load(py::object obj) { + auto caster = pybind11::detail::type_caster(); + if (caster.load(obj, true)) { + return caster; + } else { + PyErr_SetString(PyExc_RuntimeError, + "load faild \n"); + return caster; + } + } + static py::object dump(T t) { + return py::cast(t).attr("dump")(); + } +}; + + void _init_py_op_def(py::module m) { using py_op = PyOp(OpDef); auto& py_type = PyOpType(OpDef); @@ -363,6 +440,7 @@ void _init_py_op_def(py::module m) { py_type.tp_hash = PyOp(OpDef)::tp_hash; py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare; py_type.tp_getset = py_op::py_getsetters; + py_type.tp_repr = py_op::py_repr; mgb_assert(PyType_Ready(&py_type) >= 0); m.add_object("OpDef", reinterpret_cast(&py_type)); } diff --git a/imperative/python/src/pyext17.h b/imperative/python/src/pyext17.h index 4f10a207a..8d8ace864 100644 --- a/imperative/python/src/pyext17.h +++ b/imperative/python/src/pyext17.h @@ -451,18 +451,11 @@ public: template static PyObject* cnew(Args&&... args) { auto* pytype = type().operator->(); - auto* self = pytype->tp_alloc(pytype, 0); - auto* inst = reinterpret_cast(self)->inst(); - if constexpr (has_vectorcall && tp_vectorcall::valid) { - reinterpret_cast(self)->vectorcall_slot = &tp_vectorcall::template impl<>; - } - new(inst) T(std::forward(args)...); - return self; + return cnew_with_type(pytype, std::forward(args)...); } template static PyObject* cnew_with_type(PyTypeObject* pytype, Args&&... args) { - auto* self = pytype->tp_alloc(pytype, 0); auto* inst = reinterpret_cast(self)->inst(); if constexpr (has_vectorcall && tp_vectorcall::valid) { diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 3fe21c4ed..1e17f794e 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -20,6 +20,7 @@ #include "./tensor.h" #include "./grad.h" #include "./trace.h" +#include "./module_trace.h" #include "./common.h" #include "./numpy_dtypes.h" #include "./graph_rt.h" @@ -41,6 +42,7 @@ interpreter::Interpreter::Channel* interpreter_for_py; PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing; PyObject *cpp_apply_backward_varnode; +PyObject *cpp_apply_module_trace; std::shared_ptr make_const(imperative::TensorPtr value) { if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) { @@ -70,6 +72,7 @@ std::shared_ptr make_const(imperative::TensorPtr value) { REGISTE_APPLY_FUNC(cpp_apply_with_tracing) REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing) REGISTE_APPLY_FUNC(cpp_apply_backward_varnode) +REGISTE_APPLY_FUNC(cpp_apply_module_trace) #undef REGISTE_APPLY_FUNC @@ -79,6 +82,14 @@ Tensor::flags_t ApplyContext::global_enable = 0; void set_tracing() { ApplyContext::global_enable |= Tensor::Flags::TRACE; } void unset_tracing() { ApplyContext::global_enable &= ~Tensor::Flags::TRACE; } +void set_module_tracing() { ApplyContext::global_enable |= Tensor::Flags::MODULE_TRACE; } +void unset_module_tracing() { ApplyContext::global_enable &= ~Tensor::Flags::MODULE_TRACE; } +bool is_tracing_module() { + return ApplyContext::global_enable & Tensor::Flags::MODULE_TRACE; +} + + + bool skip_tracing = false; apply_result_t apply(ApplyContext& ctx) { @@ -117,6 +128,11 @@ apply_result_t apply(ApplyContext& ctx) { return ret; } + if (flags & Tensor::Flags::MODULE_TRACE) { + return apply_module_trace(ctx); + } + + if (flags & Tensor::Flags::TRACE) { return apply_trace(ctx); } else { @@ -310,6 +326,21 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording) #undef REGISTE_TENSORWRAPPER_FUNC +PyObject* TensorWrapper::module_trace_info() { + if (!m_tensor->m_module_trace_info.ptr()) { + PyErr_SetString(PyExc_AttributeError, + "Has no attribute named \'_NodeMixin__node\', please " + "set it first"); + return nullptr; + } + return m_tensor->m_module_trace_info.inc_ref().ptr(); +} + +void TensorWrapper::set_module_trace_info(PyObject* obj) { + m_tensor->m_module_trace_info = py::reinterpret_borrow(obj); +} + + #define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \ PyObject* TensorWrapper::member() { \ @@ -495,7 +526,9 @@ void TensorWrapper::reset(PyObject* tensor) { } std::string user_custom_name = m_tensor->user_custom_name; std::string automatic_name = m_tensor->automatic_name; + auto module_trace_info = m_tensor->m_module_trace_info; m_tensor = t->m_tensor; + m_tensor->m_module_trace_info = module_trace_info; m_tensor->user_custom_name = user_custom_name; m_tensor->automatic_name = automatic_name; } @@ -856,6 +889,7 @@ void init_tensor(py::module m) { .def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info") .def_getset<&TensorWrapper::user_custom_name, &TensorWrapper::set_user_custom_name>("c_name") .def_getset<&TensorWrapper::automatic_name, &TensorWrapper::set_automatic_name>("_name") + .def_getset<&TensorWrapper::module_trace_info, &TensorWrapper::set_module_trace_info>("_NodeMixin__node") .finalize(); if (!tensor_type) throw py::error_already_set(); py::setattr(m, "Tensor", tensor_type); @@ -998,7 +1032,7 @@ void init_tensor(py::module m) { m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing); m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing); m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode); - + m.def("set_cpp_apply_module_trace", &set_cpp_apply_module_trace); m.attr("skip_tracing") = &skip_tracing; py::class_(m, "SharedHandle") @@ -1016,6 +1050,9 @@ void init_tensor(py::module m) { m.def("set_allow_higher_order_directive", [](bool value){ GradKey::allow_higher_order_directive = value; }); + m.def("set_module_tracing", &set_module_tracing); + m.def("unset_module_tracing", &unset_module_tracing); + m.def("is_tracing_module", &is_tracing_module); } #undef MGE_PY_INTERFACE diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 5860a5c36..156c48613 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -96,6 +96,7 @@ struct Tensor : std::enable_shared_from_this, NonCopyableObj { static constexpr flags_t SCALAR = 1; static constexpr flags_t GRAD = 1 << 1; static constexpr flags_t TRACE = 1 << 2; + static constexpr flags_t MODULE_TRACE = 1 << 3; }; flags_t m_flags = 0; @@ -106,6 +107,7 @@ struct Tensor : std::enable_shared_from_this, NonCopyableObj { std::string user_custom_name; std::string automatic_name; cg::VarNode* m_var; + pybind11::object m_module_trace_info; using Handle = interpreter::Interpreter::Handle; @@ -158,10 +160,10 @@ struct TensorWrapper { using wrap_t = pyext17::wrap; friend wrap_t; - inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast(op)->inst();} - inline static TensorWrapper* try_cast(PyObject* op) { - if (!wrap_t::type().isinstance(op)) return nullptr; - return cast(op); + inline static TensorWrapper* cast(PyObject* obj) {return reinterpret_cast(obj)->inst();} + inline static TensorWrapper* try_cast(PyObject* obj) { + if (!wrap_t::type().isinstance(obj)) return nullptr; + return cast(obj); } inline ObjectPtr self() {return wrap_t::pycast(this);} @@ -206,6 +208,8 @@ struct TensorWrapper { void set_compiled_info(PyObject *); PyObject* trace_mixin_info(); void set_trace_mixin_info(PyObject *); + PyObject* module_trace_info(); + void set_module_trace_info(PyObject *); PyObject* user_custom_name(); void set_user_custom_name(PyObject *); PyObject* automatic_name(); @@ -331,6 +335,7 @@ void init_tensor(pybind11::module); extern PyObject *cpp_apply_with_tracing; extern PyObject *cpp_apply_backward_varnode; +extern PyObject *cpp_apply_module_trace; } // namespace mgb::imperative::python diff --git a/imperative/python/test/unit/core/test_serialization.py b/imperative/python/test/unit/core/test_serialization.py index aed29932c..7eae4ea3e 100644 --- a/imperative/python/test/unit/core/test_serialization.py +++ b/imperative/python/test/unit/core/test_serialization.py @@ -14,6 +14,11 @@ import numpy as np import megengine as mge from megengine import Parameter, Tensor +from megengine.core.ops import builtin +from megengine.experimental.traced_module.serialization import ( + get_opdef_state, + load_opdef_from_state, +) def test_tensor_serialization(): @@ -86,3 +91,25 @@ def test_compatibility(): test_old_tensor("tensor_v1_1.mge") test_old_tensor("tensor_v1_2.mge") + + +def test_opdef_serialization(): + with TemporaryFile() as f: + x = builtin.Elemwise(mode="Add") + pickle.dump(get_opdef_state(x), f) + f.seek(0) + load_x = load_opdef_from_state(pickle.load(f)) + assert x == load_x + + with TemporaryFile() as f: + x = builtin.Convolution(stride_h=9, compute_mode="float32") + x.strategy = ( + builtin.Convolution.Strategy.PROFILE + | builtin.Convolution.Strategy.HEURISTIC + | builtin.Convolution.Strategy.REPRODUCIBLE + ) + pickle.dump(get_opdef_state(x), f) + f.seek(0) + load_x = load_opdef_from_state(pickle.load(f)) + assert x.strategy == load_x.strategy + assert x == load_x diff --git a/imperative/tablegen/targets/python_c_extension.cpp b/imperative/tablegen/targets/python_c_extension.cpp index c17506e13..ff1356225 100644 --- a/imperative/tablegen/targets/python_c_extension.cpp +++ b/imperative/tablegen/targets/python_c_extension.cpp @@ -34,6 +34,7 @@ private: void emit_class(); void emit_py_init(); void emit_py_getsetters(); + void emit_py_methods(); Initproc emit_initproc(); MgbOp& op; @@ -133,9 +134,16 @@ void $0(PyTypeObject& py_type) { if (firstOccur) { os << tgfmt(R"( + static PyMethodDef tp_methods[] = { + {const_cast("dump"), (PyCFunction)$enumTpl<$opClass::$enumClass>::py_dump, METH_NOARGS, NULL}, + {NULL} /* Sentinel */ + }; + )", &ctx); + os << tgfmt(R"( static PyType_Slot slots[] = { {Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr}, {Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare}, + {Py_tp_methods, tp_methods}, )", &ctx); if (attr->getEnumCombinedFlag()) { // only bit combined enum could new instance because bitwise operation, @@ -212,17 +220,62 @@ Initproc OpDefEmitter::emit() { emit_class(); emit_py_init(); emit_py_getsetters(); + emit_py_methods(); return emit_initproc(); } void OpDefEmitter::emit_class() { + auto&& className = op.getCppClassName(); + std::string method_defs; + std::vector body; + + llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { + body.push_back(formatv(R"( + {{"{0}", serialization::dump(opdef.{0})})" + , attr.name)); + }); + method_defs += formatv(R"( + static PyObject* getstate(PyObject* self, PyObject*) {{ + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + std::unordered_map state {{ + {1} + }; + return py::cast(state).release().ptr(); + })", className, llvm::join(body, ",")); + + body.clear(); + llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { + body.push_back(formatv(R"( + {{ + auto&& iter = state.find("{0}"); + if (iter != state.end()) { + opdef.{0} = serialization::load(iter->second); + } + })", attr.name)); + }); + + method_defs += formatv(R"( + static PyObject* setstate(PyObject* self, PyObject* args) {{ + PyObject* dict = PyTuple_GetItem(args, 0); + if (!dict) return NULL; + auto state = py::cast>(dict); + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + {1} + Py_RETURN_NONE; + })", className, llvm::join(body, "\n")); + + os << tgfmt(R"( PyOpDefBegin($_self) // { static PyGetSetDef py_getsetters[]; + static PyMethodDef tp_methods[]; + $0 static int py_init(PyObject *self, PyObject *args, PyObject *kwds); // }; PyOpDefEnd($_self) -)", &ctx); +)", &ctx, method_defs); } void OpDefEmitter::emit_py_init() { @@ -302,6 +355,33 @@ PyGetSetDef PyOp($_self)::py_getsetters[] = { )", &ctx, llvm::join(llvm::map_range(op.getMgbAttributes(), f), "\n ")); } +void OpDefEmitter::emit_py_methods(){ + + // generate methods + std::string method_defs; + std::vector method_items; + { + auto&& className = op.getCppClassName(); + // generate getstate + method_items.push_back(formatv( + "{{const_cast(\"__getstate__\"), PyOp({0})::getstate, METH_NOARGS, \"{0} getstate\"},", + className)); + + // generate setstate + method_items.push_back(formatv( + "{{const_cast(\"__setstate__\"), PyOp({0})::setstate, METH_VARARGS, \"{0} setstate\"},", + className)); + } + + + os << tgfmt(R"( + PyMethodDef PyOp($_self)::tp_methods[] = { + $0 + {NULL} /* Sentinel */ + }; + )", &ctx, llvm::join(method_items, "\n ")); +} + Initproc OpDefEmitter::emit_initproc() { std::string initproc = formatv("_init_py_{0}", op.getCppClassName()); std::string subclass_init_call; @@ -321,6 +401,7 @@ void $0(py::module m) { py_type.tp_dealloc = py_dealloc_generic; py_type.tp_new = py_new_generic; py_type.tp_init = py_op::py_init; + py_type.tp_methods = py_op::tp_methods; py_type.tp_getset = py_op::py_getsetters; mgb_assert(PyType_Ready(&py_type) >= 0); $1 -- GitLab