diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index 8a5d79e79cb8c91758c321105dcdb007a9c5bb42..17e15dba3bcea37e61f1ecf960a3bfcf7b85a7ec 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -14,7 +14,10 @@ #include "megbrain/imperative/ops/autogen.h" #include "megbrain/utils/mempool.h" +#include "range/v3/all.hpp" + namespace py = pybind11; +namespace views = ranges::views; namespace mgb::imperative::python { @@ -25,6 +28,152 @@ struct GradSlotWeakPtr { size_t idx; }; +struct BackwardGraphCache : std::unordered_map>, CompNodeDepedentObject { + std::shared_ptr on_comp_node_finalize() override { + clear(); + return {}; + } +} backward_graph_cache; + +std::shared_ptr make_backward_graph( + ApplyContext& ctx, const apply_result_t& outputs) { + // hash + static_assert(alignof(size_t) % alignof(bool) == 0); + size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool); + alignas(alignof(size_t)) std::byte buf[buf_size]; + size_t* size_t_ptr = reinterpret_cast(buf); + bool* bool_ptr = reinterpret_cast(size_t_ptr + (1 + ctx.nargs * 2)); + bool* bool_ptr0 = bool_ptr; + *(size_t_ptr++) = ctx.op->hash(); + for (size_t i = 0; i < ctx.nargs; ++i) { + *(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); + *(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); + *(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn); + } + mgb_assert(bool_ptr0 == reinterpret_cast(size_t_ptr) && + bool_ptr == reinterpret_cast(buf + buf_size)); + size_t key = XXHash{}.update(buf, buf_size).digest(); + + auto&& iter = backward_graph_cache.find(key); + if (iter != backward_graph_cache.end()) { + return iter->second; + } + + // slow path + SmallVector inputs(ctx.nargs); + SmallVector input_requires_grad(ctx.nargs, false); + SmallVector output_has_grad(outputs.size(), true); + for (size_t i = 0; i < ctx.nargs; ++i) { + inputs[i].comp_node = ctx.args[i]->comp_node(); + inputs[i].layout.dtype = ctx.args[i]->dtype(); + input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn); + } + auto result = std::make_shared( + proxy_graph_detail::make_backward_graph( + *ctx.op, inputs, input_requires_grad, output_has_grad)); + if (!result->backward) { + result.reset(); + } + backward_graph_cache.emplace(key, result); + return result; +} + +struct BackwardGraphWithClosure { + std::shared_ptr backward_graph; + SmallVector> closure; + size_t output_mask_offset; + size_t grad_mask_offset; + + BackwardGraphWithClosure(std::shared_ptr backward_graph_, + ApplyContext& ctx, const apply_result_t& outputs) + : backward_graph(backward_graph_), + output_mask_offset(ctx.nargs), + grad_mask_offset(ctx.nargs + outputs.size()) { + // save_for_backward[0:nargs]: + // whether input is kept for backward + // + // save_for_backward[nargs:nargs+outputs.size()]: + // whether output is kept for backward + // + // save_for_backward[-outputs.size():]: + // whether gradient of output can propagate to any input + // + // Example: + // perform c = a * b, with a.requires_grad == True and + // b.requires_grad == False, save_for_backward = [0, 1, 0, 1] + auto& save_for_backward = backward_graph->save_for_backward; + mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size()); + closure.reserve(std::count_if(save_for_backward.begin(), + save_for_backward.end(), + ranges::identity{})); + for (size_t i = 0; i < ctx.nargs; ++i) { + if (save_for_backward[i]) { + closure.push_back(ctx.args[i]->shared_from_this()); + } + } + for (size_t i = 0; i < outputs.size(); ++i) { + if (save_for_backward[ctx.nargs + i]) { + closure.push_back(outputs[i]); + } + } + } + + template + void operator()(T&& grads, R&& receiver) { + Tensor* args[closure.size() + grads.size()]; + size_t nargs = 0; + for (auto&& t : closure) { + args[nargs++] = t.get(); + } + bool null_grad = false; + for (size_t i = 0; i < grads.size(); ++i) { + if (backward_graph->save_for_backward[grad_mask_offset + i]) { + if (grads[i]) { + if (null_grad) { + PyErr_SetString(PyExc_NotImplementedError, "report to devs"); + throw py::error_already_set(); + } + args[nargs++] = grads[i]; + } else { + null_grad = true; + } + } + } + if (null_grad) return; + + ApplyContext ctx; + ctx.op = backward_graph->backward; + ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; + ctx.nargs = nargs; + ctx.args = args; + for (size_t i = 0; i < nargs; ++i) { + ctx.flags |= args[i]->m_flags; + mgb_assert(args[i]); + } + + auto igrads = apply(ctx); + auto&& it = igrads.begin(); + for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) { + if (p) { + receiver(i, std::move(*it)); + ++it; + } + } + } + + bool input_has_grad(size_t i) { + return backward_graph->input_has_grad[i]; + } + + bool output_requires_grad(size_t i) { + return backward_graph->save_for_backward[grad_mask_offset + i]; + } + + bool output_captured(size_t i) { + return backward_graph->save_for_backward[output_mask_offset + i]; + } +}; + } // namespace struct GradProducerRecord : intrusive_list::Node { @@ -54,10 +203,15 @@ struct GradFn : std::enable_shared_from_this { static MemPool pool; std::weak_ptr key; + // slots for receiving and accumulating grads + // same length as outputs (of forward op) SmallVector slots; + // where to send and accumulate grads + // same length as inputs (of forward op) SmallVector dsts; - SmallVector> closure; - std::shared_ptr backward_graph; + // encapsules actual function to compute gradient + std::variant backward; + // a flag used during backward bool in_ref_keeper = false; static void deleter(GradFn* ptr) { @@ -72,8 +226,7 @@ struct GradFn : std::enable_shared_from_this { key.reset(); slots.clear(); dsts.clear(); - closure.clear(); - backward_graph.reset(); + backward.emplace(); } }; @@ -83,54 +236,36 @@ GradSlot* GradSlotPtr::operator->() { namespace { -struct BackwardGraphCache : std::unordered_map>, CompNodeDepedentObject { - std::shared_ptr on_comp_node_finalize() override { - clear(); - return {}; - } -} backward_graph_cache; +class GradFnHelper { + std::shared_ptr grad_fn; -std::shared_ptr make_backward_graph( - ApplyContext& ctx, const apply_result_t& outputs) { - // hash - static_assert(alignof(size_t) % alignof(bool) == 0); - size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool); - alignas(alignof(size_t)) std::byte buf[buf_size]; - size_t* size_t_ptr = reinterpret_cast(buf); - bool* bool_ptr = reinterpret_cast(size_t_ptr + (1 + ctx.nargs * 2)); - bool* bool_ptr0 = bool_ptr; - *(size_t_ptr++) = ctx.op->hash(); - for (size_t i = 0; i < ctx.nargs; ++i) { - *(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); - *(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); - *(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn); + GradFn* get() { + if (!grad_fn) { + grad_fn = std::make_shared(); + } + return grad_fn.get(); } - mgb_assert(bool_ptr0 == reinterpret_cast(size_t_ptr) && - bool_ptr == reinterpret_cast(buf + buf_size)); - size_t key = XXHash{}.update(buf, buf_size).digest(); - auto&& iter = backward_graph_cache.find(key); - if (iter != backward_graph_cache.end()) { - return iter->second; - } + friend apply_result_t imperative::python::apply_grad(ApplyContext&); - // slow path - SmallVector inputs(ctx.nargs); - SmallVector input_requires_grad(ctx.nargs, false); - SmallVector output_has_grad(outputs.size(), true); - for (size_t i = 0; i < ctx.nargs; ++i) { - inputs[i].comp_node = ctx.args[i]->comp_node(); - inputs[i].layout.dtype = ctx.args[i]->dtype(); - input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn); +public: + template + auto& emplace(Args&&... args) { + return get()->backward.emplace(std::forward(args)...); } - auto result = std::make_shared( - proxy_graph_detail::make_backward_graph( - *ctx.op, inputs, input_requires_grad, output_has_grad)); - if (!result->backward) { - result.reset(); +}; + +apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { + auto outputs = apply(ctx); + + auto backward_graph = make_backward_graph(ctx, outputs); + if (!backward_graph) { + return outputs; } - backward_graph_cache.emplace(key, result); - return result; + + ret_grad_fn.emplace(std::move(backward_graph), ctx, outputs); + + return outputs; } } // namespace @@ -164,76 +299,53 @@ apply_result_t apply_grad(ApplyContext& ctx) { ctx.flags &= ~Tensor::Flags::GRAD; - // perform forward apply_op or trace - auto outputs = apply(ctx); - if (!grad_key) { - return outputs; + return apply(ctx); } - auto backward_graph = make_backward_graph(ctx, outputs); - if (!backward_graph) { + GradFnHelper grad_fn_holder; + auto outputs = backward_graph_grad_rule(ctx, grad_fn_holder); + + auto& grad_fn = grad_fn_holder.grad_fn; + if (!grad_fn) { return outputs; } - auto grad_fn = std::make_shared(); grad_fn->key = grad_key; grad_fn->slots.resize(outputs.size()); - grad_fn->backward_graph = std::move(backward_graph); - grad_fn->dsts.reserve(ctx.nargs); - for (size_t i = 0; i < ctx.nargs; ++i) { - if (grad_fn->backward_graph->input_has_grad[i]) { - auto& input_grad_info = ctx.args[i]->m_grad_info; - grad_fn->dsts.emplace_back(input_grad_info); - grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head); - } else { - grad_fn->dsts.emplace_back(); - } - } - auto& save_for_backward = grad_fn->backward_graph->save_for_backward; - grad_fn->closure.reserve(std::count_if(save_for_backward.begin(), save_for_backward.end(), [](bool p){return p;})); - - // given op, taking gradient of output_tensor_list wrt input_tensor_list: - // - // save_for_backward[0:nargs-1]: whether input tensor requires gradient, - // i.e., whether it is in input_tensor_list - // - // save_for_backward[nargs:nargs+outputs.size()-1]: whether output tensor is - // needed to calculate gradients - // - // save_for_backward[-outputs.size():]: whether output tensor is in - // output_tensor_list - // - // Example: perform c = a * b, where a is input data, b is parameter to be - // optimized, save_for_backward = [1, 1, 0, 1] - mgb_assert(ctx.nargs + 2 * outputs.size() == save_for_backward.size()); - - // record input tensors needed to take grad - for (size_t i = 0; i < ctx.nargs; ++i) { - if (save_for_backward[i]) { - grad_fn->closure.push_back(ctx.args[i]->shared_from_this()); - } - } - // record output tensors needed to take grad - for (size_t i = 0; i < outputs.size(); ++i) { - bool requires_grad = save_for_backward[ctx.nargs + outputs.size() + i]; - if (save_for_backward[ctx.nargs + i]) { - grad_fn->closure.push_back(outputs[i]); - if (requires_grad) { - // avoid reference cycle [Tensor <-> GradFn] - outputs[i] = outputs[i]->copy(); + std::visit([&](auto& backward) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + mgb_assert(0); + } else { + for (size_t i = 0; i < ctx.nargs; ++i) { + if (backward.input_has_grad(i)) { + auto& input_grad_info = ctx.args[i]->m_grad_info; + grad_fn->dsts.emplace_back(input_grad_info); + // register as grad producer + grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head); + } else { + grad_fn->dsts.emplace_back(); + } + } + for (size_t i = 0; i < outputs.size(); ++i) { + if (backward.output_requires_grad(i)) { + if (backward.output_captured(i)) { + // avoid reference cycle [Tensor <-> GradFn] + outputs[i] = outputs[i]->copy(); + } + // populate grad info of output tensor + auto& grad_info = outputs[i]->m_grad_info; + grad_info.grad_fn = grad_fn; + grad_info.idx = i; + grad_info.insert_after(grad_key->free_vars_head); + outputs[i]->m_flags |= Tensor::Flags::GRAD; + } } } - if (requires_grad) { - auto& grad_info = outputs[i]->m_grad_info; - grad_info.grad_fn = grad_fn; - grad_info.idx = i; - grad_info.insert_after(grad_key->free_vars_head); - outputs[i]->m_flags |= Tensor::Flags::GRAD; - } - } + }, grad_fn->backward); // record forward history grad_key->tape.emplace_back(grad_fn); @@ -334,54 +446,30 @@ void GradKey::backward(std::vector tensors, std::vector= 0; --k) { auto&& grad_fn = tape[k].lock(); if (!grad_fn) continue; - if (grad_fn->backward_graph) { - for (size_t i = 0; i < grad_fn->slots.size(); ++i) { - // grad_fn->dsts correspond to input tensors during forward - // calculation, grad_fn->slots correspond to output tensors. - // condition true means the output tensor has gradient for - // back-propagation - if (grad_fn->backward_graph->save_for_backward[grad_fn->dsts.size() + grad_fn->slots.size() + i]) { - grad_fn->closure.push_back(std::move(grad_fn->slots[i].grad)); - } - } - ApplyContext ctx; - ctx.op = grad_fn->backward_graph->backward; - ctx.flags = 0; - ctx.nargs = grad_fn->closure.size(); - Tensor* args[ctx.nargs]; - for (size_t i = 0; i < ctx.nargs; ++i) { - args[i] = grad_fn->closure[i].get(); - mgb_assert(args[i]); - ctx.flags |= args[i]->m_flags; - } - ctx.args = args; - - if (is_tracing) - ctx.flags |= Tensor::Flags::TRACE; - auto grads = apply(ctx); - - size_t j = 0; - for (size_t i = 0; i < grad_fn->dsts.size(); ++i) { - if (grad_fn->backward_graph->input_has_grad[i]) { - auto& dst = grad_fn->dsts[i]; - // grads[j] is consumed in accum_grad - accum_grad(dst->grad, std::move(grads[j])); - ++j; - } + auto grad_receiver = [&](size_t i, auto&& g) { + accum_grad(grad_fn->dsts[i]->grad, std::forward(g)); + }; + std::visit([&](auto&& backward) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + mgb_assert(0); + } else { + auto&& grads = views::transform(grad_fn->slots, [](auto&& slot) {return slot.grad.get();}); + backward(std::forward(grads), grad_receiver); } - mgb_assert(j == grads.size()); - } + }, grad_fn->backward); + for (auto&& dst : grad_fn->dsts) { if (!dst.grad_fn) continue; if (!dst.grad_fn->in_ref_keeper) { + // after grad_fn is cleared, refcnt of subsequent grad_fn + // could drop to 0 dst.grad_fn->in_ref_keeper = true; ref_keeper.push_back(dst.grad_fn); } - // grad_fn->clear will unlink current dst.producer_record - // such that if dst.producer_record.next is false, dst accumulates - // all the gradients if (!dst.producer_record.next && dst->callback && dst->grad) { + // I'm the last grad producer, invoke callback dst->callback(TensorWrapper::make(pytype, dst->grad)); } }