tensor.h 10.1 KB
Newer Older
1 2 3 4
/**
 * \file imperative/python/src/tensor.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once
13
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
14 15 16

#include <variant>

17
#include <string>
18
#include <unordered_map>
M
Megvii Engine Team 已提交
19 20
#include "megbrain/imperative/interpreter.h"
#include "pybind11/pybind11.h"
21 22 23 24 25

#include "./pyext17.h"

namespace mgb::imperative::python {

M
Megvii Engine Team 已提交
26
template <typename T, typename B = pybind11::object>
27 28
struct ObjectPtr : B {
    using B::B;
M
Megvii Engine Team 已提交
29 30
    T& operator*() { return reinterpret_cast<T&>(*B::ptr()); }
    T* operator->() { return reinterpret_cast<T*>(B::ptr()); }
31 32
};

M
Megvii Engine Team 已提交
33
}  // namespace mgb::imperative::python
34

M
Megvii Engine Team 已提交
35 36
#include "./grad_info.h"   // for struct GradInfo
#include "./trace_info.h"  // for struct TraceInfo
37 38 39

namespace mgb::imperative::python {

40 41
struct GradKey;

42
extern interpreter::Interpreter::Channel* interpreter_for_py;
43 44 45 46 47 48 49

class SharedHandle {
    using Handle = interpreter::Interpreter::Handle;
    static_assert(std::is_pointer_v<Handle>);
    std::shared_ptr<std::remove_pointer_t<Handle>> holder;

public:
M
Megvii Engine Team 已提交
50 51 52 53 54 55
    inline explicit SharedHandle(Handle handle)
            : holder(handle, [](auto* h) {
                  if (h) {
                      interpreter_for_py->del(h);
                  }
              }) {}
56 57 58 59 60
    SharedHandle(const SharedHandle&) = default;
    SharedHandle& operator=(const SharedHandle&) = default;
    SharedHandle(SharedHandle&&) = default;
    SharedHandle& operator=(SharedHandle&&) = default;

M
Megvii Engine Team 已提交
61
    inline Handle get() { return holder.get(); }
62 63
};

64 65 66 67
// impl in grad.cpp
class GradInfoCollection {
private:
    SmallVector<GradInfo> m_storage;
M
Megvii Engine Team 已提交
68

69 70
protected:
    void _shrink();
M
Megvii Engine Team 已提交
71

72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
public:
    bool contains(GradKey* key);
    GradInfo& operator[](GradKey* key);
    GradInfo& at(GradKey* key);
    bool empty() {
        _shrink();
        return m_storage.empty();
    }
    auto begin() {
        _shrink();
        return m_storage.begin();
    }
    auto end() {
        _shrink();
        return m_storage.end();
    }
M
Megvii Engine Team 已提交
88
    size_t count(GradKey* key) { return contains(key) ? 1 : 0; }
89 90
};

91 92 93 94 95 96 97
struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
    using flags_t = uint64_t;

    struct Flags {
        static constexpr flags_t SCALAR = 1;
        static constexpr flags_t GRAD = 1 << 1;
        static constexpr flags_t TRACE = 1 << 2;
98
        static constexpr flags_t MODULE_TRACE = 1 << 3;
99 100 101 102
    };

    flags_t m_flags = 0;

103
    GradInfoCollection m_grad_info_dict;
104 105
    TraceInfo m_trace_info;
    SharedHandle m_handle;
106 107
    std::string user_custom_name;
    std::string automatic_name;
108
    cg::VarNode* m_var;
109
    pybind11::object m_module_trace_info;
110 111 112

    using Handle = interpreter::Interpreter::Handle;

113
    inline Tensor() : m_handle(nullptr), m_var(nullptr) {}
114
    inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {}
M
Megvii Engine Team 已提交
115 116 117
    inline explicit Tensor(SharedHandle handle)
            : m_handle(std::move(handle)), m_var(nullptr) {}
    inline explicit Tensor(cg::VarNode* var) : m_handle(nullptr), m_var(var) {}
118

119 120 121 122 123
    ~Tensor() = default;

    inline std::shared_ptr<Tensor> copy() {
        auto ret = std::make_shared<Tensor>(m_handle);
        ret->m_flags = m_flags;
124
        ret->m_grad_info_dict = m_grad_info_dict;
125
        ret->m_trace_info = m_trace_info;
126
        ret->m_var = m_var;
127 128 129
        return ret;
    }

130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    inline DType dtype() {
        if (m_var) {
            return m_var->dtype();
        }
        return interpreter_for_py->get_dtype(m_handle.get());
    }
    inline CompNode comp_node() {
        if (m_var) {
            return m_var->comp_node();
        }
        return interpreter_for_py->get_device(m_handle.get());
    }
    inline TensorShape shape() {
        if (m_var) {
            return m_var->shape();
        }
        return interpreter_for_py->get_shape(m_handle.get());
    }
148 149 150 151 152
};

struct TensorWrapper {
    std::shared_ptr<Tensor> m_tensor;

M
Megvii Engine Team 已提交
153 154
    inline TensorWrapper(std::shared_ptr<Tensor> tensor = {})
            : m_tensor(std::move(tensor)) {}
155 156 157 158 159 160 161 162
    TensorWrapper(PyObject* args, PyObject* kwargs);
    ~TensorWrapper() = default;

    static constexpr auto tp_name = pybind11::detail::_("Tensor");

    using wrap_t = pyext17::wrap<TensorWrapper>;
    friend wrap_t;

M
Megvii Engine Team 已提交
163 164 165
    inline static TensorWrapper* cast(PyObject* obj) {
        return reinterpret_cast<wrap_t*>(obj)->inst();
    }
166
    inline static TensorWrapper* try_cast(PyObject* obj) {
M
Megvii Engine Team 已提交
167 168
        if (!wrap_t::type().isinstance(obj))
            return nullptr;
169
        return cast(obj);
170
    }
M
Megvii Engine Team 已提交
171 172 173
    inline ObjectPtr<TensorWrapper, pybind11::handle> self() {
        return wrap_t::pycast(this);
    }
174 175 176 177 178 179 180 181 182

    template <typename... Args>
    static ObjectPtr<Tensor> make(Args&&... args) {
        auto* op = wrap_t::cnew(std::forward<Args>(args)...);
        return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
    }

    template <typename... Args>
    static ObjectPtr<Tensor> make(PyTypeObject* pytype, Args&&... args) {
M
Megvii Engine Team 已提交
183
        auto* op = wrap_t::cnew_with_type(pytype, std::forward<Args>(args)...);
184 185 186 187 188 189 190 191
        return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
    }

    PyObject* shape();
    PyObject* dtype();
    PyObject* device();
    PyObject* numpy();
    void reset(PyObject*);
192
    PyObject* detach();
193 194
    PyObject* isscalar();
    void setscalar();
195
    void unsetscalar();
196 197
    PyObject* _dev_tensor();
    void _drop();
198
    PyObject* varnode();
199
    void reset_varnode();
200
    PyObject* handle();
M
Megvii Engine Team 已提交
201
    void set_handle(PyObject*);
202 203

    PyObject* mixin_handle();
204
    PyObject* recording();
205
    PyObject* copied();
206 207

    void set_mixin_handle(PyObject*);
208 209 210
    void set_recording(PyObject*);

    PyObject* compiled_info();
M
Megvii Engine Team 已提交
211
    void set_compiled_info(PyObject*);
212
    PyObject* trace_mixin_info();
M
Megvii Engine Team 已提交
213
    void set_trace_mixin_info(PyObject*);
214
    PyObject* module_trace_info();
M
Megvii Engine Team 已提交
215
    void set_module_trace_info(PyObject*);
216
    PyObject* user_custom_name();
M
Megvii Engine Team 已提交
217
    void set_user_custom_name(PyObject*);
218
    PyObject* automatic_name();
M
Megvii Engine Team 已提交
219
    void set_automatic_name(PyObject*);
220
    PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
221 222
};

223 224 225 226
struct PySymbolVar {
    cg::VarNode* m_node = nullptr;
    bool is_scalar = false;
    PySymbolVar() = default;
M
Megvii Engine Team 已提交
227
    PySymbolVar(VarNode* m) : m_node(m) {}
228
};
229

M
Megvii Engine Team 已提交
230 231
PyObject* py_apply(
        PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */);
