tensor.h 8.1 KB
Newer Older
1 2 3 4
/**
 * \file imperative/python/src/tensor.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include <variant>

#include "megbrain/imperative/interpreter.h"
#include "pybind11/pybind11.h"

#include "./pyext17.h"

namespace mgb::imperative::python {

template<typename T, typename B = pybind11::object>
struct ObjectPtr : B {
    using B::B;
    T& operator*() {return reinterpret_cast<T&>(*B::ptr());}
    T* operator->() {return reinterpret_cast<T*>(B::ptr());}
};

} // namespace mgb::imperative::python

#include "./grad_info.h" // for struct GradInfo
33
#include "./trace_info.h" // for struct TraceInfo
34 35 36

namespace mgb::imperative::python {

37
extern interpreter::Interpreter::Channel* interpreter_for_py;
38 39 40 41 42 43 44 45

class SharedHandle {
    using Handle = interpreter::Interpreter::Handle;
    static_assert(std::is_pointer_v<Handle>);
    std::shared_ptr<std::remove_pointer_t<Handle>> holder;

public:
    inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){
46 47 48
        if (h) {
            interpreter_for_py->del(h);
        }
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
    }) {}
    SharedHandle(const SharedHandle&) = default;
    SharedHandle& operator=(const SharedHandle&) = default;
    SharedHandle(SharedHandle&&) = default;
    SharedHandle& operator=(SharedHandle&&) = default;

    inline Handle get() {return holder.get();}
};


struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
    using flags_t = uint64_t;

    struct Flags {
        static constexpr flags_t SCALAR = 1;
        static constexpr flags_t GRAD = 1 << 1;
        static constexpr flags_t TRACE = 1 << 2;
    };

    flags_t m_flags = 0;

    GradInfo m_grad_info;
    TraceInfo m_trace_info;
    SharedHandle m_handle;
73
    cg::VarNode* m_var;
74 75 76

    using Handle = interpreter::Interpreter::Handle;

77
    inline Tensor() : m_handle(nullptr), m_var(nullptr) {}
78 79 80 81
    inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {}
    inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)), m_var(nullptr) {}
    inline explicit Tensor(cg::VarNode *var) : m_handle(nullptr), m_var(var) {}

82 83 84 85 86 87 88
    ~Tensor() = default;

    inline std::shared_ptr<Tensor> copy() {
        auto ret = std::make_shared<Tensor>(m_handle);
        ret->m_flags = m_flags;
        ret->m_grad_info = m_grad_info;
        ret->m_trace_info = m_trace_info;
89
        ret->m_var = m_var;
90 91 92
        return ret;
    }

93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
    inline DType dtype() {
        if (m_var) {
            return m_var->dtype();
        }
        return interpreter_for_py->get_dtype(m_handle.get());
    }
    inline CompNode comp_node() {
        if (m_var) {
            return m_var->comp_node();
        }
        return interpreter_for_py->get_device(m_handle.get());
    }
    inline TensorShape shape() {
        if (m_var) {
            return m_var->shape();
        }
        return interpreter_for_py->get_shape(m_handle.get());
    }
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
};


struct TensorWrapper {
    std::shared_ptr<Tensor> m_tensor;

    inline TensorWrapper(std::shared_ptr<Tensor> tensor = {}) : m_tensor(std::move(tensor)) {}
    TensorWrapper(PyObject* args, PyObject* kwargs);
    ~TensorWrapper() = default;

    static constexpr auto tp_name = pybind11::detail::_("Tensor");

    using wrap_t = pyext17::wrap<TensorWrapper>;
    friend wrap_t;

    inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();}
127
    inline static TensorWrapper* try_cast(PyObject* op) {
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
        if (!wrap_t::type().isinstance(op)) return nullptr;
        return cast(op);
    }
    inline ObjectPtr<TensorWrapper, pybind11::handle> self() {return wrap_t::pycast(this);}

    template <typename... Args>
    static ObjectPtr<Tensor> make(Args&&... args) {
        auto* op = wrap_t::cnew(std::forward<Args>(args)...);
        return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
    }

    template <typename... Args>
    static ObjectPtr<Tensor> make(PyTypeObject* pytype, Args&&... args) {
        auto* op = wrap_t::cnew_with_type(pytype,std::forward<Args>(args)...);
        return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
    }

    PyObject* shape();
    PyObject* dtype();
    PyObject* device();
    PyObject* numpy();
    void reset(PyObject*);
150
    PyObject* detach();
151 152
    PyObject* isscalar();
    void setscalar();
153 154 155 156
    PyObject* _dev_tensor();
    void _swap_in();
    void _swap_out();
    void _drop();
157
    PyObject* varnode();
158
    void reset_varnode();
159 160 161 162
    PyObject* handle();
    void set_handle(PyObject *);

    PyObject* mixin_handle();
163
    PyObject* recording();
164
    PyObject* copied();
165 166

    void set_mixin_handle(PyObject*);
167 168 169 170 171 172
    void set_recording(PyObject*);

    PyObject* compiled_info();
    void set_compiled_info(PyObject *);
    PyObject* trace_mixin_info();
    void set_trace_mixin_info(PyObject *);
173 174 175 176 177 178
};


PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */);

