diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index dd598795b4956bd1f91032de8a807840b5a00799..1fd766362822e0e3aad634aa301abecc6158befc 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -70,7 +70,7 @@ std::shared_ptr make_backward_graph( for (size_t i = 0; i < ctx.nargs; ++i) { inputs[i].comp_node = ctx.args[i]->comp_node(); inputs[i].layout.dtype = ctx.args[i]->dtype(); - input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn); + input_requires_grad[i] = python::input_requires_grad(ctx, i); } auto result = std::make_shared( proxy_graph_detail::make_backward_graph( @@ -82,21 +82,6 @@ std::shared_ptr make_backward_graph( return result; } -struct BackwardContext { - PyTypeObject* pytype = nullptr; - - auto wrap_tensor(std::shared_ptr t) { - if (pytype) { - return TensorWrapper::make(pytype, std::move(t)); - } - return TensorWrapper::make(std::move(t)); - } - - auto wrap_tensor(Tensor* t) { - return wrap_tensor(t->shared_from_this()); - } -}; - struct BackwardGraphWithClosure { std::shared_ptr backward_graph; SmallVector> closure; @@ -270,7 +255,7 @@ struct GradFn : std::enable_shared_from_this { // same length as inputs (of forward op) SmallVector dsts; // encapsules actual function to compute gradient - std::variant backward; + std::variant backward; // a flag used during backward bool in_ref_keeper = false; @@ -335,8 +320,7 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); } auto grad_rule = py::getattr(op->obj, "_grad_rule"); - auto pyret = (scoped_disable(Flags::GRAD), - py::reinterpret_steal(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr))); // comma expression + auto pyret = py::reinterpret_steal(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr)); auto [outputs, backward] = py::cast>(pyret); ret_grad_fn.emplace(std::move(backward), ctx.nargs); if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) { @@ -388,9 +372,25 @@ apply_result_t apply_grad(ApplyContext& ctx) { } GradFnHelper grad_fn_holder; - auto outputs = ctx.op->same_type() ? - python_grad_rule(ctx, grad_fn_holder) : - backward_graph_grad_rule(ctx, grad_fn_holder); + auto outputs = [&]() { + auto _ = scoped_disable(Flags::GRAD); + if (ctx.op->same_type()) { + return python_grad_rule(ctx, grad_fn_holder); + } + auto&& registry = grad_rule_registry(); + auto&& it = registry.find(ctx.op->dyn_typeinfo()); + if (it != registry.end()) { + auto&& maker = grad_fn_holder.emplace().maker(ctx); + try { + auto ret = it->second(ctx, maker); + maker.finalize(); + return ret; + } catch (GradRuleFallback&) { + grad_fn_holder.emplace(); + } + } + return backward_graph_grad_rule(ctx, grad_fn_holder); + }(); auto& grad_fn = grad_fn_holder.grad_fn; if (!grad_fn) { @@ -407,7 +407,7 @@ apply_result_t apply_grad(ApplyContext& ctx) { mgb_assert(0); } else { for (size_t i = 0; i < ctx.nargs; ++i) { - if (backward.input_has_grad(i)) { + if (backward.input_has_grad(i) && input_requires_grad(ctx, i)) { auto& input_grad_info = ctx.args[i]->m_grad_info; grad_fn->dsts.emplace_back(input_grad_info); // register as grad producer @@ -487,18 +487,8 @@ void accum_grad(std::shared_ptr& grad, T&& delta) { grad = std::forward(delta); return; } - static ApplyContext ctx; - if (!ctx.op) { - ctx.op = std::shared_ptr(new Elemwise(Elemwise::Mode::ADD)); - ctx.nargs = 2; - } - Tensor* args[2] = {grad.get(), delta.get()}; - ctx.args = args; - ctx.flags = grad->m_flags | delta->m_flags; - if (is_tracing) { - ctx.flags |= Flags::TRACE; - } - grad = apply(ctx)[0]; + static std::shared_ptr op = std::shared_ptr(new Elemwise(Elemwise::Mode::ADD)); + grad = apply(op, grad, std::forward(delta))[0]; } void GradKey::backward(std::vector tensors, std::vector grads) { @@ -582,4 +572,9 @@ GradKey::~GradKey() { cleanup(); } +std::unordered_map& grad_rule_registry() { + static std::unordered_map registry; + return registry; +} + } // namespace mgb::imperative::python diff --git a/imperative/python/src/grad.h b/imperative/python/src/grad.h index e94c229226d68e08a80a4bda0b5e9923d2007127..864f6f5b1a75f25e582b031b63c302585324eac6 100644 --- a/imperative/python/src/grad.h +++ b/imperative/python/src/grad.h @@ -45,6 +45,117 @@ struct GradKeyWrapper { void backward(std::vector, std::vector); }; +struct BackwardContext { + PyTypeObject* pytype = nullptr; + + auto wrap_tensor(std::shared_ptr t) { + if (pytype) { + return TensorWrapper::make(pytype, std::move(t)); + } + return TensorWrapper::make(std::move(t)); + } + + auto wrap_tensor(Tensor* t) { + return wrap_tensor(t->shared_from_this()); + } +}; + +struct CustomBackward { + using BackwardFn = std::function; + BackwardFn m_backward; + SmallVector m_input_has_grad; + struct OutputAttr {bool requires_grad = true, captured = true;}; + SmallVector m_output_attrs; + +public: + template + void operator()(BackwardContext& ctx, T&& grads, R&& receiver) { + size_t nargs = grads.size(); + Tensor* args[nargs]; + for (size_t i = 0; i < nargs; ++i) { + args[i] = grads[i]; + } + auto ret = m_backward(ctx, args, nargs); + for (size_t i = 0; i < ret.size(); ++i) { + if (auto&& t = ret[i]) { + receiver(i, std::move(t)); + } + } + } + + bool input_has_grad(size_t i) {return m_input_has_grad[i];} + bool output_requires_grad(size_t i) {return m_output_attrs[i].requires_grad;} + bool output_captured(size_t i) {return m_output_attrs[i].captured;} + + class Maker { + bool output_size_set = false, input_has_grad_initialized = false; + CustomBackward& target; + ApplyContext& ctx; + + void init_input_has_grad() { + if (!input_has_grad_initialized) { + input_has_grad_initialized = true; + target.m_input_has_grad.resize(ctx.nargs, true); + } + } + + public: + Maker(CustomBackward& target_, ApplyContext& ctx_) : target(target_), ctx(ctx_) {} + + template + Maker& backward(F&& f) { + mgb_assert(!target.m_backward); + target.m_backward = std::forward(f); + return *this; + } + // mandatory + Maker& output_size(size_t sz) { + mgb_assert(!output_size_set); + output_size_set = true; + target.m_output_attrs.resize(sz); + return *this; + } + // optional, defaults to all true + Maker& input_has_grad(size_t i, bool v) { + init_input_has_grad(); + target.m_input_has_grad.at(i) = v; + return *this; + } + // optional, defaults to all true + Maker& output_requires_grad(size_t i, bool v) { + target.m_output_attrs.at(i).requires_grad = v; + return *this; + } + // optional, defaults to all true + Maker& output_captured(size_t i, bool v) { + target.m_output_attrs.at(i).captured = v; + return *this; + } + + void finalize() { + mgb_assert(output_size_set); + init_input_has_grad(); + } + }; + + Maker maker(ApplyContext& ctx) {return {*this, ctx};} +}; + +using GradRuleFn = std::function; + +std::unordered_map& grad_rule_registry(); + +inline bool input_requires_grad(const ApplyContext& ctx, size_t i) { + return bool(ctx.args[i]->m_grad_info.grad_fn); +} + +struct GradRuleFallback : std::exception {}; + +template +bool register_grad_rule(Typeinfo* typeinfo, T&& rule) { + return grad_rule_registry().emplace(typeinfo, std::forward(rule)).second; +} + } // namespace mgb::imperative::python namespace pybind11::detail { diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp new file mode 100644 index 0000000000000000000000000000000000000000..46fa7269a8c6844131323ba91748cb9d2ea5eb28 --- /dev/null +++ b/imperative/python/src/grad_override.cpp @@ -0,0 +1,63 @@ +/** + * \file imperative/python/src/grad_override.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./grad.h" +#include "megbrain/imperative/ops/autogen.h" + +namespace mgb::imperative::python { +namespace { + +std::shared_ptr get_shape(Tensor* x) { + static auto op = GetVarShape::make(); + return python::apply(op, x)[0]; +} + +std::shared_ptr reduce_to(Tensor* x, Tensor* s) { + static auto op = Reduce::make(); + return python::apply(op, x, s)[0]; +} + +apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { + auto& op = ctx.op->cast_final_safe(); + if (op.mode == Elemwise::Mode::ADD) { + mgb_assert(ctx.nargs == 2); + std::array, 2> input_shapes; + for (size_t i = 0; i < 2; ++i) { + if (input_requires_grad(ctx, i)) { + input_shapes[i] = get_shape(ctx.args[i]); + } + } + maker.output_size(1).output_captured(0, false); + maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) { + mgb_assert(ngrads == 1); + Tensor* grad = grads[0]; + apply_result_t ret(2); + for (size_t i = 0; i < 2; ++i) { + if (shapes[i]) { + ret[i] = reduce_to(grad, shapes[i].get()); + } + } + return ret; + }); + return apply(ctx); + } + throw GradRuleFallback(); +} + +struct Init { + Init() { + auto& reg = grad_rule_registry(); + reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule); + } +} _; + +} // namespace +} // namespace mgb::imperative::python diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index addb1891b1ea165fc03c7eec8e866214fff06156..f6060666b97f190988874fc01a59276e8d438fbe 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -199,12 +199,59 @@ using apply_result_t = SmallVector, 8>; apply_result_t apply(ApplyContext& ctx); -void init_tensor(pybind11::module); +template +decltype(auto) resolve_arrow(T&& p) { + if constexpr (std::is_pointer_v>) { + auto* ret = p; + return ret; + } else { + auto probe = [](auto&& p) -> decltype(p.operator->()) {}; + if constexpr (std::is_invocable_v) { + return resolve_arrow(p.operator->()); + } else { + return p; + } + } +} + +template +constexpr bool is_all_tensor_ptr = (... && std::is_same_v())), Tensor*>); -extern bool is_tracing; +extern bool is_tracing; // FIXME: should use ApplyContext::global_enable extern bool is_symbolic; extern bool is_compiled; +template , int> = 0> +apply_result_t apply(std::shared_ptr op, Args&&... args) { + ApplyContext ctx; + Tensor* arg_arr[] = {resolve_arrow(args)...}; + ctx.flags = (0 | ... | args->m_flags); + ctx.flags |= is_tracing ? Tensor::Flags::TRACE : 0; + ctx.args = arg_arr; + ctx.nargs = sizeof...(args); + ctx.op = std::move(op); + return apply(ctx); +} + +template +auto apply(std::shared_ptr op, T&& tensors) + -> std::enable_if_t, + apply_result_t> { + ApplyContext ctx; + ctx.op = std::move(op); + ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; + ctx.nargs = tensors.size(); + Tensor* args[ctx.nargs]; + ctx.args = args; + for (size_t i = 0; i < ctx.nargs; ++i) { + args[i] = resolve_arrow(tensors[i]); + ctx.flags |= args[i]->m_flags; + } + return apply(ctx); +} + +void init_tensor(pybind11::module); + extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode; extern pybind11::object cpp_apply_backward_varnode;