tensor.h 4.8 KB
Newer Older
1
#pragma once
2
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
3

4
#include <string>
5
#include <unordered_map>
6 7 8
#include <variant>

#include "megbrain/imperative/dispatch.h"
M
Megvii Engine Team 已提交
9 10
#include "megbrain/imperative/interpreter.h"
#include "pybind11/pybind11.h"
11 12

#include "./pyext17.h"
13
#include "megbrain/imperative/dispatch.h"
14 15
#include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/transformations/symbol.h"
16
#include "megbrain/imperative/utils/span.h"
17 18 19

namespace mgb::imperative::python {

M
Megvii Engine Team 已提交
20
template <typename T, typename B = pybind11::object>
21 22
struct ObjectPtr : B {
    using B::B;
M
Megvii Engine Team 已提交
23 24
    T& operator*() { return reinterpret_cast<T&>(*B::ptr()); }
    T* operator->() { return reinterpret_cast<T*>(B::ptr()); }
25 26
};

M
Megvii Engine Team 已提交
27
}  // namespace mgb::imperative::python
28 29 30

namespace mgb::imperative::python {

31
extern interpreter::Interpreter::Channel* interpreter_for_py;
32
extern PyTypeObject* py_tensor_type;
33
extern PyTypeObject* py_varnode_type;
34
extern PyTypeObject* py_external_type;
35
extern pybind11::handle py_device_type;
36 37
extern PyObject* cpp_use_symbolic_shape;
extern PyObject* cpp_astensor1d;
38

39
struct Tensor {
40
private:
41
    ValueRef m_data;
42
    std::string m_name;
M
Megvii Engine Team 已提交
43

44
public:
45 46
    using Handle = interpreter::Interpreter::Handle;

47
    inline explicit Tensor(ValueRef data) : m_data{data} {}
48

49 50
    ~Tensor() = default;

51
    inline Tensor copy() { return Tensor(imperative::apply(DupTensor(), data())[0]); }
52

53 54 55 56 57 58
    inline DType dtype() { return *data().dtype(); }
    inline CompNode comp_node() { return *data().device(); }
    inline std::optional<ValueShape> shape() {
        auto shape = data().shape();
        if (!shape) {
            return {};
59
        }
60
        return *shape;
61
    }
62
    inline Format format() { return *data().format(); }
63 64 65 66 67
    inline void set_format(std::string format) {
        if (!format.empty()) {
            m_data = imperative::apply(SetFormat(format), m_data)[0];
        }
    }
68 69 70 71 72
    inline HostValue::ref_t numpy() { return data().numpy(); }
    inline void reset(ValueRef value) {
        m_data = value;
        if (!m_name.empty()) {
            set_name(m_name);
73 74
        }
    }
75
    inline ValueRef data() const { return m_data.unwrap(); }
76 77
    bool is_scalar() { return data().is_scalar(); }
    inline std::string name() { return m_name; }
78
    inline size_t value_id() { return m_data.id(); }
79 80 81 82 83
    inline void set_name(std::string name) {
        m_name = name;
        if (!name.empty()) {
            auto output = imperative::apply(RenameValue(name), m_data)[0];
            m_data = output;
84 85
        }
    }
86 87 88
};

struct TensorWrapper {
89
public:
90
    std::optional<Tensor> m_tensor;
91

92
    inline TensorWrapper(ValueRef value) { m_tensor.emplace(value); }
93 94 95 96 97 98 99 100
    TensorWrapper(PyObject* args, PyObject* kwargs);
    ~TensorWrapper() = default;

    static constexpr auto tp_name = pybind11::detail::_("Tensor");

    using wrap_t = pyext17::wrap<TensorWrapper>;
    friend wrap_t;

M
Megvii Engine Team 已提交
101 102 103
    inline static TensorWrapper* cast(PyObject* obj) {
        return reinterpret_cast<wrap_t*>(obj)->inst();
    }
104
    inline static TensorWrapper* try_cast(PyObject* obj) {
M
Megvii Engine Team 已提交
105 106
        if (!wrap_t::type().isinstance(obj))
            return nullptr;
107
        return cast(obj);
108
    }
M
Megvii Engine Team 已提交
109 110 111
    inline ObjectPtr<TensorWrapper, pybind11::handle> self() {
        return wrap_t::pycast(this);
    }
112 113 114 115 116 117 118 119 120

    template <typename... Args>
    static ObjectPtr<Tensor> make(Args&&... args) {
        auto* op = wrap_t::cnew(std::forward<Args>(args)...);
        return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
    }

    template <typename... Args>
    static ObjectPtr<Tensor> make(PyTypeObject* pytype, Args&&... args) {
M
Megvii Engine Team 已提交
121
        auto* op = wrap_t::cnew_with_type(pytype, std::forward<Args>(args)...);
122 123 124 125 126 127
        return pybind11::reinterpret_steal<ObjectPtr<Tensor>>(op);
    }

    PyObject* shape();
    PyObject* dtype();
    PyObject* device();
128
    PyObject* format();
129 130
    PyObject* numpy();
    void reset(PyObject*);
131
    PyObject* detach();
132
    PyObject* isscalar();
133
    PyObject* value_id();
134 135
    PyObject* _dev_tensor();
    void _drop();
136
    PyObject* varnode();
137
    PyObject* recording();
138
    PyObject* copied();
139
    PyObject* module_trace_info();
M
Megvii Engine Team 已提交
140
    void set_module_trace_info(PyObject*);
141
    void _set_format(PyObject*);
142 143
    void _set_name(PyObject*);
    PyObject* _detail();
144 145
    PyObject* _var();
    PyObject* _graph();
146 147
    PyObject* _is_external_value();
    PyObject* _external_obj();
148
    void _watch();
149 150
};

M
Megvii Engine Team 已提交
151 152
PyObject* py_apply(
        PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */);
153

154 155
void init_tensor(pybind11::module);

M
Megvii Engine Team 已提交
156
extern PyObject* cpp_apply_module_trace;
157

M
Megvii Engine Team 已提交
158
}  // namespace mgb::imperative::python
159 160 161

namespace pybind11::detail {

M
Megvii Engine Team 已提交
162 163 164
template <>
struct type_caster<mgb::imperative::python::TensorWrapper>
        : mgb::imperative::python::TensorWrapper::wrap_t::caster {};
165

M
Megvii Engine Team 已提交
166
}  // namespace pybind11::detail