diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 818dd3783a9a381b1f829415751205d112b91b91..bbb1f4b26d1cda1435b6ffd0df62ce7ac9e4e21a 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -16,6 +16,7 @@ import numpy as np from .. import _config from .._imperative_rt.common import CompNode from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion +from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar from ..ops import builtin from . import amp from .indexing import getitem, setitem @@ -508,12 +509,8 @@ def _reduce(mode): elif self.dtype == np.bool_: data = data.astype("int32") if axis is None: - data = data.reshape(-1) assert not keepdims, "can not set axis=None and keepdims=True" - - op = builtin.Reduce(mode=mode, axis=0) - (result,) = apply(op, data) - result = _remove_axis(result, 0) + result = _reduce_to_scalar(builtin.Reduce(mode=mode), data) elif isinstance(axis, collections.abc.Iterable): axis = _normalize_axis(self.ndim, axis, reverse=True) for ai in axis: diff --git a/imperative/python/megengine/optimizer/sgd.py b/imperative/python/megengine/optimizer/sgd.py index 95a3dd1c975f5617e4a4fa1b5a6f5b050db7aa56..07fa14c1eb5364ea0c0f50a93fb00658f286cf21 100644 --- a/imperative/python/megengine/optimizer/sgd.py +++ b/imperative/python/megengine/optimizer/sgd.py @@ -69,7 +69,7 @@ class SGD(Optimizer): inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")) if inplace_mode: _neg_lr = tensor(-lr, dtype="float32") - c1 = tensor([1.0]) + c1 = tensor(1.0) for param in param_group["params"]: if param.grad is None: diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index e06aa9d3659ca1e0d5a606961d46f15d2fc1fec6..9c060ac8629e02f8a32de8c03332ad2f9bbf4320 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -84,14 +84,15 @@ class Tensor(_Tensor, ArrayMethodMixin): device: str = None, is_const: bool = False, no_cache: bool = False, - name: str = "", + name: str = None, ): if name is None: name = "" + else: + self._set_name(name) self._custom_name = name self._name = name self._short_name = name - self._set_name(self._name) self._prefix = None @property diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index 75f1ba8dc14b041eb7ff722d38882cc585f6b4ff..54deb79b40c8c702d3e7d4697f22ec2fa9bb4204 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -46,17 +46,17 @@ void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { if (args[1] != Py_None) { callback = py::reinterpret_borrow(args[1]); } - GenericFunction generic_callback = - [=](Span inputs) -> std::vector { + GenericFunction generic_callback = [=](Span inputs) -> ValueRefList { mgb_assert(inputs.size() == 1); if (callback) { callback(TensorWrapper::make(py_tensor_type, inputs[0])); } return {}; }; - tw->m_tensor->reset(imperative::apply( + auto attached_value = imperative::apply( AttachGrad(m_key), tw->m_tensor->data(), - FunctionValue::make(generic_callback))[0]); + FunctionValue::make(generic_callback))[0]; + tw->m_tensor->reset(attached_value); } void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) { diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index 5269113db30fc5e8f9c153e40c5306d07b0106b1..ce41a96f26a47cfde46b547c06267b24687ee306 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -98,7 +98,7 @@ ValueRef make_empty_tensor( return res; } -std::optional> elemwise_grad_rule( +std::optional elemwise_grad_rule( const OpDef& op, Span inputs, Span inputs_require_grad, CustomBackward& backward) { auto& elemwise = op.cast_final_safe(); @@ -117,7 +117,7 @@ std::optional> elemwise_grad_rule( maker.backward([shapes = std::move(input_shapes)](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - std::vector ret(2); + ValueRefList ret(2); if (!grad) { return ret; } @@ -132,7 +132,7 @@ std::optional> elemwise_grad_rule( return imperative::apply(ApplyOp(op), inputs); } -std::optional> reshape_grad_rule( +std::optional reshape_grad_rule( const OpDef& op, Span inputs, Span inputs_require_grad, CustomBackward& backward) { mgb_assert(inputs.size() == 2); @@ -147,7 +147,7 @@ std::optional> reshape_grad_rule( maker.backward([shapes = std::move(input_shapes)](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - std::vector ret(2); + ValueRefList ret(2); if (!grad) { return ret; } @@ -162,7 +162,7 @@ std::optional> reshape_grad_rule( return imperative::apply(ApplyOp(op), inputs); } -std::optional> subtensor_grad_rule( +std::optional subtensor_grad_rule( const OpDef& op, Span inputs, Span inputs_require_grad, CustomBackward& backward) { auto&& subtensor = op.cast_final_safe(); @@ -180,9 +180,9 @@ std::optional> subtensor_grad_rule( grad_op_ = std::move(grad_op)](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - std::vector ret(1); + ValueRefList ret(1); if (grad && inputs[0]) { - SmallVector args_(inputs.size() + 1); + ValueRefList args_(inputs.size() + 1); auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); args_[0] = zeros; args_[1] = grad; @@ -197,7 +197,7 @@ std::optional> subtensor_grad_rule( return imperative::apply(ApplyOp(op), inputs); } -std::optional> indexingMultiAxisVec_grad_rule( +std::optional indexingMultiAxisVec_grad_rule( const OpDef& op, Span inputs, Span inputs_require_grad, CustomBackward& backward) { auto&& indexingMultiAxisVec = op.cast_final_safe(); @@ -215,9 +215,9 @@ std::optional> indexingMultiAxisVec_grad_rule( grad_op_ = std::move(grad_op)](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - std::vector ret(1); + ValueRefList ret(1); if (grad && inputs[0]) { - SmallVector args_(inputs.size() + 1); + ValueRefList args_(inputs.size() + 1); auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); args_[0] = zeros; args_[1] = grad; @@ -232,7 +232,7 @@ std::optional> indexingMultiAxisVec_grad_rule( return imperative::apply(ApplyOp(op), inputs); } -std::optional> reduce_grad_rule( +std::optional reduce_grad_rule( const OpDef& op, Span inputs, Span inputs_require_grad, CustomBackward& backward) { auto& reduce = op.cast_final_safe(); @@ -251,7 +251,7 @@ std::optional> reduce_grad_rule( maker.backward([shapes = std::move(input_shapes)](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - std::vector ret(1); + ValueRefList ret(1); if (grad && shapes[0]) { ret[0] = broadcast_to(grad, shapes[0]); } @@ -261,7 +261,7 @@ std::optional> reduce_grad_rule( return imperative::apply(ApplyOp(op), inputs); } -std::optional> addAxis_grad_rule( +std::optional addAxis_grad_rule( const OpDef& op, Span inputs, Span inputs_require_grad, CustomBackward& backward) { auto&& addAxis = op.cast_final_safe(); @@ -274,7 +274,7 @@ std::optional> addAxis_grad_rule( maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - std::vector ret(1); + ValueRefList ret(1); if (grad && flag_) { ret[0] = imperative::apply(*grad_op_, grad)[0]; } @@ -284,7 +284,7 @@ std::optional> addAxis_grad_rule( return imperative::apply(op, inputs); } -std::optional> removeAxis_grad_rule( +std::optional removeAxis_grad_rule( const OpDef& op, Span inputs, Span inputs_require_grad, CustomBackward& backward) { auto&& removeAxis = op.cast_final_safe(); @@ -297,7 +297,7 @@ std::optional> removeAxis_grad_rule( maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - std::vector ret(1); + ValueRefList ret(1); if (grad && flag_) { ret[0] = imperative::apply(*grad_op_, grad)[0]; } @@ -307,7 +307,7 @@ std::optional> removeAxis_grad_rule( return imperative::apply(op, inputs); } -std::optional> fastpathcopy_grad_rule( +std::optional fastpathcopy_grad_rule( const OpDef& op, Span inputs, Span inputs_require_grad, CustomBackward& backward) { mgb_assert(inputs.size() == 1); @@ -316,7 +316,7 @@ std::optional> fastpathcopy_grad_rule( maker.backward([](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - std::vector ret(1); + ValueRefList ret(1); if (grad) { ret[0] = grad; } diff --git a/imperative/python/src/module_trace.h b/imperative/python/src/module_trace.h index eee35d6ce132c15ce948ea96a50cc08d33147cd4..835219fec5bf9f193e68825c412b7c5a9fb23183 100644 --- a/imperative/python/src/module_trace.h +++ b/imperative/python/src/module_trace.h @@ -25,24 +25,23 @@ private: py::function m_hook_fn; int m_enabled = 0; - std::vector apply_module_trace_hook( - const OpDef& op, Span input_values) { + ValueRefList apply_module_trace_hook(const OpDef& op, Span input_values) { py::list input_tws; for (auto&& input_value : input_values) { input_tws.append(TensorWrapper::make(py_tensor_type, input_value)); } py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws); - std::vector outputs; + ValueRefList outputs(output_tws.size()); + auto it = outputs.begin(); for (auto&& output_tw : output_tws) { - outputs.push_back( - TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data()); + *(it++) = TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data(); } return outputs; } public: ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {} - std::vector apply_transformation( + ValueRefList apply_transformation( const Operator& op, Span inputs) override { if (op.is() && m_enabled > 0) { auto outputs = apply_module_trace_hook(op.cast().op(), inputs); diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 8d160e9aaa58e21dfa37259f46f6f99f2e9b91e1..36f69ba4ccc2d338d40fdffaa39f7c3da5279687 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -87,7 +87,7 @@ PyObject* py_apply( --nargs; auto op = py::handle(py_op).cast>(); - SmallVector tensors(nargs); + SmallVector tensors(nargs); if (py::isinstance(py::handle(args[0]))) { // swap to a special context to reuse scalar handle @@ -100,16 +100,15 @@ PyObject* py_apply( Transformation::top()); std::make_shared()->register_at( Transformation::top()); - SmallVector inputs(nargs); for (size_t i = 0; i < nargs; ++i) { auto* py_input = py::handle(args[i]).cast(); ValueRef input = SymbolValue::make(py_input->m_node); if (py_input->is_scalar) { input = ScalarValue::make(input); } - inputs[i] = input; + tensors[i] = input; } - auto outputs = imperative::apply(*op, inputs); + auto outputs = imperative::apply(*op, tensors); auto ret = pybind11::tuple(outputs.size()); auto typeobj = py::handle(args[0]).get_type(); for (size_t i = 0; i < outputs.size(); ++i) { @@ -140,7 +139,7 @@ PyObject* py_apply( } } - auto outputs = imperative::apply(ApplyOp(*op), {tensors.data(), nargs}); + auto outputs = imperative::apply(*op, tensors); size_t nout = outputs.size(); auto ret = py::tuple(nout); for (size_t i = 0; i < nout; ++i) { @@ -214,16 +213,10 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { if (!name.empty()) { m_tensor->reset( imperative::apply(RenameValue(name), m_tensor->data())[0]); - mgb_assert( - ((std::string&)*m_tensor->data().name()) == name, - "result name incorrect"); - } - - if (data.ndim() == 0) { - mgb_assert(m_tensor->is_scalar(), "result should be scalar"); } } } + mgb_assert(m_tensor->data()); } PyObject* TensorWrapper::module_trace_info() { @@ -1384,15 +1377,20 @@ void init_tensor(py::module m) { std::function array_comparator; bool compare_value(ValueRef lhs, ValueRef rhs) { - if (!lhs.shape()->eq(*rhs.shape())) { + auto lvalue = lhs.numpy(); + auto rvalue = rhs.numpy(); + if (lvalue->shape() != rvalue->shape()) { return false; } - HostTensorND lvalue = lhs.numpy()->as_nd(true); - HostTensorND rvalue = rhs.numpy()->as_nd(true); + if (lvalue->shape().is_scalar()) { + return lvalue->item() == rvalue->item(); + } + HostTensorND lnd = lvalue->as_nd(true); + HostTensorND rnd = rvalue->as_nd(true); auto larr = py::reinterpret_steal( - npy::ndarray_from_tensor(lvalue, npy::ShareType::TRY_SHARE)); + npy::ndarray_from_tensor(lnd, npy::ShareType::TRY_SHARE)); auto rarr = py::reinterpret_steal( - npy::ndarray_from_tensor(rvalue, npy::ShareType::TRY_SHARE)); + npy::ndarray_from_tensor(rnd, npy::ShareType::TRY_SHARE)); return array_comparator(larr, rarr); } @@ -1539,6 +1537,19 @@ void init_tensor(py::module m) { } }); + m.def("reduce_to_scalar", [](py::object op, py::object tensor) { + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + auto make_scalar_shape = [&](CompNode device) { + return imperative::apply( + CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}), + HostStorage::make(device))[0]; + }; + auto output = imperative::apply( + *op.cast>(), tw->m_tensor->data(), + make_scalar_shape(tw->m_tensor->comp_node()))[0]; + return TensorWrapper::make(py_tensor_type, output); + }); + m.def("name_tensor", [](std::string name, py::object tensor) { auto* tw = TensorWrapper::try_cast(tensor.ptr()); auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0]; @@ -1546,9 +1557,9 @@ void init_tensor(py::module m) { }); m.def("is_grad_attached", [](std::vector tensors) -> bool { - SmallVector values; - for (auto&& tensor : tensors) { - values.push_back(tensor.cast().m_tensor->data()); + ValueRefList values(tensors.size()); + for (size_t i = 0; i < tensors.size(); ++i) { + values[i] = tensors[i].cast().m_tensor->data(); } auto outputs = imperative::apply(GetGradKey(), values); if (outputs[0].is()) { @@ -1559,9 +1570,9 @@ void init_tensor(py::module m) { }); m.def("get_grad_key", [](std::vector tensors) -> py::object { - SmallVector values; - for (auto&& tensor : tensors) { - values.push_back(tensor.cast().m_tensor->data()); + ValueRefList values(tensors.size()); + for (size_t i = 0; i < tensors.size(); ++i) { + values[i] = tensors[i].cast().m_tensor->data(); } auto outputs = imperative::apply(GetGradKey(), values); if (auto* grad_key_val = outputs[0].as()) { @@ -1578,7 +1589,7 @@ void init_tensor(py::module m) { mgb_assert(GradKeyWrapper::wrap_t::type().isinstance(py_key.ptr())); auto* key = reinterpret_cast(py_key.ptr())->inst(); GenericFunction generic_backward_fn = - [backward_fn](Span output_grads) -> std::vector { + [backward_fn](Span output_grads) -> ValueRefList { py::list output_grad_tws; for (auto&& output_grad : output_grads) { if (output_grad) { @@ -1589,23 +1600,25 @@ void init_tensor(py::module m) { } } py::tuple input_grad_tws = backward_fn(*output_grad_tws); - std::vector input_grads; - for (auto&& input_grad_tw : input_grad_tws) { + ValueRefList input_grads(input_grad_tws.size()); + for (size_t i = 0; i < input_grad_tws.size(); ++i) { + auto input_grad_tw = input_grad_tws[i]; if (!input_grad_tw.is_none()) { - input_grads.push_back( - py::cast(input_grad_tw).m_tensor->data()); + input_grads[i] = + py::cast(input_grad_tw).m_tensor->data(); } else { - input_grads.push_back({}); + input_grads[i] = {}; } } return input_grads; }; - SmallVector values; - for (auto&& input : inputs) { - values.push_back(input.cast().m_tensor->data()); + ValueRefList values(inputs.size() + outputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + values[i] = inputs[i].cast().m_tensor->data(); } - for (auto&& output : outputs) { - values.push_back(output.cast().m_tensor->data()); + for (size_t i = 0; i < outputs.size(); ++i) { + values[i + inputs.size()] = + outputs[i].cast().m_tensor->data(); } auto wrapped_output_values = imperative::apply( SetGrad(key->m_key, generic_backward_fn, inputs.size()), values); diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index c438aea6c7b53d532bef83bb2872f2b19a4ad19e..e2a02d1163a89ecce87b6802409d183146d6a16c 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -39,7 +39,7 @@ namespace mgb::imperative::python { extern interpreter::Interpreter::Channel* interpreter_for_py; extern PyTypeObject* py_tensor_type; -struct Tensor : std::enable_shared_from_this, NonCopyableObj { +struct Tensor : NonCopyableObj { private: std::string m_name; ValueRef m_data; @@ -52,7 +52,7 @@ public: ~Tensor() = default; inline std::shared_ptr copy() { - auto ret = std::make_shared(m_data.unwrap()); + auto ret = std::make_shared(m_data); ret->m_name = m_name; return ret; } diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index df10d2c52c2a6b99be348e365c81be13f1687df0..d07666b70ad72823809d53b245b4bffee8cb2304 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -11,7 +11,15 @@ #pragma once +#include +#include + +#include "pybind11/pybind11.h" + +#include "megbrain/imperative/dispatch.h" #include "megbrain/imperative/transformation.h" +#include "megbrain/imperative/value.h" +#include "megbrain/utils/small_vector.h" namespace mgb::imperative::python { struct TransformationManager { @@ -58,4 +66,14 @@ struct TransformationManager { return sl_instance; } }; + +class PyValue final : public MixinValueImpl { +public: + using MixinValueImpl::MixinValueImpl; + + std::string to_string() const { + return pybind11::str((const pybind11::object&)*this).cast(); + } +}; + } // namespace mgb::imperative::python diff --git a/imperative/src/impl/basic_operators.cpp b/imperative/src/impl/basic_operators.cpp index c0835b249abd5b3d633db480b1f0a28a473816f5..234da0d08cc2516c1bc3ca1278fa9ebff33d2d6c 100644 --- a/imperative/src/impl/basic_operators.cpp +++ b/imperative/src/impl/basic_operators.cpp @@ -45,7 +45,7 @@ CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout) layout.is_contiguous() || layout.is_empty(), "layout should be contiguous"); } -auto CreateTensor::parse(Span inputs) -> Args { +auto CreateTensor::parse(Span inputs) const -> Args { Args result; for (auto&& input : inputs) { if (auto host_storage = input.as_ref()) { diff --git a/imperative/src/impl/dispatch.cpp b/imperative/src/impl/dispatch.cpp index 2fc49f139f16beaa249622001773607a527d5eaf..2a60713756cedb070a86e416fa599633eabfd00c 100644 --- a/imperative/src/impl/dispatch.cpp +++ b/imperative/src/impl/dispatch.cpp @@ -16,70 +16,67 @@ #include "megbrain/imperative/utils/map.h" namespace mgb { + +void imperative_log_profile_begin(const char* message); +void imperative_log_profile(const char* message); +void imperative_log_profile_end(const char* message); + namespace imperative { -std::vector apply(const Operator& op, Span inputs) { - static bool log_dispatch = MGB_GETENV("MGE_LOG_OP_DISPATCH"); - bool enable_watch = ValueRef::any_watching(); - auto& context = Transformation::get_context(); - size_t& depth = context.next_transformation; - static const char tabs_storage[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t"; - const char* tabs = tabs_storage + sizeof(tabs_storage) / sizeof(char) - depth - 1; - bool log_current_dispatch = log_dispatch; - if (enable_watch) { - for (size_t i = 0; i < inputs.size(); ++i) { - auto& input = inputs[i]; - if (input.watching()) { - log_current_dispatch = true; - mgb_log_debug("%sinput[%zu] is %s", tabs, i, input.to_string().c_str()); - debug::notify_event("apply"); - } - } - } - // entrance - std::vector outputs; - if (depth >= context.transformations.size()) { - // fallback - if (log_current_dispatch) { - mgb_log_debug( - "%sfallback apply %s in %s", tabs, op.to_string().c_str(), - imperative::to_string(inputs).c_str()); +namespace { +MGB_NOINLINE void copy_outputs( + ForwardAllocator& allocator, ValueRefList& outputs) { + size_t nr_outputs = outputs.size(); + if (mgb_likely(nr_outputs == 1)) { + ValueRef output_copy; + output_copy = outputs[0]; + allocator.clear(); + outputs = ValueRefList({output_copy}); + } else if (!outputs.empty()) { + SmallVector outputs_copy(nr_outputs); + for (size_t i = 0; i < nr_outputs; ++i) { + outputs_copy[i] = outputs[i]; } - outputs = op.fallback(inputs); + outputs.clear(); + allocator.clear(); + outputs = {outputs_copy.begin(), outputs_copy.end()}; } else { - // dispatch to stack top - auto& transformation = *context.transformations[depth]; - ++depth; - context.frames.push_back({op, inputs}); - CleanupGuard _{[&] { - context.frames.pop_back(); - --depth; - }}; - if (log_current_dispatch) { - mgb_log_debug( - "%s%s apply %s in %s", tabs, transformation.name().c_str(), - op.to_string().c_str(), imperative::to_string(inputs).c_str()); - } - outputs = transformation.apply_transformation(op, inputs); + allocator.clear(); } - if (log_current_dispatch) { - mgb_log_debug("%sreturn %s", tabs, imperative::to_string(outputs).c_str()); +} +} // namespace + +ValueRefList apply(const Operator& op, Span inputs) { + auto& context = Transformation::get_context(); + size_t& depth = context.next_transformation; + bool top = depth == 0; + auto outputs = ([&] { + if (mgb_unlikely(depth >= context.transformations.size())) { + return op.fallback(inputs); + } else { + auto& transformation = *context.transformations[depth++]; + CleanupGuard _{[&] { --depth; }}; + return transformation.apply_transformation(op, inputs); + } + })(); + if (mgb_unlikely(top)) { + copy_outputs(context.allocator, outputs); } return outputs; } -std::vector apply(const OpDef& def, Span inputs) { +ValueRefList apply(const OpDef& def, Span inputs) { return imperative::apply(ApplyOp{def}, inputs); } -std::vector apply(Subgraph graph, Span inputs) { +ValueRefList apply(const Subgraph& graph, Span inputs) { SmallVector inputs_storage; for (size_t i = 0; i < inputs.size(); ++i) { inputs_storage.push_back(inputs[i]); } auto apply_functor = [](std::shared_ptr op, SmallVector inputs, size_t) { - auto outputs = imperative::apply(ApplyOp(*op), inputs); + auto outputs = imperative::apply(*op, inputs); return SmallVector(outputs.begin(), outputs.end()); }; auto make_const = [](TensorPtr constant) -> ValueRef { @@ -101,7 +98,7 @@ std::vector apply(Subgraph graph, Span inputs) { DeviceStorage::make(device_value.storage()))[0]; }; auto outputs = graph.apply(inputs_storage, apply_functor, make_const); - return {outputs.begin(), outputs.end()}; + return ValueRefList{outputs.begin(), outputs.end()}; } } // namespace imperative diff --git a/imperative/src/impl/interpreter/stack_manager.h b/imperative/src/impl/interpreter/stack_manager.h index 6d45da7b396185e544a6254bb7c32564c6cd625e..75b1963ce5b930e075124d720248a2b102a0ddf1 100644 --- a/imperative/src/impl/interpreter/stack_manager.h +++ b/imperative/src/impl/interpreter/stack_manager.h @@ -126,7 +126,7 @@ public: m_frames[m_frames.size() - 1 - i] = {node, node->version()}; node = node->parent(); } - mgb_assert(node->is_root(), ""); + mgb_assert(node->is_root()); } Trace() = default; std::string to_string() const { diff --git a/imperative/src/impl/operator.cpp b/imperative/src/impl/operator.cpp index 4337adf910568cdf867101e5c64379beb707c14d..d5aa013d2115d26b1fdd65a5eefb84ff551c859c 100644 --- a/imperative/src/impl/operator.cpp +++ b/imperative/src/impl/operator.cpp @@ -3,7 +3,7 @@ namespace mgb { namespace imperative { -std::vector Operator::fallback(Span inputs) const { +ValueRefList Operator::fallback(Span inputs) const { mgb_throw(MegBrainError, "no fallback implementation for %s", to_string().c_str()); } diff --git a/imperative/src/impl/physical_tensor.cpp b/imperative/src/impl/physical_tensor.cpp index 7f4c0710a4582b8c48555512cfd01df7941dd0c0..5b0f52943896e653dab9e71901c4b27fbc4842b6 100644 --- a/imperative/src/impl/physical_tensor.cpp +++ b/imperative/src/impl/physical_tensor.cpp @@ -99,19 +99,22 @@ Tensor::Tensor( Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) { constexpr int size_threshold = TensorShape::MAX_NDIM; - if (hv.layout().total_nr_elems() <= size_threshold) { + size_t nr_elems = hv.layout().total_nr_elems(); + if (nr_elems <= size_threshold) { m_value = hv; } - MGB_RECORD_EVENT( - profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(), - dev_tensor().raw_ptr()); - dev_tensor().copy_from_fixlayout(hv); - // even though hv is saved in m_value, Tensor itself could be - // released before copy completes - MGB_RECORD_EVENT( - profiler::HostToDeviceFinishEvent, hv.layout(), hv.comp_node(), - hv.raw_ptr(), dev_tensor().raw_ptr()); - AsyncReleaser::inst()->add(hv); + if (nr_elems) { + MGB_RECORD_EVENT( + profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(), + dev_tensor().raw_ptr()); + dev_tensor().copy_from_fixlayout(hv); + // even though hv is saved in m_value, Tensor itself could be + // released before copy completes + MGB_RECORD_EVENT( + profiler::HostToDeviceFinishEvent, hv.layout(), hv.comp_node(), + hv.raw_ptr(), dev_tensor().raw_ptr()); + AsyncReleaser::inst()->add(hv); + } } Tensor::Tensor(const DeviceTensorND& dv, const HostTensorND& hv) { diff --git a/imperative/src/impl/profiler/chrome_timeline.cpp b/imperative/src/impl/profiler/chrome_timeline.cpp index 775ac85c09110179f90570d4e996c145fa323157..e2d3548aa9d70dfa8f7cf0435c081a08eff91e3b 100644 --- a/imperative/src/impl/profiler/chrome_timeline.cpp +++ b/imperative/src/impl/profiler/chrome_timeline.cpp @@ -310,7 +310,8 @@ struct ChromeTimelineEventVisitor : EventVisitor { } else if constexpr (std::is_same_v) { new_host_event("TensorGetProp", 'X') .dur(0) - .args(current_tensor->detail(current->time)); + .args(current_tensor->detail(current->time)) + .arg("kind", imperative::to_string(event.prop)); } else if constexpr (std::is_same_v) { new_host_event("TensorWaitProp", 'B'); } else if constexpr (std::is_same_v) { diff --git a/imperative/src/impl/transformations/eval.cpp b/imperative/src/impl/transformations/eval.cpp index 289975a38e374d2b92c9100af29104d6ed763c19..35ec7e76e1f4977b5caa8fed201ef6e2bbd17df1 100644 --- a/imperative/src/impl/transformations/eval.cpp +++ b/imperative/src/impl/transformations/eval.cpp @@ -15,71 +15,109 @@ namespace mgb { namespace imperative { -std::vector InterpreterTransformation::apply_transformation( - const Operator& op, Span inputs) { - if (auto* op_val = op.as()) { - if (op_val->op().same_type()) { - return {inputs[0]}; - } - SmallVector input_handles; - SmallVector output_handles; - CleanupGuard _{[&] { - for (auto handle : output_handles) { - if (handle) { - m_channel->del(handle); - } +DTypeValue::ref_t InterpreterInfo::dtype() const { + if (!m_dtype) { + m_dtype = DTypeValue::make(handle()->channel()->get_dtype(handle()->handle())); + } + return m_dtype; +} + +CompNodeValue::ref_t InterpreterInfo::comp_node() const { + if (!m_comp_node) { + m_comp_node = CompNodeValue::make( + handle()->channel()->get_device(handle()->handle())); + } + return m_comp_node; +} + +ShapeValue::ref_t InterpreterInfo::shape() const { + if (!m_shape) { + m_shape = ShapeValue::make( + ValueShape::from(handle()->channel()->get_shape(handle()->handle()))); + } + return m_shape; +} + +ValueRefList InterpreterTransformation::apply_op( + const ApplyOp& apply_op, Span inputs) { + if (apply_op.op().same_type()) { + return {inputs[0]}; + } + SmallVector input_handles; + SmallVector output_handles; + CleanupGuard _{[&] { + for (auto handle : output_handles) { + if (handle) { + m_channel->del(handle); } - }}; - for (auto input : inputs) { - input_handles.push_back(*input.cast().handle()); - } - output_handles = - m_channel->apply_op(op_val->op().shared_from_this(), input_handles); - std::vector outputs; - for (auto& handle : output_handles) { - outputs.push_back(InterpreterValue::make(share_handle(handle))); - handle = nullptr; } - return outputs; + }}; + for (auto input : inputs) { + input_handles.push_back(input.cast().handle()->handle()); + } + output_handles = + m_channel->apply_op(apply_op.op().shared_from_this(), input_handles); + ValueRefList outputs(output_handles.size()); + for (size_t i = 0; i < output_handles.size(); ++i) { + outputs[i] = InterpreterValue::make(share_handle(output_handles[i])); + output_handles[i] = nullptr; + } + return outputs; +} + +ValueRefList InterpreterTransformation::apply_get_attr( + const GetAttr& get_attr, Span inputs) { + auto& input = inputs.item().cast(); + ValueRef output; + switch (get_attr.attr()) { + case GetAttr::DType: + output = input.dtype(); + break; + case GetAttr::Shape: + output = input.shape(); + break; + case GetAttr::Device: + output = input.comp_node(); + break; + case GetAttr::Value: + output = HostValue::make(m_channel->get_value(input.handle()->handle())); + break; + case GetAttr::Data: + output = DeviceValue::make( + m_channel->get_dev_tensor(input.handle()->handle())); + break; + default: + mgb_throw( + MegBrainError, "Interpreter: malformed GetAttr: %s", + get_attr.to_string().c_str()); + } + return {output}; +} + +ValueRefList InterpreterTransformation::apply_create_tensor( + const CreateTensor& create_tensor, Span inputs) { + auto args = create_tensor.parse(inputs); + if (!args.device) { + // implies H2D + mgb_assert(args.host, "neither host and device value is valid"); + return {InterpreterValue::make(share_handle( + m_channel->put(*args.host, args.kind == CreateTensor::Unique)))}; + } else { + return {InterpreterValue::make(share_handle(m_channel->put( + *args.device, args.host ? *args.host : HostTensorND())))}; + } +} + +ValueRefList InterpreterTransformation::apply_transformation( + const Operator& op, Span inputs) { + if (auto* op_val = op.as()) { + return apply_op(*op_val, inputs); } else if (auto* get_attr = op.as()) { - Handle handle = *inputs[0].cast().handle(); - ValueRef output; - switch (get_attr->attr()) { - case GetAttr::DType: - output = DTypeValue::make(m_channel->get_dtype(handle)); - break; - case GetAttr::Shape: - output = ShapeValue::make( - ValueShape::from(m_channel->get_shape(handle))); - break; - case GetAttr::Device: - output = CompNodeValue::make(m_channel->get_device(handle)); - break; - case GetAttr::Value: - output = HostValue::make(m_channel->get_value(handle)); - break; - case GetAttr::Data: - output = DeviceValue::make(m_channel->get_dev_tensor(handle)); - break; - default: - mgb_throw( - MegBrainError, "Interpreter: malformed GetAttr: %s", - op.to_string().c_str()); - } - return {output}; + return apply_get_attr(*get_attr, inputs); } else if (auto* create_tensor = op.as()) { - auto args = create_tensor->parse(inputs); - if (!args.device) { - // implies H2D - mgb_assert(args.host, "neither host and device value is valid"); - return {InterpreterValue::make(share_handle( - m_channel->put(*args.host, args.kind == CreateTensor::Unique)))}; - } else { - return {InterpreterValue::make(share_handle(m_channel->put( - *args.device, args.host ? *args.host : HostTensorND())))}; - } + return apply_create_tensor(*create_tensor, inputs); } else if (auto* dtr_command = op.as()) { - auto handle = *inputs[0].cast().handle(); + auto handle = inputs[0].cast().handle()->handle(); switch (dtr_command->kind()) { case DTRCommand::Drop: m_channel->drop(handle); diff --git a/imperative/src/impl/transformations/grad.cpp b/imperative/src/impl/transformations/grad.cpp index bc968a607f01c9b18abb6631d3d9efa8ce2c7ca5..50d7743981597fd33f6858585c131f521dacfbc3 100644 --- a/imperative/src/impl/transformations/grad.cpp +++ b/imperative/src/impl/transformations/grad.cpp @@ -64,12 +64,13 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( size_t count = std::count_if( save_for_backward.begin(), save_for_backward.end(), ranges::identity{}); if (!backward_graph->precomp.empty()) { - SmallVector inputs_and_outputs; + ValueRefList inputs_and_outputs(inputs.size() + outputs.size()); + auto it = inputs_and_outputs.begin(); for (auto&& input : inputs) { - inputs_and_outputs.push_back(input); + *it++ = input; } for (auto&& output : outputs) { - inputs_and_outputs.push_back(output); + *it++ = output; } auto precomp = imperative::apply(backward_graph->precomp, inputs_and_outputs); closure.reserve(precomp.size() + count); @@ -89,7 +90,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( } } void BackwardGraphWithClosure::operator()( - std::vector grads, std::function receiver) { + ValueRefList grads, std::function receiver) { ValueRef args[closure.size() + grads.size()]; size_t nargs = 0; for (auto&& value : closure) { @@ -120,7 +121,7 @@ void BackwardGraphWithClosure::operator()( } void CustomBackward::operator()( - std::vector grads, std::function receiver) { + ValueRefList grads, std::function receiver) { size_t nargs = grads.size(); ValueRef args[nargs]; for (size_t i = 0; i < nargs; ++i) { @@ -201,9 +202,10 @@ void GradKey::backward() { mgb_throw(AssertionError, "invalid backward"); } else { mgb_assert(grad_fn->m_slots.size() > 0); - std::vector grads; + ValueRefList grads (grad_fn->m_slots.size()); + auto iter = grads.begin(); for (auto&& slot : grad_fn->m_slots) { - grads.push_back(slot.m_grad); + *iter++ = slot.m_grad; } backward(grads, grad_receiver); } @@ -254,21 +256,28 @@ void GradKey::freeze() { m_frozen = true; } -std::vector GradTransformation::apply_transformation( +ValueRefList GradTransformation::apply_transformation( const Operator& op, Span inputs) { - auto unwrap_inputs = [this](Span inputs) -> SmallVector { - SmallVector unwrapped_inputs; - for (auto&& input : inputs) { - if (auto grad_value = as_grad_value(input)) { - unwrapped_inputs.push_back(grad_value->m_value); + auto fallback = [&] { + ValueRefList unwrapped_inputs(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + if (auto grad_value = as_grad_value(inputs[i])) { + unwrapped_inputs[i] = grad_value->m_value; } else { - unwrapped_inputs.push_back(input); + unwrapped_inputs[i] = inputs[i]; } } - return unwrapped_inputs; + return imperative::apply(op, unwrapped_inputs); }; + if (auto* get_attr = op.as()) { + if (auto grad_value = as_grad_value(inputs.item())) { + return imperative::apply(op, grad_value->m_value); + } else { + return imperative::apply(op, inputs); + } + } if (m_suppressed) { - return imperative::apply(op, unwrap_inputs(inputs)); + return fallback(); } if (auto* op_val = op.as()) { size_t nr_require_grad = 0; @@ -284,20 +293,21 @@ std::vector GradTransformation::apply_transformation( if (nr_require_grad == 0) { return imperative::apply(op, inputs); } - SmallVector captured_inputs; - SmallVector inputs_require_grad; + ValueRefList captured_inputs(inputs.size()); + SmallVector inputs_require_grad(inputs.size()); // capture value so that trace could assume input as same auto capture_value = [](ValueRef value) { // TODO: fastpath copy shouldn't be an OpDef return imperative::apply(ApplyOp(*FastpathCopy::make()), {&value, 1})[0]; }; - for (auto& input : inputs) { + for (size_t i = 0; i < inputs.size(); ++i) { + auto& input = inputs[i]; if (auto grad_value = as_grad_value(input)) { - captured_inputs.push_back(capture_value(grad_value->m_value)); - inputs_require_grad.push_back(true); + captured_inputs[i] = capture_value(grad_value->m_value); + inputs_require_grad[i] = true; } else { - captured_inputs.push_back(capture_value(input)); - inputs_require_grad.push_back(false); + captured_inputs[i] = capture_value(input); + inputs_require_grad[i] = false; } } decltype(std::declval().m_backward) backward_storage; @@ -373,9 +383,11 @@ std::vector GradTransformation::apply_transformation( mgb_assert(!grad_fn->m_slots.empty()); m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()}); return outputs; + } else if (op.is()) { + return imperative::apply(op, inputs); } else if (auto* attach_grad = op.as()) { if (!has_key(attach_grad->key())) { - return imperative::apply(op, unwrap_inputs(inputs)); + return fallback(); } auto tensor = inputs[0]; GenericFunction callback = (GenericFunction&)inputs[1].cast(); @@ -386,7 +398,7 @@ std::vector GradTransformation::apply_transformation( return {record_grad(output)}; } else if (auto* grad_backward = op.as()) { if (!has_key(grad_backward->key())) { - return imperative::apply(op, unwrap_inputs(inputs)); + return fallback(); } size_t nr_grads = inputs.size() / 2; mgb_assert(nr_grads * 2 == inputs.size()); @@ -416,7 +428,7 @@ std::vector GradTransformation::apply_transformation( backward.m_output_attrs = SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true}); backward.m_backward = set_grad->grad_fn(); - std::vector outputs; + ValueRefList outputs(nr_outputs); grad_fn->m_key = m_key; grad_fn->m_slots.resize(nr_outputs); grad_fn->m_dests.reserve(nr_inputs); @@ -439,13 +451,13 @@ std::vector GradTransformation::apply_transformation( } else { grad_value = GradValue::make(output, m_key, GradSlotPtr(grad_fn, i)); } - outputs.push_back(record_grad(grad_value)); + outputs[i] = record_grad(grad_value); } m_key->m_tape.push_back({grad_fn, nullptr}); return outputs; } else if (auto* gbc = op.as()) { if (gbc->key() != m_key) { - return imperative::apply(op, unwrap_inputs(inputs)); + return fallback(); } return {FunctionValue::make(make_backward_closure(inputs))}; } else if (op.is()) { @@ -471,21 +483,8 @@ std::vector GradTransformation::apply_transformation( } else { return imperative::apply(op, inputs); } - } else if (op.is()) { - return imperative::apply(op, inputs); } else { - SmallVector unwrapped_inputs; - for (auto&& input : inputs) { - if (auto grad_value = as_grad_value(input)) { - unwrapped_inputs.push_back(grad_value->m_value); - } else { - unwrapped_inputs.push_back(input); - } - } - auto outputs = imperative::apply( - op, {unwrapped_inputs.data(), unwrapped_inputs.size()}); - mgb_assert(op.kind() == Operator::GetAttrLike || outputs.empty()); - return outputs; + return fallback(); } } @@ -500,8 +499,7 @@ GenericFunction GradTransformation::make_backward_closure(Span ys) { y_slots.emplace_back(); } } - GenericFunction closure = [grad_key, - y_slots](Span dys) -> std::vector { + GenericFunction closure = [grad_key, y_slots](Span dys) -> ValueRefList { size_t nr_grads = y_slots.size(); mgb_assert(dys.size() == nr_grads); for (size_t i = 0; i < nr_grads; ++i) { diff --git a/imperative/src/impl/transformations/lazy.cpp b/imperative/src/impl/transformations/lazy.cpp index 332645da38f1a1a5ff1e0ccc00d65a09ed4079b4..af37e39b01bf0b6d152ab1c271a20c0575257c7b 100644 --- a/imperative/src/impl/transformations/lazy.cpp +++ b/imperative/src/impl/transformations/lazy.cpp @@ -21,7 +21,7 @@ namespace mgb { namespace imperative { -std::vector LazyEvalTransformation::apply_transformation( +ValueRefList LazyEvalTransformation::apply_transformation( const Operator& op, Span inputs) { if (auto* op_val = op.as()) { static std::unordered_set mm_io_ops = { @@ -59,9 +59,9 @@ std::vector LazyEvalTransformation::apply_transformation( mgb_assert(!output_nodes.empty()); m_io_link = SymbolVar(output_nodes[0]); } - std::vector outputs; - for (auto&& output_node : output_nodes) { - outputs.push_back(record_var(output_node)); + ValueRefList outputs(output_nodes.size()); + for (size_t i = 0; i < output_nodes.size(); ++i) { + outputs[i] = record_var(output_nodes[i]); } return outputs; } else if (auto* create_tensor = op.as()) { diff --git a/imperative/src/impl/transformations/scalar.cpp b/imperative/src/impl/transformations/scalar.cpp index eabd9efde8525e522ca0f54d134a61342f7ad976..7250847a3eee2f8b90969aeaf7a6c888c8a48262 100644 --- a/imperative/src/impl/transformations/scalar.cpp +++ b/imperative/src/impl/transformations/scalar.cpp @@ -19,26 +19,8 @@ namespace imperative { namespace { -using ScalarRule = std::function(const OpDef&, Span)>; -static std::unordered_map< - Typeinfo*, std::function(const OpDef&, Span)>> - scalar_rules; - -ValueRef unwrap_input(ValueRef input) { - if (auto scalar_input = input.as_ref()) { - return scalar_input->value(); - } else { - return input; - } -} - -std::vector unwrap_inputs(Span inputs) { - std::vector unwrapped_inputs; - for (auto&& input : inputs) { - unwrapped_inputs.push_back(unwrap_input(input)); - } - return unwrapped_inputs; -} +using ScalarRule = ValueRefList (*)(const OpDef&, Span, Span); +static std::unordered_map scalar_rules; ValueRef make_scalar_shape(CompNode device) { HostTensorND scalar_shape(device, {1}, dtype::Int32()); @@ -49,9 +31,6 @@ ValueRef make_scalar_shape(CompNode device) { } bool is_scalar_shape(ValueRef shape) { - if (shape.is()) { - return false; - } // may have performance issue auto shape_of_shape = shape.shape(); if (!shape_of_shape) { @@ -61,74 +40,65 @@ bool is_scalar_shape(ValueRef shape) { return *shape_of_shape == ValueShape{0}; } -template -void register_scalar_rule(std::vector (*rule)(const T&, Span)) { - scalar_rules[T::typeinfo()] = [rule](const OpDef& def, Span inputs) { - return (*rule)(def.cast_final_safe(), inputs); +template , Span)> +void register_scalar_rule() { + scalar_rules[T::typeinfo()] = [](const OpDef& def, Span inputs, + Span inputs_mask) { + return (*rule)(def.cast_final_safe(), inputs, inputs_mask); }; } -std::vector elemwise_rule(const Elemwise& elem, Span inputs) { +template +ValueRefList elemwise_rule( + const TOpDef& op_def, Span inputs, Span inputs_mask) { + if constexpr (nr_inputs != 0) { + mgb_assert(inputs.size() == inputs.size(), "inputs size mismatch"); + } bool all_scalar = true; - for (auto&& input : inputs) { - if (!input.is()) { + for (auto&& input_mask : inputs_mask) { + if (!input_mask) { all_scalar = false; - break; } } - auto output = imperative::apply(elem, unwrap_inputs(inputs))[0]; + auto outputs = imperative::apply(op_def, inputs); if (all_scalar) { - return {ScalarValue::make(output)}; - } else { - return {output}; + outputs[0] = ScalarValue::make(outputs[0]); } + return outputs; } -std::vector remove_axis_rule( - const RemoveAxis& remove_axis, Span inputs) { - mgb_assert(inputs.size() == 1); - mgb_assert(!inputs[0].is()); - auto output = imperative::apply(remove_axis, inputs)[0]; - bool is_scalar = inputs[0].shape()->ndim == remove_axis.axis.size(); +ValueRefList remove_axis_rule( + const RemoveAxis& remove_axis, Span inputs, Span inputs_mask) { + mgb_assert(!inputs_mask.item()); + bool is_scalar = inputs.item().shape()->ndim == remove_axis.axis.size(); + if (is_scalar && remove_axis.axis.size() == 1) { + return {ScalarValue::make(inputs.item())}; + } + auto outputs = imperative::apply(remove_axis, inputs); if (is_scalar) { - return {ScalarValue::make(output)}; - } else { - return {output}; + outputs[0] = ScalarValue::make(outputs[0]); } + return outputs; } -std::vector reduce_rule(const Reduce& reduce, Span inputs) { +ValueRefList reduce_rule( + const Reduce& reduce, Span inputs, Span inputs_mask) { if (inputs.size() == 1) { - return imperative::apply(reduce, unwrap_inputs(inputs)); + return imperative::apply(reduce, inputs); } mgb_assert(inputs.size() == 2); bool is_scalar = is_scalar_shape(inputs[1]); if (is_scalar) { - auto unwrapped_input = unwrap_input(inputs[0]); - CompNode device = *unwrapped_input.device(); - return {ScalarValue::make(imperative::apply( - reduce, unwrapped_input, make_scalar_shape(device))[0])}; - } - auto output = imperative::apply(reduce, unwrap_inputs(inputs))[0]; - if (is_scalar) { - return {ScalarValue::make(output)}; - } else { - return {output}; - } -} - -std::vector typecvt_rule(const TypeCvt& typecvt, Span inputs) { - mgb_assert(inputs.size() == 1); - if (auto scalar_input = inputs[0].as_ref()) { + CompNode device = *inputs[0].device(); return {ScalarValue::make( - imperative::apply(typecvt, scalar_input->value())[0])}; - } else { - return imperative::apply(typecvt, inputs); + imperative::apply(reduce, inputs[0], make_scalar_shape(device))[0])}; } + return imperative::apply(reduce, inputs); } -std::vector collective_comm_rule( - const CollectiveComm& collective_comm, Span inputs) { +ValueRefList collective_comm_rule( + const CollectiveComm& collective_comm, Span inputs, + Span inputs_mask) { mgb_assert(inputs.size() == 1); static std::unordered_set modes = { CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN, @@ -138,17 +108,17 @@ std::vector collective_comm_rule( if (modes.count(collective_comm.mode) == 0) { return imperative::apply(collective_comm, inputs); } - if (auto scalar_input = inputs[0].as_ref()) { - return {ScalarValue::make( - imperative::apply(collective_comm, scalar_input->value())[0])}; + if (inputs_mask.item()) { + return {ScalarValue::make(imperative::apply(collective_comm, inputs[0])[0])}; } else { return imperative::apply(collective_comm, inputs); } } -std::vector param_pack_split_rule( - const ParamPackSplit& param_pack_split, Span inputs) { - auto outputs = imperative::apply(param_pack_split, unwrap_inputs(inputs)); +ValueRefList param_pack_split_rule( + const ParamPackSplit& param_pack_split, Span inputs, + Span inputs_mask) { + auto outputs = imperative::apply(param_pack_split, inputs); size_t nr_outputs = outputs.size(); mgb_assert(nr_outputs == param_pack_split.shapes.size()); for (size_t i = 0; i < nr_outputs; ++i) { @@ -159,29 +129,28 @@ std::vector param_pack_split_rule( return outputs; } -std::vector dot_rule(const Dot& dot, Span inputs) { - return {ScalarValue::make(imperative::apply(dot, unwrap_inputs(inputs))[0])}; +ValueRefList dot_rule(const Dot& dot, Span inputs, Span inputs_mask) { + return {ScalarValue::make(imperative::apply(dot, inputs)[0])}; } -std::vector add_axis_rule(const AddAxis& add_axis, Span inputs) { +ValueRefList add_axis_rule( + const AddAxis& add_axis, Span inputs, Span inputs_mask) { mgb_assert(inputs.size() == 1); - if (auto scalar_input = inputs[0].as_ref()) { + if (inputs_mask.item()) { mgb_assert(add_axis.axis[0] == 0); if (add_axis.axis.size() == 1) { - return {scalar_input->value()}; + return {inputs[0]}; } else { std::vector axis(add_axis.axis.begin() + 1, add_axis.axis.end()); - return imperative::apply( - ApplyOp(*AddAxis::make(axis, add_axis.scope())), - scalar_input->value()); + return imperative::apply(*AddAxis::make(axis, add_axis.scope()), inputs[0]); } } else { return imperative::apply(add_axis, inputs); } } -std::vector remote_recv_rule( - const RemoteRecv& remote_recv, Span inputs) { +ValueRefList remote_recv_rule( + const RemoteRecv& remote_recv, Span inputs, Span inputs_mask) { if (remote_recv.shape.empty()) { std::vector shape = {1}; auto remote_recv_no_scalar = RemoteRecv::make( @@ -189,32 +158,32 @@ std::vector remote_recv_rule( remote_recv.rank_from, remote_recv.cn, shape, remote_recv.dtype, remote_recv.backend); remote_recv_no_scalar->set_scope(remote_recv.scope()); - return imperative::apply( - ApplyOp(*remote_recv_no_scalar), unwrap_inputs(inputs)); + return imperative::apply(ApplyOp(*remote_recv_no_scalar), inputs); } else { - return imperative::apply(remote_recv, unwrap_inputs(inputs)); + return imperative::apply(remote_recv, inputs); } } -std::vector check_no_finite_rule( - const CheckNonFinite& check_no_finite, Span inputs) { - auto outputs = imperative::apply(check_no_finite, unwrap_inputs(inputs)); +ValueRefList check_no_finite_rule( + const CheckNonFinite& check_no_finite, Span inputs, + Span inputs_mask) { + auto outputs = imperative::apply(check_no_finite, inputs); mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch"); outputs.back() = ScalarValue::make(outputs.back()); for (size_t i = 0; i < inputs.size(); ++i) { - if (inputs[i].is()) { + if (inputs_mask[i]) { outputs[i] = ScalarValue::make(outputs[i]); } } return outputs; } -std::vector subtensor_rule( - const Subtensor& subtensor, Span inputs) { +ValueRefList subtensor_rule( + const Subtensor& subtensor, Span inputs, Span inputs_mask) { mgb_assert(inputs.size() >= 1); auto input = inputs[0]; bool is_scalar; - mgb_assert(!input.is(), "subtensor shouldn't have scalar input"); + mgb_assert(!inputs_mask[0], "subtensor shouldn't have scalar input"); if (auto shape = input.shape()) { size_t ndim = input.shape()->ndim; for (auto&& [axis, begin, end, step, idx] : subtensor.items) { @@ -226,25 +195,25 @@ std::vector subtensor_rule( } else { is_scalar = false; } - auto output = imperative::apply(subtensor, unwrap_inputs(inputs))[0]; + auto outputs = imperative::apply(subtensor, inputs); if (is_scalar) { - return {ScalarValue::make(output)}; - } else { - return {output}; + outputs[0] = ScalarValue::make(outputs[0]); } + return outputs; } -std::vector get_var_shape_rule( - const GetVarShape& get_var_shape, Span inputs) { +ValueRefList get_var_shape_rule( + const GetVarShape& get_var_shape, Span inputs, + Span inputs_mask) { bool all_scalar = true; mgb_assert(inputs.size() >= 1); - for (auto&& input : inputs) { - if (!input.is()) { + for (auto&& input_mask : inputs_mask) { + if (!input_mask) { all_scalar = false; } } if (all_scalar) { - auto device = inputs[0].cast().value().device(); + auto device = inputs[0].device(); auto storage = HostStorage::make(*device); // storage->ensure_size(1); return imperative::apply( @@ -252,88 +221,49 @@ std::vector get_var_shape_rule( CreateTensor::Const, *device, dtype::Int32(), ValueShape{0}), storage); } else { - return imperative::apply(get_var_shape, unwrap_inputs(inputs)); - } -} - -std::vector fastpath_copy_rule( - const FastpathCopy& fastpath_copy, Span inputs) { - mgb_assert(inputs.size() == 1); - bool is_scalar = inputs[0].is(); - auto output = imperative::apply(fastpath_copy, unwrap_inputs(inputs))[0]; - if (is_scalar) { - return {ScalarValue::make(output)}; - } else { - return {output}; + return imperative::apply(get_var_shape, inputs); } } -std::vector reshape_rule(const Reshape& reshape, Span inputs) { +ValueRefList reshape_rule( + const Reshape& reshape, Span inputs, Span inputs_mask) { mgb_assert(inputs.size() == 2); bool is_scalar = is_scalar_shape(inputs[1]); - auto unwrapped_input = inputs[0].is() - ? inputs[0].cast().value() - : inputs[0]; if (is_scalar) { return {ScalarValue::make(imperative::apply( - reshape, unwrapped_input, - make_scalar_shape(*unwrapped_input.device()))[0])}; + reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; } else { - return imperative::apply(reshape, unwrap_inputs(inputs)); + return imperative::apply(reshape, inputs); } } -std::vector broadcast_rule( - const Broadcast& broadcast, Span inputs) { +ValueRefList broadcast_rule( + const Broadcast& broadcast, Span inputs, Span inputs_mask) { mgb_assert(inputs.size() == 2); bool is_scalar = is_scalar_shape(inputs[1]); - auto unwrapped_input = inputs[0].is() - ? inputs[0].cast().value() - : inputs[0]; if (is_scalar) { return {ScalarValue::make(imperative::apply( - broadcast, unwrapped_input, - make_scalar_shape(*unwrapped_input.device()))[0])}; - } else { - return imperative::apply(broadcast, unwrap_inputs(inputs)); - } -} - -std::vector copy_rule(const Copy& copy, Span inputs) { - mgb_assert(inputs.size() == 1); - bool is_scalar = inputs[0].is(); - if (is_scalar) { - return {ScalarValue::make(imperative::apply(copy, unwrap_inputs(inputs))[0])}; - } else { - return imperative::apply(copy, unwrap_inputs(inputs)); - } -} - -std::vector inplace_add_rule( - const InplaceAdd& inplace_add, Span inputs) { - mgb_assert(inputs.size() == 4); - bool is_scalar = inputs[0].is(); - if (is_scalar) { - return {ScalarValue::make( - imperative::apply(inplace_add, unwrap_inputs(inputs))[0])}; + broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; } else { - return imperative::apply(inplace_add, unwrap_inputs(inputs)); + return imperative::apply(broadcast, inputs); } } template -std::vector subgraph_op_rule(const T& op, Span inputs) { +ValueRefList subgraph_op_rule( + const T& op, Span inputs, Span inputs_mask, + const Type& scalar_type) { // TODO: add flag instead of assume bool all_scalar = true; - for (auto&& input : inputs) { - if (!input.is()) { + for (auto&& input_mask : inputs_mask) { + if (!input_mask) { all_scalar = false; } } - auto outputs = imperative::apply(op, unwrap_inputs(inputs)); + auto outputs = imperative::apply(op, inputs); if (all_scalar) { for (auto& output : outputs) { - output = ScalarValue::make(output); + output = scalar_type.make(output); } } return outputs; @@ -341,67 +271,54 @@ std::vector subgraph_op_rule(const T& op, Span inputs) { struct ScalarRuleRegistry { ScalarRuleRegistry() { - register_scalar_rule(elemwise_rule); - register_scalar_rule(remove_axis_rule); - register_scalar_rule(reduce_rule); - register_scalar_rule(typecvt_rule); - register_scalar_rule(collective_comm_rule); - register_scalar_rule(param_pack_split_rule); - register_scalar_rule(dot_rule); - register_scalar_rule(add_axis_rule); - register_scalar_rule(remote_recv_rule); - register_scalar_rule(check_no_finite_rule); - register_scalar_rule(subtensor_rule); - register_scalar_rule(get_var_shape_rule); - register_scalar_rule(fastpath_copy_rule); - register_scalar_rule(reshape_rule); - register_scalar_rule(broadcast_rule); - register_scalar_rule(copy_rule); - register_scalar_rule(inplace_add_rule); - register_scalar_rule(subgraph_op_rule); - register_scalar_rule(subgraph_op_rule); + register_scalar_rule>(); + register_scalar_rule(); + register_scalar_rule(); + register_scalar_rule>(); + register_scalar_rule(); + register_scalar_rule(); + register_scalar_rule(); + register_scalar_rule(); + register_scalar_rule(); + register_scalar_rule(); + register_scalar_rule(); + register_scalar_rule(); + register_scalar_rule>(); + register_scalar_rule(); + register_scalar_rule(); + register_scalar_rule>(); + register_scalar_rule>(); + register_scalar_rule>(); + register_scalar_rule>(); } } _; } // namespace -std::vector ScalarTransformation::apply_transformation( - const Operator& op, Span inputs) { - if (auto apply_op = op.as()) { - auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); - if (iter != scalar_rules.end()) { - return iter->second(apply_op->op(), inputs); - } else { - // TODO: repeat op - return imperative::apply(op, unwrap_inputs(inputs)); - } - } else if (auto* create_tensor = op.as()) { - if (create_tensor->shape().is_scalar()) { - ValueShape scalar_shape = {1}; - CreateTensor scalar_op( - create_tensor->kind(), create_tensor->device(), - create_tensor->dtype(), scalar_shape); - return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])}; - } else { - return imperative::apply(op, inputs); - } - } else if (auto* get_attr = op.as()) { - bool is_scalar = inputs.as_array<1>()[0].is(); - auto output = imperative::apply(op, unwrap_inputs(inputs))[0]; - if (!is_scalar) { - return {output}; +ValueRefList ScalarTransformation::apply_get_attr( + const GetAttr& get_attr, Span inputs) { + auto&& input = inputs.item(); + bool is_scalar = input.is(); + if (!is_scalar) { + return imperative::apply(get_attr, input); + } + auto unwrapped_input = input.cast().value(); + if (get_attr.attr() == GetAttr::Shape) { + if (!m_empty_shape) { + m_empty_shape = ShapeValue::make(); } - switch (get_attr->attr()) { - case GetAttr::Shape: { - // Scalar Shape - return {ShapeValue::make()}; - } + return {m_empty_shape}; + } else { + auto outputs = imperative::apply(get_attr, unwrapped_input); + auto& output = outputs[0]; + switch (get_attr.attr()) { case GetAttr::Value: { auto& hv = output.cast(); mgb_assert( hv.shape() == ValueShape({1}), "underlying value should has shape {1}, got %s", hv.shape().to_string().c_str()); - return {HostValue::make(hv.dtype(), ValueShape(), hv.storage())}; + output = HostValue::make(hv.dtype(), ValueShape(), hv.storage()); + break; } case GetAttr::Data: { auto& dv = output.cast(); @@ -409,22 +326,67 @@ std::vector ScalarTransformation::apply_transformation( dv.shape() == ValueShape({1}), "underlying value should has shape {1}, got %s", dv.shape().to_string().c_str()); - return {DeviceValue::make(dv.dtype(), ValueShape(), dv.storage())}; + output = DeviceValue::make(dv.dtype(), ValueShape(), dv.storage()); + break; } default: - return {output}; + break; + } + return outputs; + } +} + +ValueRefList ScalarTransformation::apply_transformation( + const Operator& op, Span inputs) { + if (auto* get_attr = op.as()) { + // fastpath for GetAttr + return apply_get_attr(*get_attr, inputs); + } + size_t nr_inputs = inputs.size(); + ValueRefList unwrapped_inputs(nr_inputs); + bool inputs_mask[nr_inputs]; + for (size_t i = 0; i < inputs.size(); ++i) { + if (auto scalar_value = inputs[i].as_ref()) { + unwrapped_inputs[i] = scalar_value->value(); + inputs_mask[i] = true; + } else { + unwrapped_inputs[i] = inputs[i]; + inputs_mask[i] = false; + } + } + auto fallback = [&] { return imperative::apply(op, unwrapped_inputs); }; + if (auto apply_op = op.as()) { + auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); + if (iter != scalar_rules.end()) { + return iter->second( + apply_op->op(), unwrapped_inputs, {inputs_mask, nr_inputs}); + } else { + // TODO: repeat op + return fallback(); + } + } else if (auto* create_tensor = op.as()) { + if (create_tensor->shape().is_scalar()) { + ValueShape scalar_shape = {1}; + CreateTensor scalar_op( + create_tensor->kind(), create_tensor->device(), + create_tensor->dtype(), scalar_shape); + return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])}; + } else { + return imperative::apply(op, inputs); } } else if (op.as()) { - return {BoolValue::make(inputs.as_array<1>()[0].is())}; + mgb_assert(nr_inputs == 1); + return {BoolValue::make(inputs_mask[0])}; } else if (op.is()) { - bool is_scalar = inputs.as_array<1>()[0].is(); + mgb_assert(nr_inputs == 1); + bool is_scalar = inputs_mask[0]; + auto outputs = fallback(); if (is_scalar) { - return {ScalarValue::make(imperative::apply(op, unwrap_inputs(inputs))[0])}; - } else { - return imperative::apply(op, inputs); + outputs[0] = ScalarValue::make(outputs[0]); } + return outputs; } else { - return imperative::apply(op, unwrap_inputs(inputs)); + return fallback(); } }; diff --git a/imperative/src/impl/transformations/tangent.cpp b/imperative/src/impl/transformations/tangent.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a7a0ee0516c4df9f4876fca72ce36bada3407865 --- /dev/null +++ b/imperative/src/impl/transformations/tangent.cpp @@ -0,0 +1,25 @@ +/** + * \file imperative/src/impl/transformations/tangent.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/transformations/tangent.h" + +namespace mgb { +namespace imperative { + +ValueRefList TangentTransformation::apply_transformation( + const Operator& op, Span inputs) { + if (auto* apply_op = op.as()) { + } + mgb_assert(false); +} + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/impl/transformations/trace.cpp b/imperative/src/impl/transformations/trace.cpp index 9a8074c8326a0e15431223c767b5c731dba76bd6..d7c0cd993eda5f2aa6d7e7d84309fb2cdebea0fd 100644 --- a/imperative/src/impl/transformations/trace.cpp +++ b/imperative/src/impl/transformations/trace.cpp @@ -153,7 +153,7 @@ VarNodeArray TraceResult::dump( return output_nodes; } -std::vector TracingTransformation::apply_transformation( +ValueRefList TracingTransformation::apply_transformation( const Operator& op, Span inputs) { if (auto* op_value = op.as()) { SmallVector unwrapped_inputs; @@ -180,11 +180,12 @@ std::vector TracingTransformation::apply_transformation( } const_cast(op_value->op()).set_scope(scopes_join); auto unwrapped_outputs = imperative::apply(op, unwrapped_inputs); - std::vector wrapped_outputs; + ValueRefList wrapped_outputs(unwrapped_outputs.size()); SmallVector output_ids; - for (auto&& output : unwrapped_outputs) { + for (size_t i = 0; i < unwrapped_outputs.size(); ++i) { + auto&& output = unwrapped_outputs[i]; auto wrapped_output = record_var(output, false, VarKind::Internal); - wrapped_outputs.push_back(wrapped_output); + wrapped_outputs[i] = wrapped_output; output_ids.push_back(wrapped_output->id()); } m_seq.push_back({op_value->op().shared_from_this(), input_ids, output_ids}); @@ -375,6 +376,11 @@ void CompiledTransformation::compile() { return accessor; }; std::vector var_accessors(m_vars.size()); + auto exc_setter = std::bind( + &CompiledTransformation::set_exception, this, std::placeholders::_1); + for (auto&& accessor : var_accessors) { + accessor.exc_setter = exc_setter; + } for (auto&& item : m_seq) { bool require_link = bool(item.op) && mm_io_ops.count(item.op->dyn_typeinfo()); VarNodeArray input_vars; @@ -509,8 +515,8 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { } } -TracedValue::ref_t CompiledTransformation::trace_output(size_t id) { - auto traced_value = TracedValue::make(id); +auto CompiledTransformation::trace_output(size_t id) -> TracedValue::ref_t { + auto traced_value = TracedValue::make(id, &m_vars[id], &m_var_accessors[id]); m_weak_values.push_back(traced_value); return traced_value; } @@ -520,64 +526,99 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() { return m_seq[m_pc++]; } -std::vector CompiledTransformation::apply_transformation( +ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const { + if (!m_shape) { + trace_assert(m_accessor->shape_getter, "shape unreadable"); + m_shape = ShapeValue::make(ValueShape::from(m_accessor->shape_getter())); + } + return m_shape; +} + +DTypeValue::ref_t CompiledTransformation::TracedInfo::dtype() const { + if (!m_dtype) { + m_dtype = DTypeValue::make(m_var->dtype); + } + return m_dtype; +} + +CompNodeValue::ref_t CompiledTransformation::TracedInfo::comp_node() const { + if (!m_comp_node) { + m_comp_node = CompNodeValue::make(m_var->device); + } + return m_comp_node; +} +auto CompiledTransformation::TracedInfo::accessor() const -> const VarAccessor& { + return *m_accessor; +} + +ValueRefList CompiledTransformation::apply_op( + const ApplyOp& apply_op, Span inputs) { + auto& item = next_instruction(); + trace_assert(inputs.size() == item.inputs.size(), "input size mismatch"); + trace_assert(apply_op.op().is_same(*item.op), "operator mismatch"); + for (size_t i = 0; i < inputs.size(); ++i) { + trace_input(item.inputs[i], inputs[i]); + } + ValueRefList outputs(item.outputs.size()); + for (size_t i = 0; i < item.outputs.size(); ++i) { + outputs[i] = trace_output(item.outputs[i]); + } + return outputs; +} + +ValueRefList CompiledTransformation::apply_get_attr( + const GetAttr& get_attr, Span inputs) { + if (auto* traced_value = inputs[0].as()) { + ValueRef output; + auto& var_accessor = traced_value->accessor(); + switch (get_attr.attr()) { + case GetAttr::Shape: + output = traced_value->shape(); + break; + case GetAttr::Data: + trace_assert(var_accessor.data_getter, "data unreadable"); + output = DeviceValue::make(var_accessor.data_getter()); + break; + case GetAttr::Value: + trace_assert(var_accessor.value_getter, "value unreadable"); + output = HostValue::make(var_accessor.value_getter()); + break; + case GetAttr::DType: + output = traced_value->dtype(); + break; + case GetAttr::Device: + output = traced_value->comp_node(); + default: + break; + } + return {output}; + } else { + return imperative::apply(get_attr, inputs); + } +} + +ValueRefList CompiledTransformation::apply_create_tensor( + const CreateTensor& create_tensor, Span inputs) { + if (create_tensor.kind() == CreateTensor::NoTrace) { + return imperative::apply(create_tensor, inputs); + } + auto& item = next_instruction(); + trace_assert(item.op == nullptr, "operator mismatch"); + auto input_id = item.inputs[0]; + auto output_id = item.outputs[0]; + auto tensor = imperative::apply(create_tensor, inputs)[0]; + trace_input(input_id, tensor); + return {trace_output(output_id)}; +} + +ValueRefList CompiledTransformation::apply_transformation( const Operator& op, Span inputs) { if (auto* op_value = op.as()) { - auto& item = next_instruction(); - SmallVector unwrapped_inputs; - SmallVector wrapped_inputs; - trace_assert(inputs.size() == item.inputs.size(), "input size mismatch"); - trace_assert(op_value->op().is_same(*item.op), "operator mismatch"); - for (size_t i = 0; i < inputs.size(); ++i) { - trace_input(item.inputs[i], inputs[i]); - } - std::vector outputs; - for (auto&& output_id : item.outputs) { - outputs.push_back(trace_output(output_id)); - } - return outputs; + return apply_op(*op_value, inputs); } else if (auto* create_tensor = op.as()) { - if (create_tensor->kind() == CreateTensor::NoTrace) { - return imperative::apply(op, inputs); - } - auto& item = next_instruction(); - trace_assert(item.op == nullptr, "operator mismatch"); - auto input_id = item.inputs[0]; - auto output_id = item.outputs[0]; - auto tensor = imperative::apply(op, inputs)[0]; - trace_input(input_id, tensor); - return {trace_output(output_id)}; + return apply_create_tensor(*create_tensor, inputs); } else if (auto* get_attr = op.as()) { - if (auto* traced_value = inputs[0].as()) { - ValueRef output; - auto& var = m_vars[traced_value->id()]; - auto& var_accessor = m_var_accessors[traced_value->id()]; - switch (get_attr->attr()) { - case GetAttr::Shape: - trace_assert(var_accessor.shape_getter, "shape unreadable"); - output = ShapeValue::make( - ValueShape::from(var_accessor.shape_getter())); - break; - case GetAttr::Data: - trace_assert(var_accessor.data_getter, "data unreadable"); - output = DeviceValue::make(var_accessor.data_getter()); - break; - case GetAttr::Value: - trace_assert(var_accessor.value_getter, "value unreadable"); - output = HostValue::make(var_accessor.value_getter()); - break; - case GetAttr::DType: - output = DTypeValue::make(var.dtype); - break; - case GetAttr::Device: - output = CompNodeValue::make(var.device); - default: - break; - } - return {output}; - } else { - return imperative::apply(op, inputs); - } + return apply_get_attr(*get_attr, inputs); } else if (auto* trace_mark_var = op.as()) { auto& item = next_instruction(); trace_assert(item.op == nullptr, "operator mismatch"); diff --git a/imperative/src/impl/value.cpp b/imperative/src/impl/value.cpp index 5bbb8038d70cf054da45208b815fadd04cd19a07..857f295ab8c08f554e841bf0a3f62559871af82b 100644 --- a/imperative/src/impl/value.cpp +++ b/imperative/src/impl/value.cpp @@ -8,50 +8,58 @@ namespace mgb { namespace imperative { namespace { -static thread_local size_t nr_watched_values = 0; -static thread_local uint64_t nr_values = 0; -static thread_local bool recording_values = false; -static thread_local std::vector recorded_values; +static /*thread_local*/ size_t nr_watched_values = 0; +static /*thread_local*/ uint64_t nr_values = 0; +static /*thread_local*/ bool recording_values = false; +static /*thread_local*/ std::vector recorded_values; static WeakValueMap registered_values; } // namespace ValueRef::storage_t& ValueRef::storage() const { - if (!m_storage) { + if (mgb_likely(!m_storage->m_successor.m_storage)) { return m_storage; } - if (auto& storage = m_storage->m_successor.m_storage) { - while (storage->m_successor.m_storage) { - storage = storage->m_successor.m_storage; - } - return storage; - } else { - return m_storage; + while (m_storage->m_successor.m_storage) { + m_storage = m_storage->m_successor.m_storage; + } + return m_storage; +} + +const Value* ValueRef::as(size_t typecode) const { + auto&& storage = this->storage(); + if (storage->m_typecode != typecode) { + return nullptr; } + return static_cast(storage.get()); +} + +bool ValueRef::is(size_t typecode) const { + return this->storage()->m_typecode == typecode; } TypedValueRef ValueRef::dev_tensor() const { - return imperative::apply(GetAttr(GetAttr::Data), *this)[0].as_ref(); + return imperative::apply(GetAttr(GetAttr::Data), *this)[0].cast_ref(); } TypedValueRef ValueRef::numpy() const { - return imperative::apply(GetAttr(GetAttr::Value), *this)[0].as_ref(); + return imperative::apply(GetAttr(GetAttr::Value), *this)[0].cast_ref(); } TypedValueRef ValueRef::device() const { return imperative::apply(GetAttr(GetAttr::Device), *this)[0] - .as_ref(); + .cast_ref(); } TypedValueRef ValueRef::shape() const { - return imperative::apply(GetAttr(GetAttr::Shape), *this)[0].as_ref(); + return imperative::apply(GetAttr(GetAttr::Shape), *this)[0].cast_ref(); } TypedValueRef ValueRef::dtype() const { - return imperative::apply(GetAttr(GetAttr::DType), *this)[0].as_ref(); + return imperative::apply(GetAttr(GetAttr::DType), *this)[0].cast_ref(); } TypedValueRef ValueRef::name() const { - return imperative::apply(GetName(), *this)[0].as_ref(); + return imperative::apply(GetName(), *this)[0].cast_ref(); } bool ValueRef::is_scalar() const { @@ -75,13 +83,15 @@ void ValueRef::unwatch() const { } ValueRef ValueRef::unwrap() const { - ValueRef value = *this; auto& context = Transformation::get_context(); - for (size_t i = 0; i < context.next_transformation; ++i) { - value = context.transformations[i]->unwrap(value); + if (mgb_unlikely(context.next_transformation)) { + ValueRef value = *this; + for (size_t i = 0; i < context.next_transformation; ++i) { + value = context.transformations[i]->unwrap(value); + } + return value; } - mgb_assert(value); - return value; + return *this; } std::string ValueRef::to_string() const { @@ -101,13 +111,11 @@ std::string ValueRef::raw_type() const { return types[m_storage->m_typecode].name(); } -uint64_t ValueRef::id() const { - return m_storage ? m_storage->m_id : std::numeric_limits::max(); -} - bool ValueRef::watching() const { - auto storage = this->storage(); - return storage && storage->m_watching; + if (!m_storage) { + return false; + } + return this->storage()->m_watching; } ValueRef ValueRef::make(ValueRef::storage_t storage) { @@ -186,5 +194,96 @@ void Value::try_rethrow() { } } +inline void ValueRefList::init(size_t nr_elems) { + m_size = nr_elems; + if (m_size > 0) { + if (m_size == 1) { + m_data = inline_storage(); + } else { + auto& context = Transformation::get_context(); + m_data = context.allocator.allocate(m_size); + } + for (size_t i = 0; i < m_size; ++i) { + new (m_data + i) ValueRef(); + } + } else { + m_data = nullptr; + } +} + +ValueRefList::ValueRefList(size_t nr_elems) { + init(nr_elems); +} + +ValueRefList::ValueRefList(std::initializer_list values) + : ValueRefList(values.begin(), values.end()) {} + +ValueRefList::ValueRefList(const ValueRefList& rhs) + : ValueRefList(rhs.cbegin(), rhs.cend()) {} + +ValueRefList::ValueRefList(ValueRefList&& rhs) : ValueRefList() { + m_size = rhs.m_size; + if (rhs.m_data == rhs.inline_storage()) { + m_data = inline_storage(); + new (m_data) ValueRef(); + m_data[0] = std::move(rhs.m_data[0]); + } else { + m_data = rhs.m_data; + rhs.m_data = nullptr; + rhs.m_size = 0; + } +} + +ValueRefList& ValueRefList::operator=(const ValueRefList& rhs) { + if (this == &rhs) { + return *this; + } + clear(); + init(rhs.m_size); + for (size_t i = 0; i < m_size; ++i) { + m_data[i] = rhs.m_data[i]; + } + return *this; +} + +ValueRefList& ValueRefList::operator=(ValueRefList&& rhs) { + if (this == &rhs) { + return *this; + } + clear(); + if (rhs.m_data == rhs.inline_storage()) { + m_data = inline_storage(); + new (m_data) ValueRef(); + m_data[0] = rhs.m_data[0]; + m_size = 1; + rhs.clear(); + } else { + m_data = rhs.m_data; + m_size = rhs.m_size; + rhs.m_data = nullptr; + rhs.m_size = 0; + } + return *this; +} + +ValueRefList::~ValueRefList() { + clear(); +} + +void ValueRefList::clear() { + for (size_t i = 0; i < m_size; ++i) { + m_data[i].~ValueRef(); + } + if (m_data) { + if (m_size != 1) { + Transformation::get_context().allocator.deallocate(m_data, m_size); + } else { + mgb_assert(m_data == inline_storage()); + } + } + m_data = nullptr; + m_size = 0; +} + } // namespace imperative } // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/basic_operators.h b/imperative/src/include/megbrain/imperative/basic_operators.h index ee821557f89d6d7eaaebc7c99897a2bf83329ab3..dd22185fea8eceb9c54f6c5035e0bb81909e0790 100644 --- a/imperative/src/include/megbrain/imperative/basic_operators.h +++ b/imperative/src/include/megbrain/imperative/basic_operators.h @@ -24,8 +24,6 @@ namespace imperative { class GradKey; -using GenericFunction = std::function(Span)>; - /** * \brief apply an OpDef to values * @@ -37,7 +35,7 @@ private: public: ApplyOp(const OpDef& op) : m_op(op) {} - const OpDef& op() { return m_op; } + const OpDef& op() const { return m_op; } std::string to_string() const override; }; @@ -106,7 +104,7 @@ public: * \param inputs contains host_storage and device_storage * \return Args unpacked args */ - Args parse(Span inputs); + Args parse(Span inputs) const; Kind kind() const { return m_kind; } CompNode device() const { return m_device; } @@ -129,11 +127,11 @@ private: public: DTRCommand(Kind kind) : m_kind(kind) {} - Kind kind() { return m_kind; } + Kind kind() const { return m_kind; } std::string to_string() const override; - std::vector fallback(Span inputs) const override { return {}; } + ValueRefList fallback(Span inputs) const override { return {}; } }; // deprecated @@ -141,9 +139,7 @@ class GetName final : public OperatorImpl { public: std::string to_string() const override; - std::vector fallback(Span inputs) const override { - return {ValueRef()}; - } + ValueRefList fallback(Span inputs) const override { return {ValueRef()}; } }; /** @@ -161,7 +157,7 @@ public: std::string to_string() const override; - std::vector fallback(Span inputs) const override { + ValueRefList fallback(Span inputs) const override { return {inputs.as_array<1>()[0]}; } }; diff --git a/imperative/src/include/megbrain/imperative/basic_values.h b/imperative/src/include/megbrain/imperative/basic_values.h index 487ed002bbcd64d0ebe1e0f64f6b954c6e97bffa..5789b66c4ac8aeb76e165df8a386c48b57f0ea48 100644 --- a/imperative/src/include/megbrain/imperative/basic_values.h +++ b/imperative/src/include/megbrain/imperative/basic_values.h @@ -23,7 +23,7 @@ namespace imperative { class GradKey; -using GenericFunction = std::function(Span)>; +using GenericFunction = std::function)>; class ShapeValue final : public MixinValueImpl { public: @@ -97,6 +97,10 @@ public: ValueShape shape() const { return m_shape; } CompNode device() const { return m_storage.comp_node(); } HostTensorStorage storage() const { return m_storage; } + DTypeScalar item() const { + mgb_assert(m_shape.is_scalar()); + return DTypeScalar::make_from_raw(m_dtype, m_storage.ptr()); + } HostTensorND as_nd(bool allow_scalar = false) const; }; diff --git a/imperative/src/include/megbrain/imperative/dispatch.h b/imperative/src/include/megbrain/imperative/dispatch.h index 3a1033ce5523916b4efb714e79049b642213bd65..a81facb13ee0fab7b54f6eaaf0322ad334c90f4d 100644 --- a/imperative/src/include/megbrain/imperative/dispatch.h +++ b/imperative/src/include/megbrain/imperative/dispatch.h @@ -36,11 +36,11 @@ namespace imperative { * * \param op * \param inputs - * \return std::vector + * \return ValueRefList */ -std::vector apply(const Operator& op, Span inputs); -std::vector apply(const OpDef& def, Span inputs); -std::vector apply(Subgraph graph, Span inputs); +ValueRefList apply(const Operator& op, Span inputs); +ValueRefList apply(const OpDef& def, Span inputs); +ValueRefList apply(const Subgraph& graph, Span inputs); template constexpr bool is_all_value_ref_v = @@ -49,7 +49,7 @@ constexpr bool is_all_value_ref_v = template static auto apply(T&& op, TArgs&&... args) - -> std::enable_if_t, std::vector> { + -> std::enable_if_t, ValueRefList> { ValueRef args_arr[sizeof...(TArgs)] = {std::forward(args)...}; return imperative::apply( std::forward(op), @@ -63,7 +63,7 @@ static auto apply(T&& op, TContainer&& container) -> std::enable_if_t< ValueRef> && std::is_same_v && !std::is_same_v, Span>, - std::vector> { + ValueRefList> { return imperative::apply( std::forward(op), Span(container.data(), container.size())); } diff --git a/imperative/src/include/megbrain/imperative/operator.h b/imperative/src/include/megbrain/imperative/operator.h index 587c729721c5e8810053e55c1367ee76f287d9f1..0bce20df66d1eb6b6f51d4f9a1ff82b310fbed0f 100644 --- a/imperative/src/include/megbrain/imperative/operator.h +++ b/imperative/src/include/megbrain/imperative/operator.h @@ -25,6 +25,8 @@ namespace mgb { namespace imperative { +using GenericFunction = std::function)>; + /** * \brief base class for all operators * @@ -49,25 +51,24 @@ public: Kind kind() const { return m_kind; } template - U* as() const { + const U* as() const { if (m_typecode != U::TYPE_CODE) { return nullptr; } - return static_cast(const_cast(this)); + return static_cast(this); } template bool is() const { - return as() != nullptr; + return m_typecode == U::TYPE_CODE; } template bool is() const { return kind() == kKind; } template - U& cast() const { - U* ptr = as(); - mgb_assert(ptr); - return *ptr; + const U& cast() const { + mgb_assert(m_typecode == U::TYPE_CODE); + return static_cast(*this); } virtual std::string to_string() const = 0; @@ -77,9 +78,9 @@ public: * implementation. * * \param inputs - * \return std::vector + * \return ValueRefList */ - virtual std::vector fallback(Span inputs) const; + virtual ValueRefList fallback(Span inputs) const; std::type_index type() const { return registered_types()[m_typecode]; } diff --git a/imperative/src/include/megbrain/imperative/profiler.h b/imperative/src/include/megbrain/imperative/profiler.h index 1084e73c432ae76b49f0f9bf54676dbc9238be85..da50ad449fce09e3577da9e3f8e182dce65921a6 100644 --- a/imperative/src/include/megbrain/imperative/profiler.h +++ b/imperative/src/include/megbrain/imperative/profiler.h @@ -123,7 +123,6 @@ public: template static uint64_t record(TArgs&&... args) { auto& profiler = get_instance(); - // auto& mem_pool = get_mem_pool(); if constexpr (sm_debug) { Status expected = Running; mgb_assert(profiler.m_status.compare_exchange_strong(expected, Recording)); diff --git a/imperative/src/include/megbrain/imperative/transformation.h b/imperative/src/include/megbrain/imperative/transformation.h index 322f3cc5f394e3025f7f49fd655ca95cf593f741..ae0c74a8d7548f4e560e01d822ac47baa89cc764 100644 --- a/imperative/src/include/megbrain/imperative/transformation.h +++ b/imperative/src/include/megbrain/imperative/transformation.h @@ -18,6 +18,7 @@ #include "megbrain/common.h" #include "megbrain/imperative/subgraph.h" +#include "megbrain/imperative/utils/allocator.h" #include "megbrain/imperative/utils/local_ptr.h" #include "megbrain/imperative/utils/span.h" @@ -25,6 +26,7 @@ namespace mgb { namespace imperative { class ValueRef; +class ValueRefList; class Operator; class Transformation; @@ -43,6 +45,7 @@ struct TransformationContext { // TODO: deprecate TransformationGuard, let next_transformation == frames.size() size_t next_transformation = 0; std::vector frames; + ForwardAllocator allocator; }; /** @@ -86,9 +89,9 @@ public: * * \param op * \param inputs - * \return std::vector + * \return ValueRefList */ - virtual std::vector apply_transformation( + virtual ValueRefList apply_transformation( const Operator& op, Span inputs) = 0; virtual ValueRef unwrap(ValueRef value) = 0; @@ -187,11 +190,12 @@ public: std::swap(context.transformations, current_context.transformations); std::swap(context.scopes, current_context.scopes); std::swap(context.next_transformation, current_context.next_transformation); + std::swap(context.allocator, current_context.allocator); } static TransformationContext& get_context(); - friend std::vector apply(const Operator& op, Span inputs); + friend ValueRefList apply(const Operator& op, Span inputs); friend class ValueRef; }; diff --git a/imperative/src/include/megbrain/imperative/transformations/eval.h b/imperative/src/include/megbrain/imperative/transformations/eval.h index ff5b21532ca02b7de63326be65b61089475c44d1..58d874dd0d8d16c55723350b28a8c814931fa57d 100644 --- a/imperative/src/include/megbrain/imperative/transformations/eval.h +++ b/imperative/src/include/megbrain/imperative/transformations/eval.h @@ -23,16 +23,38 @@ public: using Handle = interpreter::Interpreter::Handle; using Channel = interpreter::Interpreter::Channel; + class RAIIHandle : public NonCopyableObj { + private: + Handle m_handle = nullptr; + Channel* m_channel = nullptr; + + public: + RAIIHandle(Handle handle, Channel* channel) + : m_handle(handle), m_channel(channel) {} + ~RAIIHandle() { m_channel->del(m_handle); } + + Handle handle() const { return m_handle; } + + Channel* channel() const { return m_channel; } + }; + private: - std::shared_ptr m_handle = nullptr; + LocalPtr m_handle; std::string m_name; + mutable DTypeValue::ref_t m_dtype; + mutable CompNodeValue::ref_t m_comp_node; + mutable ShapeValue::ref_t m_shape; public: InterpreterInfo() = default; - InterpreterInfo(std::shared_ptr handle, std::string name = {}) + InterpreterInfo(LocalPtr handle, std::string name = {}) : m_handle(handle), m_name(name) {} - std::shared_ptr handle() const { return m_handle; } + const LocalPtr& handle() const { return m_handle; } + + DTypeValue::ref_t dtype() const; + CompNodeValue::ref_t comp_node() const; + ShapeValue::ref_t shape() const; std::string name() const { return m_name; } }; @@ -60,6 +82,7 @@ class InterpreterTransformation final : public Transformation { public: using Interpreter = interpreter::Interpreter; using Handle = Interpreter::Handle; + using SharedHandle = LocalPtr; using Channel = Interpreter::Channel; private: @@ -71,7 +94,14 @@ public: Channel* channel() { return m_channel.get(); } - std::vector apply_transformation( + ValueRefList apply_op(const ApplyOp& apply_op, Span inputs); + + ValueRefList apply_get_attr(const GetAttr& get_attr, Span inputs); + + ValueRefList apply_create_tensor( + const CreateTensor& create_tensor, Span inputs); + + ValueRefList apply_transformation( const Operator& op, Span inputs) override; ValueRef unwrap(ValueRef value) override { @@ -81,14 +111,8 @@ public: std::string name() const override { return "InterpreterTransformation"; } - std::shared_ptr share_handle(Handle handle) { - return std::shared_ptr( - new Handle(handle), [channel = m_channel.get()](Handle* ptr) { - if (ptr) { - channel->del(*ptr); - delete ptr; - } - }); + SharedHandle share_handle(Handle handle) { + return SharedHandle::make(handle, m_channel.get()); } }; diff --git a/imperative/src/include/megbrain/imperative/transformations/grad.h b/imperative/src/include/megbrain/imperative/transformations/grad.h index d5f2e399f156fead4d73a2af6709744b3664e386..0f1db4ce8c9b04502946e1521c1e3f666a2f006b 100644 --- a/imperative/src/include/megbrain/imperative/transformations/grad.h +++ b/imperative/src/include/megbrain/imperative/transformations/grad.h @@ -34,9 +34,7 @@ struct BackwardGraphWithClosure { std::shared_ptr backward_graph, std::shared_ptr op, Span inputs, Span outputs); - void operator()( - std::vector grads, - std::function receiver); + void operator()(ValueRefList grads, std::function receiver); bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; } @@ -50,12 +48,11 @@ struct BackwardGraphWithClosure { struct CustomBackward; -using GradRuleFn = - std::function(Span inputs, CustomBackward&)>; +using GradRuleFn = std::function inputs, CustomBackward&)>; struct CustomBackward { - using BackwardFn = std::function(Span)>; - using BackwardRule = std::function>( + using BackwardFn = std::function)>; + using BackwardRule = std::function( const OpDef&, Span, Span, CustomBackward&)>; BackwardFn m_backward; SmallVector m_input_has_grad; @@ -65,9 +62,7 @@ struct CustomBackward { SmallVector m_output_attrs; public: - void operator()( - std::vector grads, - std::function receiver); + void operator()(ValueRefList grads, std::function receiver); bool input_has_grad(size_t i) { return m_input_has_grad[i]; } bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; } @@ -188,7 +183,7 @@ public: std::string to_string() const override; - bool has_key(std::shared_ptr key) const { return m_key == key; } + bool has_key(const std::shared_ptr& key) const { return m_key == key; } const GradSlotPtr& slot_for(std::shared_ptr key) const { mgb_assert(m_key == key); @@ -287,7 +282,7 @@ public: return false; } - std::vector apply_transformation( + ValueRefList apply_transformation( const Operator& op, Span inputs) override; ValueRef unwrap(ValueRef value) override { @@ -314,7 +309,7 @@ private: public: std::string to_string() const override { return "DetachValue"; } - std::vector fallback(Span inputs) const override { + ValueRefList fallback(Span inputs) const override { return {inputs.as_array<1>()[0]}; } }; @@ -325,7 +320,7 @@ private: public: AttachGrad(std::shared_ptr key) : m_key(key) {} - std::shared_ptr key() { return m_key; } + std::shared_ptr key() const { return m_key; } std::string to_string() const override { return ssprintf("AttachGradValue{key=%s}", m_key->name().c_str()); @@ -339,7 +334,7 @@ private: public: GradBackward(std::shared_ptr key) : m_key(key) {} - std::shared_ptr key() { return m_key; } + std::shared_ptr key() const { return m_key; } std::string to_string() const override { return ssprintf("GradBackwardValue{key=%s}", m_key->name().c_str()); @@ -352,13 +347,13 @@ private: public: IsAttachedTo(std::shared_ptr key) : m_key(key) {} - std::shared_ptr key() { return m_key; } + std::shared_ptr key() const { return m_key; } std::string to_string() const override { return ssprintf("IsAttachedToValue{key=%s}", m_key->name().c_str()); } - std::vector fallback(Span inputs) const override { + ValueRefList fallback(Span inputs) const override { return {BoolValue::make(false)}; } }; @@ -373,9 +368,9 @@ public: SetGrad(std::shared_ptr key, GenericFunction grad_fn, size_t nr_inputs) : m_key(key), m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} - GenericFunction grad_fn() { return m_grad_fn; } + GenericFunction grad_fn() const { return m_grad_fn; } - size_t nr_inputs() { return m_nr_inputs; } + size_t nr_inputs() const { return m_nr_inputs; } std::string to_string() const override { return ssprintf("SetGradValue{key=%s}", m_key->name().c_str()); @@ -388,9 +383,7 @@ public: std::string to_string() const override { return ssprintf("GetGradKeyValue{}"); } - std::vector fallback(Span inputs) const override { - return {ValueRef()}; - } + ValueRefList fallback(Span inputs) const override { return {ValueRef()}; } }; class GetBackwardColsure @@ -401,7 +394,7 @@ private: public: GetBackwardColsure(std::shared_ptr key) : m_key(key) {} - std::shared_ptr key() { return m_key; } + std::shared_ptr key() const { return m_key; } std::string to_string() const override { return ssprintf("GetBackwardClosure{key=%s}", m_key->name().c_str()); diff --git a/imperative/src/include/megbrain/imperative/transformations/lazy.h b/imperative/src/include/megbrain/imperative/transformations/lazy.h index 6a29f41fc7401a4d355ef141dfdfc247a7f8c03b..07ce55019a362058205dcbf2b755bc843d78ac1e 100644 --- a/imperative/src/include/megbrain/imperative/transformations/lazy.h +++ b/imperative/src/include/megbrain/imperative/transformations/lazy.h @@ -81,7 +81,7 @@ public: ComputingGraph::Options& options() { return m_graph->options(); } - std::vector apply_transformation( + ValueRefList apply_transformation( const Operator& op, Span inputs) override; ValueRef unwrap(ValueRef value) override { diff --git a/imperative/src/include/megbrain/imperative/transformations/scalar.h b/imperative/src/include/megbrain/imperative/transformations/scalar.h index a56fa093df724e329302b900e43e3aa593b2cbd8..5db4bfda6574a86f9bb8f933dc78cc8e520e6b78 100644 --- a/imperative/src/include/megbrain/imperative/transformations/scalar.h +++ b/imperative/src/include/megbrain/imperative/transformations/scalar.h @@ -11,6 +11,7 @@ #pragma once +#include "megbrain/imperative/basic_operators.h" #include "megbrain/imperative/dispatch.h" #include "megbrain/imperative/ops/autogen.h" @@ -45,8 +46,10 @@ public: */ class ScalarTransformation final : public Transformation { private: + ShapeValue::ref_t m_empty_shape; // [] public: - std::vector apply_transformation( + ValueRefList apply_get_attr(const GetAttr& get_attr, Span inputs); + ValueRefList apply_transformation( const Operator& op, Span inputs) override; ValueRef unwrap(ValueRef value) override { diff --git a/imperative/src/include/megbrain/imperative/transformations/symbol.h b/imperative/src/include/megbrain/imperative/transformations/symbol.h index 2032ef42aeb47c4bc4bf747645893be1c2ffa1c2..56f76df2cec612630bcdd850dd4b66fad3ee1f55 100644 --- a/imperative/src/include/megbrain/imperative/transformations/symbol.h +++ b/imperative/src/include/megbrain/imperative/transformations/symbol.h @@ -50,7 +50,7 @@ private: public: SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {} - std::vector apply_transformation( + ValueRefList apply_transformation( const Operator& op, Span inputs) override { if (auto* apply_op = op.as()) { SmallVector input_nodes; @@ -58,9 +58,9 @@ public: input_nodes.push_back(input.cast().node()); } auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes); - std::vector outputs; - for (auto&& output_node : output_nodes) { - outputs.push_back(SymbolValue::make(output_node)); + ValueRefList outputs(output_nodes.size()); + for (size_t i = 0; i < output_nodes.size(); ++i) { + outputs[i] = SymbolValue::make(output_nodes[i]); } return outputs; } else if (auto* create_tensor = op.as()) { diff --git a/imperative/src/include/megbrain/imperative/transformations/tangent.h b/imperative/src/include/megbrain/imperative/transformations/tangent.h new file mode 100644 index 0000000000000000000000000000000000000000..bbd21520a7f896852231259b824981f97f5804fa --- /dev/null +++ b/imperative/src/include/megbrain/imperative/transformations/tangent.h @@ -0,0 +1,36 @@ +/** + * \file imperative/src/include/megbrain/imperative/grad.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/imperative/basic_operators.h" +#include "megbrain/imperative/operator.h" +#include "megbrain/imperative/transformation.h" +#include "megbrain/imperative/value.h" + +namespace mgb::imperative { + +struct TangentInfo { + ValueRef value; + ValueRef tangent; +}; + +class TangentTransformation final : public Transformation { +public: + ValueRefList apply_transformation( + const Operator& op, Span inputs) override; + + ValueRef unwrap(ValueRef value) override { mgb_assert(false); } + + std::string name() const override { return "Tangent"; } +}; + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/transformations/trace.h b/imperative/src/include/megbrain/imperative/transformations/trace.h index fceb3ec2f223bf5ff70ae0e1d4813c6e4060bed9..134e01107d59e2d6ed1bcbb20c326665f1d9f19e 100644 --- a/imperative/src/include/megbrain/imperative/transformations/trace.h +++ b/imperative/src/include/megbrain/imperative/transformations/trace.h @@ -126,25 +126,6 @@ public: void on_unwatch() override { value().unwatch(); } }; -class TracedInfo { -private: - size_t m_id = 0; - -public: - TracedInfo() = default; - TracedInfo(size_t id) : m_id(id) {} - size_t id() const { return m_id; } -}; - -class TracedValue final : public MixinValueImpl { -public: - using MixinValueImpl::MixinValueImpl; - - std::string to_string() const override { - return ssprintf("TracedValue{\"id\"=%zu}", id()); - } -}; - /** * \brief trace operation sequence to TraceResult * @@ -202,7 +183,7 @@ public: return value; } - std::vector apply_transformation( + ValueRefList apply_transformation( const Operator& op, Span inputs) override; ValueRef unwrap(ValueRef value) override { @@ -248,6 +229,40 @@ public: std::function data_getter; std::function value_getter; std::function data_setter; + std::function exc_setter; + }; + + class TracedInfo { + private: + size_t m_id = 0; + VarInfo* m_var = nullptr; + VarAccessor* m_accessor = nullptr; + mutable ShapeValue::ref_t m_shape; + mutable DTypeValue::ref_t m_dtype; + mutable CompNodeValue::ref_t m_comp_node; + + public: + TracedInfo() = default; + TracedInfo(size_t id, VarInfo* var, VarAccessor* accessor) + : m_id(id), m_var(var), m_accessor(accessor) {} + size_t id() const { return m_id; } + ShapeValue::ref_t shape() const; + DTypeValue::ref_t dtype() const; + CompNodeValue::ref_t comp_node() const; + const VarAccessor& accessor() const; + + void set_exception(std::exception_ptr exc) const { + m_accessor->exc_setter(exc); + } + }; + + class TracedValue final : public MixinValueImpl { + public: + using MixinValueImpl::MixinValueImpl; + + std::string to_string() const override { + return ssprintf("TracedValue{\"id\"=%zu}", id()); + } }; private: @@ -319,7 +334,14 @@ public: TraceResult::SeqItem& next_instruction(); - std::vector apply_transformation( + ValueRefList apply_op(const ApplyOp& apply_op, Span inputs); + + ValueRefList apply_get_attr(const GetAttr& get_attr, Span inputs); + + ValueRefList apply_create_tensor( + const CreateTensor& create_tensor, Span inputs); + + ValueRefList apply_transformation( const Operator& op, Span inputs) override; void on_unregister() noexcept override; diff --git a/imperative/src/include/megbrain/imperative/utils/allocator.h b/imperative/src/include/megbrain/imperative/utils/allocator.h index 4bc9cca9b9708fd1249d1a3f119c7723cc409dba..f800ec256226fdf16350365466eee0ddfff7f866 100644 --- a/imperative/src/include/megbrain/imperative/utils/allocator.h +++ b/imperative/src/include/megbrain/imperative/utils/allocator.h @@ -36,12 +36,12 @@ private: public: Allocator(pool_type* pool) : m_pool(pool) {} - T* allocate(size_type n) { + pointer allocate(size_type n) { mgb_assert(n == 1); return m_pool->alloc(sizeof(T)); } - void deallocate(pointer* p, size_type n) { + void deallocate(pointer p, size_type n) { mgb_assert(n == 1); m_pool->free(p); } @@ -68,4 +68,114 @@ public: bool operator!=(const ThreadLocalAllocatorAdapter& rhs) const { return false; } }; -} // namespace mgb::imperative \ No newline at end of file +template +class ForwardAllocator { +public: + using value_type = T; + using size_type = std::size_t; + using pointer = T*; + + static constexpr size_t alignment = alignof(T); + static constexpr size_t element_offset = + sizeof(T) + + ((sizeof(T) % alignment) ? 0 : (alignment - sizeof(T) % alignment)); + +private: + struct Block { + std::unique_ptr data; + size_t size = 0; + size_t capacity = 0; + + T* allocate(size_type n) { + static_assert(element_offset > std::max(alignment, sizeof(T))); + size_t begin = size; + size_t end = begin + element_offset * n; + if (end > capacity) { + return nullptr; + } + size = end; + return reinterpret_cast(data.get() + begin); + } + + void reset() { size = 0; } + }; + std::vector m_used; + std::optional m_current; + size_t block_size = 16 * 1024 * 1024; + size_t nr_allocated = 0; + +private: + Block allocate_block() { + block_size *= 2; + return Block{std::make_unique(block_size), 0, block_size}; + } + +public: + pointer allocate(size_type n) { + if (!m_current) { + m_current.emplace(allocate_block()); + } + pointer pointer = m_current->allocate(n); + while (pointer == nullptr) { + m_used.push_back(allocate_block()); + std::swap(m_used.back(), *m_current); + pointer = m_current->allocate(n); + } + nr_allocated++; + return pointer; + } + + void deallocate(pointer p, size_type n) { + mgb_assert(nr_allocated > 0); + nr_allocated--; + } + + void clear() { + if (mgb_likely(m_used.empty())) { + // fastpath + if (m_current) { + m_current->reset(); + } + } else { + // trim + *m_current = allocate_block(); + m_used.clear(); + } + mgb_assert(nr_allocated == 0); + } + + bool operator==(const ForwardAllocator& rhs) const { return &rhs == this; } + bool operator!=(const ForwardAllocator& rhs) const { return &rhs != this; } +}; + +template typename TAllocator> +class ProxyAllocator { +public: + using value_type = T; + using size_type = typename TAllocator::size_type; + using pointer = typename TAllocator::pointer; + +private: + TAllocator* m_impl; + +public: + T* allocate(size_type n) { return m_impl->allocate(n); } + + void deallocate(pointer* p, size_type n) { return m_impl->deallocate(p, n); } + + bool operator==(const ProxyAllocator& rhs) const { + if (m_impl == rhs.m_impl) { + return true; + } else if (bool(m_impl) ^ bool(rhs.m_impl)) { + return false; + } else { + return *m_impl == *rhs.m_impl; + } + } + + bool operator!=(const ProxyAllocator& rhs) const { + return !((*this) == rhs); + } +}; + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/utils/local_ptr.h b/imperative/src/include/megbrain/imperative/utils/local_ptr.h index cc87fd29136a506eb79c10044044c3ab01c6b136..ef15ff22b00eac9dad31e1407e36ff270e50db8e 100644 --- a/imperative/src/include/megbrain/imperative/utils/local_ptr.h +++ b/imperative/src/include/megbrain/imperative/utils/local_ptr.h @@ -16,6 +16,8 @@ #include "megbrain/imperative/utils/mempool.h" #include "megbrain/utils/metahelper.h" +#define MGB_FAT_LOCAL_PTR 0 + namespace mgb::imperative { template @@ -52,6 +54,8 @@ private: } } + size_t ref_count() const { return m_ref_count; } + template friend class LocalPtr; @@ -88,14 +92,24 @@ public: using storage_t = LocalPtrStorage; using pool_t = MemPool; using weak_type = LocalWeakPtr; + using pointer_t = T*; private: storage_t* m_storage = nullptr; +#if MGB_FAT_LOCAL_PTR + pointer_t m_pointer = nullptr; +#endif + + // (m_storage == nullptr) == (m_pointer == nullptr) + void emplace(storage_t* ptr) { if (ptr) { ptr->inc_ref(); m_storage = ptr; +#if MGB_FAT_LOCAL_PTR + m_pointer = ptr->m_pointer; +#endif } } @@ -103,8 +117,22 @@ private: public: LocalPtr() = default; - LocalPtr(const LocalPtr& rhs) { (*this) = rhs; } - LocalPtr(LocalPtr&& rhs) { (*this) = std::move(rhs); } + LocalPtr(const LocalPtr& rhs) { + auto storage = rhs.m_storage; + if (storage) { + storage->inc_ref(); + } + m_storage = storage; +#if MGB_FAT_LOCAL_PTR + m_pointer = rhs.m_pointer; +#endif + } + LocalPtr(LocalPtr&& rhs) { + std::swap(m_storage, rhs.m_storage); +#if MGB_FAT_LOCAL_PTR + std::swap(m_pointer, rhs.m_pointer); +#endif + } LocalPtr& operator=(const LocalPtr& rhs) { if (this == &rhs) { return *this; @@ -115,9 +143,11 @@ public: } if (m_storage) { m_storage->dec_ref(); - // rhs.m_storage may be invalid here } m_storage = storage; +#if MGB_FAT_LOCAL_PTR + m_pointer = rhs.m_pointer; +#endif return *this; } LocalPtr& operator=(LocalPtr&& rhs) { @@ -125,6 +155,9 @@ public: return *this; } std::swap(m_storage, rhs.m_storage); +#if MGB_FAT_LOCAL_PTR + std::swap(m_pointer, rhs.m_pointer); +#endif rhs.reset(); return *this; } @@ -186,10 +219,11 @@ public: T& operator*() const { return *get(); } T* get() const { - if ((!m_storage) || !m_storage->m_pointer) { - return nullptr; - } - return m_storage->m_pointer; +#if MGB_FAT_LOCAL_PTR + return m_pointer; +#else + return m_storage ? m_storage->m_pointer : nullptr; +#endif } T* operator->() const { return get(); } @@ -202,6 +236,9 @@ public: if (m_storage) { m_storage->dec_ref(); m_storage = nullptr; +#if MGB_FAT_LOCAL_PTR + m_pointer = nullptr; +#endif } } diff --git a/imperative/src/include/megbrain/imperative/utils/mempool.h b/imperative/src/include/megbrain/imperative/utils/mempool.h index ca3b4778248eceefa1b3bec26468f7c1881f2480..099e38bccfd94617d80f047b7685030c65d2b8cb 100644 --- a/imperative/src/include/megbrain/imperative/utils/mempool.h +++ b/imperative/src/include/megbrain/imperative/utils/mempool.h @@ -49,8 +49,8 @@ public: instance = std::make_unique>(); sm_instance = instance.get(); } - mgb_assert(sm_instance); } + return *sm_instance; } }; @@ -62,9 +62,9 @@ std::unordered_map>> MemPoolUtils::sm_instances; template -thread_local MemPool* MemPoolUtils::tm_instance; +thread_local MemPool* MemPoolUtils::tm_instance = nullptr; template -MemPool* MemPoolUtils::sm_instance; +MemPool* MemPoolUtils::sm_instance = nullptr; -} // namespace mgb::imperative \ No newline at end of file +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/utils/value_shape.h b/imperative/src/include/megbrain/imperative/utils/value_shape.h index ecf911480c1b73dd382620c7cc5e78e0858ae34c..c33fbe6a83f128f72baaccd2bedbf7ff1f81740b 100644 --- a/imperative/src/include/megbrain/imperative/utils/value_shape.h +++ b/imperative/src/include/megbrain/imperative/utils/value_shape.h @@ -95,6 +95,8 @@ struct ValueShape { } return true; } + + bool operator!=(const ValueShape& rhs) const { return !operator==(rhs); } }; static_assert(sizeof(size_t) >= sizeof(int)); diff --git a/imperative/src/include/megbrain/imperative/value.h b/imperative/src/include/megbrain/imperative/value.h index 0fa7e978a024045b07fcff843b630d9bc1066d64..999d8bf1023bc350e5f3ceec295a11e66e9138ad 100644 --- a/imperative/src/include/megbrain/imperative/value.h +++ b/imperative/src/include/megbrain/imperative/value.h @@ -47,6 +47,17 @@ class StringValue; class Operator; +class ValueRefList; + +template +class Type { +private: + const size_t m_code = T::TYPE_CODE; + +public: + inline size_t code() const { return m_code; } +}; + /** * \brief an smart reference of value * @@ -64,8 +75,9 @@ public: protected: mutable storage_t m_storage; + size_t m_id = std::numeric_limits::max(); - ValueRef(storage_t storage) { m_storage = storage; } + inline ValueRef(storage_t storage); private: /** @@ -75,6 +87,10 @@ private: */ storage_t& storage() const; + const Value* as(size_t typecode) const; + + bool is(size_t typecode) const; + public: ValueRef() = default; @@ -86,7 +102,7 @@ public: * \return false if empty or type of value is not TValue */ template - bool is() const; + inline bool is(Type type = {}) const; /** * \brief try cast value as target type @@ -95,7 +111,7 @@ public: * \return TValue* raw pointer if success, otherwise nullptr */ template - const TValue* as() const; + inline const TValue* as(Type type = {}) const; /** * \brief cast value to target type @@ -104,7 +120,7 @@ public: * \return TValue& reference of value */ template - const TValue& cast() const; + inline const TValue& cast(Type type = {}) const; /** * \brief like as(), but returns TypedValueRef instead @@ -113,7 +129,13 @@ public: * \return TypedValueRef reference if success, otherwise empty reference */ template - inline TypedValueRef as_ref() const; + inline TypedValueRef as_ref(Type type = {}) const; + + template + inline TypedValueRef cast_ref(Type type = {}) const; + + template + void on_cast_failure() const; operator bool() const { return bool(m_storage); } @@ -132,7 +154,7 @@ public: ValueRef unwrap() const; std::string to_string() const; std::string raw_type() const; - uint64_t id() const; + uint64_t id() const { return m_id; } size_t hash() const { return id(); } static ValueRef make(storage_t storage); @@ -144,7 +166,7 @@ public: friend class TypedValueRef; template friend class ValueImpl; - friend std::vector apply(const Operator& op, Span inputs); + friend ValueRefList apply(const Operator& op, Span inputs); }; template <> @@ -244,7 +266,7 @@ public: using ref_t = TypedValueRef; using weak_ref_t = TypedValueWeakRef; - static inline size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); + static inline const size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); /** * \brief helper function for construct a value @@ -254,7 +276,7 @@ public: * \return TypedValueRef reference of value */ template - static TypedValueRef make(TArgs&&... args) { + static MGB_NOINLINE TypedValueRef make(TArgs&&... args) { static_assert(std::is_final_v); return ValueRef::make(LocalPtr::make(std::forward(args)...)); } @@ -279,46 +301,60 @@ public: bool eq(const TMixin& value) const { return ((const TMixin&)*this) == value; } }; +inline ValueRef::ValueRef(storage_t storage) { + // mgb_assert(storage); + m_storage = storage; + m_id = m_storage->m_id; +} + template -const TValue* ValueRef::as() const { +inline const TValue* ValueRef::as(Type type) const { static_assert(std::is_base_of_v, TValue>); - auto storage = this->storage(); - if (!storage) { - return nullptr; - } - if (storage->m_typecode != TValue::TYPE_CODE) { - return nullptr; - } - return static_cast(storage.get()); + return static_cast(as(type.code())); } template -const TValue& ValueRef::cast() const { - auto* ptr = as(); - if (!ptr) { - // if this is ErrorValue, rethrow directly - storage()->try_rethrow(); - mgb_assert( - ptr, "expect type %s, got %s", typeid(TValue).name(), - to_string().c_str()); +inline const TValue& ValueRef::cast(Type type) const { + auto* ptr = as(type); + if (mgb_unlikely(!ptr)) { + on_cast_failure(); } - return *ptr; + return static_cast(*ptr); +} + +template +inline bool ValueRef::is(Type type) const { + return is(type.code()); } template -bool ValueRef::is() const { - auto* ptr = as(); - return ptr != nullptr; +inline TypedValueRef ValueRef::as_ref(Type type) const { + if (!is(type)) { + return {}; + } + return TypedValueRef(*this); } template -TypedValueRef ValueRef::as_ref() const { - if (!is()) { +inline TypedValueRef ValueRef::cast_ref(Type type) const { + if (!m_storage) { return {}; } + if (mgb_unlikely(!is(type))) { + on_cast_failure(); + } return TypedValueRef(*this); } +template +void ValueRef::on_cast_failure() const { + // if this is ErrorValue, rethrow directly + storage()->try_rethrow(); + mgb_assert( + storage()->m_typecode != TValue::TYPE_CODE, "expect type %s, got %s", + typeid(TValue).name(), to_string().c_str()); +} + /** * \brief ValueRef with concrete type, convenient for dereference * @@ -361,11 +397,87 @@ private: public: TypedValueWeakRef(ValueRef value) : ValueWeakRef(value) {} TypedValueWeakRef(ValueWeakRef value) : ValueWeakRef(value) {} - TypedValueRef lock() { return ValueWeakRef::lock().template as_ref(); } + TypedValueRef lock() { + auto value = ValueWeakRef::lock(); + if (value) { + return value.template as_ref(); + } else { + return {}; + } + } }; // TODO: add proxy value type, which is meant to be reset in the end +class ValueRefList { +private: + ValueRef* m_data = nullptr; + size_t m_size = 0; + std::aligned_storage_t m_storage; + +private: + void init(size_t nr_elems); + ValueRef* inline_storage() { return reinterpret_cast(&m_storage); } + +public: + ValueRefList() = default; + ValueRefList(size_t nr_elems); + ValueRefList(ValueRef item); + ValueRefList(std::initializer_list values); + template + ValueRefList(TIterator begin, TIterator end); + ValueRefList(const ValueRefList& rhs); + ValueRefList(ValueRefList&& rhs); + ValueRefList& operator=(const ValueRefList& rhs); + ValueRefList& operator=(ValueRefList&& rhs); + ~ValueRefList(); + void clear(); + + ValueRef* begin() { return m_data; } + ValueRef* end() { return m_data + m_size; } + const ValueRef* cbegin() const { return m_data; } + const ValueRef* cend() const { return m_data + m_size; } + size_t size() const { return m_size; } + ValueRef& at(size_t idx) { + mgb_assert(idx < m_size); + return m_data[idx]; + } + const ValueRef& at(size_t idx) const { + mgb_assert(idx < m_size); + return m_data[idx]; + } + ValueRef& operator[](size_t idx) { return m_data[idx]; } + const ValueRef& operator[](size_t idx) const { return m_data[idx]; } + ValueRef* data() { return m_data; } + const ValueRef* data() const { return m_data; } + bool empty() const { return m_size == 0; } + ValueRef& front() { + mgb_assert(m_size > 1); + return m_data[0]; + } + ValueRef& back() { + mgb_assert(m_size > 1); + return m_data[m_size - 1]; + } +}; + +template +ValueRefList::ValueRefList(TIterator begin, TIterator end) : ValueRefList(end - begin) { + for (size_t i = 0; i < m_size; ++i) { + m_data[i] = *(begin + i); + } +} + +inline ValueRefList::ValueRefList(ValueRef item) : m_data(inline_storage()), m_size(1) { + new (m_data) ValueRef(); + m_data[0] = std::move(item); +} + +/*class ValueRefList : public SmallVector { +public: + using SmallVector::SmallVector; +};*/ + } // namespace imperative } // namespace mgb