tensor.h 4.8 KB
Newer Older
1 2 3 4
/**
 * \file imperative/python/src/tensor.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once
13
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
14 15 16

#include <variant>

17
#include <string>
18
#include <unordered_map>
M
Megvii Engine Team 已提交
19 20
#include "megbrain/imperative/interpreter.h"
#include "pybind11/pybind11.h"
21 22

#include "./pyext17.h"
23 24
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/utils/span.h"
25 26 27

namespace mgb::imperative::python {

M
Megvii Engine Team 已提交
28
template <typename T, typename B = pybind11::object>
29 30
struct ObjectPtr : B {
    using B::B;
M
Megvii Engine Team 已提交
31 32
    T& operator*() { return reinterpret_cast<T&>(*B::ptr()); }
    T* operator->() { return reinterpret_cast<T*>(B::ptr()); }
33 34
};

M
Megvii Engine Team 已提交
35
}  // namespace mgb::imperative::python
36 37 38

namespace mgb::imperative::python {

39
extern interpreter::Interpreter::Channel* interpreter_for_py;
40
extern PyTypeObject* py_tensor_type;
41

42
struct Tensor : NonCopyableObj {
43
private:
44 45
    std::string m_name;
    ValueRef m_data;
M
Megvii Engine Team 已提交
46

47
public:
48 49
    using Handle = interpreter::Interpreter::Handle;

50
    inline explicit Tensor(ValueRef data) : m_data{data} {}
51

52 53 54
    ~Tensor() = default;

    inline std::shared_ptr<Tensor> copy() {
55
        auto ret = std::make_shared<Tensor>(m_data);
56
        ret->m_name = m_name;
57 58 59
        return ret;
    }

60 61 62 63 64 65
    inline DType dtype() { return *data().dtype(); }
    inline CompNode comp_node() { return *data().device(); }
    inline std::optional<ValueShape> shape() {
        auto shape = data().shape();
        if (!shape) {
            return {};
66
        }
67
        return *shape;
68
    }
69 70 71 72 73
    inline HostValue::ref_t numpy() { return data().numpy(); }
    inline void reset(ValueRef value) {
        m_data = value;
        if (!m_name.empty()) {
            set_name(m_name);
74 75
        }
    }
76 77 78 79 80 81 82 83
    inline ValueRef data() { return m_data.unwrap(); }
    bool is_scalar() { return data().is_scalar(); }
    inline std::string name() { return m_name; }
    inline void set_name(std::string name) {
        m_name = name;
        if (!name.empty()) {
            auto output = imperative::apply(RenameValue(name), m_data)[0];
            m_data = output;
84 85
        }
    }
86 87 88
};

struct TensorWrapper {
89
public:
90 91
    std::shared_ptr<Tensor> m_tensor;

M
Megvii Engine Team 已提交
92
    inline TensorWrapper(std::shared_ptr<Tensor> tensor = {})
93 94 95 96 97
            : m_tensor(std::move(tensor)) {
        mgb_assert(tensor, "empty storage");
    }

    inline TensorWrapper(ValueRef value) : m_tensor(std::make_shared<Tensor>(value)) {}
98 99 100 101 102 103 104 105
    TensorWrapper(PyObject* args, PyObject* kwargs);
    ~TensorWrapper() = default;

    static constexpr auto tp_name = pybind11::detail::_("Tensor");

    using wrap_t = pyext17::wrap<TensorWrapper>;
    friend wrap_t;

M
Megvii Engine Team 已提交
106 107 108
    inline static TensorWrapper* cast(PyObject* obj) {
        return reinterpret_cast<wrap_t*>(obj)->inst();
    }
109
    inline static TensorWrapper* try_cast(PyObject* obj) {
M
Megvii Engine Team 已提交
110 111
        if (!wrap_t::type().isinstance(obj))
            return nullptr;
112
        return cast(obj);
113
    }
M
Megvii Engine Team 已提交
114 115 116
    inline ObjectPtr<TensorWrapper, pybind11::handle> self() {
        return wrap_t::pycast(this);
    }
117 118 119 120 121 122 123 124 125

    template <typename... Args>
    static ObjectPtr<Tensor> make(Args&&... args) {
        auto* op = wrap_t::cnew(std::forward<Args>(args)...);
        return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
    }

    template <typename... Args>
    static ObjectPtr<Tensor> make(PyTypeObject* pytype, Args&&... args) {
M
Megvii Engine Team 已提交
126
        auto* op = wrap_t::cnew_with_type(pytype, std::forward<Args>(args)...);
127 128 129 130 131 132 133 134
        return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
    }

    PyObject* shape();
    PyObject* dtype();
    PyObject* device();
    PyObject* numpy();
    void reset(PyObject*);
135
    PyObject* detach();
136
    PyObject* isscalar();
137 138
    PyObject* _dev_tensor();
    void _drop();
139
    PyObject* varnode();
140
    PyObject* recording();
141
    PyObject* copied();
142
    PyObject* module_trace_info();
M
Megvii Engine Team 已提交
143
    void set_module_trace_info(PyObject*);
144
    void _set_name(PyObject*);
145
    PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
146 147
    PyObject* _detail();
    void _watch();
148 149
};

150 151 152 153
struct PySymbolVar {
    cg::VarNode* m_node = nullptr;
    bool is_scalar = false;
    PySymbolVar() = default;
M
Megvii Engine Team 已提交
154
    PySymbolVar(VarNode* m) : m_node(m) {}
155
};
156

M
Megvii Engine Team 已提交
157 158
PyObject* py_apply(
        PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */);
159

160 161
void init_tensor(pybind11::module);

M
Megvii Engine Team 已提交
162
extern PyObject* cpp_apply_module_trace;
163

M
Megvii Engine Team 已提交
164
}  // namespace mgb::imperative::python
165 166 167

namespace pybind11::detail {

M
Megvii Engine Team 已提交
168 169 170
template <>
struct type_caster<mgb::imperative::python::TensorWrapper>
        : mgb::imperative::python::TensorWrapper::wrap_t::caster {};
171

M
Megvii Engine Team 已提交
172
}  // namespace pybind11::detail