struct ApplyContext {
179 180
    static Tensor::flags_t global_disable;

181 182 183 184
    Tensor::flags_t flags;
    std::shared_ptr<OpDef> op;
    Tensor*const* args;
    size_t nargs;
185
    PyTypeObject* pytype = nullptr;
186
    bool backward = false;
187 188 189 190 191 192 193 194 195 196 197 198

    class scoped_disable : NonCopyableObj {
        Tensor::flags_t saved_flags;

    public:
        scoped_disable(Tensor::flags_t flags) : saved_flags(ApplyContext::global_disable) {
            ApplyContext::global_disable |= flags;
        }
        ~scoped_disable() {
            ApplyContext::global_disable = saved_flags;
        }
    };
199 200 201 202 203 204
};

using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>;

apply_result_t apply(ApplyContext& ctx);

205 206 207 208 209 210 211 212 213 214
template <typename T>
decltype(auto) resolve_arrow(T&& p) {
    if constexpr (std::is_pointer_v<std::remove_reference_t<T>>) {
        auto* ret = p;
        return ret;
    } else {
        auto probe = [](auto&& p) -> decltype(p.operator->()) {};
        if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) {
            return resolve_arrow(p.operator->());
        } else {
215
            return std::forward<T>(p);
216 217 218 219 220 221
        }
    }
}

template <typename... Args>
constexpr bool is_all_tensor_ptr = (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>);
222

223
extern bool is_tracing; // FIXME: should use ApplyContext::global_enable
224 225
extern bool is_compiled;

226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0>
apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) {
    ApplyContext ctx;
    Tensor* arg_arr[] = {resolve_arrow(args)...};
    ctx.flags = (0 | ... | args->m_flags);
    ctx.flags |= is_tracing ? Tensor::Flags::TRACE : 0;
    ctx.args = arg_arr;
    ctx.nargs = sizeof...(args);
    ctx.op = std::move(op);
    return apply(ctx);
}

template <typename T>
auto apply(std::shared_ptr<OpDef> op, T&& tensors)
        -> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>,
                            apply_result_t> {
    ApplyContext ctx;
    ctx.op = std::move(op);
    ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0;
    ctx.nargs = tensors.size();
    Tensor* args[ctx.nargs];
    ctx.args = args;
    for (size_t i = 0; i < ctx.nargs; ++i) {
        args[i] = resolve_arrow(tensors[i]);
        ctx.flags |= args[i]->m_flags;
    }
    return apply(ctx);
}

255 256 257 258 259 260 261 262 263 264 265 266
inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) {
    ApplyContext ctx;
    ctx.op = std::move(op);
    ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0;
    ctx.nargs = nargs;
    ctx.args = args;
    for (size_t i = 0; i < nargs; ++i) {
        ctx.flags |= args[i]->m_flags;
    }
    return apply(ctx);
}

267 268
void init_tensor(pybind11::module);

269 270
extern PyObject *cpp_apply_with_tracing, *cpp_apply_compiled_mode;
extern PyObject *cpp_apply_backward_varnode;
271

272 273 274 275 276 277 278
} // namespace mgb::imperative::python

namespace pybind11::detail {

template<> struct type_caster<mgb::imperative::python::TensorWrapper> : mgb::imperative::python::TensorWrapper::wrap_t::caster {};

} // namespace pybind11::detail