grad.cpp 4.5 KB
Newer Older
1 2 3 4
/**
 * \file imperative/python/src/grad.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "./grad.h"
13

14
#include "megbrain/imperative/backward_graph_opt.h"
15
#include "megbrain/imperative/ops/autogen.h"
M
Megvii Engine Team 已提交
16
#include "megbrain/imperative/proxy_graph_detail.h"
17
#include "megbrain/imperative/resource_manager.h"
18 19
#include "megbrain/utils/mempool.h"

20 21
#include "range/v3/all.hpp"

22
#include "./helper.h"
23 24
#include "./transformation.h"

25
namespace py = pybind11;
26
namespace views = ranges::views;
27 28 29 30

namespace mgb::imperative::python {

namespace {
31
std::unordered_map<std::shared_ptr<GradKey>, GradKeyWrapper*> grad_key_map;
32 33
}

34
GradKeyWrapper::GradKeyWrapper() {}
35

M
Megvii Engine Team 已提交
36
void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) {
37 38 39
    if (nargs != 2) {
        throw py::type_error("expect 2 arguments");
    }
40
    auto* tw = TensorWrapper::try_cast(args[0]);
41 42 43 44 45 46 47
    if (!tw) {
        throw py::type_error("argument 1 must be Tensor");
    }
    py::object callback;
    if (args[1] != Py_None) {
        callback = py::reinterpret_borrow<py::object>(args[1]);
    }
48
    GenericFunction generic_callback = [=](Span<ValueRef> inputs) -> ValueRefList {
49 50 51
        mgb_assert(inputs.size() == 1);
        if (callback) {
            callback(TensorWrapper::make(py_tensor_type, inputs[0]));
52
        }
53 54
        return {};
    };
55
    auto attached_value = imperative::apply(
56
            AttachGrad(m_key), tw->m_tensor->data(),
57 58
            FunctionValue::make(generic_callback))[0];
    tw->m_tensor->reset(attached_value);
59 60
}

61 62 63 64 65
void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) {
    std::vector<ValueRef> args;
    mgb_assert(tensors.size() == grads.size());
    for (auto&& tensor : tensors) {
        args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data());
66
    }
67 68
    for (auto&& grad : grads) {
        args.push_back(TensorWrapper::try_cast(grad.ptr())->m_tensor->data());
69
    }
70
    imperative::apply(GradBackward(self->m_key), {args.data(), args.size()});
71 72
}

73 74 75 76 77
pybind11::function GradKeyWrapper::get_backward_closure(
        GradKeyWrapper* self, py::list tensors) {
    std::vector<ValueRef> args;
    for (auto&& tensor : tensors) {
        args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data());
78
    }
79 80
    auto closure_value = imperative::apply(GetBackwardColsure(self->m_key), args)[0];
    auto closure = closure_value.as_ref<FunctionValue>();
81 82 83 84 85 86 87 88
    auto py_function = [closure](std::vector<TensorWrapper*> tensors) {
        std::vector<ValueRef> args;
        for (auto* tw : tensors) {
            args.push_back(tw->m_tensor->data());
        }
        (*closure)(args);
    };
    return pybind11::cpp_function(py_function);
89 90
}

91
PyObject* GradKeyWrapper::get_name() {
92
    return py::cast(m_name).release().ptr();
93 94 95
}

void GradKeyWrapper::set_name(py::handle name) {
96 97 98 99
    m_name = py::cast<std::string>(name);
    if (m_key) {
        m_key->name(m_name);
    }
100 101
}

M
Megvii Engine Team 已提交
102
PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) {
103 104 105 106 107 108 109 110 111
    if (nargs != 1) {
        PyErr_SetString(PyExc_TypeError, "expect 1 argument");
        return nullptr;
    }
    auto* tw = TensorWrapper::try_cast(args[0]);
    if (!tw) {
        PyErr_SetString(PyExc_TypeError, "expect Tensor");
        return nullptr;
    }
112 113
    if (imperative::apply(IsAttachedTo(m_key), tw->m_tensor->data())[0]
                .cast<BoolValue>()) {
114 115 116 117 118
        Py_RETURN_TRUE;
    }
    Py_RETURN_FALSE;
}

119
void GradKeyWrapper::enter() {
120 121 122 123
    m_transformation = std::make_shared<GradTransformation>();
    m_key = m_transformation->key();
    m_key->name(m_name);
    grad_key_map[m_key] = this;
124 125
    TransformationManager::get_instance().register_at<TransformationManager::Grad>(
            m_transformation);
126 127
}

128 129 130
void GradKeyWrapper::exit() {
    TransformationManager::get_instance().unregister<TransformationManager::Grad>(
            m_transformation);
131 132
    grad_key_map.erase(m_key);
    m_key = {};
133
    m_transformation.reset();
134 135
}

136 137
void GradKeyWrapper::suppress() {
    m_transformation->suppress();
138 139
}

140 141
void GradKeyWrapper::resume() {
    m_transformation->resume();
142 143
}

144 145
GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr<GradKey> key) {
    return grad_key_map.at(key);
146 147
}

148
GradKeyWrapper::~GradKeyWrapper() {}
149

M
Megvii Engine Team 已提交
150
}  // namespace mgb::imperative::python