diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 19b48eb158ac8e49078199e64e5d024ad3ec91af..47456fa674fdfd5a30d7d258f1a70aee04212459 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -20,6 +20,9 @@ class AttachSpec: __slots__ = "tensor", "callbacks" +_global_priority = 0 + + class GradManager: r""" GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode @@ -118,6 +121,7 @@ class GradManager: self._grad = None self._after_backward_callback = [] self._gradients = {} + self._priority = None def attach(self, tensors: Iterable[Tensor], callbacks=None): r""" @@ -293,6 +297,7 @@ class GradManager: After this call, you will be able to call :meth:`backward`. """ + global _global_priority if self._recording: raise RuntimeError("already recording") grad = Grad() @@ -300,6 +305,9 @@ class GradManager: self._grad = grad for spec in self._attach_specs.values(): self._do_record(spec) + if self._priority is None: + grad._priority = _global_priority + _global_priority -= 1 grad.__enter__() def _do_record(self, spec): @@ -321,11 +329,14 @@ class GradManager: After this call, you will not be able to call :meth:`backward`. """ + global _global_priority if self._grad is not None: self._grad.__exit__(None, None, None) self._grad = None self._recording = False self._gradients = dict() + if self._priority is None: + _global_priority += 1 def __enter__(self): self.record() @@ -333,3 +344,41 @@ class GradManager: def __exit__(self, exc_type, exc_val, exc_tb): self.release() + + def __and__(self, other): + if isinstance(other, GradManager): + return GradManagerGroup([self, other]) + return NotImplemented + + __rand__ = __and__ + + +class GradManagerGroup: + def __init__(self, gms) -> None: + self._gms = list(gms) + + def merge_with(self, other): + if isinstance(other, GradManager): + other = GradManagerGroup([other]) + elif not isinstance(other, GradManagerGroup): + return NotImplemented + return GradManagerGroup([*self._gms, *other._gms]) + + __and__ = merge_with + __rand__ = merge_with + __or__ = merge_with + __ror__ = merge_with + + def __enter__(self): + global _global_priority + _global_priority += 1 + for gm in self._gms: + gm._priority = _global_priority + gm.record() + + def __exit__(self, exc_type, exc_val, exc_tb): + global _global_priority + _global_priority -= 1 + for gm in self._gms: + gm.release() + gm._priority = None diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index 072f09d2e633ebeb31f8c92de1548d31350dff97..4a6114dcf7e6187f241aba12ffd5c4038aab2c6f 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -47,6 +47,14 @@ class Grad: self._impl = GradKey(name) _grad_manager_dict[self._name] = self + @property + def _priority(self): + return self._impl.priority + + @_priority.setter + def _priority(self, priority): + self._impl.priority = priority + @property def _name(self): return self._impl.name diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index 45ce932d08454dfd0bc41d35c53c267702a8e5bf..23507848952b811daffc8ac3d844f295c10b32f9 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -54,7 +54,7 @@ std::shared_ptr make_backward_graph( for (size_t i = 0; i < ctx.nargs; ++i) { *(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); *(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); - *(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn); + *(bool_ptr++) = !ctx.args[i]->m_grad_info_dict.empty(); } mgb_assert(bool_ptr0 == reinterpret_cast(size_t_ptr) && bool_ptr == reinterpret_cast(buf + buf_size)); @@ -321,7 +321,7 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra for (size_t i = 0; i < ctx.nargs; ++i) { inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]); inputs_copy_weak.push_back(inputs_copy.back().get()); - inputs_copy.back()->m_grad_info = ctx.args[i]->m_grad_info; + inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict; } ApplyContext ctx_dup = ctx; ctx_dup.args = inputs_copy_weak.data(); @@ -365,25 +365,19 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { } // namespace apply_result_t apply_grad(ApplyContext& ctx) { - std::shared_ptr grad_key; + std::unordered_set> grad_keys; for (size_t i = 0; i < ctx.nargs; ++i) { auto* tensor = ctx.args[i]; - if (tensor->m_grad_info.grad_fn) { - auto&& input_grad_key = tensor->m_grad_info.grad_fn->key.lock(); - // tensor is attached to a live GradKey - if (input_grad_key && input_grad_key->active) { - if (grad_key) { - if (grad_key != input_grad_key) { - PyErr_SetString(PyExc_NotImplementedError, "second order grad"); - throw pyext17::py_err_set(); - } - } else { - grad_key = std::move(input_grad_key); + if (!tensor->m_grad_info_dict.empty()) { + size_t grad_cnt = 0; + for (auto&& grad_info: tensor->m_grad_info_dict) { + auto input_grad_key = grad_info.grad_fn->key.lock(); + if (input_grad_key && input_grad_key->active && !input_grad_key->is_blocked()) { + grad_keys.insert(input_grad_key); + grad_cnt++; } - } else { - // cleanup stale grad info - // under what condition? - tensor->m_grad_info = {}; + } + if (!grad_cnt) { tensor->m_flags &= ~Flags::GRAD; } } else { @@ -393,7 +387,7 @@ apply_result_t apply_grad(ApplyContext& ctx) { ctx.flags &= ~Flags::GRAD; - if (!grad_key) { + if (grad_keys.empty()) { return apply(ctx); } @@ -418,54 +412,65 @@ apply_result_t apply_grad(ApplyContext& ctx) { return backward_graph_grad_rule(ctx, grad_fn_holder); }(); - auto& grad_fn = grad_fn_holder.grad_fn; - if (!grad_fn) { + if (!grad_fn_holder.grad_fn) { return outputs; } - grad_fn->key = grad_key; - grad_fn->slots.resize(outputs.size()); - grad_fn->dsts.reserve(ctx.nargs); + for (auto&& grad_key: grad_keys) { + auto grad_fn = std::make_shared(); + grad_fn->backward = grad_fn_holder.grad_fn->backward; + grad_fn->key = grad_key; + grad_fn->slots.resize(outputs.size()); + grad_fn->dsts.reserve(ctx.nargs); - std::visit([&](auto& backward) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - mgb_assert(0); - } else { - for (size_t i = 0; i < ctx.nargs; ++i) { - if (backward.input_has_grad(i) && input_requires_grad(ctx, i)) { - auto& input_grad_info = ctx.args[i]->m_grad_info; - grad_fn->dsts.emplace_back(input_grad_info); - // register as grad producer - grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head); - } else { - grad_fn->dsts.emplace_back(); + std::visit([&](auto& backward) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + mgb_assert(0); + } else { + for (size_t i = 0; i < ctx.nargs; ++i) { + if (backward.input_has_grad(i) && input_requires_grad(ctx, i) && ctx.args[i]->m_grad_info_dict.count(grad_key.get())) { + auto& input_grad_info = ctx.args[i]->m_grad_info_dict.at(grad_key.get()); + grad_fn->dsts.emplace_back(input_grad_info); + // register as grad producer + grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head); + } else { + grad_fn->dsts.emplace_back(); + } } - } - for (size_t i = 0; i < outputs.size(); ++i) { - if (backward.output_requires_grad(i)) { - if (backward.output_captured(i)) { - // avoid reference cycle [Tensor <-> GradFn] - static std::shared_ptr op = std::shared_ptr(new FastpathCopy()); - outputs[i] = python::apply(op, outputs[i])[0]; + for (size_t i = 0; i < outputs.size(); ++i) { + if (backward.output_requires_grad(i)) { + if (backward.output_captured(i)) { + // avoid reference cycle [Tensor <-> GradFn] + static std::shared_ptr op = std::make_shared(); + outputs[i] = python::apply(op, outputs[i])[0]; + } + // populate grad info of output tensor + auto& grad_info = outputs[i]->m_grad_info_dict[grad_key.get()]; + grad_info.grad_fn = grad_fn; + grad_info.idx = i; + grad_info.insert_after(grad_key->free_vars_head); + outputs[i]->m_flags |= Flags::GRAD; } - // populate grad info of output tensor - auto& grad_info = outputs[i]->m_grad_info; - grad_info.grad_fn = grad_fn; - grad_info.idx = i; - grad_info.insert_after(grad_key->free_vars_head); - outputs[i]->m_flags |= Flags::GRAD; } } - } - }, grad_fn->backward); + }, grad_fn->backward); - // record forward history - grad_key->tape.emplace_back(grad_fn); + // record forward history + grad_key->tape.emplace_back(grad_fn); + } return outputs; } +PyObject* GradKeyWrapper::get_priority() { + return py::cast(m_key->priority).release().ptr(); +} + +void GradKeyWrapper::set_priority(pybind11::handle priority) { + m_key->name = py::cast(priority); +} + void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { if (nargs != 2) { throw py::type_error("expect 2 arguments"); @@ -488,24 +493,21 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) { throw py::value_error("grad key finalized"); } - if (tensor->m_grad_info.grad_fn) { - if (tensor->m_grad_info.grad_fn->key.lock().get() != this) { - PyErr_SetString(PyExc_NotImplementedError, "second order grad"); - throw pyext17::py_err_set(); - } - if (tensor->m_grad_info->callback) { + if (tensor->m_grad_info_dict.count(this)) { + if (tensor->m_grad_info_dict.at(this)->callback) { throw py::value_error("callback already set on this tensor"); } } else { - tensor->m_grad_info.idx = 0; - auto& grad_fn = tensor->m_grad_info.grad_fn; + auto& grad_info = tensor->m_grad_info_dict[this]; + grad_info.idx = 0; + auto& grad_fn = grad_info.grad_fn; grad_fn = std::make_shared(); grad_fn->key = shared_from_this(); grad_fn->slots.resize(1); - tensor->m_grad_info.insert_after(free_vars_head); + grad_info.insert_after(free_vars_head); tensor->m_flags |= Flags::GRAD; } - tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback); + tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback); } template @@ -530,8 +532,15 @@ void GradKey::backward(std::vector tensors, std::vectorcleanup();} + size_t priority_backup; + CleanupGuard(GradKey* this_) : owner(this_) { + priority_backup = sm_min_priority; + sm_min_priority = owner->priority; + } + ~CleanupGuard() { + owner->cleanup(); + sm_min_priority = priority_backup; + } } _cleanup_guard(this); if (tape.empty()) return; @@ -542,14 +551,16 @@ void GradKey::backward(std::vector tensors, std::vectorm_tensor->m_grad_info; - if (grad_info.grad_fn && grad_info.grad_fn->key.lock().get() == this) { - grad_info->grad = grads[i]->m_tensor; + if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) { + continue; } + auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this); + grad_info->grad = grads[i]->m_tensor; } std::vector> ref_keeper; ref_keeper.reserve(tape.size()); + // back-propagation in reverse order for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { auto&& grad_fn = tape[k].lock(); @@ -619,13 +630,14 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) { PyErr_SetString(PyExc_TypeError, "expect Tensor"); return nullptr; } - auto&& grad_fn = tw->m_tensor->m_grad_info.grad_fn; - if (grad_fn && grad_fn->key.lock() == m_key) { + if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) { Py_RETURN_TRUE; } Py_RETURN_FALSE; } +int GradKey::sm_min_priority = 0; + GradKey::~GradKey() { cleanup(); } @@ -635,4 +647,41 @@ std::unordered_map& grad_rule_registry() { return registry; } +void GradInfoCollection::_shrink() { + auto pred = [](GradInfo& info){ return !(info.grad_fn) || info.grad_fn->key.expired(); }; + auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred); + m_storage.erase(iter, m_storage.end()); +} + +bool GradInfoCollection::contains(GradKey* key) { + _shrink(); + for (auto&& grad_info: m_storage) { + if (grad_info.grad_fn->key.lock().get() == key) { + return true; + } + } + return false; +} + +GradInfo& GradInfoCollection::operator[](GradKey* key) { + _shrink(); + for (auto&& grad_info: m_storage) { + if (grad_info.grad_fn->key.lock().get() == key) { + return grad_info; + } + } + m_storage.emplace_back(); + return m_storage.back(); +} + +GradInfo& GradInfoCollection::at(GradKey* key) { + _shrink(); + for (auto&& grad_info: m_storage) { + if (grad_info.grad_fn->key.lock().get() == key) { + return grad_info; + } + } + mgb_assert(false); +} + } // namespace mgb::imperative::python diff --git a/imperative/python/src/grad.h b/imperative/python/src/grad.h index a3fb58e181d68f38e13236c426af0a2e2ae08042..17f28dbbb5e28b118ec30c7e7dc5a67fc98ce688 100644 --- a/imperative/python/src/grad.h +++ b/imperative/python/src/grad.h @@ -26,12 +26,18 @@ struct GradKey : std::enable_shared_from_this, NonCopyableObj { bool active = true; GradInfo::head_t free_vars_head; std::vector> tape; + int priority = 0; ~GradKey(); void attach(Tensor* tensor, pybind11::object callback); void backward(std::vector, std::vector); void cleanup(); + bool is_blocked() const { + return priority < sm_min_priority; + } +private: + static int sm_min_priority; }; struct GradKeyWrapper { @@ -44,6 +50,8 @@ struct GradKeyWrapper { PyObject* get_name(); void set_name(pybind11::handle name); + PyObject* get_priority(); + void set_priority(pybind11::handle priority); void attach(PyObject*const* args, size_t nargs); void backward(std::vector, std::vector); PyObject* is_attached_to(PyObject*const* args, size_t nargs); @@ -150,7 +158,7 @@ using GradRuleFn = std::function& grad_rule_registry(); inline bool input_requires_grad(const ApplyContext& ctx, size_t i) { - return bool(ctx.args[i]->m_grad_info.grad_fn); + return !ctx.args[i]->m_grad_info_dict.empty(); } struct GradRuleFallback : std::exception {}; diff --git a/imperative/python/src/grad_info.h b/imperative/python/src/grad_info.h index b42356e6d233db6fd8fddcc066a421b871374420..f5119144d9c435d9b4c249064d2615d16cae819e 100644 --- a/imperative/python/src/grad_info.h +++ b/imperative/python/src/grad_info.h @@ -15,6 +15,7 @@ namespace mgb::imperative::python { +struct GradKey; struct GradFn; struct GradSlot; @@ -32,6 +33,10 @@ struct GradInfo : GradSlotPtr, intrusive_list::Node(rhs)){} + GradInfo& operator=(const GradInfo& rhs) { + return *this = const_cast(rhs); + } }; } // namespace mgb::imperative::python diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 38be390631cbd1810f0b2d11a940edadda1e1d31..87c254e59e93d639ab9f90ac06fd3f4475181a7e 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -182,7 +182,7 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje if (py::isinstance(py::handle(args[0]))){ SmallVector vinputs(nargs); for (size_t i = 0; i < nargs; ++i) { - vinputs[i] = py::handle(args[i]).cast()->m_node; + vinputs[i] = py::handle(args[i]).cast()->m_node; } auto op = ctx.op.get(); auto rst = OpDef::apply_on_var_node(*op, vinputs); diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index f9409a3f2afad8b58290e0c693156aa9a3a9e81d..d3f2c9fffa3f0e24e584f211f7403ca10b327abd 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -17,6 +17,7 @@ #include "megbrain/imperative/interpreter.h" #include "pybind11/pybind11.h" #include +#include #include "./pyext17.h" @@ -36,6 +37,8 @@ struct ObjectPtr : B { namespace mgb::imperative::python { +struct GradKey; + extern interpreter::Interpreter::Channel* interpreter_for_py; class SharedHandle { @@ -58,6 +61,34 @@ public: }; +// impl in grad.cpp +class GradInfoCollection { +private: + SmallVector m_storage; +protected: + void _shrink(); +public: + bool contains(GradKey* key); + GradInfo& operator[](GradKey* key); + GradInfo& at(GradKey* key); + bool empty() { + _shrink(); + return m_storage.empty(); + } + auto begin() { + _shrink(); + return m_storage.begin(); + } + auto end() { + _shrink(); + return m_storage.end(); + } + size_t count(GradKey* key) { + return contains(key) ? 1 : 0; + } +}; + + struct Tensor : std::enable_shared_from_this, NonCopyableObj { using flags_t = uint64_t; @@ -69,7 +100,7 @@ struct Tensor : std::enable_shared_from_this, NonCopyableObj { flags_t m_flags = 0; - GradInfo m_grad_info; + GradInfoCollection m_grad_info_dict; TraceInfo m_trace_info; SharedHandle m_handle; std::string user_custom_name; @@ -88,7 +119,7 @@ struct Tensor : std::enable_shared_from_this, NonCopyableObj { inline std::shared_ptr copy() { auto ret = std::make_shared(m_handle); ret->m_flags = m_flags; - ret->m_grad_info = m_grad_info; + ret->m_grad_info_dict = m_grad_info_dict; ret->m_trace_info = m_trace_info; ret->m_var = m_var; return ret; diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 810246dd0e4686b24653f83eb0d24bf656bf9aea..248d7ca1bccbee6c016a8126587685ea5187716c 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -108,21 +108,24 @@ def test_grad_2(): np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) -@pytest.mark.skip(reason="high order gradient was not implemented yet") def test_2nd_grad(): x_np = np.random.rand(10).astype("float32") x = as_tensor(x_np) ones = as_tensor(np.ones_like(x_np)) grad = Grad().wrt(x, callback=save_to(x)) + grad._priority = -1 grad2 = Grad().wrt(x, callback=save_to(x)) + grad2._priority = 0 y = cos(x) grad(y, ones) + z = x.grad np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) - grad2(x.grad, ones) + x.grad = None + grad2(z, ones) np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np)) diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 13ebcced7f6d50c6b22f0091aea6adfb1fc87248..185b1d7ac2e26180d075e020afa72f8a3e5a6b3d 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -398,20 +398,6 @@ OP_TRAIT_REG(Copy, Copy) .fallback(); }} // copy -namespace { namespace identity { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { - auto&& op = def.cast_final_safe(); - mgb_assert(inputs.size() == 1); - OperatorNodeConfig config{op.make_name()}; - return opr::Identity::make(inputs[0], config); -} -OP_TRAIT_REG(Identity, Identity) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // identity - namespace { namespace assert_equal { auto apply_on_var_node( const OpDef& def, diff --git a/imperative/src/impl/ops/utility.cpp b/imperative/src/impl/ops/utility.cpp index 72e72abc862a231d95dce3146e213b061a763a75..d4e8deefa93f53a9db5a8def19f489b5b42139ca 100644 --- a/imperative/src/impl/ops/utility.cpp +++ b/imperative/src/impl/ops/utility.cpp @@ -9,6 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/opr/utility.h" @@ -32,4 +33,25 @@ OP_TRAIT_REG(FastpathCopy,FastpathCopy) MGB_DYN_TYPE_OBJ_FINAL_IMPL(FastpathCopy); +namespace { namespace identity { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = def.cast_final_safe(); + mgb_assert(inputs.size() == 1); + OperatorNodeConfig config{op.make_name()}; + return opr::Identity::make(inputs[0], config); +} + +auto apply_on_physical_tensor( + const OpDef& def, + const SmallVector& inputs) { + return SmallVector{inputs[0]}; +} +OP_TRAIT_REG(Identity, Identity) + .apply_on_var_node(apply_on_var_node) + .apply_on_physical_tensor(apply_on_physical_tensor) + .fallback(); +}} // identity + } // namespace mgb::imperative