grad.cpp 22.8 KB
Newer Older
1 2 3 4
/**
 * \file imperative/python/src/grad.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12 13
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"

14 15
#include "./grad.h"
#include "megbrain/imperative/proxy_graph_detail.h"
16
#include "megbrain/imperative/backward_graph_opt.h"
17 18 19
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/utils/mempool.h"

20 21
#include "range/v3/all.hpp"

22
namespace py = pybind11;
23
namespace views = ranges::views;
24 25 26

namespace mgb::imperative::python {

27 28 29
using scoped_disable = ApplyContext::scoped_disable;
using Flags = Tensor::Flags;

30 31 32 33 34 35 36
namespace {

struct GradSlotWeakPtr {
    std::weak_ptr<GradFn> grad_fn;
    size_t idx;
};

37
std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
38 39
        ApplyContext& ctx, const apply_result_t& outputs) {
    // hash
40 41 42 43 44 45 46
    using OptimizedBackwardGraphCache = OpMethResultCache<std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>;
    thread_local OptimizedBackwardGraphCache cache;
    decltype(cache)::key_t cache_key{ctx.op};
    SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs;
    SmallVector<bool>& input_requires_grad = std::get<0>(cache_key.extras);
    input_descs.resize(ctx.nargs);
    input_requires_grad.resize(ctx.nargs);
47
    for (size_t i = 0; i < ctx.nargs; ++i) {
48 49 50
        input_descs[i].layout.dtype = ctx.args[i]->dtype();
        input_descs[i].comp_node = ctx.args[i]->comp_node();
        input_requires_grad[i] = python::input_requires_grad(ctx, i);
51 52
    }

53 54
    auto iter = cache.find(cache_key);
    if (iter != cache.end()) {
55 56 57 58 59
        return iter->second;
    }

    // slow path
    SmallVector<bool> output_has_grad(outputs.size(), true);
60
    std::shared_ptr<OptimizedBackwardGraphResult> ret;
61
    auto bg = OpDef::make_backward_graph(
62
            *ctx.op, input_descs, input_requires_grad, output_has_grad);
63
    if (!bg.graph.empty()) {
64
        ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
65
    }
66
    cache.emplace(cache_key, ret);
67
    return ret;
68 69 70
}

struct BackwardGraphWithClosure {
71
    std::shared_ptr<OptimizedBackwardGraphResult> backward_graph;
72 73 74 75
    SmallVector<std::shared_ptr<Tensor>> closure;
    size_t output_mask_offset;
    size_t grad_mask_offset;

76
    BackwardGraphWithClosure(std::shared_ptr<OptimizedBackwardGraphResult> backward_graph_,
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
                             ApplyContext& ctx, const apply_result_t& outputs)
            : backward_graph(backward_graph_),
              output_mask_offset(ctx.nargs),
              grad_mask_offset(ctx.nargs + outputs.size()) {
        // save_for_backward[0:nargs]:
        //     whether input is kept for backward
        //
        // save_for_backward[nargs:nargs+outputs.size()]:
        //     whether output is kept for backward
        //
        // save_for_backward[-outputs.size():]:
        //     whether gradient of output can propagate to any input
        //
        // Example:
        //     perform c = a * b, with a.requires_grad == True and
        //     b.requires_grad == False, save_for_backward = [0, 1, 0, 1]
        auto& save_for_backward = backward_graph->save_for_backward;
        mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size());
95 96 97
        size_t count = std::count_if(save_for_backward.begin(),
                                     save_for_backward.end(),
                                     ranges::identity{});
98
        if (!backward_graph->precomp.empty()) {
99 100 101 102 103 104 105 106
            auto&& irng = ranges::span(ctx.args, ctx.nargs);
            auto&& orng = views::transform(outputs, [](auto&& i){return i.get();});
            auto precomp = apply(backward_graph->precomp, views::concat(irng, orng));
            closure.reserve(precomp.size() + count);
            std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure));
        } else {
            closure.reserve(count);
        }
107 108 109 110 111 112 113 114 115 116 117 118 119
        for (size_t i = 0; i < ctx.nargs; ++i) {
            if (save_for_backward[i]) {
                closure.push_back(ctx.args[i]->shared_from_this());
            }
        }
        for (size_t i = 0; i < outputs.size(); ++i) {
            if (save_for_backward[ctx.nargs + i]) {
                closure.push_back(outputs[i]);
            }
        }
    }

    template <typename T, typename R>
120
    void operator()(BackwardContext&, T&& grads, R&& receiver) {
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
        Tensor* args[closure.size() + grads.size()];
        size_t nargs = 0;
        for (auto&& t : closure) {
            args[nargs++] = t.get();
        }
        bool null_grad = false;
        for (size_t i = 0; i < grads.size(); ++i) {
            if (backward_graph->save_for_backward[grad_mask_offset + i]) {
                if (grads[i]) {
                    if (null_grad) {
                        PyErr_SetString(PyExc_NotImplementedError, "report to devs");
                        throw py::error_already_set();
                    }
                    args[nargs++] = grads[i];
                } else {
                    null_grad = true;
                }
            }
        }
        if (null_grad) return;

142
        auto igrads = apply(backward_graph->backward, args, nargs);
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
        auto&& it = igrads.begin();
        for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) {
            if (p) {
                receiver(i, std::move(*it));
                ++it;
            }
        }
    }

    bool input_has_grad(size_t i) {
        return backward_graph->input_has_grad[i];
    }

    bool output_requires_grad(size_t i) {
        return backward_graph->save_for_backward[grad_mask_offset + i];
    }

    bool output_captured(size_t i) {
        return backward_graph->save_for_backward[output_mask_offset + i];
    }
};

165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
struct PythonBackward {
    py::object pyfunc;
    size_t input_size;

    PythonBackward(py::object f, size_t nin)
            : pyfunc(f), input_size(nin) {}

    template <typename T, typename R>
    void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
        auto args = py::tuple(grads.size());
        for (size_t i = 0; i < grads.size(); ++i) {
            auto&& g = grads[i];
            args[i] = g ? ctx.wrap_tensor(g) : py::none();
        }
        auto input_grads = py::reinterpret_steal<py::object>(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr));
180
        if (!input_grads) throw py::error_already_set();
181 182 183 184 185
        if (input_grads.is_none()) return;
        if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) {
            if (input_size != 1) {
                throw py::value_error("custom grad rule returned wrong number of grads");
            }
186 187 188
            if (!ctx.pytype) {
                ctx.pytype = Py_TYPE(input_grads.ptr());
            }
189 190 191 192 193 194 195 196 197 198 199 200
            receiver(0, tw->m_tensor);
            return;
        }
        if (py::len(input_grads) != input_size) {
            throw py::value_error("custom grad rule returned wrong number of grads");
        }
        for (auto [i, g] : views::enumerate(input_grads)) {
            if (g.is_none()) continue;
            auto* tw = TensorWrapper::try_cast(g.ptr());
            if (!tw) {
                throw py::type_error("custom grad rule returned non-tensor");
            }
201 202 203
            if (!ctx.pytype) {
                ctx.pytype = Py_TYPE(g.ptr());
            }
204 205 206 207 208 209 210 211 212
            receiver(i, tw->m_tensor);
        }
    }

    static constexpr bool input_has_grad(size_t) {return true;}
    static constexpr bool output_requires_grad(size_t) {return true;}
    static constexpr bool output_captured(size_t) {return true;}
};

213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
} // namespace

struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> {
    using Base = intrusive_list::Node<GradProducerRecord>;

    GradProducerRecord() = default;
    GradProducerRecord(GradProducerRecord::head_t& head) : Base(intrusive_list::after_t{}, head) {}
    // GradProducerRecord(GradProducerRecord&&) = default;
    // GradProducerRecord& operator=(GradProducerRecord&) = default;
    // GradProducerRecord& operator=(GradProducerRecord&&) = default;
};

struct GradSlot {
    std::shared_ptr<Tensor> grad;
    py::object callback;
    GradProducerRecord::head_t producer_head;
};

struct GradSlotProducerPtr : GradSlotPtr {
    GradProducerRecord producer_record;

    GradSlotProducerPtr() = default;
    GradSlotProducerPtr(GradInfo& info) : GradSlotPtr(info), producer_record(info->producer_head) {}
};

struct GradFn : std::enable_shared_from_this<GradFn> {
    static MemPool<GradFn> pool;

    std::weak_ptr<GradKey> key;
242 243
    // slots for receiving and accumulating grads
    // same length as outputs (of forward op)
244
    SmallVector<GradSlot> slots;
245 246
    // where to send and accumulate grads
    // same length as inputs (of forward op)
247
    SmallVector<GradSlotProducerPtr> dsts;
248
    // encapsules actual function to compute gradient
249
    std::variant<std::monostate, BackwardGraphWithClosure, PythonBackward, CustomBackward> backward;
250
    // a flag used during backward
251 252 253 254 255 256
    bool in_ref_keeper = false;

    static void deleter(GradFn* ptr) {
        pool.free(ptr);
    }

257
    static std::shared_ptr<GradFn> make() {
258 259 260 261 262 263 264
        return std::shared_ptr<GradFn>(pool.alloc(), &deleter);
    }

    void clear() {
        key.reset();
        slots.clear();
        dsts.clear();
265
        backward.emplace<std::monostate>();
266 267 268
    }
};

269 270 271 272
GradSlotPtr::operator bool() const {
    return bool(grad_fn);
}

273 274 275 276 277 278
GradSlot* GradSlotPtr::operator->() {
    return &grad_fn->slots[idx];
}

namespace {

279 280
class GradFnHelper {
    std::shared_ptr<GradFn> grad_fn;
281

282 283 284 285 286
    GradFn* get() {
        if (!grad_fn) {
            grad_fn = std::make_shared<GradFn>();
        }
        return grad_fn.get();
287 288
    }

289
    friend apply_result_t imperative::python::apply_grad(ApplyContext&);
290

291 292 293 294
public:
    template<typename T, typename... Args>
    auto& emplace(Args&&... args) {
        return get()->backward.emplace<T>(std::forward<Args>(args)...);
295
    }
296 297

    void reset() { grad_fn = nullptr; }
298 299 300
};

apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
301
    // copy inputs first, or trace will make InputNodes for each usage
302
    ApplyContext ctx_dup = ctx;
303 304 305
    SmallVector<std::shared_ptr<Tensor>> inputs_copy;
    SmallVector<Tensor*> inputs_copy_weak;
    for (size_t i = 0; i < ctx.nargs; ++i) {
306 307
        Tensor* input = ctx.args[i];
        inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]);
308
        inputs_copy_weak.push_back(inputs_copy.back().get());
309
        inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict;
310 311 312
        if (input->m_flags & Flags::GRAD) {
            inputs_copy.back()->m_flags |= Flags::GRAD;
        }
313 314 315 316
    }
    ctx_dup.args = inputs_copy_weak.data();

    auto outputs = apply(ctx_dup);
317

318
    auto backward_graph = make_backward_graph(ctx_dup, outputs);
319 320
    if (!backward_graph) {
        return outputs;
321
    }
322
    ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx_dup, outputs);
323 324

    return outputs;
325 326
}

327 328 329 330 331 332 333
apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
    auto* op = ctx.op->try_cast_final<GenericPyOp>();
    py::tuple pyin(ctx.nargs);
    for (size_t i = 0; i < ctx.nargs; ++i) {
        pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
    }
    auto grad_rule = py::getattr(op->obj, "_grad_rule");
334
    auto pyret = py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr));
335
    if (!pyret) throw py::error_already_set();
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
    auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret);
    ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs);
    if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) {
        return {tw->m_tensor};
    }
    apply_result_t ret;
    ret.reserve(py::len(outputs));
    for (auto&& i : outputs) {
        auto* tw = TensorWrapper::try_cast(i.ptr());
        mgb_assert(tw);
        ret.push_back(tw->m_tensor);
    }
    return ret;
}

351 352 353
} // namespace

apply_result_t apply_grad(ApplyContext& ctx) {
354
    std::unordered_set<std::shared_ptr<GradKey>> grad_keys;
355 356
    for (size_t i = 0; i < ctx.nargs; ++i) {
        auto* tensor = ctx.args[i];
357 358 359 360 361 362 363
        if (!tensor->m_grad_info_dict.empty()) {
            size_t grad_cnt = 0;
            for (auto&& grad_info: tensor->m_grad_info_dict) {
                auto input_grad_key = grad_info.grad_fn->key.lock();
                if (input_grad_key && input_grad_key->active && !input_grad_key->is_blocked()) {
                    grad_keys.insert(input_grad_key);
                    grad_cnt++;
364
                }
365 366
            }
            if (!grad_cnt) {
367
                tensor->m_flags &= ~Flags::GRAD;
368 369
            }
        } else {
370
            tensor->m_flags &= ~Flags::GRAD;
371 372 373
        }
    }

374
    ctx.flags &= ~Flags::GRAD;
375

376
    if (grad_keys.empty()) {
377
        return apply(ctx);
378 379 380 381 382 383
    } else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) {
        PyErr_SetString(
                PyExc_NotImplementedError,
                "second order directive not enabled, please call "
                "'megengine.experimental.enable_higher_order_directive'");
        throw pyext17::py_err_set();
384 385
    }

386
    GradFnHelper grad_fn_holder;
387 388 389 390 391 392 393 394 395
    auto outputs = [&]() {
        auto _ = scoped_disable(Flags::GRAD);
        if (ctx.op->same_type<GenericPyOp>()) {
            return python_grad_rule(ctx, grad_fn_holder);
        }
        auto&& registry = grad_rule_registry();
        auto&& it = registry.find(ctx.op->dyn_typeinfo());
        if (it != registry.end()) {
            auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx);
396
            if (auto ret = it->second(ctx, maker)) {
397
                maker.finalize();
398
                return *ret;
399
            }
400
            grad_fn_holder.reset();
401 402 403
        }
        return backward_graph_grad_rule(ctx, grad_fn_holder);
    }();
404

405
    if (!grad_fn_holder.grad_fn) {
406 407 408
        return outputs;
    }

409 410 411 412 413 414
    for (auto&& grad_key: grad_keys) {
        auto grad_fn = std::make_shared<GradFn>();
        grad_fn->backward = grad_fn_holder.grad_fn->backward;
        grad_fn->key = grad_key;
        grad_fn->slots.resize(outputs.size());
        grad_fn->dsts.reserve(ctx.nargs);
415

416 417 418 419 420 421 422 423 424 425 426 427 428 429
        std::visit([&](auto& backward) {
            using T = std::decay_t<decltype(backward)>;
            if constexpr (std::is_same_v<T, std::monostate>) {
                mgb_assert(0);
            } else {
                for (size_t i = 0; i < ctx.nargs; ++i) {
                    if (backward.input_has_grad(i) && input_requires_grad(ctx, i) && ctx.args[i]->m_grad_info_dict.count(grad_key.get())) {
                        auto& input_grad_info = ctx.args[i]->m_grad_info_dict.at(grad_key.get());
                        grad_fn->dsts.emplace_back(input_grad_info);
                        // register as grad producer
                        grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head);
                    } else {
                        grad_fn->dsts.emplace_back();
                    }
430
                }
431 432 433 434 435 436 437 438 439 440 441 442 443
                for (size_t i = 0; i < outputs.size(); ++i) {
                    if (backward.output_requires_grad(i)) {
                        if (backward.output_captured(i)) {
                            // avoid reference cycle [Tensor <-> GradFn]
                            static std::shared_ptr<OpDef> op = std::make_shared<FastpathCopy>();
                            outputs[i] = python::apply(op, outputs[i])[0];
                        }
                        // populate grad info of output tensor
                        auto& grad_info = outputs[i]->m_grad_info_dict[grad_key.get()];
                        grad_info.grad_fn = grad_fn;
                        grad_info.idx = i;
                        grad_info.insert_after(grad_key->free_vars_head);
                        outputs[i]->m_flags |= Flags::GRAD;
444 445
                    }
                }
446
            }
447
        }, grad_fn->backward);
448

449 450 451
        // record forward history
        grad_key->tape.emplace_back(grad_fn);
    }
452 453 454 455

    return outputs;
}

456 457 458 459 460
PyObject* GradKeyWrapper::get_priority() {
    return py::cast(m_key->priority).release().ptr();
}

void GradKeyWrapper::set_priority(pybind11::handle priority) {
461
    m_key->priority = py::cast<int>(priority);
462 463
}

464 465 466 467
void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) {
    if (nargs != 2) {
        throw py::type_error("expect 2 arguments");
    }
468
    auto* tw = TensorWrapper::try_cast(args[0]);
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
    if (!tw) {
        throw py::type_error("argument 1 must be Tensor");
    }
    auto* tensor = tw->m_tensor.get();
    py::object callback;
    if (args[1] != Py_None) {
        callback = py::reinterpret_borrow<py::object>(args[1]);
    }
    m_key->attach(tensor, std::move(callback));
}

//!  GradKey is weakly refered by tensor->m_grad_info.grad_fn->key after attach
void GradKey::attach(Tensor* tensor, pybind11::object callback) {
    if (!active) {
        throw py::value_error("grad key finalized");
    }

486 487
    if (tensor->m_grad_info_dict.count(this)) {
        if (tensor->m_grad_info_dict.at(this)->callback) {
488 489 490
            throw py::value_error("callback already set on this tensor");
        }
    } else {
491 492 493
        auto& grad_info = tensor->m_grad_info_dict[this];
        grad_info.idx = 0;
        auto& grad_fn = grad_info.grad_fn;
494 495 496
        grad_fn = std::make_shared<GradFn>();
        grad_fn->key = shared_from_this();
        grad_fn->slots.resize(1);
497
        grad_info.insert_after(free_vars_head);
498
        tensor->m_flags |= Flags::GRAD;
499
    }
500
    tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback);
501 502
}

503 504
template<typename T>
void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) {
505
    if (!grad) {
506
        grad = std::forward<T>(delta);
507 508
        return;
    }
509 510
    static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD));
    grad = apply(op, grad, std::forward<T>(delta))[0];
511 512 513 514 515 516 517 518 519 520 521 522 523 524
}

void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
    if (!active) {
        throw py::value_error("finalized");
    }
    if (tensors.size() != grads.size()) {
        throw py::value_error("tensor and grad size mismatch");
    }

    // this GradKey is marked inactive here
    active = false;
    struct CleanupGuard {
        GradKey* owner;
525 526 527
        size_t priority_backup;
        CleanupGuard(GradKey* this_) : owner(this_) {
            priority_backup = sm_min_priority;
528
            sm_min_priority = owner->priority + 1;
529 530 531 532 533
        }
        ~CleanupGuard() {
            owner->cleanup();
            sm_min_priority = priority_backup;
        }
534 535
    } _cleanup_guard(this);

536 537 538 539 540 541
    if (tape.empty()) return;

    BackwardContext bctx;
    if (!grads.empty()) {
        bctx.pytype = Py_TYPE(grads[0]->self().ptr());
    }
542 543

    for (size_t i = 0; i < tensors.size(); ++i) {
544 545
        if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) {
            continue;
546
        }
547 548
        auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this);
        grad_info->grad = grads[i]->m_tensor;
549 550 551 552
    }

    std::vector<std::shared_ptr<GradFn>> ref_keeper;
    ref_keeper.reserve(tape.size());
553

554 555 556 557
    // back-propagation in reverse order
    for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
        auto&& grad_fn = tape[k].lock();
        if (!grad_fn) continue;
558

559
        auto grad_receiver = [&](size_t i, auto&& g) {
560 561 562 563
            auto& dst = grad_fn->dsts[i];
            if (dst) {
                accum_grad(dst->grad, std::forward<decltype(g)>(g));
            }
564 565 566 567 568 569 570
        };
        std::visit([&](auto&& backward) {
            using T = std::decay_t<decltype(backward)>;
            if constexpr (std::is_same_v<T, std::monostate>) {
                mgb_assert(0);
            } else {
                auto&& grads = views::transform(grad_fn->slots, [](auto&& slot) {return slot.grad.get();});
571
                backward(bctx, std::forward<decltype(grads)>(grads), grad_receiver);
572
            }
573 574
        }, grad_fn->backward);

575 576 577
        for (auto&& dst : grad_fn->dsts) {
            if (!dst.grad_fn) continue;
            if (!dst.grad_fn->in_ref_keeper) {
578 579
                // after grad_fn is cleared, refcnt of subsequent grad_fn
                // could drop to 0
580 581 582 583
                dst.grad_fn->in_ref_keeper = true;
                ref_keeper.push_back(dst.grad_fn);
            }
            if (!dst.producer_record.next && dst->callback && dst->grad) {
584
                // I'm the last grad producer, invoke callback
585
                dst->callback(bctx.wrap_tensor(dst->grad));
586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
            }
        }
        grad_fn->clear();
    } // finish tape loop
}

void GradKey::cleanup() {
    active = false;
    tape.clear();
    for (intrusive_list::Iterator it(free_vars_head); it;) {
        it->grad_fn.reset();
        (it++)->unlink();
    }
}

void GradKeyWrapper::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
    m_key->backward(std::move(tensors), std::move(grads));
}

605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622
PyObject* GradKeyWrapper::get_name() {
    return py::cast(m_key->name).release().ptr();
}

void GradKeyWrapper::set_name(py::handle name) {
    m_key->name = py::cast<std::string>(name);
}

PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) {
    if (nargs != 1) {
        PyErr_SetString(PyExc_TypeError, "expect 1 argument");
        return nullptr;
    }
    auto* tw = TensorWrapper::try_cast(args[0]);
    if (!tw) {
        PyErr_SetString(PyExc_TypeError, "expect Tensor");
        return nullptr;
    }
623
    if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) {
624 625 626 627 628
        Py_RETURN_TRUE;
    }
    Py_RETURN_FALSE;
}

629
int GradKey::sm_min_priority = std::numeric_limits<int>::min();
630

631 632 633 634
GradKey::~GradKey() {
    cleanup();
}

635 636 637 638 639
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() {
    static std::unordered_map<Typeinfo*, GradRuleFn> registry;
    return registry;
}

640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676
void GradInfoCollection::_shrink() {
    auto pred = [](GradInfo& info){ return !(info.grad_fn) || info.grad_fn->key.expired(); };
    auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred);
    m_storage.erase(iter, m_storage.end());
}

bool GradInfoCollection::contains(GradKey* key) {
    _shrink();
    for (auto&& grad_info: m_storage) {
        if (grad_info.grad_fn->key.lock().get() == key) {
            return true;
        }
    }
    return false;
}

GradInfo& GradInfoCollection::operator[](GradKey* key) {
    _shrink();
    for (auto&& grad_info: m_storage) {
        if (grad_info.grad_fn->key.lock().get() == key) {
            return grad_info;
        }
    }
    m_storage.emplace_back();
    return m_storage.back();
}

GradInfo& GradInfoCollection::at(GradKey* key) {
    _shrink();
    for (auto&& grad_info: m_storage) {
        if (grad_info.grad_fn->key.lock().get() == key) {
            return grad_info;
        }
    }
    mgb_assert(false);
}

677
} // namespace mgb::imperative::python