232 233

struct ApplyContext {
234
    static Tensor::flags_t global_disable;
235
    static Tensor::flags_t global_enable;
236

237
    Tensor::flags_t flags = 0;
238
    std::shared_ptr<OpDef> op;
M
Megvii Engine Team 已提交
239
    Tensor* const* args;
240
    size_t nargs;
241
    PyTypeObject* pytype = nullptr;
242
    bool backward = false;
243 244 245 246 247

    class scoped_disable : NonCopyableObj {
        Tensor::flags_t saved_flags;

    public:
M
Megvii Engine Team 已提交
248 249
        scoped_disable(Tensor::flags_t flags)
                : saved_flags(ApplyContext::global_disable) {
250 251
            ApplyContext::global_disable |= flags;
        }
M
Megvii Engine Team 已提交
252
        ~scoped_disable() { ApplyContext::global_disable = saved_flags; }
253
    };
254 255 256 257 258 259
};

using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>;

apply_result_t apply(ApplyContext& ctx);

260 261 262 263 264 265 266 267 268 269
template <typename T>
decltype(auto) resolve_arrow(T&& p) {
    if constexpr (std::is_pointer_v<std::remove_reference_t<T>>) {
        auto* ret = p;
        return ret;
    } else {
        auto probe = [](auto&& p) -> decltype(p.operator->()) {};
        if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) {
            return resolve_arrow(p.operator->());
        } else {
270
            return std::forward<T>(p);
271 272 273 274 275
        }
    }
}

