grad.cpp 23.7 KB
Newer Older
1 2 3 4
/**
 * \file imperative/python/src/grad.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12 13
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"

14
#include "./grad.h"
15
#include "megbrain/imperative/backward_graph_opt.h"
16
#include "megbrain/imperative/ops/autogen.h"
M
Megvii Engine Team 已提交
17
#include "megbrain/imperative/proxy_graph_detail.h"
18 19
#include "megbrain/utils/mempool.h"

20 21
#include "range/v3/all.hpp"

22
namespace py = pybind11;
23
namespace views = ranges::views;
24 25 26

namespace mgb::imperative::python {

27 28 29
using scoped_disable = ApplyContext::scoped_disable;
using Flags = Tensor::Flags;

30 31 32 33 34 35 36
namespace {

struct GradSlotWeakPtr {
    std::weak_ptr<GradFn> grad_fn;
    size_t idx;
};

37
std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
38 39
        ApplyContext& ctx, const apply_result_t& outputs) {
    // hash
M
Megvii Engine Team 已提交
40 41
    using OptimizedBackwardGraphCache = OpMethResultCache<
            std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>;
42 43 44 45 46 47
    thread_local OptimizedBackwardGraphCache cache;
    decltype(cache)::key_t cache_key{ctx.op};
    SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs;
    SmallVector<bool>& input_requires_grad = std::get<0>(cache_key.extras);
    input_descs.resize(ctx.nargs);
    input_requires_grad.resize(ctx.nargs);
48
    for (size_t i = 0; i < ctx.nargs; ++i) {
49 50 51
        input_descs[i].layout.dtype = ctx.args[i]->dtype();
        input_descs[i].comp_node = ctx.args[i]->comp_node();
        input_requires_grad[i] = python::input_requires_grad(ctx, i);
52 53
    }

54 55
    auto iter = cache.find(cache_key);
    if (iter != cache.end()) {
56 57 58 59 60
        return iter->second;
    }

    // slow path
    SmallVector<bool> output_has_grad(outputs.size(), true);
61
    std::shared_ptr<OptimizedBackwardGraphResult> ret;
62
    auto bg = OpDef::make_backward_graph(
63
            *ctx.op, input_descs, input_requires_grad, output_has_grad);
64
    if (!bg.graph.empty()) {
65
        ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
66
    }
67
    cache.emplace(cache_key, ret);
68
    return ret;
69 70 71
}

struct BackwardGraphWithClosure {
72
    std::shared_ptr<OptimizedBackwardGraphResult> backward_graph;
73 74 75 76
    SmallVector<std::shared_ptr<Tensor>> closure;
    size_t output_mask_offset;
    size_t grad_mask_offset;

M
Megvii Engine Team 已提交
77 78 79
    BackwardGraphWithClosure(
            std::shared_ptr<OptimizedBackwardGraphResult> backward_graph_,
            ApplyContext& ctx, const apply_result_t& outputs)
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
            : backward_graph(backward_graph_),
              output_mask_offset(ctx.nargs),
              grad_mask_offset(ctx.nargs + outputs.size()) {
        // save_for_backward[0:nargs]:
        //     whether input is kept for backward
        //
        // save_for_backward[nargs:nargs+outputs.size()]:
        //     whether output is kept for backward
        //
        // save_for_backward[-outputs.size():]:
        //     whether gradient of output can propagate to any input
        //
        // Example:
        //     perform c = a * b, with a.requires_grad == True and
        //     b.requires_grad == False, save_for_backward = [0, 1, 0, 1]
        auto& save_for_backward = backward_graph->save_for_backward;
        mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size());
M
Megvii Engine Team 已提交
97 98
        size_t count = std::count_if(
                save_for_backward.begin(), save_for_backward.end(), ranges::identity{});
99
        if (!backward_graph->precomp.empty()) {
100
            auto&& irng = ranges::span(ctx.args, ctx.nargs);
M
Megvii Engine Team 已提交
101
            auto&& orng = views::transform(outputs, [](auto&& i) { return i.get(); });
102 103 104 105 106 107
            auto precomp = apply(backward_graph->precomp, views::concat(irng, orng));
            closure.reserve(precomp.size() + count);
            std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure));
        } else {
            closure.reserve(count);
        }
108 109 110 111 112 113 114 115 116 117 118 119 120
        for (size_t i = 0; i < ctx.nargs; ++i) {
            if (save_for_backward[i]) {
                closure.push_back(ctx.args[i]->shared_from_this());
            }
        }
        for (size_t i = 0; i < outputs.size(); ++i) {
            if (save_for_backward[ctx.nargs + i]) {
                closure.push_back(outputs[i]);
            }
        }
    }

    template <typename T, typename R>
121
    void operator()(BackwardContext&, T&& grads, R&& receiver) {
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
        Tensor* args[closure.size() + grads.size()];
        size_t nargs = 0;
        for (auto&& t : closure) {
            args[nargs++] = t.get();
        }
        bool null_grad = false;
        for (size_t i = 0; i < grads.size(); ++i) {
            if (backward_graph->save_for_backward[grad_mask_offset + i]) {
                if (grads[i]) {
                    if (null_grad) {
                        PyErr_SetString(PyExc_NotImplementedError, "report to devs");
                        throw py::error_already_set();
                    }
                    args[nargs++] = grads[i];
                } else {
                    null_grad = true;
                }
            }
        }
M
Megvii Engine Team 已提交
141 142
        if (null_grad)
            return;
143

144
        auto igrads = apply(backward_graph->backward, args, nargs);
145 146 147 148 149 150 151 152 153
        auto&& it = igrads.begin();
        for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) {
            if (p) {
                receiver(i, std::move(*it));
                ++it;
            }
        }
    }

M
Megvii Engine Team 已提交
154
    bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; }
155 156 157 158 159 160 161 162 163 164

    bool output_requires_grad(size_t i) {
        return backward_graph->save_for_backward[grad_mask_offset + i];
    }

    bool output_captured(size_t i) {
        return backward_graph->save_for_backward[output_mask_offset + i];
    }
};

165 166 167 168
struct PythonBackward {
    py::object pyfunc;
    size_t input_size;

M
Megvii Engine Team 已提交
169
    PythonBackward(py::object f, size_t nin) : pyfunc(f), input_size(nin) {}
170 171 172 173 174 175 176 177

    template <typename T, typename R>
    void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
        auto args = py::tuple(grads.size());
        for (size_t i = 0; i < grads.size(); ++i) {
            auto&& g = grads[i];
            args[i] = g ? ctx.wrap_tensor(g) : py::none();
        }
M
Megvii Engine Team 已提交
178 179 180 181 182 183
        auto input_grads = py::reinterpret_steal<py::object>(
                PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr));
        if (!input_grads)
            throw py::error_already_set();
        if (input_grads.is_none())
            return;
184 185
        if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) {
            if (input_size != 1) {
M
Megvii Engine Team 已提交
186 187
                throw py::value_error(
                        "custom grad rule returned wrong number of grads");
188
            }
189 190 191
            if (!ctx.pytype) {
                ctx.pytype = Py_TYPE(input_grads.ptr());
            }
192 193 194 195 196 197 198
            receiver(0, tw->m_tensor);
            return;
        }
        if (py::len(input_grads) != input_size) {
            throw py::value_error("custom grad rule returned wrong number of grads");
        }
        for (auto [i, g] : views::enumerate(input_grads)) {
M
Megvii Engine Team 已提交
199 200
            if (g.is_none())
                continue;
201 202 203 204
            auto* tw = TensorWrapper::try_cast(g.ptr());
            if (!tw) {
                throw py::type_error("custom grad rule returned non-tensor");
            }
205 206 207
            if (!ctx.pytype) {
                ctx.pytype = Py_TYPE(g.ptr());
            }
208 209 210 211
            receiver(i, tw->m_tensor);
        }
    }

M
Megvii Engine Team 已提交
212 213 214
    static constexpr bool input_has_grad(size_t) { return true; }
    static constexpr bool output_requires_grad(size_t) { return true; }
    static constexpr bool output_captured(size_t) { return true; }
215 216
};

M
Megvii Engine Team 已提交
217
}  // namespace
218 219 220 221 222

struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> {
    using Base = intrusive_list::Node<GradProducerRecord>;

    GradProducerRecord() = default;
M
Megvii Engine Team 已提交
223 224
    GradProducerRecord(GradProducerRecord::head_t& head)
            : Base(intrusive_list::after_t{}, head) {}
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
    // GradProducerRecord(GradProducerRecord&&) = default;
    // GradProducerRecord& operator=(GradProducerRecord&) = default;
    // GradProducerRecord& operator=(GradProducerRecord&&) = default;
};

struct GradSlot {
    std::shared_ptr<Tensor> grad;
    py::object callback;
    GradProducerRecord::head_t producer_head;
};

struct GradSlotProducerPtr : GradSlotPtr {
    GradProducerRecord producer_record;

    GradSlotProducerPtr() = default;
M
Megvii Engine Team 已提交
240 241
    GradSlotProducerPtr(GradInfo& info)
            : GradSlotPtr(info), producer_record(info->producer_head) {}
242 243 244 245 246 247
};

struct GradFn : std::enable_shared_from_this<GradFn> {
    static MemPool<GradFn> pool;

    std::weak_ptr<GradKey> key;
248 249
    // slots for receiving and accumulating grads
    // same length as outputs (of forward op)
250
    SmallVector<GradSlot> slots;
251 252
    // where to send and accumulate grads
    // same length as inputs (of forward op)
253
    SmallVector<GradSlotProducerPtr> dsts;
254
    // encapsules actual function to compute gradient
M
Megvii Engine Team 已提交
255 256 257
    std::variant<
            std::monostate, BackwardGraphWithClosure, PythonBackward, CustomBackward>
            backward;
258
    // a flag used during backward
259 260
    bool in_ref_keeper = false;

M
Megvii Engine Team 已提交
261
    static void deleter(GradFn* ptr) { pool.free(ptr); }
262

263
    static std::shared_ptr<GradFn> make() {
264 265 266 267 268 269 270
        return std::shared_ptr<GradFn>(pool.alloc(), &deleter);
    }

    void clear() {
        key.reset();
        slots.clear();
        dsts.clear();
271
        backward.emplace<std::monostate>();
272 273 274
    }
};

275 276 277 278
GradSlotPtr::operator bool() const {
    return bool(grad_fn);
}

279 280 281 282 283 284
GradSlot* GradSlotPtr::operator->() {
    return &grad_fn->slots[idx];
}

namespace {

285 286
class GradFnHelper {
    std::shared_ptr<GradFn> grad_fn;
287

288 289 290 291 292
    GradFn* get() {
        if (!grad_fn) {
            grad_fn = std::make_shared<GradFn>();
        }
        return grad_fn.get();
293 294
    }

295
    friend apply_result_t imperative::python::apply_grad(ApplyContext&);
296

297
public:
M
Megvii Engine Team 已提交
298
    template <typename T, typename... Args>
299 300
    auto& emplace(Args&&... args) {
        return get()->backward.emplace<T>(std::forward<Args>(args)...);
301
    }
302 303

    void reset() { grad_fn = nullptr; }
304 305 306
};

apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
307
    // copy inputs first, or trace will make InputNodes for each usage
308
    ApplyContext ctx_dup = ctx;
309 310 311
    SmallVector<std::shared_ptr<Tensor>> inputs_copy;
    SmallVector<Tensor*> inputs_copy_weak;
    for (size_t i = 0; i < ctx.nargs; ++i) {
312 313
        Tensor* input = ctx.args[i];
        inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]);
314
        inputs_copy_weak.push_back(inputs_copy.back().get());
315
        inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict;
316 317 318
        if (input->m_flags & Flags::GRAD) {
            inputs_copy.back()->m_flags |= Flags::GRAD;
        }
319 320 321 322
    }
    ctx_dup.args = inputs_copy_weak.data();

    auto outputs = apply(ctx_dup);
323

324
    auto backward_graph = make_backward_graph(ctx_dup, outputs);
325 326
    if (!backward_graph) {
        return outputs;
327
    }
M
Megvii Engine Team 已提交
328 329
    ret_grad_fn.emplace<BackwardGraphWithClosure>(
            std::move(backward_graph), ctx_dup, outputs);
330 331

    return outputs;
332 333
}

334 335 336 337 338 339 340
apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
    auto* op = ctx.op->try_cast_final<GenericPyOp>();
    py::tuple pyin(ctx.nargs);
    for (size_t i = 0; i < ctx.nargs; ++i) {
        pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
    }
    auto grad_rule = py::getattr(op->obj, "_grad_rule");
M
Megvii Engine Team 已提交
341 342 343 344
    auto pyret = py::reinterpret_steal<py::object>(
            PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr));
    if (!pyret)
        throw py::error_already_set();
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
    auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret);
    ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs);
    if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) {
        return {tw->m_tensor};
    }
    apply_result_t ret;
    ret.reserve(py::len(outputs));
    for (auto&& i : outputs) {
        auto* tw = TensorWrapper::try_cast(i.ptr());
        mgb_assert(tw);
        ret.push_back(tw->m_tensor);
    }
    return ret;
}

M
Megvii Engine Team 已提交
360
}  // namespace
361 362

apply_result_t apply_grad(ApplyContext& ctx) {
363
    std::unordered_set<std::shared_ptr<GradKey>> grad_keys;
364 365
    for (size_t i = 0; i < ctx.nargs; ++i) {
        auto* tensor = ctx.args[i];
366 367
        if (!tensor->m_grad_info_dict.empty()) {
            size_t grad_cnt = 0;
M
Megvii Engine Team 已提交
368
            for (auto&& grad_info : tensor->m_grad_info_dict) {
369
                auto input_grad_key = grad_info.grad_fn->key.lock();
M
Megvii Engine Team 已提交
370 371
                if (input_grad_key && input_grad_key->active &&
                    !input_grad_key->is_blocked()) {
372 373
                    grad_keys.insert(input_grad_key);
                    grad_cnt++;
374
                }
375 376
            }
            if (!grad_cnt) {
377
                tensor->m_flags &= ~Flags::GRAD;
378 379
            }
        } else {
380
            tensor->m_flags &= ~Flags::GRAD;
381 382 383
        }
    }

384
    ctx.flags &= ~Flags::GRAD;
385

386
    if (grad_keys.empty()) {
387
        return apply(ctx);
388 389 390 391 392 393
    } else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) {
        PyErr_SetString(
                PyExc_NotImplementedError,
                "second order directive not enabled, please call "
                "'megengine.experimental.enable_higher_order_directive'");
        throw pyext17::py_err_set();
394 395
    }

396
    GradFnHelper grad_fn_holder;
397 398 399 400 401 402 403 404 405
    auto outputs = [&]() {
        auto _ = scoped_disable(Flags::GRAD);
        if (ctx.op->same_type<GenericPyOp>()) {
            return python_grad_rule(ctx, grad_fn_holder);
        }
        auto&& registry = grad_rule_registry();
        auto&& it = registry.find(ctx.op->dyn_typeinfo());
        if (it != registry.end()) {
            auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx);
406
            if (auto ret = it->second(ctx, maker)) {
407
                maker.finalize();
408
                return *ret;
409
            }
410
            grad_fn_holder.reset();
411 412 413
        }
        return backward_graph_grad_rule(ctx, grad_fn_holder);
    }();
414

415
    if (!grad_fn_holder.grad_fn) {
416 417 418
        return outputs;
    }

M
Megvii Engine Team 已提交
419
    for (auto&& grad_key : grad_keys) {
420 421 422 423 424
        auto grad_fn = std::make_shared<GradFn>();
        grad_fn->backward = grad_fn_holder.grad_fn->backward;
        grad_fn->key = grad_key;
        grad_fn->slots.resize(outputs.size());
        grad_fn->dsts.reserve(ctx.nargs);
425

M
Megvii Engine Team 已提交
426 427 428 429 430
        std::visit(
                [&](auto& backward) {
                    using T = std::decay_t<decltype(backward)>;
                    if constexpr (std::is_same_v<T, std::monostate>) {
                        mgb_assert(0);
431
                    } else {
M
Megvii Engine Team 已提交
432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
                        for (size_t i = 0; i < ctx.nargs; ++i) {
                            if (backward.input_has_grad(i) &&
                                input_requires_grad(ctx, i) &&
                                ctx.args[i]->m_grad_info_dict.count(grad_key.get())) {
                                auto& input_grad_info =
                                        ctx.args[i]->m_grad_info_dict.at(
                                                grad_key.get());
                                grad_fn->dsts.emplace_back(input_grad_info);
                                // register as grad producer
                                grad_fn->dsts.back().producer_record.insert_after(
                                        input_grad_info->producer_head);
                            } else {
                                grad_fn->dsts.emplace_back();
                            }
                        }
                        for (size_t i = 0; i < outputs.size(); ++i) {
                            if (backward.output_requires_grad(i)) {
                                if (backward.output_captured(i)) {
                                    // avoid reference cycle [Tensor <-> GradFn]
                                    static std::shared_ptr<OpDef> op =
                                            std::make_shared<FastpathCopy>();
                                    outputs[i] = python::apply(op, outputs[i])[0];
                                }
                                // populate grad info of output tensor
                                auto& grad_info =
                                        outputs[i]->m_grad_info_dict[grad_key.get()];
                                grad_info.grad_fn = grad_fn;
                                grad_info.idx = i;
                                grad_info.insert_after(grad_key->free_vars_head);
                                outputs[i]->m_flags |= Flags::GRAD;
                            }
463
                        }
464
                    }
M
Megvii Engine Team 已提交
465 466
                },
                grad_fn->backward);
467

468 469 470
        // record forward history
        grad_key->tape.emplace_back(grad_fn);
    }
471 472 473 474

    return outputs;
}

475 476 477 478 479
PyObject* GradKeyWrapper::get_priority() {
    return py::cast(m_key->priority).release().ptr();
}

void GradKeyWrapper::set_priority(pybind11::handle priority) {
480
    m_key->priority = py::cast<int>(priority);
481 482
}

M
Megvii Engine Team 已提交
483
void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) {
484 485 486
    if (nargs != 2) {
        throw py::type_error("expect 2 arguments");
    }
487
    auto* tw = TensorWrapper::try_cast(args[0]);
488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
    if (!tw) {
        throw py::type_error("argument 1 must be Tensor");
    }
    auto* tensor = tw->m_tensor.get();
    py::object callback;
    if (args[1] != Py_None) {
        callback = py::reinterpret_borrow<py::object>(args[1]);
    }
    m_key->attach(tensor, std::move(callback));
}

//!  GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach
void GradKey::attach(Tensor* tensor, pybind11::object callback) {
    if (!active) {
        throw py::value_error("grad key finalized");
    }

505 506
    if (tensor->m_grad_info_dict.count(this)) {
        if (tensor->m_grad_info_dict.at(this)->callback) {
507 508 509
            throw py::value_error("callback already set on this tensor");
        }
    } else {
510 511 512
        auto& grad_info = tensor->m_grad_info_dict[this];
        grad_info.idx = 0;
        auto& grad_fn = grad_info.grad_fn;
513 514 515
        grad_fn = std::make_shared<GradFn>();
        grad_fn->key = shared_from_this();
        grad_fn->slots.resize(1);
516
        grad_info.insert_after(free_vars_head);
517
        tensor->m_flags |= Flags::GRAD;
518
    }
519
    tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback);
520 521
}

M
Megvii Engine Team 已提交
522
template <typename T>
523
void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) {
524
    if (!grad) {
525
        grad = std::forward<T>(delta);
526 527
        return;
    }
M
Megvii Engine Team 已提交
528 529
    static std::shared_ptr<OpDef> op =
            std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD));
530
    grad = apply(op, grad, std::forward<T>(delta))[0];
531 532
}

M
Megvii Engine Team 已提交
533 534
void GradKey::backward(
        std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
535 536 537 538 539 540 541 542 543 544 545
    if (!active) {
        throw py::value_error("finalized");
    }
    if (tensors.size() != grads.size()) {
        throw py::value_error("tensor and grad size mismatch");
    }

    // this GradKey is marked inactive here
    active = false;
    struct CleanupGuard {
        GradKey* owner;
546 547 548
        size_t priority_backup;
        CleanupGuard(GradKey* this_) : owner(this_) {
            priority_backup = sm_min_priority;
549
            sm_min_priority = owner->priority + 1;
550 551 552 553 554
        }
        ~CleanupGuard() {
            owner->cleanup();
            sm_min_priority = priority_backup;
        }
555 556
    } _cleanup_guard(this);

M
Megvii Engine Team 已提交
557 558
    if (tape.empty())
        return;
559 560 561 562 563

    BackwardContext bctx;
    if (!grads.empty()) {
        bctx.pytype = Py_TYPE(grads[0]->self().ptr());
    }
564 565

    for (size_t i = 0; i < tensors.size(); ++i) {
566 567
        if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) {
            continue;
568
        }
569 570
        auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this);
        grad_info->grad = grads[i]->m_tensor;
571 572 573 574
    }

    std::vector<std::shared_ptr<GradFn>> ref_keeper;
    ref_keeper.reserve(tape.size());
575

576 577 578
    // back-propagation in reverse order
    for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
        auto&& grad_fn = tape[k].lock();
M
Megvii Engine Team 已提交
579 580
        if (!grad_fn)
            continue;
581

582
        auto grad_receiver = [&](size_t i, auto&& g) {
583 584 585 586
            auto& dst = grad_fn->dsts[i];
            if (dst) {
                accum_grad(dst->grad, std::forward<decltype(g)>(g));
            }
587
        };
M
Megvii Engine Team 已提交
588 589 590 591 592 593 594 595 596 597 598 599 600 601 602
        std::visit(
                [&](auto&& backward) {
                    using T = std::decay_t<decltype(backward)>;
                    if constexpr (std::is_same_v<T, std::monostate>) {
                        mgb_assert(0);
                    } else {
                        auto&& grads = views::transform(
                                grad_fn->slots,
                                [](auto&& slot) { return slot.grad.get(); });
                        backward(
                                bctx, std::forward<decltype(grads)>(grads),
                                grad_receiver);
                    }
                },
                grad_fn->backward);
603

604
        for (auto&& dst : grad_fn->dsts) {
M
Megvii Engine Team 已提交
605 606
            if (!dst.grad_fn)
                continue;
607
            if (!dst.grad_fn->in_ref_keeper) {
608 609
                // after grad_fn is cleared, refcnt of subsequent grad_fn
                // could drop to 0
610 611 612 613
                dst.grad_fn->in_ref_keeper = true;
                ref_keeper.push_back(dst.grad_fn);
            }
            if (!dst.producer_record.next && dst->callback && dst->grad) {
614
                // I'm the last grad producer, invoke callback
615
                dst->callback(bctx.wrap_tensor(dst->grad));
616 617 618
            }
        }
        grad_fn->clear();
M
Megvii Engine Team 已提交
619
    }  // finish tape loop
620 621 622 623 624 625 626 627 628 629 630
}

void GradKey::cleanup() {
    active = false;
    tape.clear();
    for (intrusive_list::Iterator it(free_vars_head); it;) {
        it->grad_fn.reset();
        (it++)->unlink();
    }
}

M
Megvii Engine Team 已提交
631 632
void GradKeyWrapper::backward(
        std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
633 634 635
    m_key->backward(std::move(tensors), std::move(grads));
}

636 637 638 639 640 641 642 643
PyObject* GradKeyWrapper::get_name() {
    return py::cast(m_key->name).release().ptr();
}

void GradKeyWrapper::set_name(py::handle name) {
    m_key->name = py::cast<std::string>(name);
}

M
Megvii Engine Team 已提交
644
PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) {
645 646 647 648 649 650 651 652 653
    if (nargs != 1) {
        PyErr_SetString(PyExc_TypeError, "expect 1 argument");
        return nullptr;
    }
    auto* tw = TensorWrapper::try_cast(args[0]);
    if (!tw) {
        PyErr_SetString(PyExc_TypeError, "expect Tensor");
        return nullptr;
    }
654
    if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) {
655 656 657 658 659
        Py_RETURN_TRUE;
    }
    Py_RETURN_FALSE;
}

660
int GradKey::sm_min_priority = std::numeric_limits<int>::min();
661

662 663 664 665
GradKey::~GradKey() {
    cleanup();
}

666 667 668 669 670
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() {
    static std::unordered_map<Typeinfo*, GradRuleFn> registry;
    return registry;
}

671
void GradInfoCollection::_shrink() {
M
Megvii Engine Team 已提交
672 673 674
    auto pred = [](GradInfo& info) {
        return !(info.grad_fn) || info.grad_fn->key.expired();
    };
675 676 677 678 679 680
    auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred);
    m_storage.erase(iter, m_storage.end());
}

bool GradInfoCollection::contains(GradKey* key) {
    _shrink();
M
Megvii Engine Team 已提交
681
    for (auto&& grad_info : m_storage) {
682 683 684 685 686 687 688 689 690
        if (grad_info.grad_fn->key.lock().get() == key) {
            return true;
        }
    }
    return false;
}

GradInfo& GradInfoCollection::operator[](GradKey* key) {
    _shrink();
M
Megvii Engine Team 已提交
691
    for (auto&& grad_info : m_storage) {
692 693 694 695 696 697 698 699 700 701
        if (grad_info.grad_fn->key.lock().get() == key) {
            return grad_info;
        }
    }
    m_storage.emplace_back();
    return m_storage.back();
}

GradInfo& GradInfoCollection::at(GradKey* key) {
    _shrink();
M
Megvii Engine Team 已提交
702
    for (auto&& grad_info : m_storage) {
703 704 705 706 707 708 709
        if (grad_info.grad_fn->key.lock().get() == key) {
            return grad_info;
        }
    }
    mgb_assert(false);
}

M
Megvii Engine Team 已提交
710
}  // namespace mgb::imperative::python