tensor.h 9.3 KB
Newer Older
1 2 3 4
/**
 * \file imperative/python/src/tensor.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once
13
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
14 15 16 17 18

#include <variant>

#include "megbrain/imperative/interpreter.h"
#include "pybind11/pybind11.h"
19
#include <string>
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34

#include "./pyext17.h"

namespace mgb::imperative::python {

template<typename T, typename B = pybind11::object>
struct ObjectPtr : B {
    using B::B;
    T& operator*() {return reinterpret_cast<T&>(*B::ptr());}
    T* operator->() {return reinterpret_cast<T*>(B::ptr());}
};

} // namespace mgb::imperative::python

#include "./grad_info.h" // for struct GradInfo
35
#include "./trace_info.h" // for struct TraceInfo
36 37 38

namespace mgb::imperative::python {

39
extern interpreter::Interpreter::Channel* interpreter_for_py;
40 41 42 43 44 45 46 47

class SharedHandle {
    using Handle = interpreter::Interpreter::Handle;
    static_assert(std::is_pointer_v<Handle>);
    std::shared_ptr<std::remove_pointer_t<Handle>> holder;

public:
    inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){
48 49 50
        if (h) {
            interpreter_for_py->del(h);
        }
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
    }) {}
    SharedHandle(const SharedHandle&) = default;
    SharedHandle& operator=(const SharedHandle&) = default;
    SharedHandle(SharedHandle&&) = default;
    SharedHandle& operator=(SharedHandle&&) = default;

    inline Handle get() {return holder.get();}
};


struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
    using flags_t = uint64_t;

    struct Flags {
        static constexpr flags_t SCALAR = 1;
        static constexpr flags_t GRAD = 1 << 1;
        static constexpr flags_t TRACE = 1 << 2;
    };

    flags_t m_flags = 0;

    GradInfo m_grad_info;
    TraceInfo m_trace_info;
    SharedHandle m_handle;
75 76
    std::string user_custom_name;
    std::string automatic_name;
77
    cg::VarNode* m_var;
78 79 80

    using Handle = interpreter::Interpreter::Handle;

81
    inline Tensor() : m_handle(nullptr), m_var(nullptr) {}
82 83 84 85
    inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {}
    inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)), m_var(nullptr) {}
    inline explicit Tensor(cg::VarNode *var) : m_handle(nullptr), m_var(var) {}

86 87 88 89 90 91 92
    ~Tensor() = default;

    inline std::shared_ptr<Tensor> copy() {
        auto ret = std::make_shared<Tensor>(m_handle);
        ret->m_flags = m_flags;
        ret->m_grad_info = m_grad_info;
        ret->m_trace_info = m_trace_info;
93
        ret->m_var = m_var;
94 95 96
        return ret;
    }

97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
    inline DType dtype() {
        if (m_var) {
            return m_var->dtype();
        }
        return interpreter_for_py->get_dtype(m_handle.get());
    }
    inline CompNode comp_node() {
        if (m_var) {
            return m_var->comp_node();
        }
        return interpreter_for_py->get_device(m_handle.get());
    }
    inline TensorShape shape() {
        if (m_var) {
            return m_var->shape();
        }
        return interpreter_for_py->get_shape(m_handle.get());
    }
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
};


struct TensorWrapper {
    std::shared_ptr<Tensor> m_tensor;

    inline TensorWrapper(std::shared_ptr<Tensor> tensor = {}) : m_tensor(std::move(tensor)) {}
    TensorWrapper(PyObject* args, PyObject* kwargs);
    ~TensorWrapper() = default;

    static constexpr auto tp_name = pybind11::detail::_("Tensor");

    using wrap_t = pyext17::wrap<TensorWrapper>;
    friend wrap_t;

    inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();}
131
    inline static TensorWrapper* try_cast(PyObject* op) {
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
        if (!wrap_t::type().isinstance(op)) return nullptr;
        return cast(op);
    }
    inline ObjectPtr<TensorWrapper, pybind11::handle> self() {return wrap_t::pycast(this);}

    template <typename... Args>
    static ObjectPtr<Tensor> make(Args&&... args) {
        auto* op = wrap_t::cnew(std::forward<Args>(args)...);
        return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
    }

    template <typename... Args>
    static ObjectPtr<Tensor> make(PyTypeObject* pytype, Args&&... args) {
        auto* op = wrap_t::cnew_with_type(pytype,std::forward<Args>(args)...);
        return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
    }

    PyObject* shape();
    PyObject* dtype();
    PyObject* device();
    PyObject* numpy();
    void reset(PyObject*);
154
    PyObject* detach();
155 156
    PyObject* isscalar();
    void setscalar();
157
    void unsetscalar();
158 159 160 161
    PyObject* _dev_tensor();
    void _swap_in();
    void _swap_out();
    void _drop();
162
    PyObject* varnode();
163
    void reset_varnode();
164 165 166 167
    PyObject* handle();
    void set_handle(PyObject *);

    PyObject* mixin_handle();
168
    PyObject* recording();
169
    PyObject* copied();
170 171

    void set_mixin_handle(PyObject*);
172 173 174 175 176 177
    void set_recording(PyObject*);

    PyObject* compiled_info();
    void set_compiled_info(PyObject *);
    PyObject* trace_mixin_info();
    void set_trace_mixin_info(PyObject *);
178 179 180 181
    PyObject* user_custom_name();
    void set_user_custom_name(PyObject *);
    PyObject* automatic_name();
    void set_automatic_name(PyObject *);
182
    PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
183 184
};

185 186 187 188 189 190
struct PySymbolVar {
    cg::VarNode* m_node = nullptr;
    bool is_scalar = false;
    PySymbolVar() = default;
    PySymbolVar(VarNode *m): m_node(m){}
};
191 192 193 194

PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */);