template <typename... Args>
M
Megvii Engine Team 已提交
276 277
constexpr bool is_all_tensor_ptr =
        (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>);
278

279 280 281 282 283 284 285 286 287 288 289
template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0>
apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) {
    ApplyContext ctx;
    Tensor* arg_arr[] = {resolve_arrow(args)...};
    ctx.flags = (0 | ... | args->m_flags);
    ctx.args = arg_arr;
    ctx.nargs = sizeof...(args);
    ctx.op = std::move(op);
    return apply(ctx);
}

M
Megvii Engine Team 已提交
290
inline auto apply(std::shared_ptr<OpDef> op, Tensor* const* args, size_t nargs) {
291 292
    ApplyContext ctx;
    ctx.op = std::move(op);
293
    ctx.nargs = nargs;
294
    ctx.args = args;
295
    for (size_t i = 0; i < nargs; ++i) {
296 297 298 299 300
        ctx.flags |= args[i]->m_flags;
    }
    return apply(ctx);
}

301
template <typename T>
M
Megvii Engine Team 已提交
302 303
auto apply(std::shared_ptr<OpDef> op, T&& tensors) -> std::enable_if_t<
        std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>, apply_result_t> {
304 305
    size_t nargs = tensors.size();
    Tensor* args[nargs];
306
    for (size_t i = 0; i < nargs; ++i) {
307
        args[i] = resolve_arrow(tensors[i]);
308
    }
309 310 311
    return apply(op, args, nargs);
}

312 313
std::shared_ptr<Tensor> make_const(imperative::TensorPtr value);

M
Megvii Engine Team 已提交
314
inline auto apply(Subgraph graph, Tensor* const* args, size_t nargs) {
315 316 317 318
    SmallVector<std::shared_ptr<Tensor>> inputs;
    for (size_t i = 0; i < nargs; ++i) {
        inputs.push_back(args[i]->shared_from_this());
    }
M
Megvii Engine Team 已提交
319 320 321
    auto apply_functor = [](std::shared_ptr<OpDef> op,
                            SmallVector<std::shared_ptr<Tensor>> inputs,
                            size_t) { return apply(op, std::move(inputs)); };
322
    return graph.apply(inputs, apply_functor, &make_const);
323 324 325
}

template <typename T>
M
Megvii Engine Team 已提交
326 327
auto apply(Subgraph graph, T&& tensors) -> std::enable_if_t<
        std::is_same_v<std::decay_t<decltype(tensors[0])>, Tensor*>, apply_result_t> {
328 329 330 331 332 333
    size_t nargs = tensors.size();
    Tensor* args[nargs];
    for (size_t i = 0; i < nargs; ++i) {
        args[i] = resolve_arrow(tensors[i]);
    }
    return apply(graph, args, nargs);
334 335
}

336 337
void init_tensor(pybind11::module);

M
Megvii Engine Team 已提交
338 339 340
extern PyObject* cpp_apply_with_tracing;
extern PyObject* cpp_apply_backward_varnode;
extern PyObject* cpp_apply_module_trace;
341

M
Megvii Engine Team 已提交
342
}  // namespace mgb::imperative::python
343 344 345

namespace pybind11::detail {

M
Megvii Engine Team 已提交
346 347 348
template <>
struct type_caster<mgb::imperative::python::TensorWrapper>
        : mgb::imperative::python::TensorWrapper::wrap_t::caster {};
349

M
Megvii Engine Team 已提交
350
}  // namespace pybind11::detail