/** * \file imperative/python/src/grad.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "./grad.h" #include "megbrain/imperative/backward_graph_opt.h" #include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/proxy_graph_detail.h" #include "megbrain/imperative/resource_manager.h" #include "megbrain/utils/mempool.h" #include "range/v3/all.hpp" #include "./helper.h" #include "./transformation.h" namespace py = pybind11; namespace views = ranges::views; namespace mgb::imperative::python { namespace { std::unordered_map, GradKeyWrapper*> grad_key_map; } GradKeyWrapper::GradKeyWrapper() {} void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { if (nargs != 2) { throw py::type_error("expect 2 arguments"); } auto* tw = TensorWrapper::try_cast(args[0]); if (!tw) { throw py::type_error("argument 1 must be Tensor"); } py::object callback; if (args[1] != Py_None) { callback = py::reinterpret_borrow(args[1]); } GenericFunction generic_callback = [=](Span inputs) -> ValueRefList { mgb_assert(inputs.size() == 1); if (callback) { callback(TensorWrapper::make(py_tensor_type, inputs[0])); } return {}; }; auto attached_value = imperative::apply( AttachGrad(m_key), tw->m_tensor->data(), FunctionValue::make(generic_callback))[0]; tw->m_tensor->reset(attached_value); } void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) { std::vector args; mgb_assert(tensors.size() == grads.size()); for (auto&& tensor : tensors) { args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); } for (auto&& grad : grads) { args.push_back(TensorWrapper::try_cast(grad.ptr())->m_tensor->data()); } imperative::apply(GradBackward(self->m_key), {args.data(), args.size()}); } pybind11::function GradKeyWrapper::get_backward_closure( GradKeyWrapper* self, py::list tensors) { std::vector args; for (auto&& tensor : tensors) { args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); } auto closure_value = imperative::apply(GetBackwardColsure(self->m_key), args)[0]; auto closure = closure_value.as_ref(); auto py_function = [closure](std::vector tensors) { std::vector args; for (auto* tw : tensors) { args.push_back(tw->m_tensor->data()); } (*closure)(args); }; return pybind11::cpp_function(py_function); } PyObject* GradKeyWrapper::get_name() { return py::cast(m_name).release().ptr(); } void GradKeyWrapper::set_name(py::handle name) { m_name = py::cast(name); if (m_key) { m_key->name(m_name); } } PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { if (nargs != 1) { PyErr_SetString(PyExc_TypeError, "expect 1 argument"); return nullptr; } auto* tw = TensorWrapper::try_cast(args[0]); if (!tw) { PyErr_SetString(PyExc_TypeError, "expect Tensor"); return nullptr; } if (imperative::apply(IsAttachedTo(m_key), tw->m_tensor->data())[0] .cast()) { Py_RETURN_TRUE; } Py_RETURN_FALSE; } void GradKeyWrapper::enter() { m_transformation = std::make_shared(); m_key = m_transformation->key(); m_key->name(m_name); grad_key_map[m_key] = this; TransformationManager::get_instance().register_at( m_transformation); } void GradKeyWrapper::exit() { TransformationManager::get_instance().unregister( m_transformation); grad_key_map.erase(m_key); m_key = {}; m_transformation.reset(); } void GradKeyWrapper::suppress() { m_transformation->suppress(); } void GradKeyWrapper::resume() { m_transformation->resume(); } GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr key) { return grad_key_map.at(key); } GradKeyWrapper::~GradKeyWrapper() {} } // namespace mgb::imperative::python