diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index 17e15dba3bcea37e61f1ecf960a3bfcf7b85a7ec..dd598795b4956bd1f91032de8a807840b5a00799 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -12,6 +12,7 @@ #include "./grad.h" #include "megbrain/imperative/proxy_graph_detail.h" #include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/ops/utility.h" #include "megbrain/utils/mempool.h" #include "range/v3/all.hpp" @@ -21,6 +22,9 @@ namespace views = ranges::views; namespace mgb::imperative::python { +using scoped_disable = ApplyContext::scoped_disable; +using Flags = Tensor::Flags; + namespace { struct GradSlotWeakPtr { @@ -78,6 +82,21 @@ std::shared_ptr make_backward_graph( return result; } +struct BackwardContext { + PyTypeObject* pytype = nullptr; + + auto wrap_tensor(std::shared_ptr t) { + if (pytype) { + return TensorWrapper::make(pytype, std::move(t)); + } + return TensorWrapper::make(std::move(t)); + } + + auto wrap_tensor(Tensor* t) { + return wrap_tensor(t->shared_from_this()); + } +}; + struct BackwardGraphWithClosure { std::shared_ptr backward_graph; SmallVector> closure; @@ -119,7 +138,7 @@ struct BackwardGraphWithClosure { } template - void operator()(T&& grads, R&& receiver) { + void operator()(BackwardContext&, T&& grads, R&& receiver) { Tensor* args[closure.size() + grads.size()]; size_t nargs = 0; for (auto&& t : closure) { @@ -143,7 +162,7 @@ struct BackwardGraphWithClosure { ApplyContext ctx; ctx.op = backward_graph->backward; - ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; + ctx.flags = is_tracing ? Flags::TRACE : 0; ctx.nargs = nargs; ctx.args = args; for (size_t i = 0; i < nargs; ++i) { @@ -174,6 +193,47 @@ struct BackwardGraphWithClosure { } }; +struct PythonBackward { + py::object pyfunc; + size_t input_size; + + PythonBackward(py::object f, size_t nin) + : pyfunc(f), input_size(nin) {} + + template + void operator()(BackwardContext& ctx, T&& grads, R&& receiver) { + auto args = py::tuple(grads.size()); + for (size_t i = 0; i < grads.size(); ++i) { + auto&& g = grads[i]; + args[i] = g ? ctx.wrap_tensor(g) : py::none(); + } + auto input_grads = py::reinterpret_steal(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr)); + if (input_grads.is_none()) return; + if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) { + if (input_size != 1) { + throw py::value_error("custom grad rule returned wrong number of grads"); + } + receiver(0, tw->m_tensor); + return; + } + if (py::len(input_grads) != input_size) { + throw py::value_error("custom grad rule returned wrong number of grads"); + } + for (auto [i, g] : views::enumerate(input_grads)) { + if (g.is_none()) continue; + auto* tw = TensorWrapper::try_cast(g.ptr()); + if (!tw) { + throw py::type_error("custom grad rule returned non-tensor"); + } + receiver(i, tw->m_tensor); + } + } + + static constexpr bool input_has_grad(size_t) {return true;} + static constexpr bool output_requires_grad(size_t) {return true;} + static constexpr bool output_captured(size_t) {return true;} +}; + } // namespace struct GradProducerRecord : intrusive_list::Node { @@ -210,7 +270,7 @@ struct GradFn : std::enable_shared_from_this { // same length as inputs (of forward op) SmallVector dsts; // encapsules actual function to compute gradient - std::variant backward; + std::variant backward; // a flag used during backward bool in_ref_keeper = false; @@ -268,6 +328,30 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra return outputs; } +apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { + auto* op = ctx.op->try_cast_final(); + py::tuple pyin(ctx.nargs); + for (size_t i = 0; i < ctx.nargs; ++i) { + pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); + } + auto grad_rule = py::getattr(op->obj, "_grad_rule"); + auto pyret = (scoped_disable(Flags::GRAD), + py::reinterpret_steal(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr))); // comma expression + auto [outputs, backward] = py::cast>(pyret); + ret_grad_fn.emplace(std::move(backward), ctx.nargs); + if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) { + return {tw->m_tensor}; + } + apply_result_t ret; + ret.reserve(py::len(outputs)); + for (auto&& i : outputs) { + auto* tw = TensorWrapper::try_cast(i.ptr()); + mgb_assert(tw); + ret.push_back(tw->m_tensor); + } + return ret; +} + } // namespace apply_result_t apply_grad(ApplyContext& ctx) { @@ -290,21 +374,23 @@ apply_result_t apply_grad(ApplyContext& ctx) { // cleanup stale grad info // under what condition? tensor->m_grad_info = {}; - tensor->m_flags &= ~Tensor::Flags::GRAD; + tensor->m_flags &= ~Flags::GRAD; } } else { - tensor->m_flags &= ~Tensor::Flags::GRAD; + tensor->m_flags &= ~Flags::GRAD; } } - ctx.flags &= ~Tensor::Flags::GRAD; + ctx.flags &= ~Flags::GRAD; if (!grad_key) { return apply(ctx); } GradFnHelper grad_fn_holder; - auto outputs = backward_graph_grad_rule(ctx, grad_fn_holder); + auto outputs = ctx.op->same_type() ? + python_grad_rule(ctx, grad_fn_holder) : + backward_graph_grad_rule(ctx, grad_fn_holder); auto& grad_fn = grad_fn_holder.grad_fn; if (!grad_fn) { @@ -341,7 +427,7 @@ apply_result_t apply_grad(ApplyContext& ctx) { grad_info.grad_fn = grad_fn; grad_info.idx = i; grad_info.insert_after(grad_key->free_vars_head); - outputs[i]->m_flags |= Tensor::Flags::GRAD; + outputs[i]->m_flags |= Flags::GRAD; } } } @@ -357,7 +443,7 @@ void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { if (nargs != 2) { throw py::type_error("expect 2 arguments"); } - auto* tw = TensorWrapper::cast_safe(args[0]); + auto* tw = TensorWrapper::try_cast(args[0]); if (!tw) { throw py::type_error("argument 1 must be Tensor"); } @@ -390,14 +476,15 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) { grad_fn->key = shared_from_this(); grad_fn->slots.resize(1); tensor->m_grad_info.insert_after(free_vars_head); - tensor->m_flags |= Tensor::Flags::GRAD; + tensor->m_flags |= Flags::GRAD; } tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback); } -void accum_grad(std::shared_ptr& grad, std::shared_ptr&& delta) { +template +void accum_grad(std::shared_ptr& grad, T&& delta) { if (!grad) { - grad = std::forward(delta); + grad = std::forward(delta); return; } static ApplyContext ctx; @@ -409,7 +496,7 @@ void accum_grad(std::shared_ptr& grad, std::shared_ptr&& delta) ctx.args = args; ctx.flags = grad->m_flags | delta->m_flags; if (is_tracing) { - ctx.flags |= Tensor::Flags::TRACE; + ctx.flags |= Flags::TRACE; } grad = apply(ctx)[0]; } @@ -440,6 +527,7 @@ void GradKey::backward(std::vector tensors, std::vector> ref_keeper; ref_keeper.reserve(tape.size()); // back-propagation in reverse order @@ -456,7 +544,7 @@ void GradKey::backward(std::vector tensors, std::vectorslots, [](auto&& slot) {return slot.grad.get();}); - backward(std::forward(grads), grad_receiver); + backward(bctx, std::forward(grads), grad_receiver); } }, grad_fn->backward); diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index ba76bd2e4006bb83fd911bc3c8204a384d7e0c31..b5083e9c906f065774ff32850a1edb7ac2d82d92 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -14,6 +14,7 @@ #include "megbrain/imperative.h" #include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/opr_attr.h" +#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/autogen.h" #include @@ -245,6 +246,35 @@ void _init_py_backward_graph(py::module m) { mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second); } +struct PyOpBase : PyOpDef { + static PyTypeObject py_type; + + static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) { + auto* obj = type->tp_alloc(type, 0); + if (obj) { + auto* self = reinterpret_cast(obj); + new(&self->op) decltype(self->op); + } + return obj; + } +}; +PyTypeObject PyOpBase::py_type; + +void _init_py_op_base(py::module m) { + using py_op = PyOpBase; + auto& py_type = PyOpBase::py_type; + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase"; + py_type.tp_basicsize = sizeof(py_op); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "PyOpBase"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic; + py_type.tp_new = py_op::tp_new; + mgb_assert(PyType_Ready(&py_type) >= 0); + m.add_object("PyOpBase", reinterpret_cast(&py_type)); +} + /*********** end of hand-write opdefs **************/ // auto generated opdefs @@ -260,9 +290,16 @@ bool type_caster::load(handle src, bool convert) { return false; } value = reinterpret_cast(obj)->op; + if (!value) { + // opdef only defined in Python + value = std::make_shared(reinterpret_borrow(src)); + } return true; } handle type_caster::cast(const OpDef& op, return_value_policy, handle) { + if (auto* pyop = op.try_cast_final()) { + return object(pyop->obj).release(); + } PyTypeObject* pytype; auto& c2p = PyOp(OpDef)::ctype2pytype; auto&& iter = c2p.find(op.dyn_typeinfo()); @@ -283,5 +320,6 @@ handle type_caster::cast(const OpDef& op, return_value_policy, handle) { void init_ops(py::module m) { _init_py_op_def(m); _init_py_backward_graph(m); + _init_py_op_base(m); INIT_ALL_OP(m) } diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 3806a31f09d7afc0da9de694e99215446980d4ba..6f68d48397bc40b15b97bb7b49eabfd35c8c53e2 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -11,6 +11,7 @@ #include "megbrain/dtype.h" #include "megbrain/common.h" +#include "megbrain/imperative/ops/utility.h" #include "./tensor.h" #include "./grad.h" @@ -22,10 +23,12 @@ #include #include +#include #include namespace py = pybind11; +namespace views = ranges::views; namespace mgb::imperative::python { @@ -69,21 +72,45 @@ SET_UNSET_PROP(compiled) bool skip_tracing = false; +Tensor::flags_t ApplyContext::global_disable = 0; + apply_result_t apply(ApplyContext& ctx) { // emulating scalar should be put to specific op's apply, e.g., // elementwise, reduce, typecvt. Currently it's still handled at python // side. It could be move to C++ side if it has an impact on performance - if (ctx.flags & Tensor::Flags::SCALAR) { + auto flags = ctx.flags & ~ApplyContext::global_disable; + + if (flags & Tensor::Flags::SCALAR) { // TODO: emulate scalar } - if (ctx.flags & Tensor::Flags::GRAD) { + if (flags & Tensor::Flags::GRAD) { return apply_grad(ctx); } - if (ctx.flags & Tensor::Flags::TRACE) { + if (flags & Tensor::Flags::TRACE) { return apply_trace(ctx); } else { + if (auto* op = ctx.op->try_cast_final()) { + py::tuple pyin(ctx.nargs); + for (size_t i = 0; i < ctx.nargs; ++i) { + pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); + } + auto f = py::getattr(op->obj, "_default_rule"); + auto pyout = py::reinterpret_steal(PyObject_Call(f.ptr(), pyin.ptr(), nullptr)); + if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) { + return {tw->m_tensor}; + } + apply_result_t ret; + ret.reserve(py::len(pyout)); + for (auto&& i : pyout) { + auto* tw = TensorWrapper::try_cast(i.ptr()); + mgb_assert(tw); + ret.push_back(tw->m_tensor); + } + return ret; + } + SmallVector handles(ctx.nargs); for (size_t i = 0; i < ctx.nargs; ++i) { handles[i] = ctx.args[i]->m_handle.get(); @@ -125,12 +152,13 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje SmallVector tensors(nargs); ctx.args = &tensors[0]; ctx.nargs = nargs; + ctx.pytype = pytype; if (strstr(op->ob_type->tp_name, "BackwardGraph")) { ctx.backward = true; } for (size_t i = 0; i < nargs; ++i) { - if (TensorWrapper* tw = TensorWrapper::cast_safe(args[i])) { + if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { auto* t = tensors[i] = tw->m_tensor.get(); ctx.flags |= t->m_flags; } else { @@ -166,7 +194,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { if (nargs == 0) { throw py::type_error("too few arguments"); } - if (auto* t = cast_safe(tup[0].ptr())) { + if (auto* t = try_cast(tup[0].ptr())) { if (nargs > 1) { throw py::type_error("expect 1 argument"); } @@ -211,7 +239,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { auto ret = pyf(*tup); auto py_ret = py::reinterpret_borrow(ret); - if (auto* t = cast_safe(py_ret[0].ptr())) { + if (auto* t = try_cast(py_ret[0].ptr())) { m_tensor = t->m_tensor; } return; @@ -349,7 +377,7 @@ PyObject* TensorWrapper::varnode() { } void TensorWrapper::reset(PyObject* tensor) { - TensorWrapper* t = TensorWrapper::cast_safe(tensor); + TensorWrapper* t = TensorWrapper::try_cast(tensor); if (!t) { throw py::type_error("expect Tensor"); } @@ -446,7 +474,7 @@ uint8_t max_priority(SmallVector types) { } } -// Returns the data type with sufficient size to hold all types of +// Returns the data type with sufficient size to hold all types of // category `cat` in the list `types`. PyArray_Descr* promote_types(SmallVector types, uint8_t cat) { // Return value: New reference @@ -507,7 +535,7 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { for (size_t i = 0; i < nargs; ++i) { PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; if (handle == Py_None) continue; - TensorWrapper* tw = TensorWrapper::cast_safe(handle); + TensorWrapper* tw = TensorWrapper::try_cast(handle); if (tw) { mgb::DType type = tw->m_tensor->dtype(); auto&& descr = npy::dtype_mgb2np_descr(type); @@ -562,7 +590,7 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { CompNode cn; for (size_t i = 0; i < nargs; ++i) { PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; - TensorWrapper* tw = TensorWrapper::cast_safe(handle); + TensorWrapper* tw = TensorWrapper::try_cast(handle); if (tw) { if (!valid) { cn = tw->m_tensor->comp_node(); diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 8dd9256d4e6615a2f7808d3d164971bb8999e719..addb1891b1ea165fc03c7eec8e866214fff06156 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -124,7 +124,7 @@ struct TensorWrapper { friend wrap_t; inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast(op)->inst();} - inline static TensorWrapper* cast_safe(PyObject* op) { + inline static TensorWrapper* try_cast(PyObject* op) { if (!wrap_t::type().isinstance(op)) return nullptr; return cast(op); } @@ -173,11 +173,26 @@ struct TensorWrapper { PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); struct ApplyContext { + static Tensor::flags_t global_disable; + Tensor::flags_t flags; std::shared_ptr op; Tensor*const* args; size_t nargs; + PyTypeObject* pytype = nullptr; bool backward = false; + + class scoped_disable : NonCopyableObj { + Tensor::flags_t saved_flags; + + public: + scoped_disable(Tensor::flags_t flags) : saved_flags(ApplyContext::global_disable) { + ApplyContext::global_disable |= flags; + } + ~scoped_disable() { + ApplyContext::global_disable = saved_flags; + } + }; }; using apply_result_t = SmallVector, 8>; diff --git a/imperative/python/src/trace.cpp b/imperative/python/src/trace.cpp index a698a9698c0f8524ed2c05fe03148bb65bbd0f51..fb01d6b36fe048d589b0e600400a79db87d62ebf 100644 --- a/imperative/python/src/trace.cpp +++ b/imperative/python/src/trace.cpp @@ -85,7 +85,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { // assumption: python function always returns PyList auto tup = py::reinterpret_borrow(ret); for (auto i = 0; i < tup.size(); i++) { - auto tw = TensorWrapper::cast_safe(tup[i].ptr()); + auto tw = TensorWrapper::try_cast(tup[i].ptr()); outputs.emplace_back(tw->m_tensor); } return outputs; diff --git a/imperative/src/impl/ops/utility.cpp b/imperative/src/impl/ops/utility.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5ec2d4ef58534c0afcbc87398f6de0a9ada2d5ba --- /dev/null +++ b/imperative/src/impl/ops/utility.cpp @@ -0,0 +1,21 @@ +/** + * \file imperative/src/impl/ops/utility.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/ops/utility.h" +#include "megbrain/imperative/ops/opr_attr.h" +#include "megbrain/opr/utility.h" +#include "../op_trait.h" + +namespace mgb::imperative { + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/ops/utility.h b/imperative/src/include/megbrain/imperative/ops/utility.h new file mode 100644 index 0000000000000000000000000000000000000000..b688e986b2384764b430bf5087ca18ef3c3c417c --- /dev/null +++ b/imperative/src/include/megbrain/imperative/ops/utility.h @@ -0,0 +1,38 @@ +/** + * \file imperative/src/include/megbrain/imperative/ops/utility.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/imperative/op_def.h" + +#include "megbrain/utils/hash.h" + +#include + +namespace mgb::imperative { + +struct GenericPyOp final : OpDefImplBase { + pybind11::object obj; + + GenericPyOp(pybind11::object obj_) : obj(std::move(obj_)) {}; + + size_t hash() const override { + return pybind11::hash(obj); + } + + bool is_same_st(const Hashable& rhs) const override { + return obj.equal(static_cast(rhs).obj); + } + + MGB_DYN_TYPE_OBJ_FINAL_DECL; +}; + +} // namespace mgb::imperative