/** * \file imperative/python/src/tensor.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include #include "megbrain/imperative/interpreter.h" #include "pybind11/pybind11.h" #include "./pyext17.h" namespace mgb::imperative::python { template struct ObjectPtr : B { using B::B; T& operator*() {return reinterpret_cast(*B::ptr());} T* operator->() {return reinterpret_cast(B::ptr());} }; } // namespace mgb::imperative::python #include "./grad_info.h" // for struct GradInfo #include "./trace_info.h" // for struct TraceInfo namespace mgb::imperative::python { extern std::unique_ptr interpreter_for_py; class SharedHandle { using Handle = interpreter::Interpreter::Handle; static_assert(std::is_pointer_v); std::shared_ptr> holder; public: inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){ if (h) { interpreter_for_py->del(h); } }) {} SharedHandle(const SharedHandle&) = default; SharedHandle& operator=(const SharedHandle&) = default; SharedHandle(SharedHandle&&) = default; SharedHandle& operator=(SharedHandle&&) = default; inline Handle get() {return holder.get();} }; struct Tensor : std::enable_shared_from_this, NonCopyableObj { using flags_t = uint64_t; struct Flags { static constexpr flags_t SCALAR = 1; static constexpr flags_t GRAD = 1 << 1; static constexpr flags_t TRACE = 1 << 2; }; flags_t m_flags = 0; GradInfo m_grad_info; TraceInfo m_trace_info; SharedHandle m_handle; cg::VarNode* m_var; using Handle = interpreter::Interpreter::Handle; inline Tensor() : m_handle(nullptr), m_var(nullptr) {} inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {} inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)), m_var(nullptr) {} inline explicit Tensor(cg::VarNode *var) : m_handle(nullptr), m_var(var) {} ~Tensor() = default; inline std::shared_ptr copy() { auto ret = std::make_shared(m_handle); ret->m_flags = m_flags; ret->m_grad_info = m_grad_info; ret->m_trace_info = m_trace_info; ret->m_var = m_var; return ret; } inline DType dtype() { if (m_var) { return m_var->dtype(); } return interpreter_for_py->get_dtype(m_handle.get()); } inline CompNode comp_node() { if (m_var) { return m_var->comp_node(); } return interpreter_for_py->get_device(m_handle.get()); } inline TensorShape shape() { if (m_var) { return m_var->shape(); } return interpreter_for_py->get_shape(m_handle.get()); } }; struct TensorWrapper { std::shared_ptr m_tensor; inline TensorWrapper(std::shared_ptr tensor = {}) : m_tensor(std::move(tensor)) {} TensorWrapper(PyObject* args, PyObject* kwargs); ~TensorWrapper() = default; static constexpr auto tp_name = pybind11::detail::_("Tensor"); using wrap_t = pyext17::wrap; friend wrap_t; inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast(op)->inst();} inline static TensorWrapper* try_cast(PyObject* op) { if (!wrap_t::type().isinstance(op)) return nullptr; return cast(op); } inline ObjectPtr self() {return wrap_t::pycast(this);} template static ObjectPtr make(Args&&... args) { auto* op = wrap_t::cnew(std::forward(args)...); return pybind11::reinterpret_steal>(op); } template static ObjectPtr make(PyTypeObject* pytype, Args&&... args) { auto* op = wrap_t::cnew_with_type(pytype,std::forward(args)...); return pybind11::reinterpret_steal>(op); } PyObject* shape(); PyObject* dtype(); PyObject* device(); PyObject* numpy(); void reset(PyObject*); PyObject* detach(); PyObject* isscalar(); void setscalar(); PyObject* _dev_tensor(); void _swap_in(); void _swap_out(); void _drop(); PyObject* varnode(); void reset_varnode(); PyObject* handle(); void set_handle(PyObject *); PyObject* data_read(); PyObject* value_read(); PyObject* shape_read(); PyObject* mixin_handle(); void set_data_read(PyObject*); void set_value_read(PyObject*); void set_shape_read(PyObject*); void set_mixin_handle(PyObject*); }; PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); struct ApplyContext { static Tensor::flags_t global_disable; Tensor::flags_t flags; std::shared_ptr op; Tensor*const* args; size_t nargs; PyTypeObject* pytype = nullptr; bool backward = false; class scoped_disable : NonCopyableObj { Tensor::flags_t saved_flags; public: scoped_disable(Tensor::flags_t flags) : saved_flags(ApplyContext::global_disable) { ApplyContext::global_disable |= flags; } ~scoped_disable() { ApplyContext::global_disable = saved_flags; } }; }; using apply_result_t = SmallVector, 8>; apply_result_t apply(ApplyContext& ctx); template decltype(auto) resolve_arrow(T&& p) { if constexpr (std::is_pointer_v>) { auto* ret = p; return ret; } else { auto probe = [](auto&& p) -> decltype(p.operator->()) {}; if constexpr (std::is_invocable_v) { return resolve_arrow(p.operator->()); } else { return p; } } } template constexpr bool is_all_tensor_ptr = (... && std::is_same_v())), Tensor*>); extern bool is_tracing; // FIXME: should use ApplyContext::global_enable extern bool is_symbolic; extern bool is_compiled; template , int> = 0> apply_result_t apply(std::shared_ptr op, Args&&... args) { ApplyContext ctx; Tensor* arg_arr[] = {resolve_arrow(args)...}; ctx.flags = (0 | ... | args->m_flags); ctx.flags |= is_tracing ? Tensor::Flags::TRACE : 0; ctx.args = arg_arr; ctx.nargs = sizeof...(args); ctx.op = std::move(op); return apply(ctx); } template auto apply(std::shared_ptr op, T&& tensors) -> std::enable_if_t, apply_result_t> { ApplyContext ctx; ctx.op = std::move(op); ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; ctx.nargs = tensors.size(); Tensor* args[ctx.nargs]; ctx.args = args; for (size_t i = 0; i < ctx.nargs; ++i) { args[i] = resolve_arrow(tensors[i]); ctx.flags |= args[i]->m_flags; } return apply(ctx); } void init_tensor(pybind11::module); extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode; extern pybind11::object cpp_apply_backward_varnode; } // namespace mgb::imperative::python namespace pybind11::detail { template<> struct type_caster : mgb::imperative::python::TensorWrapper::wrap_t::caster {}; } // namespace pybind11::detail