struct ApplyContext {
195
    static Tensor::flags_t global_disable;
196
    static Tensor::flags_t global_enable;
197

198
    Tensor::flags_t flags = 0;
199 200 201
    std::shared_ptr<OpDef> op;
    Tensor*const* args;
    size_t nargs;
202
    PyTypeObject* pytype = nullptr;
203
    bool backward = false;
204 205 206 207 208 209 210 211 212 213 214 215

    class scoped_disable : NonCopyableObj {
        Tensor::flags_t saved_flags;

    public:
        scoped_disable(Tensor::flags_t flags) : saved_flags(ApplyContext::global_disable) {
            ApplyContext::global_disable |= flags;
        }
        ~scoped_disable() {
            ApplyContext::global_disable = saved_flags;
        }
    };
216 217 218 219 220 221
};

using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>;

apply_result_t apply(ApplyContext& ctx);

222 223 224 225 226 227 228 229 230 231
template <typename T>
decltype(auto) resolve_arrow(T&& p) {
    if constexpr (std::is_pointer_v<std::remove_reference_t<T>>) {
        auto* ret = p;
        return ret;
    } else {
        auto probe = [](auto&& p) -> decltype(p.operator->()) {};
        if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) {
            return resolve_arrow(p.operator->());
        } else {
232
            return std::forward<T>(p);
233 234 235 236 237 238
        }
    }
}

template <typename... Args>
constexpr bool is_all_tensor_ptr = (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>);
239

240 241 242 243 244 245 246 247 248 249 250
template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0>
apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) {
    ApplyContext ctx;
    Tensor* arg_arr[] = {resolve_arrow(args)...};
    ctx.flags = (0 | ... | args->m_flags);
    ctx.args = arg_arr;
    ctx.nargs = sizeof...(args);
    ctx.op = std::move(op);
    return apply(ctx);
}

251
inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) {
252 253
    ApplyContext ctx;
    ctx.op = std::move(op);
254
    ctx.nargs = nargs;
255
    ctx.args = args;
256
    for (size_t i = 0; i < nargs; ++i) {
257 258 259 260 261
        ctx.flags |= args[i]->m_flags;
    }
    return apply(ctx);
}

262 263 264 265 266 267
template <typename T>
auto apply(std::shared_ptr<OpDef> op, T&& tensors)
        -> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>,
                            apply_result_t> {
    size_t nargs = tensors.size();
    Tensor* args[nargs];
268
    for (size_t i = 0; i < nargs; ++i) {
269
        args[i] = resolve_arrow(tensors[i]);
270
    }
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
    return apply(op, args, nargs);
}

inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) {
    SmallVector<std::shared_ptr<Tensor>> inputs;
    for (size_t i = 0; i < nargs; ++i) {
        inputs.push_back(args[i]->shared_from_this());
    }
    auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs) {
        return apply(op, inputs);
    };
    auto const_functor = [](imperative::TensorPtr value) {
        return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor()));
    };
    return graph.apply(inputs, apply_functor, const_functor);
}

template <typename T>
auto apply(Subgraph graph, T&& tensors)
        -> std::enable_if_t<std::is_same_v<decltype(tensors[0]), Tensor*>,
                            apply_result_t> {
    size_t nargs = tensors.size();
    Tensor* args[nargs];
    for (size_t i = 0; i < nargs; ++i) {
        args[i] = resolve_arrow(tensors[i]);
    }
    return apply(graph, args, nargs);
298 299
}

300 301
void init_tensor(pybind11::module);

302
extern PyObject *cpp_apply_with_tracing;
303
extern PyObject *cpp_apply_backward_varnode;
304

305 306 307 308 309 310 311
} // namespace mgb::imperative::python

namespace pybind11::detail {

template<> struct type_caster<mgb::imperative::python::TensorWrapper> : mgb::imperative::python::TensorWrapper::wrap_t::caster {};

} // namespace pybind11::detail