diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index 54deb79b40c8c702d3e7d4697f22ec2fa9bb4204..fd9c70e1a724f7550fb0e45ad6dfac6b168b58b0 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -19,6 +19,7 @@ #include "range/v3/all.hpp" +#include "./helper.h" #include "./transformation.h" namespace py = pybind11; @@ -30,9 +31,7 @@ namespace { std::unordered_map, GradKeyWrapper*> grad_key_map; } -GradKeyWrapper::GradKeyWrapper() : m_key(std::make_shared()) { - grad_key_map[m_key] = this; -} +GradKeyWrapper::GradKeyWrapper() {} void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { if (nargs != 2) { @@ -77,8 +76,8 @@ pybind11::function GradKeyWrapper::get_backward_closure( for (auto&& tensor : tensors) { args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); } - auto closure = imperative::apply(GetBackwardColsure(self->m_key), args)[0] - .as(); + auto closure_value = imperative::apply(GetBackwardColsure(self->m_key), args)[0]; + auto closure = closure_value.as_ref(); auto py_function = [closure](std::vector tensors) { std::vector args; for (auto* tw : tensors) { @@ -90,11 +89,14 @@ pybind11::function GradKeyWrapper::get_backward_closure( } PyObject* GradKeyWrapper::get_name() { - return py::cast(m_key->name()).release().ptr(); + return py::cast(m_name).release().ptr(); } void GradKeyWrapper::set_name(py::handle name) { - m_key->name(py::cast(name)); + m_name = py::cast(name); + if (m_key) { + m_key->name(m_name); + } } PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { @@ -115,7 +117,10 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { } void GradKeyWrapper::enter() { - m_transformation = std::make_shared(m_key); + m_transformation = std::make_shared(); + m_key = m_transformation->key(); + m_key->name(m_name); + grad_key_map[m_key] = this; TransformationManager::get_instance().register_at( m_transformation); } @@ -123,6 +128,8 @@ void GradKeyWrapper::enter() { void GradKeyWrapper::exit() { TransformationManager::get_instance().unregister( m_transformation); + grad_key_map.erase(m_key); + m_key = {}; m_transformation.reset(); } @@ -138,8 +145,6 @@ GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr key) { return grad_key_map.at(key); } -GradKeyWrapper::~GradKeyWrapper() { - grad_key_map.erase(m_key); -} +GradKeyWrapper::~GradKeyWrapper() {} } // namespace mgb::imperative::python diff --git a/imperative/python/src/grad.h b/imperative/python/src/grad.h index d5175aaf87b2f8f56ffd8d352f2b8e6daddb4e74..914b0cedd465eb374e9a4c365453e672aa4b630f 100644 --- a/imperative/python/src/grad.h +++ b/imperative/python/src/grad.h @@ -26,6 +26,7 @@ struct GradKeyWrapper : NonCopyableObj { using wrap_t = pyext17::wrap; static constexpr auto tp_name = pybind11::detail::_("GradKey"); + std::string m_name; std::shared_ptr m_key; std::shared_ptr m_transformation; diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index ce41a96f26a47cfde46b547c06267b24687ee306..0e37499d3b9a9429fc4ddbb957588580c638d898 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -117,7 +117,7 @@ std::optional elemwise_grad_rule( maker.backward([shapes = std::move(input_shapes)](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - ValueRefList ret(2); + SmallVector ret(2); if (!grad) { return ret; } @@ -147,7 +147,7 @@ std::optional reshape_grad_rule( maker.backward([shapes = std::move(input_shapes)](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - ValueRefList ret(2); + SmallVector ret(2); if (!grad) { return ret; } @@ -180,7 +180,7 @@ std::optional subtensor_grad_rule( grad_op_ = std::move(grad_op)](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - ValueRefList ret(1); + SmallVector ret(1); if (grad && inputs[0]) { ValueRefList args_(inputs.size() + 1); auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); @@ -215,7 +215,7 @@ std::optional indexingMultiAxisVec_grad_rule( grad_op_ = std::move(grad_op)](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - ValueRefList ret(1); + SmallVector ret(1); if (grad && inputs[0]) { ValueRefList args_(inputs.size() + 1); auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); @@ -251,7 +251,7 @@ std::optional reduce_grad_rule( maker.backward([shapes = std::move(input_shapes)](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - ValueRefList ret(1); + SmallVector ret(1); if (grad && shapes[0]) { ret[0] = broadcast_to(grad, shapes[0]); } @@ -274,7 +274,7 @@ std::optional addAxis_grad_rule( maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - ValueRefList ret(1); + SmallVector ret(1); if (grad && flag_) { ret[0] = imperative::apply(*grad_op_, grad)[0]; } @@ -297,7 +297,7 @@ std::optional removeAxis_grad_rule( maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - ValueRefList ret(1); + SmallVector ret(1); if (grad && flag_) { ret[0] = imperative::apply(*grad_op_, grad)[0]; } @@ -316,7 +316,7 @@ std::optional fastpathcopy_grad_rule( maker.backward([](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - ValueRefList ret(1); + SmallVector ret(1); if (grad) { ret[0] = grad; } diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index b29b83cedc9463f969a07f021f34feaebde4be8c..a56c8ca65ab0936d0b3d32437d3c980459501f1a 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -56,42 +56,44 @@ WeakKeyMap module_trace_info_map; struct SymbolVarContext { TransformationContext context; - cg::ComputingGraph* graph; + std::shared_ptr symbol_tsf; + std::shared_ptr scalar_tsf; - SymbolVarContext(cg::ComputingGraph* graph) : graph(graph) { + SymbolVarContext(cg::ComputingGraph* graph) { + symbol_tsf = std::make_shared(graph); + scalar_tsf = std::make_shared(); Transformation::swap_context(context); } void init() { - std::make_shared(graph)->register_at( - Transformation::top()); - std::make_shared()->register_at(Transformation::top()); + symbol_tsf->register_at(Transformation::top()); + scalar_tsf->register_at(Transformation::top()); } - ~SymbolVarContext() { Transformation::swap_context(context); } -}; + ValueRef symvar2val(py::handle py_symbol_var) { + auto* symbol_var = py_symbol_var.cast(); + ValueRef value = symbol_tsf->value_type().make(symbol_var->m_node); + if (symbol_var->is_scalar) { + value = scalar_tsf->value_type().make(value); + } + return value; + } -ValueRef symvar2val(py::handle py_symbol_var) { - auto* symbol_var = py_symbol_var.cast(); - ValueRef value = SymbolValue::make(symbol_var->m_node); - if (symbol_var->is_scalar) { - value = ScalarValue::make(value); + py::object val2symvar(py::handle typeobj, ValueRef value) { + bool is_scalar = false; + if (auto* scalar_value = value.as(scalar_tsf->value_type())) { + value = scalar_value->value(); + is_scalar = true; + } + auto* node = value.cast(symbol_tsf->value_type()).node(); + auto py_symbol_var = + typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic)); + py_symbol_var.cast()->is_scalar = is_scalar; + return py_symbol_var; } - return value; -} -py::object val2symvar(py::handle typeobj, ValueRef value) { - bool is_scalar = false; - if (auto* scalar_value = value.as()) { - value = scalar_value->value(); - is_scalar = true; - } - auto* node = value.cast().node(); - auto py_symbol_var = - typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic)); - py_symbol_var.cast()->is_scalar = is_scalar; - return py_symbol_var; -} + ~SymbolVarContext() { Transformation::swap_context(context); } +}; } // namespace @@ -130,19 +132,21 @@ PyObject* py_apply( auto op = py::handle(py_op).cast>(); SmallVector tensors(nargs); - if (py::isinstance(py::handle(args[0]))) { + bool is_symbol_var = (!TensorWrapper::try_cast(args[0])) && + py::isinstance(py::handle(args[0])); + if (is_symbol_var) { // swap to a special context to reuse scalar handle SymbolVarContext context( py::handle(args[0]).cast()->m_node->owner_graph()); context.init(); for (size_t i = 0; i < nargs; ++i) { - tensors[i] = symvar2val(args[i]); + tensors[i] = context.symvar2val(args[i]); } auto outputs = imperative::apply(*op, tensors); auto ret = pybind11::tuple(outputs.size()); auto typeobj = py::handle(args[0]).get_type(); for (size_t i = 0; i < outputs.size(); ++i) { - ret[i] = val2symvar(typeobj, outputs[i]); + ret[i] = context.val2symvar(typeobj, outputs[i]); } return ret.release().ptr(); } @@ -161,7 +165,7 @@ PyObject* py_apply( } } - auto outputs = imperative::apply(*op, tensors); + auto outputs = [&] { return imperative::apply(*op, tensors); }(); size_t nout = outputs.size(); auto ret = py::tuple(nout); for (size_t i = 0; i < nout; ++i) { @@ -1573,9 +1577,9 @@ void init_tensor(py::module m) { SymbolVarContext context(graph); context.init(); auto output = reduce_to_scalar( - *op.cast>(), symvar2val(tensor)); + *op.cast>(), context.symvar2val(tensor)); auto typeobj = tensor.get_type(); - return val2symvar(typeobj, output); + return context.val2symvar(typeobj, output); } else { auto* tw = TensorWrapper::try_cast(tensor.ptr()); auto output = reduce_to_scalar( diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index 024578d39384db1a8ad21d2341d47cf8b5ad4489..37e66c9fe480c5cf304be7fb4f3c3c718559baaa 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -67,10 +67,9 @@ struct TransformationManager { } }; -class PyValue final - : public MixinValueImpl { +class PyValue final : public PrimitiveValue { public: - using MixinValueImpl::MixinValueImpl; + using PrimitiveValue::PrimitiveValue; std::string to_string() const { return pybind11::str((const pybind11::object&)*this).cast(); diff --git a/imperative/src/impl/basic_operators.cpp b/imperative/src/impl/basic_operators.cpp index 234da0d08cc2516c1bc3ca1278fa9ebff33d2d6c..d7b57e55cfc3a457dd1e244c6c8159ccf47662cc 100644 --- a/imperative/src/impl/basic_operators.cpp +++ b/imperative/src/impl/basic_operators.cpp @@ -63,7 +63,7 @@ auto CreateTensor::parse(Span inputs) const -> Args { MegBrainError, "unknown input type, expects HostStorage or DeviceStorage, got " "%s", - input.name()->c_str()); + input.to_string().c_str()); } } mgb_assert( diff --git a/imperative/src/impl/basic_values.cpp b/imperative/src/impl/basic_values.cpp index 19caa3859f9f0d67defa40e066e1ece5eaaea802..01b8b3e70ac22b40f57735d195f1406bef2486a3 100644 --- a/imperative/src/impl/basic_values.cpp +++ b/imperative/src/impl/basic_values.cpp @@ -12,7 +12,7 @@ std::string CompNodeValue::to_string() const { } std::string BoolValue::to_string() const { - return (*m_value) ? "true" : "false"; + return (*this) ? "true" : "false"; } std::string HostStorage::to_string() const { @@ -26,10 +26,10 @@ std::string DeviceStorage::to_string() const { std::string HostValue::to_string() const { return ssprintf( "HostValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(), - m_dtype.name(), m_shape.to_string().c_str()); + dtype().name(), shape().to_string().c_str()); } -HostTensorND HostValue::as_nd(bool allow_scalar) const { +HostTensorND HostTensor::as_nd(bool allow_scalar) const { HostTensorND nd; TensorShape tensor_shape; if (m_shape.is_scalar()) { @@ -45,10 +45,10 @@ HostTensorND HostValue::as_nd(bool allow_scalar) const { std::string DeviceValue::to_string() const { return ssprintf( "DeviceValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(), - m_dtype.name(), m_shape.to_string().c_str()); + dtype().name(), shape().to_string().c_str()); } -DeviceTensorND DeviceValue::as_nd(bool allow_scalar) const { +DeviceTensorND DeviceTensor::as_nd(bool allow_scalar) const { DeviceTensorND nd; TensorShape tensor_shape; if (m_shape.is_scalar()) { diff --git a/imperative/src/impl/dispatch.cpp b/imperative/src/impl/dispatch.cpp index 5b0d131430cb49006aea5d6fb17171a4cde3c96a..e132a8e10a3d63d6d57e311dfa8941cc40a9361e 100644 --- a/imperative/src/impl/dispatch.cpp +++ b/imperative/src/impl/dispatch.cpp @@ -19,46 +19,18 @@ namespace mgb { namespace imperative { -namespace { -MGB_NOINLINE void copy_outputs( - ForwardAllocator& allocator, ValueRefList& outputs) { - size_t nr_outputs = outputs.size(); - if (mgb_likely(nr_outputs == 1)) { - ValueRef output_copy; - output_copy = outputs[0]; - allocator.clear(); - outputs = ValueRefList({output_copy}); - } else if (!outputs.empty()) { - SmallVector outputs_copy(nr_outputs); - for (size_t i = 0; i < nr_outputs; ++i) { - outputs_copy[i] = outputs[i]; - } - outputs.clear(); - allocator.clear(); - outputs = {outputs_copy.begin(), outputs_copy.end()}; - } else { - allocator.clear(); - } -} -} // namespace - ValueRefList apply(const Operator& op, Span inputs) { auto& context = Transformation::get_context(); size_t& depth = context.next_transformation; - bool top = depth == 0; - auto outputs = ([&] { - if (mgb_unlikely(depth >= context.transformations.size())) { - return op.fallback(inputs); - } else { - auto& transformation = *context.transformations[depth++]; - CleanupGuard _{[&] { --depth; }}; - return transformation.apply_transformation(op, inputs); - } - })(); - if (mgb_unlikely(top)) { - copy_outputs(context.allocator, outputs); + // TODO: add fallback transformation + bool fallback = depth >= context.transformations.size(); + if (mgb_unlikely(fallback)) { + return op.fallback(inputs); + } else { + auto& transformation = *context.transformations[depth++]; + CleanupGuard _{[&] { --depth; }}; + return transformation.apply_transformation(op, inputs); } - return outputs; } ValueRefList apply(const OpDef& def, Span inputs) { @@ -66,12 +38,7 @@ ValueRefList apply(const OpDef& def, Span inputs) { } ValueRefList apply(const Subgraph& graph, Span inputs) { - SmallVector inputs_storage; - for (size_t i = 0; i < inputs.size(); ++i) { - inputs_storage.push_back(inputs[i]); - } - auto apply_functor = [](std::shared_ptr op, SmallVector inputs, - size_t) { + auto apply_functor = [](std::shared_ptr op, Span inputs, size_t) { auto outputs = imperative::apply(*op, inputs); return SmallVector(outputs.begin(), outputs.end()); }; @@ -93,7 +60,7 @@ ValueRefList apply(const Subgraph& graph, Span inputs) { HostStorage::make(host_value.storage()), DeviceStorage::make(device_value.storage()))[0]; }; - auto outputs = graph.apply(inputs_storage, apply_functor, make_const); + auto outputs = graph.apply(inputs, apply_functor, make_const); return ValueRefList{outputs.begin(), outputs.end()}; } diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 48a62b2d73aacd626a9282d80a60867fb7b57e61..738627bb506e7d214fc5ae9c28c63bbf2c5b0d3c 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -331,6 +331,7 @@ void ChannelImpl::dispatch_kernel( cmd.inputs = std::move(input_infos); cmd.outputs.reserve(output_descs.size()); outputs->reserve(output_descs.size()); + for (int i = 0; i < output_descs.size(); ++i) { auto&& desc = output_descs[i]; auto info = alloc(); @@ -730,7 +731,8 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { input_descs.push_back({{{}, input->dtype()}, input->comp_node()}); } auto forward_graph = OpDef::make_forward_graph(def, input_descs); - auto outputs = forward_graph.apply(inputs, apply_functor, const_functor); + auto outputs = forward_graph.apply( + inputs, apply_functor, const_functor); return outputs; } return OpDef::apply_on_physical_tensor(def, inputs); diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index 0a73d1963528b1a05a6bf782721200ebf7428ad3..ba5b1e6a6c0ebe358a862e98a8cf1dab0dca214e 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -11,6 +11,7 @@ #include "megbrain/imperative/opr_utility.h" #include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/utils/stats.h" #include "megbrain/opr/basic_arith.h" #include "megbrain/opr/utility.h" @@ -101,7 +102,7 @@ void apply_on_device_tensornd( const OpDef& def, const SmallVector& inputs, SmallVector* outputs) { auto&& op_def = def.cast_final_safe(); - auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); + auto&& trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); mgb_assert( inputs.size() == trait.arity, "%s expects %u inputs; got %zu actually", trait.name, trait.arity, inputs.size()); diff --git a/imperative/src/impl/subgraph_detail.cpp b/imperative/src/impl/subgraph_detail.cpp index cb4578b07b2fc7f5ac85d697efa72d65204b24a0..13149de68b43981e653080e6d22f4fd5f9244851 100644 --- a/imperative/src/impl/subgraph_detail.cpp +++ b/imperative/src/impl/subgraph_detail.cpp @@ -36,7 +36,7 @@ VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { .node(); }; auto subgraph = def.trait()->make_forward_graph(def, input_descs); - auto outputs = subgraph.apply(inputs, apply_functor, const_functor); + auto outputs = subgraph.apply(inputs, apply_functor, const_functor); return outputs; } @@ -56,7 +56,8 @@ std::tuple, bool> infer_output_attrs_fallible( value->layout(), value->comp_node(), value->get_value().proxy_to_default_cpu()}; }; - auto outputs = subgraph.apply(inputs, apply_functor, const_functor); + auto outputs = + subgraph.apply(inputs, apply_functor, const_functor); return {outputs, all_validated}; } @@ -72,7 +73,7 @@ SmallVector apply_on_physical_tensor( return OpDef::apply_on_physical_tensor(*op, inputs); }; auto const_functor = [&](const TensorPtr& value) { return value; }; - auto outputs = subgraph.apply(inputs, apply_functor, const_functor); + auto outputs = subgraph.apply(inputs, apply_functor, const_functor); return outputs; } @@ -94,7 +95,7 @@ static EncodedSubgraph make_backward_graph_from_forward( }; GradContext grad_context{accum_grad}; auto input_vars = builder.write_inputs(inputs); - auto outputs = forward_graph.apply( + auto outputs = forward_graph.apply( input_vars, std::bind(&decltype(builder)::write_expr, &builder, _1, _2, _3), [&](TensorPtr constant) { return builder.write_constant( @@ -102,7 +103,7 @@ static EncodedSubgraph make_backward_graph_from_forward( }); size_t nr_outputs = outputs.size(); auto apply_mask = [](auto&& values, SmallVector mask) { - mgb_assert(mask.size() == values.size(), ""); + mgb_assert(mask.size() == values.size()); std::decay_t results; for (size_t i = 0; i < mask.size(); ++i) { if (mask[i]) { @@ -143,7 +144,7 @@ static EncodedSubgraph make_backward_graph_from_forward( return builder.write_constant( constant, {constant->layout(), constant->comp_node()}); }; - return bg.apply(grad_inputs, apply_functor, const_functor); + return bg.apply(grad_inputs, apply_functor, const_functor); }); builder.add_outputs(grad_context.get_grads(input_vars)); for (size_t i = 0; i < nr_outputs; ++i) { diff --git a/imperative/src/impl/transformations/eval.cpp b/imperative/src/impl/transformations/eval.cpp index 87a821aa0a9972b341d85dc2585f90969c470f9a..5e357e7589adaa34d97a13b2cf05a2c03db11467 100644 --- a/imperative/src/impl/transformations/eval.cpp +++ b/imperative/src/impl/transformations/eval.cpp @@ -10,20 +10,19 @@ */ #include "megbrain/imperative/transformations/eval.h" -#include "megbrain/imperative/transformations/grad.h" #include "megbrain/imperative/utils/stats.h" namespace mgb { namespace imperative { -DTypeValue::ref_t InterpreterInfo::dtype() const { +DTypeValue::ref_t InterpreterValue::dtype() const { if (!m_dtype) { m_dtype = DTypeValue::make(handle()->channel()->get_dtype(handle()->handle())); } return m_dtype; } -CompNodeValue::ref_t InterpreterInfo::comp_node() const { +CompNodeValue::ref_t InterpreterValue::comp_node() const { if (!m_comp_node) { m_comp_node = CompNodeValue::make( handle()->channel()->get_device(handle()->handle())); @@ -31,7 +30,7 @@ CompNodeValue::ref_t InterpreterInfo::comp_node() const { return m_comp_node; } -ShapeValue::ref_t InterpreterInfo::shape() const { +ShapeValue::ref_t InterpreterValue::shape() const { if (!m_shape) { m_shape = ShapeValue::make( ValueShape::from(handle()->channel()->get_shape(handle()->handle()))); @@ -51,21 +50,22 @@ ValueRefList InterpreterTransformation::apply_op( } }}; for (auto input : inputs) { - input_handles.push_back(input.cast().handle()->handle()); + input_handles.push_back(input.cast(m_value_type).handle()->handle()); } output_handles = m_channel->apply_op(apply_op.op().shared_from_this(), input_handles); ValueRefList outputs(output_handles.size()); for (size_t i = 0; i < output_handles.size(); ++i) { - outputs[i] = InterpreterValue::make(share_handle(output_handles[i])); + outputs[i] = m_value_type.make(share_handle(output_handles[i])); output_handles[i] = nullptr; } + output_handles.clear(); return outputs; } ValueRefList InterpreterTransformation::apply_get_attr( const GetAttr& get_attr, Span inputs) { - auto& input = inputs.item().cast(); + auto& input = inputs.item().cast(m_value_type); ValueRef output; switch (get_attr.attr()) { case GetAttr::DType: @@ -98,10 +98,10 @@ ValueRefList InterpreterTransformation::apply_create_tensor( if (!args.device) { // implies H2D mgb_assert(args.host, "neither host and device value is valid"); - return {InterpreterValue::make(share_handle( + return {m_value_type.make(share_handle( m_channel->put(*args.host, args.kind == CreateTensor::Unique)))}; } else { - return {InterpreterValue::make(share_handle(m_channel->put( + return {m_value_type.make(share_handle(m_channel->put( *args.device, args.host ? *args.host : HostTensorND())))}; } } @@ -119,7 +119,7 @@ ValueRefList InterpreterTransformation::apply_transformation( } else if (auto* create_tensor = op.as()) { return apply_create_tensor(*create_tensor, inputs); } else if (auto* dtr_command = op.as()) { - auto handle = inputs[0].cast().handle()->handle(); + auto handle = inputs[0].cast(m_value_type).handle()->handle(); switch (dtr_command->kind()) { case DTRCommand::Drop: m_channel->drop(handle); @@ -129,10 +129,10 @@ ValueRefList InterpreterTransformation::apply_transformation( } return {}; } else if (auto* rename_value = op.as()) { - auto& input = inputs[0].cast(); - return {InterpreterValue::make(input.handle(), rename_value->name())}; + auto& input = inputs[0].cast(m_value_type); + return {m_value_type.make(input.handle(), rename_value->name())}; } else if (op.is()) { - auto name = inputs[0].cast().name(); + auto name = inputs[0].cast(m_value_type).name(); if (!name.empty()) { return {StringValue::make(name)}; } else { diff --git a/imperative/src/impl/transformations/grad.cpp b/imperative/src/impl/transformations/grad.cpp index 50d0fad0c81ecb0a658bea49690bd3fc12c716f0..49b386ed54163be0cfd51c34a0baeacf4ddf43a9 100644 --- a/imperative/src/impl/transformations/grad.cpp +++ b/imperative/src/impl/transformations/grad.cpp @@ -68,7 +68,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( size_t count = std::count_if( save_for_backward.begin(), save_for_backward.end(), ranges::identity{}); if (!backward_graph->precomp.empty()) { - ValueRefList inputs_and_outputs(inputs.size() + outputs.size()); + SmallVector inputs_and_outputs(inputs.size() + outputs.size()); auto it = inputs_and_outputs.begin(); for (auto&& input : inputs) { *it++ = input; @@ -94,7 +94,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( } } void BackwardGraphWithClosure::operator()( - ValueRefList grads, std::function receiver) { + Span grads, std::function receiver) { ValueRef args[closure.size() + grads.size()]; size_t nargs = 0; for (auto&& value : closure) { @@ -114,7 +114,9 @@ void BackwardGraphWithClosure::operator()( if (null_grad) { return; } - auto igrads = imperative::apply(backward_graph->backward, Span(args, nargs)); + auto igrads_ = imperative::apply(backward_graph->backward, Span(args, nargs)); + SmallVector igrads = {igrads_.begin(), igrads_.end()}; + igrads_.clear(); auto&& iter = igrads.begin(); for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) { if (p) { @@ -125,7 +127,7 @@ void BackwardGraphWithClosure::operator()( } void CustomBackward::operator()( - ValueRefList grads, std::function receiver) { + Span grads, std::function receiver) { size_t nargs = grads.size(); ValueRef args[nargs]; for (size_t i = 0; i < nargs; ++i) { @@ -206,7 +208,7 @@ void GradKey::backward() { mgb_throw(AssertionError, "invalid backward"); } else { mgb_assert(grad_fn->m_slots.size() > 0); - ValueRefList grads (grad_fn->m_slots.size()); + SmallVector grads (grad_fn->m_slots.size()); auto iter = grads.begin(); for (auto&& slot : grad_fn->m_slots) { *iter++ = slot.m_grad; @@ -231,11 +233,9 @@ void GradKey::backward() { GradValue::ref_t GradKey::attach( ValueRef tensor, std::function callback) { - auto grad_value = tensor.as_ref(); - if (grad_value && grad_value->has_key(shared_from_this())) { - mgb_assert( - !tensor.cast().slot_for(shared_from_this())->callback, - "callback exists"); + auto grad_value = tensor.as_ref(m_value_type); + if (grad_value) { + mgb_assert(!tensor.cast(m_value_type).slot()->callback, "callback exists"); } else { GradSlotPtr grad_slot; auto& grad_fn = grad_slot.m_fn; @@ -243,9 +243,9 @@ GradValue::ref_t GradKey::attach( grad_fn->m_key = shared_from_this(); grad_fn->m_slots.resize(1); grad_slot.m_index = 0; - grad_value = GradValue::make(tensor, shared_from_this(), grad_slot); + grad_value = m_value_type.make(tensor, shared_from_this(), grad_slot); } - grad_value->slot_for(shared_from_this()).m_fn->m_slots[0].callback = callback; + grad_value->slot().m_fn->m_slots[0].callback = callback; return grad_value; } @@ -263,7 +263,7 @@ void GradKey::freeze() { ValueRefList GradTransformation::apply_transformation( const Operator& op, Span inputs) { auto fallback = [&] { - ValueRefList unwrapped_inputs(inputs.size()); + SmallVector unwrapped_inputs(inputs.size()); { // overhead for (size_t i = 0; i < inputs.size(); ++i) { @@ -367,7 +367,7 @@ ValueRefList GradTransformation::apply_transformation( for (size_t i = 0; i < inputs.size(); ++i) { if (backward.input_has_grad(i) && require_grads[i]) { auto& input_grad_slot = - inputs[i].cast().slot_for(m_key); + inputs[i].cast(m_value_type).slot(); grad_fn->m_dests.emplace_back(input_grad_slot); grad_fn->m_dests.back().m_producer_record.insert_after( input_grad_slot->m_producer_head); @@ -378,7 +378,7 @@ ValueRefList GradTransformation::apply_transformation( for (size_t i = 0; i < outputs.size(); ++i) { if (backward.output_requires_grad(i)) { // little overhead: Value::make - auto grad_value = GradValue::make(outputs[i], m_key, GradSlotPtr{grad_fn, i}); + auto grad_value = m_value_type.make(outputs[i], m_key, GradSlotPtr{grad_fn, i}); outputs[i] = record_grad(grad_value); } } @@ -435,7 +435,10 @@ ValueRefList GradTransformation::apply_transformation( backward.m_input_has_grad = SmallVector(nr_inputs, true); backward.m_output_attrs = SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true}); - backward.m_backward = set_grad->grad_fn(); + backward.m_backward = [fn = set_grad->grad_fn()](Span inputs) { + auto result = fn(inputs); + return SmallVector(result.begin(), result.end()); + }; ValueRefList outputs(nr_outputs); grad_fn->m_key = m_key; grad_fn->m_slots.resize(nr_outputs); @@ -454,10 +457,10 @@ ValueRefList GradTransformation::apply_transformation( auto& output = outputs_[i]; auto grad_value = as_grad_value(output); if (grad_value) { - grad_value = GradValue::make( + grad_value = m_value_type.make( grad_value->m_value, m_key, GradSlotPtr(grad_fn, i)); } else { - grad_value = GradValue::make(output, m_key, GradSlotPtr(grad_fn, i)); + grad_value = m_value_type.make(output, m_key, GradSlotPtr(grad_fn, i)); } outputs[i] = record_grad(grad_value); } @@ -485,8 +488,7 @@ ValueRefList GradTransformation::apply_transformation( mgb_assert(inputs.size() == 1); if (auto&& grad_value = as_grad_value(inputs[0])) { auto output = imperative::apply(op, grad_value->m_value)[0]; - auto grad_output = GradValue::make( - output, grad_value->key(), grad_value->slot_for(m_key)); + auto grad_output = m_value_type.make(output, m_key, grad_value->slot()); return {record_grad(grad_output)}; } else { return imperative::apply(op, inputs); @@ -502,7 +504,7 @@ GenericFunction GradTransformation::make_backward_closure(Span ys) { std::vector y_slots; for (auto&& y : ys) { if (auto&& grad_value = as_grad_value(y)) { - y_slots.push_back(grad_value->slot_for(grad_key)); + y_slots.push_back(grad_value->slot()); } else { y_slots.emplace_back(); } diff --git a/imperative/src/impl/transformations/lazy.cpp b/imperative/src/impl/transformations/lazy.cpp index 403c198729ee38efb56d5f539cdeded7a6e2485a..89c4b35520624cef9e303158505f6a2614fabc14 100644 --- a/imperative/src/impl/transformations/lazy.cpp +++ b/imperative/src/impl/transformations/lazy.cpp @@ -32,7 +32,7 @@ ValueRefList LazyEvalTransformation::apply_transformation( bool require_link = mm_io_ops.count(op_val->op().dyn_typeinfo()); VarNodeArray input_nodes; for (auto&& input : inputs) { - if (auto* input_node = input.as()) { + if (auto* input_node = input.as(m_value_type)) { input_nodes.push_back(input_node->node()); } else { // ImmutableTensor has empty shape issues @@ -112,7 +112,7 @@ ValueRefList LazyEvalTransformation::apply_transformation( return {record_var(node)}; } } else if (auto* get_attr = op.as()) { - if (auto* lazy_val = inputs.item().as()) { + if (auto* lazy_val = inputs.item().as(m_value_type)) { switch (get_attr->attr()) { case GetAttr::DType: return {DTypeValue::make(lazy_val->node()->dtype())}; @@ -167,14 +167,14 @@ ValueRefList LazyEvalTransformation::apply_transformation( return imperative::apply(op, inputs); } } else if (auto* rename_value = op.as()) { - if (auto* lazy_val = inputs.item().as()) { + if (auto* lazy_val = inputs.item().as(m_value_type)) { return {record_var( lazy_val->node(), lazy_val->bound_data(), rename_value->name())}; } else { return imperative::apply(op, inputs); } } else if (op.is()) { - if (auto* lazy_val = inputs.item().as()) { + if (auto* lazy_val = inputs.item().as(m_value_type)) { auto name = lazy_val->name(); if (!name.empty()) { return {StringValue::make(lazy_val->name())}; @@ -255,7 +255,7 @@ void LazyEvalTransformation::on_unregister() noexcept { DeviceStorage::make(data.storage()))[0]); } for (auto&& lazy_val : lazy_vals) { - if (lazy_val.is()) { + if (lazy_val.is(m_value_type)) { std::string repr = ssprintf("lazy eval failed for %s", lazy_val->to_string().c_str()); mgb_log_debug("%s", repr.c_str()); diff --git a/imperative/src/impl/transformations/scalar.cpp b/imperative/src/impl/transformations/scalar.cpp index 338089bc11b7de5cd92bc95aa239f96f76f7ee2f..ea1ba6725b8a762748e2af2cf211407a6b56252a 100644 --- a/imperative/src/impl/transformations/scalar.cpp +++ b/imperative/src/impl/transformations/scalar.cpp @@ -20,7 +20,8 @@ namespace imperative { namespace { -using ScalarRule = ValueRefList (*)(const OpDef&, Span, Span); +using ScalarRule = ValueRefList (*)( + const OpDef&, Span, Span, const Type&); static std::unordered_map scalar_rules; ValueRef make_scalar_shape(CompNode device) { @@ -41,17 +42,22 @@ bool is_scalar_shape(ValueRef shape) { return *shape_of_shape == ValueShape{0}; } -template , Span)> +template < + typename T, + ValueRefList (*rule)( + const T&, Span, Span, const Type&)> void register_scalar_rule() { scalar_rules[T::typeinfo()] = [](const OpDef& def, Span inputs, - Span inputs_mask) { - return (*rule)(def.cast_final_safe(), inputs, inputs_mask); + Span inputs_mask, + const Type& value_type) { + return (*rule)(def.cast_final_safe(), inputs, inputs_mask, value_type); }; } template ValueRefList elemwise_rule( - const TOpDef& op_def, Span inputs, Span inputs_mask) { + const TOpDef& op_def, Span inputs, Span inputs_mask, + const Type& scalar_type) { if constexpr (nr_inputs != 0) { mgb_assert(inputs.size() == inputs.size(), "inputs size mismatch"); } @@ -63,27 +69,29 @@ ValueRefList elemwise_rule( } auto outputs = imperative::apply(op_def, inputs); if (all_scalar) { - outputs[0] = ScalarValue::make(outputs[0]); + outputs[0] = scalar_type.make(outputs[0]); } return outputs; } ValueRefList remove_axis_rule( - const RemoveAxis& remove_axis, Span inputs, Span inputs_mask) { + const RemoveAxis& remove_axis, Span inputs, Span inputs_mask, + const Type& scalar_type) { mgb_assert(!inputs_mask.item()); bool is_scalar = inputs.item().shape()->ndim == remove_axis.axis.size(); if (is_scalar && remove_axis.axis.size() == 1) { - return {ScalarValue::make(inputs.item())}; + return {scalar_type.make(inputs.item())}; } auto outputs = imperative::apply(remove_axis, inputs); if (is_scalar) { - outputs[0] = ScalarValue::make(outputs[0]); + outputs[0] = scalar_type.make(outputs[0]); } return outputs; } ValueRefList reduce_rule( - const Reduce& reduce, Span inputs, Span inputs_mask) { + const Reduce& reduce, Span inputs, Span inputs_mask, + const Type& scalar_type) { if (inputs.size() == 1) { return imperative::apply(reduce, inputs); } @@ -91,7 +99,7 @@ ValueRefList reduce_rule( bool is_scalar = is_scalar_shape(inputs[1]); if (is_scalar) { CompNode device = *inputs[0].device(); - return {ScalarValue::make( + return {scalar_type.make( imperative::apply(reduce, inputs[0], make_scalar_shape(device))[0])}; } return imperative::apply(reduce, inputs); @@ -99,7 +107,7 @@ ValueRefList reduce_rule( ValueRefList collective_comm_rule( const CollectiveComm& collective_comm, Span inputs, - Span inputs_mask) { + Span inputs_mask, const Type& scalar_type) { mgb_assert(inputs.size() == 1); static std::unordered_set modes = { CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN, @@ -110,7 +118,7 @@ ValueRefList collective_comm_rule( return imperative::apply(collective_comm, inputs); } if (inputs_mask.item()) { - return {ScalarValue::make(imperative::apply(collective_comm, inputs[0])[0])}; + return {scalar_type.make(imperative::apply(collective_comm, inputs[0])[0])}; } else { return imperative::apply(collective_comm, inputs); } @@ -118,24 +126,27 @@ ValueRefList collective_comm_rule( ValueRefList param_pack_split_rule( const ParamPackSplit& param_pack_split, Span inputs, - Span inputs_mask) { + Span inputs_mask, const Type& scalar_type) { auto outputs = imperative::apply(param_pack_split, inputs); size_t nr_outputs = outputs.size(); mgb_assert(nr_outputs == param_pack_split.shapes.size()); for (size_t i = 0; i < nr_outputs; ++i) { if (param_pack_split.shapes[i].empty()) { - outputs[i] = ScalarValue::make(outputs[i]); + outputs[i] = scalar_type.make(outputs[i]); } } return outputs; } -ValueRefList dot_rule(const Dot& dot, Span inputs, Span inputs_mask) { - return {ScalarValue::make(imperative::apply(dot, inputs)[0])}; +ValueRefList dot_rule( + const Dot& dot, Span inputs, Span inputs_mask, + const Type& scalar_type) { + return {scalar_type.make(imperative::apply(dot, inputs)[0])}; } ValueRefList add_axis_rule( - const AddAxis& add_axis, Span inputs, Span inputs_mask) { + const AddAxis& add_axis, Span inputs, Span inputs_mask, + const Type& scalar_type) { mgb_assert(inputs.size() == 1); if (inputs_mask.item()) { mgb_assert(add_axis.axis[0] == 0); @@ -151,7 +162,8 @@ ValueRefList add_axis_rule( } ValueRefList remote_recv_rule( - const RemoteRecv& remote_recv, Span inputs, Span inputs_mask) { + const RemoteRecv& remote_recv, Span inputs, Span inputs_mask, + const Type& scalar_type) { if (remote_recv.shape.empty()) { std::vector shape = {1}; auto remote_recv_no_scalar = RemoteRecv::make( @@ -167,20 +179,21 @@ ValueRefList remote_recv_rule( ValueRefList check_no_finite_rule( const CheckNonFinite& check_no_finite, Span inputs, - Span inputs_mask) { + Span inputs_mask, const Type& scalar_type) { auto outputs = imperative::apply(check_no_finite, inputs); mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch"); - outputs.back() = ScalarValue::make(outputs.back()); + outputs.back() = scalar_type.make(outputs.back()); for (size_t i = 0; i < inputs.size(); ++i) { if (inputs_mask[i]) { - outputs[i] = ScalarValue::make(outputs[i]); + outputs[i] = scalar_type.make(outputs[i]); } } return outputs; } ValueRefList subtensor_rule( - const Subtensor& subtensor, Span inputs, Span inputs_mask) { + const Subtensor& subtensor, Span inputs, Span inputs_mask, + const Type& scalar_type) { mgb_assert(inputs.size() >= 1); auto input = inputs[0]; bool is_scalar; @@ -199,14 +212,14 @@ ValueRefList subtensor_rule( } auto outputs = imperative::apply(subtensor, inputs); if (is_scalar) { - outputs[0] = ScalarValue::make(outputs[0]); + outputs[0] = scalar_type.make(outputs[0]); } return outputs; } ValueRefList get_var_shape_rule( - const GetVarShape& get_var_shape, Span inputs, - Span inputs_mask) { + const GetVarShape& get_var_shape, Span inputs, Span inputs_mask, + const Type& scalar_type) { bool all_scalar = true; mgb_assert(inputs.size() >= 1); for (auto&& input_mask : inputs_mask) { @@ -228,11 +241,12 @@ ValueRefList get_var_shape_rule( } ValueRefList reshape_rule( - const Reshape& reshape, Span inputs, Span inputs_mask) { + const Reshape& reshape, Span inputs, Span inputs_mask, + const Type& scalar_type) { mgb_assert(inputs.size() == 2); bool is_scalar = is_scalar_shape(inputs[1]); if (is_scalar) { - return {ScalarValue::make(imperative::apply( + return {scalar_type.make(imperative::apply( reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; } else { return imperative::apply(reshape, inputs); @@ -240,11 +254,12 @@ ValueRefList reshape_rule( } ValueRefList broadcast_rule( - const Broadcast& broadcast, Span inputs, Span inputs_mask) { + const Broadcast& broadcast, Span inputs, Span inputs_mask, + const Type& scalar_type) { mgb_assert(inputs.size() == 2); bool is_scalar = is_scalar_shape(inputs[1]); if (is_scalar) { - return {ScalarValue::make(imperative::apply( + return {scalar_type.make(imperative::apply( broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; } else { return imperative::apply(broadcast, inputs); @@ -299,11 +314,11 @@ struct ScalarRuleRegistry { ValueRefList ScalarTransformation::apply_get_attr( const GetAttr& get_attr, Span inputs) { auto&& input = inputs.item(); - bool is_scalar = input.is(); + bool is_scalar = input.is(m_value_type); if (!is_scalar) { return imperative::apply(get_attr, input); } - auto unwrapped_input = input.cast().value(); + auto unwrapped_input = input.cast(m_value_type).value(); if (get_attr.attr() == GetAttr::Shape) { if (!m_empty_shape) { m_empty_shape = ShapeValue::make(); @@ -352,7 +367,7 @@ ValueRefList ScalarTransformation::apply_transformation( ValueRefList unwrapped_inputs(nr_inputs); SmallVector inputs_mask(nr_inputs); for (size_t i = 0; i < inputs.size(); ++i) { - if (auto&& scalar_value = inputs[i].as_ref()) { + if (auto&& scalar_value = inputs[i].as_ref(m_value_type)) { unwrapped_inputs[i] = scalar_value->value(); inputs_mask[i] = true; } else { @@ -364,7 +379,8 @@ ValueRefList ScalarTransformation::apply_transformation( if (auto apply_op = op.as()) { auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); if (iter != scalar_rules.end()) { - return iter->second(apply_op->op(), unwrapped_inputs, inputs_mask); + return iter->second( + apply_op->op(), unwrapped_inputs, inputs_mask, m_value_type); } else { // TODO: repeat op return fallback(); @@ -375,7 +391,7 @@ ValueRefList ScalarTransformation::apply_transformation( CreateTensor scalar_op( create_tensor->kind(), create_tensor->device(), create_tensor->dtype(), scalar_shape); - return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])}; + return {m_value_type.make(imperative::apply(scalar_op, inputs)[0])}; } else { return imperative::apply(op, inputs); } @@ -387,7 +403,7 @@ ValueRefList ScalarTransformation::apply_transformation( bool is_scalar = inputs_mask[0]; auto outputs = fallback(); if (is_scalar) { - outputs[0] = ScalarValue::make(outputs[0]); + outputs[0] = m_value_type.make(outputs[0]); } return outputs; } else { diff --git a/imperative/src/impl/transformations/trace.cpp b/imperative/src/impl/transformations/trace.cpp index afce8017050c03ae4fd31476bfc72db40690d9d9..c74597a266a3831f4137c65d7f1a454081c22a8a 100644 --- a/imperative/src/impl/transformations/trace.cpp +++ b/imperative/src/impl/transformations/trace.cpp @@ -160,7 +160,7 @@ ValueRefList TracingTransformation::apply_transformation( SmallVector wrapped_inputs; SmallVector input_ids; for (auto input : inputs) { - auto tracing_value = input.as_ref(); + auto tracing_value = input.as_ref(m_value_type); if (!tracing_value) { tracing_value = record_var(input, m_capture_as_const, VarKind::External); @@ -208,7 +208,7 @@ ValueRefList TracingTransformation::apply_transformation( } else if (auto* get_attr = op.as()) { auto unwrapped_input = unwrap_var(inputs[0]); auto outputs = imperative::apply(op, unwrapped_input); - if (auto* tracing_value = inputs[0].as()) { + if (auto* tracing_value = inputs[0].as(m_value_type)) { auto& var_info = m_vars[tracing_value->id()]; switch (get_attr->attr()) { case GetAttr::Shape: @@ -228,7 +228,7 @@ ValueRefList TracingTransformation::apply_transformation( } else if (auto* trace_mark_var = op.as()) { mgb_assert(inputs.size() == 1, "TraceMarkVar expects exactly one input"); auto input = inputs[0]; - auto tracing_var = input.as_ref(); + auto tracing_var = input.as_ref(m_value_type); if (!tracing_var) { bool is_input = trace_mark_var->mark().substr(0, 4) == "arg_" || trace_mark_var->mark().substr(0, 6) == "kwarg_"; @@ -247,7 +247,7 @@ ValueRefList TracingTransformation::apply_transformation( } else if (auto* trace_name_var = op.as()) { mgb_assert(inputs.size() == 1, "RenameValue expects exactly one input"); auto input = inputs[0]; - auto tracing_var = input.as_ref(); + auto tracing_var = input.as_ref(m_value_type); if (!tracing_var) { tracing_var = record_var(input, m_capture_as_const, VarKind::External); } else { @@ -260,7 +260,7 @@ ValueRefList TracingTransformation::apply_transformation( } else if (op.is()) { mgb_assert(inputs.size() == 1, "GetName expects exactly one input"); auto input = inputs[0]; - if (auto tracing_var = input.as_ref()) { + if (auto tracing_var = input.as_ref(m_value_type)) { auto name = m_vars[tracing_var->id()].name; if (!name.empty()) { return {StringValue::make(name)}; @@ -425,26 +425,12 @@ void CompiledTransformation::compile() { } auto& node = var_accessors[input].node; if (input_vars.empty() && require_link && mm_io_link.node()) { - /*mgb_assert( - !input_vars.empty(), - "io-mm operator should have at least one input");*/ auto comp_node = mm_io_link.node()->comp_node(); - // auto comp_node = input_vars[0]->comp_node(); node = opr::VirtualDep::make({SymbolVar(node), mm_io_link}, comp_node) .node(); } input_vars.push_back(node); } - /*if (require_link && mm_io_link.node()) { - mgb_assert( - !input_vars.empty(), - "io-mm operator should have at least one input"); - auto comp_node = mm_io_link.node()->comp_node(); - // auto comp_node = input_vars[0]->comp_node(); - input_vars[0] = opr::VirtualDep::make( - {SymbolVar(input_vars[0]), mm_io_link}, comp_node) - .node(); - }*/ VarNodeArray output_vars; if (item.op) { output_vars = OpDef::apply_on_var_node(*item.op, input_vars); @@ -520,7 +506,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { switch (var.kind) { case VarKind::External: { trace_assert( - !value.is(), "expect external node, got internal"); + !value.is(m_value_type), "expect external node, got internal"); if (var.bound_data) { assert_tensor_equal(var.bound_data, value); } else { @@ -545,8 +531,8 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { } case VarKind::Internal: { trace_assert( - value.is(), "expect internal node, got external"); - auto& traced_value = value.cast(); + value.is(m_value_type), "expect internal node, got external"); + auto& traced_value = value.cast(m_value_type); trace_assert(traced_value.id() == id, "input id mismatch"); break; } @@ -559,7 +545,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { } auto CompiledTransformation::trace_output(size_t id) -> TracedValue::ref_t { - auto traced_value = TracedValue::make(id, &m_vars[id], &m_var_accessors[id]); + auto traced_value = m_value_type.make(id, &m_vars[id], &m_var_accessors[id]); m_weak_values.push_back(traced_value); return traced_value; } @@ -569,7 +555,7 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() { return m_seq[m_pc++]; } -ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const { +ShapeValue::ref_t CompiledTransformation::TracedValue::shape() const { if (!m_shape) { trace_assert(m_accessor->shape_getter, "shape unreadable"); m_shape = ShapeValue::make(ValueShape::from(m_accessor->shape_getter())); @@ -577,14 +563,14 @@ ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const { return m_shape; } -DTypeValue::ref_t CompiledTransformation::TracedInfo::dtype() const { +DTypeValue::ref_t CompiledTransformation::TracedValue::dtype() const { return m_var->dtype; } -CompNodeValue::ref_t CompiledTransformation::TracedInfo::comp_node() const { +CompNodeValue::ref_t CompiledTransformation::TracedValue::comp_node() const { return m_var->device; } -auto CompiledTransformation::TracedInfo::accessor() const -> const VarAccessor& { +auto CompiledTransformation::TracedValue::accessor() const -> const VarAccessor& { return *m_accessor; } @@ -605,7 +591,7 @@ ValueRefList CompiledTransformation::apply_op( ValueRefList CompiledTransformation::apply_get_attr( const GetAttr& get_attr, Span inputs) { - if (auto* traced_value = inputs[0].as()) { + if (auto* traced_value = inputs[0].as(m_value_type)) { ValueRef output; auto& var_accessor = traced_value->accessor(); switch (get_attr.attr()) { @@ -718,15 +704,11 @@ void CompiledTransformation::on_unregister() noexcept { void CompiledTransformation::execute() { mgb_assert(m_executable != nullptr); - m_graph_executor = std::thread([&] { - try { - m_executable->execute(); - m_executable->wait(); - } catch (...) { - auto exc = std::current_exception(); - set_exception(exc); - } - }); + { + MGB_LOCK_GUARD(m_mutex); + m_graph_status = 1; + } + m_cv.notify_all(); } void CompiledTransformation::wait() { @@ -735,8 +717,9 @@ void CompiledTransformation::wait() { } catch (...) { } mgb_assert(m_executable != nullptr); - m_graph_executor.join(); - m_graph_executor = {}; + std::unique_lock lock{m_mutex}; + m_cv.wait(lock, [&] { return m_graph_status == 0; }); + lock.unlock(); for (auto&& box : m_boxes) { box->reset(); } diff --git a/imperative/src/impl/value.cpp b/imperative/src/impl/value.cpp index 47a0918049e16db12fd75956982e2f92a77d9bf7..8b4cb4dd3f863b6e5fe07cc22d724348e0c5aa01 100644 --- a/imperative/src/impl/value.cpp +++ b/imperative/src/impl/value.cpp @@ -25,16 +25,16 @@ ValueRef::storage_t& ValueRef::storage() const { return m_storage; } -const Value* ValueRef::as(size_t typecode) const { +const Value* ValueRef::as(const IType& type) const { auto&& storage = this->storage(); - if (storage->m_typecode != typecode) { + if (storage->type() != type) { return nullptr; } return static_cast(storage.get()); } -bool ValueRef::is(size_t typecode) const { - return this->storage()->m_typecode == typecode; +bool ValueRef::is(const IType& type) const { + return this->storage()->type() == type; } TypedValueRef ValueRef::dev_tensor() const { @@ -106,9 +106,7 @@ std::string ValueRef::raw_type() const { if (!m_storage) { return "null"; } - auto& types = Value::registered_types(); - mgb_assert(types.size() > m_storage->m_typecode); - return types[m_storage->m_typecode].name(); + return m_storage->type().name(); } bool ValueRef::watching() const { @@ -137,7 +135,7 @@ ValueRef ValueWeakRef::lock() { return {strong_storage}; } -Value::Value(size_t typecode) : m_typecode{typecode} { +Value::Value() { m_id = nr_values++; } @@ -147,17 +145,6 @@ Value::~Value() { } } -size_t Value::register_type(std::type_index type) { - auto& types = const_cast&>(registered_types()); - types.push_back(type); - return types.size() - 1; -} - -const std::vector& Value::registered_types() { - static std::vector sm_registered_types; - return sm_registered_types; -} - void Value::register_value(ValueRef value) { registered_values[value.id()] = ValueWeakRef(value); } @@ -188,7 +175,7 @@ std::vector Value::end_record_values() { } void Value::try_rethrow() { - if (m_typecode == ErrorValue::TYPE_CODE) { + if (type() == PrimitiveType::instance) { auto message = static_cast(this)->message(); mgb_throw(MegBrainError, "invalid value: %s", message.c_str()); } @@ -198,13 +185,9 @@ inline void ValueRefList::init(size_t nr_elems) { m_size = nr_elems; if (m_size > 0) { if (m_size == 1) { - m_data = inline_storage(); + m_data = new (inline_storage()) ValueRef(); } else { - auto& context = Transformation::get_context(); - m_data = context.allocator.allocate(m_size); - } - for (size_t i = 0; i < m_size; ++i) { - new (m_data + i) ValueRef(); + m_data = new ValueRef[m_size]; } } else { m_data = nullptr; @@ -215,9 +198,6 @@ ValueRefList::ValueRefList(size_t nr_elems) { init(nr_elems); } -/*ValueRefList::ValueRefList(std::initializer_list values) - : ValueRefList(values.begin(), values.end()) {}*/ - ValueRefList::ValueRefList(const ValueRefList& rhs) : ValueRefList(rhs.cbegin(), rhs.cend()) {} @@ -271,14 +251,12 @@ ValueRefList::~ValueRefList() { } void ValueRefList::clear() { - for (size_t i = 0; i < m_size; ++i) { - m_data[i].~ValueRef(); - } if (m_data) { if (m_size != 1) { - Transformation::get_context().allocator.deallocate(m_data, m_size); + delete[] m_data; } else { mgb_assert(m_data == inline_storage()); + m_data->~ValueRef(); } } m_data = nullptr; diff --git a/imperative/src/include/megbrain/imperative/basic_values.h b/imperative/src/include/megbrain/imperative/basic_values.h index 02e528edc5e0b4015cdb1a68f1c54543f3da38a3..bbe57874f03900b230e3cb6aff2d94e7f84d2768 100644 --- a/imperative/src/include/megbrain/imperative/basic_values.h +++ b/imperative/src/include/megbrain/imperative/basic_values.h @@ -25,79 +25,68 @@ class GradKey; using GenericFunction = std::function)>; -class ShapeValue final - : public MixinValueImpl { +class ShapeValue final : public PrimitiveValue { public: - using MixinValueImpl::MixinValueImpl; + using PrimitiveValue::PrimitiveValue; std::string to_string() const override; }; -class CompNodeValue final - : public MixinValueImpl { +class CompNodeValue final : public PrimitiveValue { public: - using MixinValueImpl::MixinValueImpl; + using PrimitiveValue::PrimitiveValue; std::string to_string() const override; }; -// TODO: override factory method -class BoolValue final : public ValueImpl { +class Boolean { private: - std::optional m_value; + bool m_value; public: - BoolValue(bool value) : m_value{value} {} - operator bool() const { return *m_value; } + Boolean() = default; + Boolean(bool value) : m_value(value) {} - std::string to_string() const override; + operator bool() const { return m_value; } +}; - void clear() override { m_value.reset(); } +// TODO: override factory method +class BoolValue final : public PrimitiveValue { +public: + using PrimitiveValue::PrimitiveValue; + + std::string to_string() const override; }; -class HostStorage final - : public MixinValueImpl { +class HostStorage final : public PrimitiveValue { public: - using MixinValueImpl::MixinValueImpl; + using PrimitiveValue::PrimitiveValue; std::string to_string() const override; }; -class DeviceStorage final - : public MixinValueImpl< - DeviceStorage, ValueKind::Primitive, DeviceTensorStorage> { +class DeviceStorage final : public PrimitiveValue { public: - using MixinValueImpl::MixinValueImpl; + using PrimitiveValue::PrimitiveValue; std::string to_string() const override; }; -/** - * \brief like HostTensorND mixin, but allow scalar value - * - */ -class HostValue final : public ValueImpl { +class HostTensor { private: DType m_dtype; ValueShape m_shape; HostTensorStorage m_storage; public: - HostValue(DType dtype, ValueShape shape, HostTensorStorage storage) + HostTensor() = default; + HostTensor(DType dtype, ValueShape shape, HostTensorStorage storage) : m_dtype(dtype), m_shape(shape), m_storage(storage) {} - HostValue(HostTensorND value) - : HostValue( + HostTensor(HostTensorND value) + : HostTensor( value.dtype(), ValueShape::from(value.shape()), value.storage()) { } - std::string to_string() const override; - - void clear() override { - m_dtype = {}; - m_shape = {}; - m_storage = {}; - } - DType dtype() const { return m_dtype; } const ValueShape& shape() const { return m_shape; } CompNode device() const { return m_storage.comp_node(); } @@ -112,31 +101,31 @@ public: }; /** - * \brief like DeviceTensorND mixin, but allow scalar value + * \brief like HostTensorND mixin, but allow scalar value * */ -class DeviceValue final : public ValueImpl { +class HostValue final : public PrimitiveValue { +public: + using PrimitiveValue::PrimitiveValue; + + std::string to_string() const override; +}; + +class DeviceTensor { private: DType m_dtype; ValueShape m_shape; DeviceTensorStorage m_storage; public: - DeviceValue(DType dtype, ValueShape shape, DeviceTensorStorage storage) + DeviceTensor() = default; + DeviceTensor(DType dtype, ValueShape shape, DeviceTensorStorage storage) : m_dtype(dtype), m_shape(shape), m_storage(std::move(storage)) {} - DeviceValue(const DeviceTensorND& value) - : DeviceValue( + DeviceTensor(const DeviceTensorND& value) + : DeviceTensor( value.dtype(), ValueShape::from(value.shape()), value.storage()) { } - std::string to_string() const override; - - void clear() override { - m_dtype = {}; - m_shape = {}; - m_storage = {}; - } - DType dtype() const { return m_dtype; } const ValueShape& shape() const { return m_shape; } CompNode device() const { return m_storage.comp_node(); } @@ -145,26 +134,34 @@ public: DeviceTensorND as_nd(bool allow_scalar = false) const; }; -class FunctionValue final - : public MixinValueImpl { +/** + * \brief like DeviceTensorND mixin, but allow scalar value + * + */ +class DeviceValue final : public PrimitiveValue { +public: + using PrimitiveValue::PrimitiveValue; + + std::string to_string() const override; +}; + +class FunctionValue final : public PrimitiveValue { public: - using MixinValueImpl::MixinValueImpl; + using PrimitiveValue::PrimitiveValue; std::string to_string() const override; }; -class DTypeValue final - : public MixinValueImpl { +class DTypeValue final : public PrimitiveValue { public: - using MixinValueImpl::MixinValueImpl; + using PrimitiveValue::PrimitiveValue; std::string to_string() const override; }; -class StringValue final - : public MixinValueImpl { +class StringValue final : public PrimitiveValue { public: - using MixinValueImpl::MixinValueImpl; + using PrimitiveValue::PrimitiveValue; std::string to_string() const override; }; @@ -180,10 +177,9 @@ public: std::string message() const { return m_message; } }; -class ErrorValue final - : public MixinValueImpl { +class ErrorValue final : public PrimitiveValue { public: - using MixinValueImpl::MixinValueImpl; + using PrimitiveValue::PrimitiveValue; std::string to_string() const override; }; diff --git a/imperative/src/include/megbrain/imperative/subgraph.h b/imperative/src/include/megbrain/imperative/subgraph.h index b54a069bed85847f99890f21ae29cab22860463d..56444372617831dec8f3bfd2ab82618599ee36e7 100644 --- a/imperative/src/include/megbrain/imperative/subgraph.h +++ b/imperative/src/include/megbrain/imperative/subgraph.h @@ -57,7 +57,7 @@ struct Subgraph { SmallVector exprs; template - SmallVector apply(SmallVector input_vars, F&& f, C&& c) const { + SmallVector apply(Span input_vars, F&& f, C&& c) const { std::unordered_map idx2var; mgb_assert(inputs.size() == input_vars.size(), "input size mismatch"); for (size_t i = 0; i < inputs.size(); ++i) { @@ -71,8 +71,7 @@ struct Subgraph { for (auto idx : expr.inputs) { expr_inputs.push_back(idx2var[idx]); } - SmallVector expr_outputs = - f(expr.op, std::move(expr_inputs), expr.outputs.size()); + SmallVector expr_outputs = f(expr.op, expr_inputs, expr.outputs.size()); mgb_assert( expr_outputs.size() == expr.outputs.size(), "output size mismatch"); for (size_t i = 0; i < expr_outputs.size(); ++i) { @@ -102,9 +101,9 @@ struct EncodedSubgraph { SmallVector input_mask; SmallVector output_mask; - template - TContainer encode_inputs(TContainer inputs) const { - TContainer encoded_inputs; + template + SmallVector encode_inputs(Span inputs) const { + SmallVector encoded_inputs; size_t index = 0; for (auto&& input : inputs) { mgb_assert(index < input_mask.size(), "index out of range"); @@ -116,9 +115,9 @@ struct EncodedSubgraph { return encoded_inputs; } - template - TContainer encode_outputs(TContainer outputs) const { - TContainer encoded_outputs; + template + SmallVector encode_outputs(Span outputs) const { + SmallVector encoded_outputs; size_t index = 0; for (auto&& output : outputs) { mgb_assert(index < output_mask.size(), "index out of range"); @@ -130,9 +129,9 @@ struct EncodedSubgraph { return encoded_outputs; } - template - TContainer decode_outputs(TContainer outputs) const { - TContainer decoded_outputs; + template + SmallVector decode_outputs(Span outputs) const { + SmallVector decoded_outputs; size_t index = 0; for (size_t i = 0; i < output_mask.size(); i++) { mgb_assert(index < output_mask.size(), "index out of range"); @@ -150,8 +149,8 @@ struct EncodedSubgraph { EncodedSubgraph result; result.input_mask = graph.gen_input_mask(); result.output_mask = graph.gen_output_mask(); - graph.inputs = result.encode_inputs(graph.inputs); - graph.outputs = result.encode_outputs(graph.outputs); + graph.inputs = result.encode_inputs(graph.inputs); + graph.outputs = result.encode_outputs(graph.outputs); result.graph = graph; return result; } @@ -179,11 +178,11 @@ struct EncodedSubgraph { } template - SmallVector apply(SmallVector input_vars, F&& f, C&& c) const { - auto encoded_inputs = encode_inputs(input_vars); + SmallVector apply(Span input_vars, F&& f, C&& c) const { + auto encoded_inputs = encode_inputs(input_vars); auto encoded_outputs = - graph.apply(encoded_inputs, std::forward(f), std::forward(c)); - return decode_outputs(encoded_outputs); + graph.apply(encoded_inputs, std::forward(f), std::forward(c)); + return decode_outputs(encoded_outputs); } std::string repr() const; @@ -280,4 +279,4 @@ public: }; } // namespace imperative -} // namespace mgb \ No newline at end of file +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/transformations/eval.h b/imperative/src/include/megbrain/imperative/transformations/eval.h index 8fbd16a71ca13aeb6dafc75d13a9e261c644ae6c..93319af9e7a8ee2b31aa244cafd39e69487549db 100644 --- a/imperative/src/include/megbrain/imperative/transformations/eval.h +++ b/imperative/src/include/megbrain/imperative/transformations/eval.h @@ -18,7 +18,7 @@ namespace mgb::imperative { -struct InterpreterInfo { +class InterpreterValue final : public ObjectValue { public: using Handle = interpreter::Interpreter::Handle; using Channel = interpreter::Interpreter::Channel; @@ -46,8 +46,7 @@ private: mutable ShapeValue::ref_t m_shape; public: - InterpreterInfo() = default; - InterpreterInfo(LocalPtr handle, std::string name = {}) + InterpreterValue(LocalPtr handle, std::string name = {}) : m_handle(handle), m_name(name) {} const LocalPtr& handle() const { return m_handle; } @@ -57,18 +56,14 @@ public: ShapeValue::ref_t shape() const; std::string name() const { return m_name; } -}; - -class InterpreterValue final - : public MixinValueImpl { -public: - using MixinValueImpl::MixinValueImpl; std::string to_string() const override { return ssprintf( "Handle{ptr=%p, name=%s}", handle().get(), imperative::quoted(name()).c_str()); } + + void clear() override { m_handle = {}; } }; /** @@ -82,11 +77,12 @@ class InterpreterTransformation final : public Transformation { public: using Interpreter = interpreter::Interpreter; using Handle = Interpreter::Handle; - using SharedHandle = LocalPtr; + using SharedHandle = LocalPtr; using Channel = Interpreter::Channel; private: std::shared_ptr m_channel; + ObjectType m_value_type{"InterpreterValue"}; public: explicit InterpreterTransformation(std::shared_ptr channel) @@ -105,7 +101,7 @@ public: const Operator& op, Span inputs) override; ValueRef unwrap(ValueRef value) override { - mgb_assert(!value.is()); + mgb_assert(!value.is(m_value_type)); return value; } diff --git a/imperative/src/include/megbrain/imperative/transformations/grad.h b/imperative/src/include/megbrain/imperative/transformations/grad.h index 93055a0b873069b021a39c48969fba5139f95f03..0a1970ae689c328a65aaf65ecb033c88a4cdafbc 100644 --- a/imperative/src/include/megbrain/imperative/transformations/grad.h +++ b/imperative/src/include/megbrain/imperative/transformations/grad.h @@ -34,7 +34,8 @@ struct BackwardGraphWithClosure { std::shared_ptr backward_graph, std::shared_ptr op, Span inputs, Span outputs); - void operator()(ValueRefList grads, std::function receiver); + void operator()( + Span grads, std::function receiver); bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; } @@ -51,7 +52,7 @@ struct CustomBackward; using GradRuleFn = std::function inputs, CustomBackward&)>; struct CustomBackward { - using BackwardFn = std::function)>; + using BackwardFn = std::function(Span)>; using BackwardRule = std::function( const OpDef&, Span, Span, CustomBackward&)>; BackwardFn m_backward; @@ -62,7 +63,8 @@ struct CustomBackward { SmallVector m_output_attrs; public: - void operator()(ValueRefList grads, std::function receiver); + void operator()( + Span grads, std::function receiver); bool input_has_grad(size_t i) { return m_input_has_grad[i]; } bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; } @@ -175,7 +177,7 @@ inline GradSlot* GradSlotPtr::operator->() const { return &m_fn->m_slots[m_index]; } -class GradValue final : public ValueImpl { +class GradValue final : public ObjectValue { private: ValueRef m_value; std::shared_ptr m_key; @@ -187,14 +189,9 @@ public: std::string to_string() const override; - bool has_key(const std::shared_ptr& key) const { return m_key == key; } + const GradSlotPtr& slot() const { return m_slot; } - const GradSlotPtr& slot_for(std::shared_ptr key) const { - mgb_assert(m_key == key); - return m_slot; - } - - std::shared_ptr key() const { return m_key; } + // std::shared_ptr key() const { return m_key; } void clear() override { m_slot = {}; @@ -216,9 +213,12 @@ private: std::vector, std::shared_ptr>> m_tape; std::vector, std::shared_ptr>> m_frozen_tape; bool m_frozen = false; + const Type& m_value_type; public: - GradKey() { m_tape.reserve(4 * 1024); } + GradKey(const Type& value_type) : m_value_type(value_type) { + m_tape.reserve(4 * 1024); + } void backward(); GradValue::ref_t attach(ValueRef tensor, std::function callback); @@ -230,10 +230,9 @@ public: }; class GradKeyValue final - : public MixinValueImpl< - GradKeyValue, ValueKind::Primitive, std::shared_ptr> { + : public PrimitiveValue> { public: - using MixinValueImpl::MixinValueImpl; + using PrimitiveValue::PrimitiveValue; std::string to_string() const override { return ssprintf("GradKey{%s}", (*this)->name().c_str()); @@ -242,26 +241,20 @@ public: class GradTransformation final : public Transformation { private: + ObjectType m_value_type{"GradValue"}; std::shared_ptr m_key; std::vector m_weak_values; size_t m_suppressed = 0; public: - GradTransformation(std::shared_ptr key) : m_key(key) {} + GradTransformation() { m_key = std::make_shared(m_value_type); } auto record_grad(GradValue::ref_t tensor) { m_weak_values.push_back(tensor); return tensor; } - bool is_grad_value(const ValueRef& value) { - if (auto* grad_value = value.as()) { - if (grad_value->has_key(m_key)) { - return true; - } - } - return false; - } + bool is_grad_value(const ValueRef& value) { return value.is(m_value_type); } /** * \brief test whether value is related to this GradTransformation @@ -273,13 +266,7 @@ public: * \return GradValue::ref_t */ const GradValue::ref_t& as_grad_value(const ValueRef& value) { - auto&& grad_value = value.as_ref(); - if (grad_value) { - if (grad_value->has_key(m_key)) { - return grad_value; - } - } - return GradValue::ref_t::nil; + return value.as_ref(m_value_type); } bool has_key(std::shared_ptr key) { @@ -299,6 +286,8 @@ public: return value; } + const std::shared_ptr& key() const { return m_key; } + std::string name() const override { return "GradTransformation"; } GenericFunction make_backward_closure(Span ys); diff --git a/imperative/src/include/megbrain/imperative/transformations/lazy.h b/imperative/src/include/megbrain/imperative/transformations/lazy.h index f4855b7f4e6bc389100bcccb7f3b7b6faa2640b5..5c03b9b971aa79a3f52b947d7d943b7b3f52a27e 100644 --- a/imperative/src/include/megbrain/imperative/transformations/lazy.h +++ b/imperative/src/include/megbrain/imperative/transformations/lazy.h @@ -22,32 +22,27 @@ namespace mgb::imperative { -class LazyEvalInfo { +class LazyEvalValue final : public ObjectValue { private: VarNode* m_node = nullptr; ValueRef m_bound_data; std::string m_name; public: - LazyEvalInfo() = default; - LazyEvalInfo(VarNode* node, ValueRef bound_data, std::string name) + LazyEvalValue(VarNode* node, ValueRef bound_data, std::string name) : m_node(node), m_bound_data(bound_data), m_name(name) {} VarNode* node() const { return m_node; } ValueRef bound_data() const { return m_bound_data; } std::string name() const { return m_name; } -}; - -class LazyEvalValue final - : public MixinValueImpl { -public: - using MixinValueImpl::MixinValueImpl; std::string to_string() const override { return ssprintf( "LazyEvalValue{node=%p, name=%s}", node(), node()->name().c_str()); } + + void clear() override {} }; /** @@ -67,6 +62,7 @@ private: std::vector m_weak_vars; SymbolVar m_io_link = nullptr; std::exception_ptr m_graph_exc; + ObjectType m_value_type{"LazyEvalValue"}; public: LazyEvalTransformation(bool no_exec) : m_no_exec(no_exec) { @@ -75,7 +71,7 @@ public: LazyEvalValue::ref_t record_var( VarNode* node, ValueRef bound_data = {}, std::string name = {}) { - auto lazy_eval_val = LazyEvalValue::make(node, bound_data, name); + auto lazy_eval_val = m_value_type.make(node, bound_data, name); m_weak_vars.push_back(lazy_eval_val); return lazy_eval_val; } @@ -86,7 +82,7 @@ public: const Operator& op, Span inputs) override; ValueRef unwrap(ValueRef value) override { - mgb_assert(!value.is()); + mgb_assert(!value.is(m_value_type)); return value; } diff --git a/imperative/src/include/megbrain/imperative/transformations/scalar.h b/imperative/src/include/megbrain/imperative/transformations/scalar.h index 142cafabaef747f0697fcd12d693a3f5cbebb28b..496de0f611ca75f6dfb5da316180cd9e174c4729 100644 --- a/imperative/src/include/megbrain/imperative/transformations/scalar.h +++ b/imperative/src/include/megbrain/imperative/transformations/scalar.h @@ -17,7 +17,7 @@ namespace mgb::imperative { -class ScalarValue final : public ValueImpl { +class ScalarValue final : public ObjectValue { private: ValueRef m_value; @@ -47,17 +47,21 @@ public: class ScalarTransformation final : public Transformation { private: ShapeValue::ref_t m_empty_shape; // [] + ObjectType m_value_type{"ScalarValue"}; + public: ValueRefList apply_get_attr(const GetAttr& get_attr, Span inputs); ValueRefList apply_transformation( const Operator& op, Span inputs) override; ValueRef unwrap(ValueRef value) override { - mgb_assert(!value.is()); + mgb_assert(!value.is(m_value_type)); return value; } std::string name() const override { return "ScalarTransformation"; } + + const Type& value_type() const { return m_value_type; } }; } // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/transformations/symbol.h b/imperative/src/include/megbrain/imperative/transformations/symbol.h index 976df56b0c6b8a408af4241bc006996e97b51587..27f92fb0c6d70edec3a269e8780d0929cd7d7e87 100644 --- a/imperative/src/include/megbrain/imperative/transformations/symbol.h +++ b/imperative/src/include/megbrain/imperative/transformations/symbol.h @@ -22,7 +22,7 @@ namespace mgb::imperative { -class SymbolValue final : public ValueImpl { +class SymbolValue final : public ObjectValue { private: VarNode* m_node = nullptr; @@ -47,6 +47,7 @@ public: class SymbolTransformation final : public Transformation { private: ComputingGraph* m_graph = nullptr; + ObjectType m_value_type{"SymbolValue"}; public: SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {} @@ -55,12 +56,12 @@ public: if (auto* apply_op = op.as()) { SmallVector input_nodes; for (auto&& input : inputs) { - input_nodes.push_back(input.cast().node()); + input_nodes.push_back(input.cast(m_value_type).node()); } auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes); ValueRefList outputs(output_nodes.size()); for (size_t i = 0; i < output_nodes.size(); ++i) { - outputs[i] = SymbolValue::make(output_nodes[i]); + outputs[i] = m_value_type.make(output_nodes[i]); } return outputs; } else if (auto* create_tensor = op.as()) { @@ -69,9 +70,9 @@ public: args.kind == CreateTensor::Const, "only const value is allowed here"); auto* node = opr::ImmutableTensor::make(*m_graph, *args.host, {}).node(); - return {SymbolValue::make(node)}; + return {m_value_type.make(node)}; } else if (auto* get_attr = op.as()) { - auto* node = inputs.as_array<1>()[0].cast().node(); + auto* node = inputs.item().cast(m_value_type).node(); switch (get_attr->attr()) { case GetAttr::DType: return {DTypeValue::make(node->dtype())}; @@ -121,11 +122,13 @@ public: } ValueRef unwrap(ValueRef value) override { - mgb_assert(!value.is(), "SymbolValue doesn't support unwrap"); + mgb_assert(!value.is(m_value_type), "SymbolValue doesn't support unwrap"); return value; } std::string name() const override { return "SymbolTransformation"; } + + const Type& value_type() const { return m_value_type; } }; } // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/transformations/trace.h b/imperative/src/include/megbrain/imperative/transformations/trace.h index 36432107ddd32cc6d4f97960882ed6f761cd8d96..3378b333736ed830c8b6b478bcd0fc3ea911b2a6 100644 --- a/imperative/src/include/megbrain/imperative/transformations/trace.h +++ b/imperative/src/include/megbrain/imperative/transformations/trace.h @@ -100,22 +100,15 @@ public: } }; -class TracingInfo { +class TracingValue final : public ObjectValue { private: ValueRef m_value = {}; size_t m_id = 0; public: - TracingInfo() = default; - TracingInfo(ValueRef value, size_t id) : m_value(value), m_id(id) {} + TracingValue(ValueRef value, size_t id) : m_value(value), m_id(id) {} ValueRef value() const { return m_value; } size_t id() const { return m_id; } -}; - -class TracingValue final - : public MixinValueImpl { -public: - using MixinValueImpl::MixinValueImpl; std::string to_string() const override { return ssprintf( @@ -126,6 +119,8 @@ public: void on_watch() override { value().watch(); } void on_unwatch() override { value().unwatch(); } + + void clear() override { m_value = {}; } }; /** @@ -146,6 +141,7 @@ private: std::vector m_weak_vars; bool m_capture_as_const = false; bool m_record_input_shapes = false; + ObjectType m_value_type{"TracingValue"}; public: TracingTransformation(bool capture_as_const, bool record_input_shapes) @@ -162,7 +158,7 @@ public: */ TypedValueRef record_var(ValueRef value, bool capture, VarKind kind) { size_t id = m_vars.size(); - auto wrapped_value = TracingValue::make(value, id); + auto wrapped_value = m_value_type.make(value, id); m_vars.push_back({id, value.dtype(), value.device()}); auto& var = m_vars.back(); if (capture) { @@ -179,7 +175,7 @@ public: return wrapped_value; } ValueRef unwrap_var(ValueRef value) { - if (auto* tracing_value = value.as()) { + if (auto* tracing_value = value.as(m_value_type)) { return tracing_value->value(); } return value; @@ -189,7 +185,7 @@ public: const Operator& op, Span inputs) override; ValueRef unwrap(ValueRef value) override { - if (auto* tracing_value = value.as()) { + if (auto* tracing_value = value.as(m_value_type)) { return tracing_value->value(); } return value; @@ -234,7 +230,7 @@ public: std::function exc_setter; }; - class TracedInfo { + class TracedValue final : public ObjectValue { private: size_t m_id = 0; VarInfo* m_var = nullptr; @@ -244,8 +240,7 @@ public: mutable CompNodeValue::ref_t m_comp_node; public: - TracedInfo() = default; - TracedInfo(size_t id, VarInfo* var, VarAccessor* accessor) + TracedValue(size_t id, VarInfo* var, VarAccessor* accessor) : m_id(id), m_var(var), m_accessor(accessor) {} size_t id() const { return m_id; } ShapeValue::ref_t shape() const; @@ -256,16 +251,12 @@ public: void set_exception(std::exception_ptr exc) const { m_accessor->exc_setter(exc); } - }; - - class TracedValue final - : public MixinValueImpl { - public: - using MixinValueImpl::MixinValueImpl; std::string to_string() const override { return ssprintf("TracedValue{\"id\"=%zu}", id()); } + + void clear() override {} }; private: @@ -280,9 +271,12 @@ private: std::function m_value_comparator; bool m_input_shape_static; std::mutex m_mutex; + std::condition_variable m_cv; std::exception_ptr m_graph_exc; + int m_graph_status = 0; // 0 = stop, 1 = running, 2 = finalizing std::vector> m_boxes; ComputingGraph::OutputSpec m_output_spec; + ObjectType m_value_type{"TracedValue"}; public: CompiledTransformation(TraceResult result, bool input_shape_static) @@ -292,6 +286,27 @@ public: m_graph = ComputingGraph::make(); options().no_force_inplace = true; options().async_exec_level = 0b100; + m_graph_executor = std::thread([&] { + while (true) { + std::unique_lock lock{m_mutex}; + m_cv.wait(lock, [&] { return m_graph_status != 0; }); + lock.unlock(); + if (m_graph_status == 2) { + break; + } + try { + m_executable->execute(); + m_executable->wait(); + } catch (...) { + auto exc = std::current_exception(); + set_exception(exc); + } + lock.lock(); + m_graph_status = 0; + lock.unlock(); + m_cv.notify_all(); + } + }); } ComputingGraph& graph() { return *m_graph; } @@ -350,7 +365,7 @@ public: void on_unregister() noexcept override; ValueRef unwrap(ValueRef value) override { - mgb_assert(!value.is()); + mgb_assert(!value.is(m_value_type)); return value; } @@ -368,6 +383,15 @@ public: m_boxes.push_back(box); return box; } + + ~CompiledTransformation() { + { + MGB_LOCK_GUARD(m_mutex); + m_graph_status = 2; + } + m_cv.notify_all(); + m_graph_executor.join(); + } }; } // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/utils/allocator.h b/imperative/src/include/megbrain/imperative/utils/allocator.h index f800ec256226fdf16350365466eee0ddfff7f866..afc9b21fddd43c87cf32297af4b2816902269450 100644 --- a/imperative/src/include/megbrain/imperative/utils/allocator.h +++ b/imperative/src/include/megbrain/imperative/utils/allocator.h @@ -11,7 +11,9 @@ #pragma once +#include #include +#include #include "megbrain/utils/mempool.h" #include "megbrain/utils/metahelper.h" diff --git a/imperative/src/include/megbrain/imperative/utils/span.h b/imperative/src/include/megbrain/imperative/utils/span.h index bb5efe52ac3b82f0cafc6b9be1ee4ae5abbb8f2f..97d130548d351ad097ef8f37330568797bf3347f 100644 --- a/imperative/src/include/megbrain/imperative/utils/span.h +++ b/imperative/src/include/megbrain/imperative/utils/span.h @@ -34,7 +34,7 @@ public: Span(const T* begin, const T* end) : m_begin{begin}, m_end{end} {} Span(const T* begin, size_t size) : Span(begin, begin + size) {} template - Span(TContainer& container) : Span(container.data(), container.size()) {} + Span(const TContainer& container) : Span(container.data(), container.size()) {} const T* begin() const { return m_begin; } const T* end() const { return m_end; } const T* data() const { return m_begin; } diff --git a/imperative/src/include/megbrain/imperative/utils/stats.h b/imperative/src/include/megbrain/imperative/utils/stats.h index 9bab330af37159ed027a2ffc60ef24d8b1a56207..0e82a3f3808ff555d398f05e4adbf89518847996 100644 --- a/imperative/src/include/megbrain/imperative/utils/stats.h +++ b/imperative/src/include/megbrain/imperative/utils/stats.h @@ -2,7 +2,10 @@ #include #include +#include +#include #include +#include #include namespace mgb { @@ -18,7 +21,7 @@ public: private: clock_t::duration m_duration = clock_t::duration{0}; size_t m_timing = 0; - const char* m_name = nullptr; + std::string m_name; uint64_t m_count = 0; size_t m_enabled = 1; bool m_default_enabled = true; @@ -42,7 +45,8 @@ private: } if (timer.m_enabled) { if (!--timer.m_timing) { - timer.m_duration += (clock_t::now() - start); + auto duration = (clock_t::now() - start); + timer.m_duration += duration; } timer.m_count++; } @@ -67,13 +71,10 @@ private: } }; - using TimeScope = TimeScopeRecursive; - public: - Timer(const char* name, bool default_enabled); + Timer(std::string name, bool default_enabled = true); - const char* name() { return m_name; } - auto time_scope() { return TimeScope(*this); } + std::string name() { return m_name; } auto time_scope_recursive() { return TimeScopeRecursive(*this); }; auto enable_scope() { return EnableScope(*this); } void reset() { @@ -88,7 +89,14 @@ public: } // namespace stats struct Stats { - static inline std::vector sm_timers; + struct TimerNode { + std::map> children; + stats::Timer* timer = nullptr; + + TimerNode() {} + }; + + static inline TimerNode sm_root; // register your timers here // for example: @@ -97,33 +105,84 @@ struct Stats { // // then use MGE_TIMER_SCOPE(mytimer) to collect durations in your code - static void print() { - std::vector unused_timers; - - for (auto* timer : sm_timers) { - if (timer->count() == 0) { - unused_timers.push_back(timer->name()); - } else { - printf("%s costs %ld ns, happens %ld times\n", timer->name(), - timer->get().count(), timer->count()); + static std::pair print_node( + std::string name, TimerNode& node, size_t indent = 0) { + auto print_indent = [&] { + for (size_t i = 0; i < indent; ++i) { + printf(" "); + } + }; + long ns = 0, count = 0; + if (auto* timer = node.timer) { + print_indent(); + printf("%s costs %'ld ns, hits %'ld times\n", name.c_str(), + (long)timer->get().count(), (long)timer->count()); + ns = timer->get().count(); + count = timer->count(); + } + if (!node.children.empty()) { + bool collect_children = node.timer == nullptr; + if (collect_children) { + print_indent(); + printf("%s:\n", name.c_str()); + } + long ns = 0, count = 0; + for (auto&& child : node.children) { + auto [child_ns, child_count] = + print_node(child.first, *child.second, indent + 4); + if (collect_children) { + ns += child_ns; + count += child_count; + } + } + if (collect_children) { + print_indent(); + printf("total costs %'ld ns, hits %'ld times\n", ns, count); } } + return {ns, count}; + } - if (!unused_timers.empty()) { - printf("%zu timers unused\n", unused_timers.size()); + static void print() { + for (auto&& child : sm_root.children) { + print_node(child.first, *child.second); } } static void reset() { - for (auto* timer : sm_timers) { - timer->reset(); - } + auto reset_node = [](TimerNode& node, auto&& reset_node) -> void { + if (auto* timer = node.timer) { + timer->reset(); + } + for (auto&& child : node.children) { + reset_node(*child.second, reset_node); + } + }; + reset_node(sm_root, reset_node); } }; -inline stats::Timer::Timer(const char* name, bool default_enabled) +inline stats::Timer::Timer(std::string name, bool default_enabled) : m_name(name), m_default_enabled(default_enabled) { - Stats::sm_timers.push_back(this); + std::vector terms; + Stats::TimerNode* node = &Stats::sm_root; + while (true) { + auto pos = name.find("."); + if (pos == std::string::npos) { + auto& child = node->children[name]; + child = std::make_unique(); + node = child.get(); + node->timer = this; + break; + } else { + auto& child = node->children[name.substr(0, pos)]; + if (!child) { + child = std::make_unique(); + } + node = child.get(); + name = name.substr(pos + 1); + } + } } #if MGE_ENABLE_STATS diff --git a/imperative/src/include/megbrain/imperative/value.h b/imperative/src/include/megbrain/imperative/value.h index cf4381df97168e7bffd82c8dc2e9d32a06135405..21541443041ea2b709ca9441211f02bd7ed8f31e 100644 --- a/imperative/src/include/megbrain/imperative/value.h +++ b/imperative/src/include/megbrain/imperative/value.h @@ -50,18 +50,70 @@ class Operator; class ValueRefList; +/** + * \brief base class of all value types + */ +class IType : public NonCopyableObj { +private: + std::string m_name; + // TODO: count values, or make an linkedlist + +public: + IType(std::string name) : m_name(std::move(name)) {} + + const std::string& name() const { return m_name; } + + bool operator==(const IType& rhs) const { return this == &rhs; } + + bool operator!=(const IType& rhs) const { return this != &rhs; } +}; + +/** + * \brief type of values. + * + * \tparam T ctype of value + */ +template +class Type : public IType { +protected: + Type(std::string name) : IType(std::move(name)) {} + // TODO: each type owns an allocator + +public: + /** + * \brief helper function for construct a value + * + * \tparam TArgs types of arguments + * \param args arguments + * \return TypedValueRef reference of value + */ + template + TypedValueRef make(TArgs&&... args) const; +}; + +/** + * \brief type of primitive values. + * + * \tparam T ctype of value + */ template -class Type { +class PrimitiveType : public Type { private: - const size_t m_code = T::TYPE_CODE; + PrimitiveType(); public: - inline size_t code() const { return m_code; } + static inline PrimitiveType instance; }; -enum class ValueKind { - Primitive, - Object, +/** + * \brief type of object values. + * + * \tparam T ctype of value + */ +template +class ObjectType : public Type { +public: + ObjectType(std::string name) : Type(name) {} }; /** @@ -71,9 +123,8 @@ enum class ValueKind { * and only the tail node is valid. ValueRef stores a value node, and it may be * an invalid internal node. When you dereference it, it will check its successor, * automatically find the tail node and return. This list would be modified to reduce - * list length by change value's successor, but a ValueRef always has steady m_storage - * when not explicitly modified. - * So we use m_storage to identify a ValueRef ( hash / equility / id ). + * list length by change value's successor, but a steady id was kept in ValueRef + * so we can use it for identify a ValueRef ( hash / equility / id ). */ class ValueRef { public: @@ -93,9 +144,7 @@ private: */ storage_t& storage() const; - const Value* as(size_t typecode) const; - - bool is(size_t typecode) const; + const Value* as(const IType& type) const; public: ValueRef() = default; @@ -103,45 +152,76 @@ public: /** * \brief whether value is instance of target type or not * - * \tparam TValue target type - * \return true if type of value is TValue - * \return false if empty or type of value is not TValue + * \param type target type + * \return true if type of value is instance of type + * \return false if empty or type of value is not instance of type */ - template - inline bool is(Type type = {}) const; + bool is(const IType& type) const; /** * \brief try cast value as target type * - * \tparam TValue target type + * \tparam type target type * \return TValue* raw pointer if success, otherwise nullptr */ template - inline const TValue* as(Type type = {}) const; + inline const TValue* as(const Type& type) const; /** * \brief cast value to target type * - * \tparam TValue target type + * \param type target type * \return TValue& reference of value */ template - inline const TValue& cast(Type type = {}) const; + inline const TValue& cast(const Type& type) const; /** * \brief like as(), but returns TypedValueRef instead * - * \tparam TValue target type + * \param type target type * \return TypedValueRef reference if success, otherwise empty reference */ template - inline const TypedValueRef& as_ref(Type type = {}) const; + inline const TypedValueRef& as_ref(const Type& type) const; + + /** + * \brief like cast(), but allow empty value and returns TypedValueRef instead + * + * \param type target type + * \return TypedValueRef reference if success, otherwise empty reference + */ + template + inline const TypedValueRef& cast_ref(const Type& type) const; + + template + inline std::enable_if_t is() const { + return is(PrimitiveType::instance); + } + + template + inline std::enable_if_t as() const { + return as(PrimitiveType::instance); + } + + template + inline std::enable_if_t cast() const { + return cast(PrimitiveType::instance); + } template - inline const TypedValueRef& cast_ref(Type type = {}) const; + inline std::enable_if_t&> as_ref() + const { + return as_ref(PrimitiveType::instance); + } template - void on_cast_failure() const; + inline std::enable_if_t&> + cast_ref() const { + return cast_ref(PrimitiveType::instance); + } + + void on_cast_failure(const IType& type) const; operator bool() const { return bool(m_storage); } @@ -172,8 +252,6 @@ public: friend class ValueWeakRef; template friend class TypedValueRef; - template - friend class ValueImpl; friend ValueRefList apply(const Operator& op, Span inputs); }; @@ -195,7 +273,8 @@ protected: public: ValueWeakRef() = default; - ValueWeakRef(ValueRef value) : m_id(value.id()), m_storage(value.m_storage) {} + ValueWeakRef(const ValueRef& value) + : m_id(value.id()), m_storage(value.m_storage) {} /** * \brief try promote to ValueRef @@ -218,19 +297,15 @@ public: class Value : public NonCopyableObj { private: uint64_t m_id = std::numeric_limits::max(); - size_t m_typecode = 0; + const IType* m_type = nullptr; ValueRef m_successor; size_t m_watching = 0; protected: - Value(size_t typecode); + Value(); public: - size_t typecode() const { return m_typecode; } - const std::type_index type() const { return registered_types()[m_typecode]; } - - static size_t register_type(std::type_index type); - static const std::vector& registered_types(); + const IType& type() const { return *m_type; } static void register_value(ValueRef value); static ValueRef get_value_by_id(uint64_t id); @@ -251,11 +326,12 @@ public: friend class ValueRef; friend class ValueWeakRef; - template - friend class ValueImpl; template friend class TypedValueRef; + template + friend class Type; + ~Value(); private: @@ -267,30 +343,17 @@ private: * * \tparam T type of value */ -template -class ValueImpl : public Value { +template +class ObjectValue : public Value { protected: - ValueImpl() : Value(TYPE_CODE) {} + ObjectValue() {} public: using ref_t = TypedValueRef; using weak_ref_t = TypedValueWeakRef; - static inline const size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); - static constexpr ValueKind KIND = Kind; - - /** - * \brief helper function for construct a value - * - * \tparam TArgs types of arguments - * \param args arguments - * \return TypedValueRef reference of value - */ - template - static MGB_NOINLINE TypedValueRef make(TArgs&&... args) { - static_assert(std::is_final_v); - return ValueRef::make(LocalPtr::make(std::forward(args)...)); - } + static constexpr bool is_primitive = false; + static constexpr bool is_object = true; }; /** @@ -299,74 +362,89 @@ public: * \tparam T type of value * \tparam TMixin type of mixin class */ -template -class MixinValueImpl : public ValueImpl, public TMixin { +template +class PrimitiveValue : public Value, public TMixin { public: + using ref_t = TypedValueRef; + using weak_ref_t = TypedValueWeakRef; + using TMixin::TMixin; - MixinValueImpl(TMixin mixin) : TMixin(std::move(mixin)) {} + PrimitiveValue(TMixin&& mixin) : TMixin(std::move(mixin)) {} + PrimitiveValue(const TMixin& mixin) : TMixin(mixin) {} public: void clear() override final { ((TMixin&)*this) = {}; } bool eq(const TMixin& value) const { return ((const TMixin&)*this) == value; } + + /** + * \brief helper function for construct a value + * + * \tparam TArgs types of arguments + * \param args arguments + * \return TypedValueRef reference of value + */ + template + static TypedValueRef make(TArgs&&... args) { + return PrimitiveType::instance.make(std::forward(args)...); + } + + static constexpr bool is_primitive = true; + static constexpr bool is_object = false; }; +template +PrimitiveType::PrimitiveType() : Type(typeid(T).name()) { + static_assert(std::is_base_of_v); + static_assert(!std::is_base_of_v, T>); +} + inline ValueRef::ValueRef(storage_t storage) { - // mgb_assert(storage); m_storage = storage; m_id = m_storage->m_id; } template -inline const TValue* ValueRef::as(Type type) const { - // auto _ = Stats::time_value_as.time_scope(); +inline const TValue* ValueRef::as(const Type& type) const { static_assert(std::is_base_of_v); - return static_cast(as(type.code())); + return static_cast(as((const IType&)type)); } template -inline const TValue& ValueRef::cast(Type type) const { - // auto _ = Stats::time_value_cast.time_scope(); +inline const TValue& ValueRef::cast(const Type& type) const { auto* ptr = as(type); if (mgb_unlikely(!ptr)) { - on_cast_failure(); + on_cast_failure(type); } return static_cast(*ptr); } template -inline bool ValueRef::is(Type type) const { - // auto _ = Stats::time_value_is.time_scope(); - return is(type.code()); -} - -template -inline const TypedValueRef& ValueRef::as_ref(Type type) const { - if (!is(type)) { +inline const TypedValueRef& ValueRef::as_ref(const Type& type) const { + if (!is(type)) { return TypedValueRef::nil; } return *reinterpret_cast*>(this); } template -inline const TypedValueRef& ValueRef::cast_ref(Type type) const { +inline const TypedValueRef& ValueRef::cast_ref(const Type& type) const { if (!m_storage) { return TypedValueRef::nil; } - if (mgb_unlikely(!is(type))) { - on_cast_failure(); + if (mgb_unlikely(!is(type))) { + on_cast_failure(type); } return *reinterpret_cast*>(this); } -template -void ValueRef::on_cast_failure() const { +inline void ValueRef::on_cast_failure(const IType& type) const { // if this is ErrorValue, rethrow directly storage()->try_rethrow(); mgb_assert( - storage()->m_typecode != TValue::TYPE_CODE, "expect type %s, got %s", - typeid(TValue).name(), to_string().c_str()); + storage()->type() != type, "expect type %s, got %s", type.name().c_str(), + to_string().c_str()); } /** @@ -382,26 +460,10 @@ private: public: TypedValueRef() = default; const T& operator*() const { - if constexpr (T::KIND == ValueKind::Object) { - return this->template cast(); - } else if constexpr (T::KIND == ValueKind::Primitive) { - if (!m_storage) { - on_cast_failure(); - } - return static_cast(*m_storage); - } else { - static_assert(!std::is_same_v); - } - } - const T* operator->() const { - if constexpr (T::KIND == ValueKind::Object) { - return this->template as(); - } else if constexpr (T::KIND == ValueKind::Primitive) { - return static_cast(m_storage.get()); - } else { - static_assert(!std::is_same_v); - } + mgb_assert(m_storage, "empty storage"); + return static_cast(*m_storage); } + const T* operator->() const { return static_cast(m_storage.get()); } /** * \brief reset underlying value to another value @@ -409,7 +471,7 @@ public: * \param successor new value */ inline void reset(ValueRef successor) { - static_assert(T::KIND == ValueKind::Object); + static_assert(std::is_base_of_v, T>); mgb_assert(m_storage); mgb_assert(!m_storage->m_successor); if (m_storage->m_watching) { @@ -422,25 +484,19 @@ public: static inline const TypedValueRef nil; friend class ValueRef; - - template - friend class ValueImpl; + friend class Type; + friend class TypedValueWeakRef; }; template class TypedValueWeakRef : public ValueWeakRef { private: + TypedValueWeakRef(const ValueRef& value) : ValueWeakRef(value) {} + TypedValueWeakRef(const ValueWeakRef& value) : ValueWeakRef(value) {} + public: - TypedValueWeakRef(ValueRef value) : ValueWeakRef(value) {} - TypedValueWeakRef(ValueWeakRef value) : ValueWeakRef(value) {} - TypedValueRef lock() { - auto value = ValueWeakRef::lock(); - if (value) { - return value.template as_ref(); - } else { - return {}; - } - } + TypedValueWeakRef(const TypedValueRef& value) : ValueWeakRef(value) {} + TypedValueRef lock() { return (TypedValueRef)ValueWeakRef::lock(); } }; // TODO: add proxy value type, which is meant to be reset in the end @@ -509,10 +565,14 @@ inline ValueRefList::ValueRefList(ValueRef item) : m_data(inline_storage()), m_s m_data[0] = std::move(item); } -/*class ValueRefList : public SmallVector { -public: - using SmallVector::SmallVector; -};*/ +template +template +TypedValueRef Type::make(TArgs&&... args) const { + static_assert(std::is_final_v); + auto storage = LocalPtr::make(std::forward(args)...); + storage->m_type = this; + return ValueRef::make(std::move(storage)); +} } // namespace imperative } // namespace mgb diff --git a/imperative/src/test/backward_graph.cpp b/imperative/src/test/backward_graph.cpp index a678e51d973a551e1cc519afa4b943133e44eeed..632b53254dafe7106ccd6d919400a4c1dba74f3e 100644 --- a/imperative/src/test/backward_graph.cpp +++ b/imperative/src/test/backward_graph.cpp @@ -123,7 +123,7 @@ TEST(TestImperative, BackwardGraphBasic) { } } inputs.clear(); - auto input_grads = result.graph.apply( + auto input_grads = result.graph.apply( backward_graph_inputs, apply_shared_on_physical_tensor, [&](auto&& x) { return x; }); mgb_assert(input_grads.size() == input_has_grad.size()); @@ -177,7 +177,7 @@ TEST(TestImperative, BackwardGraphIdentity) { } } inputs.clear(); - auto input_grads = result.graph.apply( + auto input_grads = result.graph.apply( backward_graph_inputs, apply_shared_on_physical_tensor, [&](auto&& x) { return x; }); mgb_assert(input_grads.size() == input_has_grad.size()); @@ -244,11 +244,11 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); auto grads = expand_grads( bg.output_mask, - bg.graph.apply( + bg.graph.apply( backward_graph_inputs, apply_shared_on_physical_tensor, [&](auto&& x) { return x; })); - auto precomp = obg.precomp.apply( + auto precomp = obg.precomp.apply( SmallVector{a_tn, b_tn, c_tn}, apply_shared_on_physical_tensor, [&](auto&& x) { return x; }); ASSERT_EQ(precomp.size(), 2); @@ -261,7 +261,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); auto grads2 = expand_grads( obg.input_has_grad, - obg.backward.apply( + obg.backward.apply( backward_inputs, apply_shared_on_physical_tensor, [&](auto&& x) { return x; }));