/** * \file imperative/python/src/grad.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "./tensor.h" #include #include namespace mgb::imperative::python { apply_result_t apply_grad(ApplyContext& ctx); struct GradKey : std::enable_shared_from_this, NonCopyableObj { std::string name; bool active = true; GradInfo::head_t free_vars_head; std::vector> tape; ~GradKey(); void attach(Tensor* tensor, pybind11::object callback); void backward(std::vector, std::vector); void cleanup(); }; struct GradKeyWrapper { using wrap_t = pyext17::wrap; static constexpr auto tp_name = pybind11::detail::_("GradKey"); std::shared_ptr m_key; inline GradKeyWrapper() : m_key(std::make_shared()) {} void attach(PyObject*const* args, size_t nargs); void backward(std::vector, std::vector); }; } // namespace mgb::imperative::python namespace pybind11::detail { template<> struct type_caster : mgb::imperative::python::GradKeyWrapper::wrap_t::caster {}; } // namespace pybind11::detail