From 4fa6162027031866a650aefc77c31df8c76fe7c7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 27 Jan 2022 23:53:24 +0800 Subject: [PATCH] perf(dispatch): improve performance of dispatch system GitOrigin-RevId: 860028e1af63936e7b4edefbed90d8244e7cb8d2 --- imperative/python/src/module_trace.h | 1 + imperative/python/src/tensor.cpp | 121 +++++++++------ imperative/python/src/transformation.h | 3 +- imperative/src/impl/dispatch.cpp | 6 +- .../src/impl/interpreter/interpreter_impl.cpp | 1 + imperative/src/impl/transformation.cpp | 1 + imperative/src/impl/transformations/eval.cpp | 10 +- imperative/src/impl/transformations/grad.cpp | 100 +++++++------ .../src/impl/transformations/scalar.cpp | 15 +- imperative/src/impl/value.cpp | 4 +- .../megbrain/imperative/basic_values.h | 43 +++--- .../include/megbrain/imperative/dispatch.h | 9 +- .../include/megbrain/imperative/graph_cache.h | 5 + .../imperative/transformations/eval.h | 2 +- .../imperative/transformations/grad.h | 77 +++++----- .../imperative/transformations/lazy.h | 3 +- .../imperative/transformations/scalar.h | 2 +- .../imperative/transformations/symbol.h | 2 +- .../imperative/transformations/trace.h | 6 +- .../include/megbrain/imperative/utils/stats.h | 140 ++++++++++++++++++ .../src/include/megbrain/imperative/value.h | 76 +++++++--- 21 files changed, 442 insertions(+), 185 deletions(-) create mode 100644 imperative/src/include/megbrain/imperative/utils/stats.h diff --git a/imperative/python/src/module_trace.h b/imperative/python/src/module_trace.h index 835219fec..7ed0be906 100644 --- a/imperative/python/src/module_trace.h +++ b/imperative/python/src/module_trace.h @@ -13,6 +13,7 @@ #include "megbrain/imperative/transformations/trace.h" #include "megbrain/imperative/utils/map.h" +#include "megbrain/imperative/utils/stats.h" #include "./tensor.h" diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 36f69ba4c..b3a30d12b 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -21,6 +21,7 @@ #include "megbrain/imperative/transformations/symbol.h" #include "megbrain/imperative/transformations/trace.h" #include "megbrain/imperative/utils/map.h" +#include "megbrain/imperative/utils/stats.h" #include "megbrain/opr/io.h" #include "megbrain/plugin/profiler.h" @@ -52,8 +53,48 @@ namespace mgb::imperative::python { namespace { WeakKeyMap module_trace_info_map; + +struct SymbolVarContext { + TransformationContext context; + cg::ComputingGraph* graph; + + SymbolVarContext(cg::ComputingGraph* graph) : graph(graph) { + Transformation::swap_context(context); + } + + void init() { + std::make_shared(graph)->register_at( + Transformation::top()); + std::make_shared()->register_at(Transformation::top()); + } + + ~SymbolVarContext() { Transformation::swap_context(context); } +}; + +ValueRef symvar2val(py::handle py_symbol_var) { + auto* symbol_var = py_symbol_var.cast(); + ValueRef value = SymbolValue::make(symbol_var->m_node); + if (symbol_var->is_scalar) { + value = ScalarValue::make(value); + } + return value; +} + +py::object val2symvar(py::handle typeobj, ValueRef value) { + bool is_scalar = false; + if (auto* scalar_value = value.as()) { + value = scalar_value->value(); + is_scalar = true; + } + auto* node = value.cast().node(); + auto py_symbol_var = + typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic)); + py_symbol_var.cast()->is_scalar = is_scalar; + return py_symbol_var; } +} // namespace + interpreter::Interpreter::Channel* interpreter_for_py = nullptr; PyTypeObject* py_tensor_type = nullptr; PyObject *cpp_use_symbolic_shape, *cpp_astensor1d; @@ -91,36 +132,17 @@ PyObject* py_apply( if (py::isinstance(py::handle(args[0]))) { // swap to a special context to reuse scalar handle - TransformationContext symbol_var_context; - Transformation::swap_context(symbol_var_context); - CleanupGuard _{[&] { Transformation::swap_context(symbol_var_context); }}; - auto* graph = - py::handle(args[0]).cast()->m_node->owner_graph(); - std::make_shared(graph)->register_at( - Transformation::top()); - std::make_shared()->register_at( - Transformation::top()); + SymbolVarContext context( + py::handle(args[0]).cast()->m_node->owner_graph()); + context.init(); for (size_t i = 0; i < nargs; ++i) { - auto* py_input = py::handle(args[i]).cast(); - ValueRef input = SymbolValue::make(py_input->m_node); - if (py_input->is_scalar) { - input = ScalarValue::make(input); - } - tensors[i] = input; + tensors[i] = symvar2val(args[i]); } auto outputs = imperative::apply(*op, tensors); auto ret = pybind11::tuple(outputs.size()); auto typeobj = py::handle(args[0]).get_type(); for (size_t i = 0; i < outputs.size(); ++i) { - bool is_scalar = false; - if (auto* scalar_value = outputs[i].as()) { - outputs[i] = scalar_value->value(); - is_scalar = true; - } - auto* node = outputs[i].cast().node(); - ret[i] = typeobj( - pybind11::cast(node, pybind11::return_value_policy::automatic)); - py::handle(ret[i]).cast()->is_scalar = is_scalar; + ret[i] = val2symvar(typeobj, outputs[i]); } return ret.release().ptr(); } @@ -1537,17 +1559,29 @@ void init_tensor(py::module m) { } }); - m.def("reduce_to_scalar", [](py::object op, py::object tensor) { - auto* tw = TensorWrapper::try_cast(tensor.ptr()); - auto make_scalar_shape = [&](CompNode device) { - return imperative::apply( - CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}), - HostStorage::make(device))[0]; + m.def("reduce_to_scalar", [](py::object op, py::object tensor) -> py::object { + auto reduce_to_scalar = [](const OpDef& op, const ValueRef& input) { + auto make_scalar_shape = [&](CompNode device) { + return imperative::apply( + CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}), + HostStorage::make(device))[0]; + }; + return imperative::apply(op, input, make_scalar_shape(*input.device()))[0]; }; - auto output = imperative::apply( - *op.cast>(), tw->m_tensor->data(), - make_scalar_shape(tw->m_tensor->comp_node()))[0]; - return TensorWrapper::make(py_tensor_type, output); + if (py::isinstance(tensor)) { + auto* graph = tensor.cast()->m_node->owner_graph(); + SymbolVarContext context(graph); + context.init(); + auto output = reduce_to_scalar( + *op.cast>(), symvar2val(tensor)); + auto typeobj = tensor.get_type(); + return val2symvar(typeobj, output); + } else { + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + auto output = reduce_to_scalar( + *op.cast>(), tw->m_tensor->data()); + return TensorWrapper::make(py_tensor_type, output); + } }); m.def("name_tensor", [](std::string name, py::object tensor) { @@ -1557,7 +1591,7 @@ void init_tensor(py::module m) { }); m.def("is_grad_attached", [](std::vector tensors) -> bool { - ValueRefList values(tensors.size()); + SmallVector values(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { values[i] = tensors[i].cast().m_tensor->data(); } @@ -1570,17 +1604,16 @@ void init_tensor(py::module m) { }); m.def("get_grad_key", [](std::vector tensors) -> py::object { - ValueRefList values(tensors.size()); + SmallVector values(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { values[i] = tensors[i].cast().m_tensor->data(); } - auto outputs = imperative::apply(GetGradKey(), values); - if (auto* grad_key_val = outputs[0].as()) { - return py::reinterpret_borrow( - GradKeyWrapper::wrap_t::pycast(GradKeyWrapper::get(*grad_key_val))); - } else { + auto output = imperative::apply(GetGradKey(), values)[0]; + if (!output) { return py::none(); } + return py::reinterpret_borrow(GradKeyWrapper::wrap_t::pycast( + GradKeyWrapper::get(output.cast()))); }); m.def("set_grad", [](py::object py_key, py::function backward_fn, @@ -1612,7 +1645,7 @@ void init_tensor(py::module m) { } return input_grads; }; - ValueRefList values(inputs.size() + outputs.size()); + SmallVector values(inputs.size() + outputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { values[i] = inputs[i].cast().m_tensor->data(); } @@ -1669,6 +1702,10 @@ void init_tensor(py::module m) { return reprs; }); + m.def("print_stats", [] { imperative::Stats::print(); }); + + m.def("reset_stats", [] { imperative::Stats::reset(); }); + py::register_exception(m, "TraceError"); } diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index d07666b70..024578d39 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -67,7 +67,8 @@ struct TransformationManager { } }; -class PyValue final : public MixinValueImpl { +class PyValue final + : public MixinValueImpl { public: using MixinValueImpl::MixinValueImpl; diff --git a/imperative/src/impl/dispatch.cpp b/imperative/src/impl/dispatch.cpp index 2a6071375..5b0d13143 100644 --- a/imperative/src/impl/dispatch.cpp +++ b/imperative/src/impl/dispatch.cpp @@ -14,13 +14,9 @@ #include "megbrain/imperative/utils/debug.h" #include "megbrain/imperative/utils/helper.h" #include "megbrain/imperative/utils/map.h" +#include "megbrain/imperative/utils/stats.h" namespace mgb { - -void imperative_log_profile_begin(const char* message); -void imperative_log_profile(const char* message); -void imperative_log_profile_end(const char* message); - namespace imperative { namespace { diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 3971073c0..48a62b2d7 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -19,6 +19,7 @@ #include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/utility.h" +#include "megbrain/imperative/utils/stats.h" #include "megbrain/imperative/utils/to_string.h" #include "../blob_manager_impl.h" diff --git a/imperative/src/impl/transformation.cpp b/imperative/src/impl/transformation.cpp index 2b3b326f9..32c0c4b3c 100644 --- a/imperative/src/impl/transformation.cpp +++ b/imperative/src/impl/transformation.cpp @@ -1,4 +1,5 @@ #include "megbrain/imperative/transformation.h" +#include "megbrain/imperative/utils/stats.h" namespace mgb { namespace imperative { diff --git a/imperative/src/impl/transformations/eval.cpp b/imperative/src/impl/transformations/eval.cpp index 35ec7e76e..87a821aa0 100644 --- a/imperative/src/impl/transformations/eval.cpp +++ b/imperative/src/impl/transformations/eval.cpp @@ -11,6 +11,7 @@ #include "megbrain/imperative/transformations/eval.h" #include "megbrain/imperative/transformations/grad.h" +#include "megbrain/imperative/utils/stats.h" namespace mgb { namespace imperative { @@ -40,9 +41,6 @@ ShapeValue::ref_t InterpreterInfo::shape() const { ValueRefList InterpreterTransformation::apply_op( const ApplyOp& apply_op, Span inputs) { - if (apply_op.op().same_type()) { - return {inputs[0]}; - } SmallVector input_handles; SmallVector output_handles; CleanupGuard _{[&] { @@ -111,7 +109,11 @@ ValueRefList InterpreterTransformation::apply_create_tensor( ValueRefList InterpreterTransformation::apply_transformation( const Operator& op, Span inputs) { if (auto* op_val = op.as()) { - return apply_op(*op_val, inputs); + if (op_val->op().same_type()) { + return inputs[0]; + } else { + return apply_op(*op_val, inputs); + } } else if (auto* get_attr = op.as()) { return apply_get_attr(*get_attr, inputs); } else if (auto* create_tensor = op.as()) { diff --git a/imperative/src/impl/transformations/grad.cpp b/imperative/src/impl/transformations/grad.cpp index 50d774398..50d0fad0c 100644 --- a/imperative/src/impl/transformations/grad.cpp +++ b/imperative/src/impl/transformations/grad.cpp @@ -11,8 +11,11 @@ #include "megbrain/imperative/transformations/grad.h" +#include + #include "megbrain/imperative/graph_cache.h" #include "megbrain/imperative/resource_manager.h" +#include "megbrain/imperative/utils/stats.h" #include @@ -20,20 +23,21 @@ namespace mgb { namespace imperative { static std::shared_ptr make_optimized_backward_graph( - std::shared_ptr op, Span inputs, Span outputs, + const OpDef& op, Span inputs, Span outputs, Span inputs_require_grad) { // hash using OptimizedBackwardGraphCache = OpMethResultCache< std::shared_ptr, SmallVector>; thread_local auto& cache = *ResourceManager::create_local(); - OptimizedBackwardGraphCache::key_t cache_key{op}; + OptimizedBackwardGraphCache::key_t cache_key{op.shared_from_this()}; SmallVector& input_descs = cache_key.inputs; - std::get<0>(cache_key.extras) = inputs_require_grad.copy_into>(); + cache_key.extra<0>() = inputs_require_grad.copy_into>(); input_descs.resize(inputs.size()); + // some overhead, consider simplify LogicalTensorDesc for (size_t i = 0; i < inputs.size(); ++i) { - input_descs[i].layout.dtype = inputs[i].dtype().cast(); - input_descs[i].comp_node = inputs[i].device().cast(); + input_descs[i].layout.dtype = *inputs[i].dtype(); + input_descs[i].comp_node = *inputs[i].device(); } auto iter = cache.find(cache_key); @@ -45,7 +49,7 @@ static std::shared_ptr make_optimized_backward_gra SmallVector output_has_grad(outputs.size(), true); std::shared_ptr ret; auto bg = OpDef::make_backward_graph( - *op, input_descs, std::get<0>(cache_key.extras), output_has_grad); + op, input_descs, std::get<0>(cache_key.extras), output_has_grad); if (!bg.graph.empty()) { ret = std::make_shared(bg); } @@ -235,7 +239,7 @@ GradValue::ref_t GradKey::attach( } else { GradSlotPtr grad_slot; auto& grad_fn = grad_slot.m_fn; - grad_fn = std::make_shared(); + grad_fn = LocalPtr::make(); grad_fn->m_key = shared_from_this(); grad_fn->m_slots.resize(1); grad_slot.m_index = 0; @@ -260,17 +264,21 @@ ValueRefList GradTransformation::apply_transformation( const Operator& op, Span inputs) { auto fallback = [&] { ValueRefList unwrapped_inputs(inputs.size()); - for (size_t i = 0; i < inputs.size(); ++i) { - if (auto grad_value = as_grad_value(inputs[i])) { - unwrapped_inputs[i] = grad_value->m_value; - } else { - unwrapped_inputs[i] = inputs[i]; + { + // overhead + for (size_t i = 0; i < inputs.size(); ++i) { + if (auto&& grad_value = as_grad_value(inputs[i])) { + unwrapped_inputs[i] = grad_value->m_value; + } else { + unwrapped_inputs[i] = inputs[i]; + } } } return imperative::apply(op, unwrapped_inputs); }; - if (auto* get_attr = op.as()) { - if (auto grad_value = as_grad_value(inputs.item())) { + if (op.is()) { + // overhead + if (auto&& grad_value = as_grad_value(inputs.item())) { return imperative::apply(op, grad_value->m_value); } else { return imperative::apply(op, inputs); @@ -281,28 +289,29 @@ ValueRefList GradTransformation::apply_transformation( } if (auto* op_val = op.as()) { size_t nr_require_grad = 0; - SmallVector require_grads; - for (auto&& input : inputs) { - if (is_grad_value(input)) { + SmallVector require_grads(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_grad_value(inputs[i])) { nr_require_grad++; - require_grads.push_back(true); + require_grads[i] = true; } else { - require_grads.push_back(false); + require_grads[i] = false; } } if (nr_require_grad == 0) { return imperative::apply(op, inputs); } - ValueRefList captured_inputs(inputs.size()); + SmallVector captured_inputs(inputs.size()); SmallVector inputs_require_grad(inputs.size()); // capture value so that trace could assume input as same - auto capture_value = [](ValueRef value) { + auto capture_value = [](const ValueRef& value) { // TODO: fastpath copy shouldn't be an OpDef - return imperative::apply(ApplyOp(*FastpathCopy::make()), {&value, 1})[0]; + static auto fastpath_copy = FastpathCopy::make(); + return imperative::apply(ApplyOp(*fastpath_copy), value)[0]; }; for (size_t i = 0; i < inputs.size(); ++i) { auto& input = inputs[i]; - if (auto grad_value = as_grad_value(input)) { + if (auto&& grad_value = as_grad_value(input)) { captured_inputs[i] = capture_value(grad_value->m_value); inputs_require_grad[i] = true; } else { @@ -310,32 +319,28 @@ ValueRefList GradTransformation::apply_transformation( inputs_require_grad[i] = false; } } - decltype(std::declval().m_backward) backward_storage; + // copy grad_fn->m_backward is expensive + auto grad_fn = LocalPtr::make(); + auto& backward_storage = grad_fn->m_backward; auto outputs = [&] { auto backward_rule = CustomBackward::lookup_grad_rule(op_val->op().dyn_typeinfo()); if (backward_rule) { CustomBackward backward; auto optional_outputs = backward_rule( - op_val->op(), {captured_inputs.data(), captured_inputs.size()}, - {inputs_require_grad.data(), inputs_require_grad.size()}, - backward); + op_val->op(), captured_inputs, inputs_require_grad, backward); if (optional_outputs) { backward_storage = backward; // backward by rule return *optional_outputs; } } - auto outputs = imperative::apply( - op, {captured_inputs.begin(), captured_inputs.end()}); + auto outputs = imperative::apply(op, captured_inputs); auto backward_graph = make_optimized_backward_graph( - op.cast().op().shared_from_this(), - {captured_inputs.begin(), captured_inputs.end()}, - {outputs.data(), outputs.size()}, - {inputs_require_grad.data(), inputs_require_grad.size()}); + op_val->op(), captured_inputs, outputs, inputs_require_grad); if (backward_graph) { backward_storage = BackwardGraphWithClosure( - backward_graph, op.cast().op().shared_from_this(), + backward_graph, op_val->op().shared_from_this(), {captured_inputs.begin(), captured_inputs.end()}, {outputs.data(), outputs.size()}); // backward by make_backward_graph @@ -348,18 +353,17 @@ ValueRefList GradTransformation::apply_transformation( if (std::holds_alternative(backward_storage)) { return outputs; } - auto grad_fn = std::make_shared(); grad_fn->m_key = m_key; grad_fn->m_slots.resize(outputs.size()); - grad_fn->m_backward = backward_storage; mgb_assert(!outputs.empty()); grad_fn->m_dests.reserve(inputs.size()); // clang-format off - std::visit([&](auto& backward) { + auto visitor = [&](auto& backward) { using T = std::decay_t; if constexpr (std::is_same_v) { mgb_throw(AssertionError, "invalid backward"); } else { + // little overhead for (size_t i = 0; i < inputs.size(); ++i) { if (backward.input_has_grad(i) && require_grads[i]) { auto& input_grad_slot = @@ -373,19 +377,23 @@ ValueRefList GradTransformation::apply_transformation( } for (size_t i = 0; i < outputs.size(); ++i) { if (backward.output_requires_grad(i)) { + // little overhead: Value::make auto grad_value = GradValue::make(outputs[i], m_key, GradSlotPtr{grad_fn, i}); outputs[i] = record_grad(grad_value); } } } - }, grad_fn->m_backward); + }; + // std::visit may be slightly slower than direct if + std::visit(visitor, backward_storage); // clang-format on mgb_assert(!grad_fn->m_slots.empty()); m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()}); return outputs; } else if (op.is()) { return imperative::apply(op, inputs); - } else if (auto* attach_grad = op.as()) { + } + if (auto* attach_grad = op.as()) { if (!has_key(attach_grad->key())) { return fallback(); } @@ -408,7 +416,7 @@ ValueRefList GradTransformation::apply_transformation( return {}; } else if (auto* is_attached_to = op.as()) { if (has_key(is_attached_to->key())) { - if (auto grad_value = as_grad_value(inputs[0])) { + if (auto&& grad_value = as_grad_value(inputs[0])) { // TODO: assert grad_fn return {BoolValue::make(true)}; } @@ -416,7 +424,7 @@ ValueRefList GradTransformation::apply_transformation( return {BoolValue::make(false)}; } else if (auto* set_grad = op.as()) { // TODO: merge SetGrad and ApplyOp - auto grad_fn = std::make_shared(); + auto grad_fn = LocalPtr::make(); auto& backward = std::get(grad_fn->m_backward = CustomBackward()); size_t nr_inputs = set_grad->nr_inputs(); @@ -433,7 +441,7 @@ ValueRefList GradTransformation::apply_transformation( grad_fn->m_slots.resize(nr_outputs); grad_fn->m_dests.reserve(nr_inputs); for (size_t i = 0; i < nr_inputs; ++i) { - if (auto grad_value = as_grad_value(inputs_[i])) { + if (auto&& grad_value = as_grad_value(inputs_[i])) { auto& input_grad_slot = grad_value->m_slot; grad_fn->m_dests.emplace_back(grad_value->m_slot); grad_fn->m_dests.back().m_producer_record.insert_after( @@ -461,21 +469,21 @@ ValueRefList GradTransformation::apply_transformation( } return {FunctionValue::make(make_backward_closure(inputs))}; } else if (op.is()) { - if (auto grad_value = as_grad_value(inputs[0])) { + if (auto&& grad_value = as_grad_value(inputs[0])) { return {grad_value->m_value}; } else { return {inputs[0]}; } } else if (op.is()) { for (auto&& input : inputs) { - if (auto grad_value = as_grad_value(input)) { + if (auto&& grad_value = as_grad_value(input)) { return {GradKeyValue::make(grad_value->m_key)}; } } return imperative::apply(op, inputs); } else if (op.kind() == Operator::IdentityLike) { mgb_assert(inputs.size() == 1); - if (auto grad_value = as_grad_value(inputs[0])) { + if (auto&& grad_value = as_grad_value(inputs[0])) { auto output = imperative::apply(op, grad_value->m_value)[0]; auto grad_output = GradValue::make( output, grad_value->key(), grad_value->slot_for(m_key)); @@ -493,7 +501,7 @@ GenericFunction GradTransformation::make_backward_closure(Span ys) { auto grad_key = m_key; std::vector y_slots; for (auto&& y : ys) { - if (auto grad_value = as_grad_value(y)) { + if (auto&& grad_value = as_grad_value(y)) { y_slots.push_back(grad_value->slot_for(grad_key)); } else { y_slots.emplace_back(); diff --git a/imperative/src/impl/transformations/scalar.cpp b/imperative/src/impl/transformations/scalar.cpp index 7250847a3..338089bc1 100644 --- a/imperative/src/impl/transformations/scalar.cpp +++ b/imperative/src/impl/transformations/scalar.cpp @@ -13,6 +13,7 @@ #include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/utility.h" +#include "megbrain/imperative/utils/stats.h" namespace mgb { namespace imperative { @@ -185,7 +186,7 @@ ValueRefList subtensor_rule( bool is_scalar; mgb_assert(!inputs_mask[0], "subtensor shouldn't have scalar input"); if (auto shape = input.shape()) { - size_t ndim = input.shape()->ndim; + size_t ndim = shape->ndim; for (auto&& [axis, begin, end, step, idx] : subtensor.items) { if (idx) { ndim--; @@ -193,6 +194,7 @@ ValueRefList subtensor_rule( } is_scalar = ndim == 0; } else { + // assume not scalar is_scalar = false; } auto outputs = imperative::apply(subtensor, inputs); @@ -341,12 +343,16 @@ ValueRefList ScalarTransformation::apply_transformation( if (auto* get_attr = op.as()) { // fastpath for GetAttr return apply_get_attr(*get_attr, inputs); + } else if (auto* apply_op = op.as()) { + if (apply_op->op().same_type()) { + return inputs[0]; + } } size_t nr_inputs = inputs.size(); ValueRefList unwrapped_inputs(nr_inputs); - bool inputs_mask[nr_inputs]; + SmallVector inputs_mask(nr_inputs); for (size_t i = 0; i < inputs.size(); ++i) { - if (auto scalar_value = inputs[i].as_ref()) { + if (auto&& scalar_value = inputs[i].as_ref()) { unwrapped_inputs[i] = scalar_value->value(); inputs_mask[i] = true; } else { @@ -358,8 +364,7 @@ ValueRefList ScalarTransformation::apply_transformation( if (auto apply_op = op.as()) { auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); if (iter != scalar_rules.end()) { - return iter->second( - apply_op->op(), unwrapped_inputs, {inputs_mask, nr_inputs}); + return iter->second(apply_op->op(), unwrapped_inputs, inputs_mask); } else { // TODO: repeat op return fallback(); diff --git a/imperative/src/impl/value.cpp b/imperative/src/impl/value.cpp index 857f295ab..47a091804 100644 --- a/imperative/src/impl/value.cpp +++ b/imperative/src/impl/value.cpp @@ -215,8 +215,8 @@ ValueRefList::ValueRefList(size_t nr_elems) { init(nr_elems); } -ValueRefList::ValueRefList(std::initializer_list values) - : ValueRefList(values.begin(), values.end()) {} +/*ValueRefList::ValueRefList(std::initializer_list values) + : ValueRefList(values.begin(), values.end()) {}*/ ValueRefList::ValueRefList(const ValueRefList& rhs) : ValueRefList(rhs.cbegin(), rhs.cend()) {} diff --git a/imperative/src/include/megbrain/imperative/basic_values.h b/imperative/src/include/megbrain/imperative/basic_values.h index 5789b66c4..4b97bb2fc 100644 --- a/imperative/src/include/megbrain/imperative/basic_values.h +++ b/imperative/src/include/megbrain/imperative/basic_values.h @@ -25,14 +25,16 @@ class GradKey; using GenericFunction = std::function)>; -class ShapeValue final : public MixinValueImpl { +class ShapeValue final + : public MixinValueImpl { public: using MixinValueImpl::MixinValueImpl; std::string to_string() const override; }; -class CompNodeValue final : public MixinValueImpl { +class CompNodeValue final + : public MixinValueImpl { public: using MixinValueImpl::MixinValueImpl; @@ -40,7 +42,7 @@ public: }; // TODO: override factory method -class BoolValue final : public ValueImpl { +class BoolValue final : public ValueImpl { private: std::optional m_value; @@ -53,14 +55,17 @@ public: void clear() override { m_value.reset(); } }; -class HostStorage final : public MixinValueImpl { +class HostStorage final + : public MixinValueImpl { public: using MixinValueImpl::MixinValueImpl; std::string to_string() const override; }; -class DeviceStorage final : public MixinValueImpl { +class DeviceStorage final + : public MixinValueImpl< + DeviceStorage, ValueKind::Primitive, DeviceTensorStorage> { public: using MixinValueImpl::MixinValueImpl; @@ -71,7 +76,7 @@ public: * \brief like HostTensorND mixin, but allow scalar value * */ -class HostValue final : public ValueImpl { +class HostValue final : public ValueImpl { private: DType m_dtype; ValueShape m_shape; @@ -94,9 +99,9 @@ public: } DType dtype() const { return m_dtype; } - ValueShape shape() const { return m_shape; } + const ValueShape& shape() const { return m_shape; } CompNode device() const { return m_storage.comp_node(); } - HostTensorStorage storage() const { return m_storage; } + const HostTensorStorage& storage() const { return m_storage; } DTypeScalar item() const { mgb_assert(m_shape.is_scalar()); return DTypeScalar::make_from_raw(m_dtype, m_storage.ptr()); @@ -109,7 +114,7 @@ public: * \brief like DeviceTensorND mixin, but allow scalar value * */ -class DeviceValue final : public ValueImpl { +class DeviceValue final : public ValueImpl { private: DType m_dtype; ValueShape m_shape; @@ -117,8 +122,8 @@ private: public: DeviceValue(DType dtype, ValueShape shape, DeviceTensorStorage storage) - : m_dtype(dtype), m_shape(shape), m_storage(storage) {} - DeviceValue(DeviceTensorND value) + : m_dtype(dtype), m_shape(shape), m_storage(std::move(storage)) {} + DeviceValue(const DeviceTensorND& value) : DeviceValue( value.dtype(), ValueShape::from(value.shape()), value.storage()) { } @@ -132,28 +137,31 @@ public: } DType dtype() const { return m_dtype; } - ValueShape shape() const { return m_shape; } + const ValueShape& shape() const { return m_shape; } CompNode device() const { return m_storage.comp_node(); } - DeviceTensorStorage storage() const { return m_storage; } + const DeviceTensorStorage& storage() const { return m_storage; } DeviceTensorND as_nd(bool allow_scalar = false) const; }; -class FunctionValue final : public MixinValueImpl { +class FunctionValue final + : public MixinValueImpl { public: using MixinValueImpl::MixinValueImpl; std::string to_string() const override; }; -class DTypeValue final : public MixinValueImpl { +class DTypeValue final + : public MixinValueImpl { public: using MixinValueImpl::MixinValueImpl; std::string to_string() const override; }; -class StringValue final : public MixinValueImpl { +class StringValue final + : public MixinValueImpl { public: using MixinValueImpl::MixinValueImpl; @@ -171,7 +179,8 @@ public: std::string message() const { return m_message; } }; -class ErrorValue final : public MixinValueImpl { +class ErrorValue final + : public MixinValueImpl { public: using MixinValueImpl::MixinValueImpl; diff --git a/imperative/src/include/megbrain/imperative/dispatch.h b/imperative/src/include/megbrain/imperative/dispatch.h index a81facb13..d39a952a8 100644 --- a/imperative/src/include/megbrain/imperative/dispatch.h +++ b/imperative/src/include/megbrain/imperative/dispatch.h @@ -47,9 +47,14 @@ constexpr bool is_all_value_ref_v = (... && (std::is_base_of_v> || std::is_same_v>)); +template +static ValueRefList apply(T&& op, const ValueRef& arg) { + return imperative::apply(std::forward(op), Span{&arg, 1}); +} + template -static auto apply(T&& op, TArgs&&... args) - -> std::enable_if_t, ValueRefList> { +static auto apply(T&& op, TArgs&&... args) -> std::enable_if_t< + is_all_value_ref_v && sizeof...(args) != 1, ValueRefList> { ValueRef args_arr[sizeof...(TArgs)] = {std::forward(args)...}; return imperative::apply( std::forward(op), diff --git a/imperative/src/include/megbrain/imperative/graph_cache.h b/imperative/src/include/megbrain/imperative/graph_cache.h index a12293ea3..8a1b38e9d 100644 --- a/imperative/src/include/megbrain/imperative/graph_cache.h +++ b/imperative/src/include/megbrain/imperative/graph_cache.h @@ -54,6 +54,11 @@ struct OpMethArgs { return extras == rhs.extras; } + template + auto& extra() { + return std::get(extras); + } + struct hash_t { size_t operator()(const OpMethArgs& key) const { return key.hash(); } }; diff --git a/imperative/src/include/megbrain/imperative/transformations/eval.h b/imperative/src/include/megbrain/imperative/transformations/eval.h index 58d874dd0..8fbd16a71 100644 --- a/imperative/src/include/megbrain/imperative/transformations/eval.h +++ b/imperative/src/include/megbrain/imperative/transformations/eval.h @@ -60,7 +60,7 @@ public: }; class InterpreterValue final - : public MixinValueImpl { + : public MixinValueImpl { public: using MixinValueImpl::MixinValueImpl; diff --git a/imperative/src/include/megbrain/imperative/transformations/grad.h b/imperative/src/include/megbrain/imperative/transformations/grad.h index 0f1db4ce8..93055a0b8 100644 --- a/imperative/src/include/megbrain/imperative/transformations/grad.h +++ b/imperative/src/include/megbrain/imperative/transformations/grad.h @@ -104,37 +104,15 @@ struct ToStringTrait { std::string operator()(const GradSlot& value) const { return value.to_string(); } }; -class GradFn { -private: - std::weak_ptr m_key; - std::vector m_slots; - std::vector m_dests; - std::variant m_backward; - -public: - void clear() { - m_key.reset(); - m_slots.clear(); - m_dests.clear(); - m_backward.emplace(); - } - - std::string to_string() const; - - friend class GradSlotPtr; - friend class GradKey; - friend class GradTransformation; -}; - class GradSlotPtr { private: - std::shared_ptr m_fn; + LocalPtr m_fn; size_t m_index = 0; public: - GradSlotPtr(std::shared_ptr fn, size_t index) : m_fn(fn), m_index(index) {} + GradSlotPtr(LocalPtr fn, size_t index) : m_fn(fn), m_index(index) {} GradSlotPtr() = default; - GradSlot* operator->() const { return &m_fn->m_slots[m_index]; } + GradSlot* operator->() const; operator bool() const { return bool(m_fn); } @@ -171,7 +149,33 @@ struct ToStringTrait { } }; -class GradValue final : public ValueImpl { +class GradFn { +private: + std::weak_ptr m_key; + SmallVector m_slots; + SmallVector m_dests; + std::variant m_backward; + +public: + void clear() { + m_key.reset(); + m_slots.clear(); + m_dests.clear(); + m_backward.emplace(); + } + + std::string to_string() const; + + friend class GradSlotPtr; + friend class GradKey; + friend class GradTransformation; +}; + +inline GradSlot* GradSlotPtr::operator->() const { + return &m_fn->m_slots[m_index]; +} + +class GradValue final : public ValueImpl { private: ValueRef m_value; std::shared_ptr m_key; @@ -179,7 +183,7 @@ private: public: GradValue(ValueRef value, std::shared_ptr key, GradSlotPtr slot = {}) - : m_value(value), m_key(key), m_slot(slot) {} + : m_value(std::move(value)), m_key(std::move(key)), m_slot(slot) {} std::string to_string() const override; @@ -209,12 +213,13 @@ public: class GradKey : public std::enable_shared_from_this { private: std::string m_name; - std::vector, std::shared_ptr>> m_tape; - std::vector, std::shared_ptr>> - m_frozen_tape; + std::vector, std::shared_ptr>> m_tape; + std::vector, std::shared_ptr>> m_frozen_tape; bool m_frozen = false; public: + GradKey() { m_tape.reserve(4 * 1024); } + void backward(); GradValue::ref_t attach(ValueRef tensor, std::function callback); const std::string& name() const { return m_name; } @@ -225,7 +230,8 @@ public: }; class GradKeyValue final - : public MixinValueImpl> { + : public MixinValueImpl< + GradKeyValue, ValueKind::Primitive, std::shared_ptr> { public: using MixinValueImpl::MixinValueImpl; @@ -248,7 +254,7 @@ public: return tensor; } - bool is_grad_value(ValueRef value) { + bool is_grad_value(const ValueRef& value) { if (auto* grad_value = value.as()) { if (grad_value->has_key(m_key)) { return true; @@ -266,13 +272,14 @@ public: * \param value * \return GradValue::ref_t */ - GradValue::ref_t as_grad_value(ValueRef value) { - if (auto grad_value = value.as_ref()) { + const GradValue::ref_t& as_grad_value(const ValueRef& value) { + auto&& grad_value = value.as_ref(); + if (grad_value) { if (grad_value->has_key(m_key)) { return grad_value; } } - return {}; + return GradValue::ref_t::nil; } bool has_key(std::shared_ptr key) { diff --git a/imperative/src/include/megbrain/imperative/transformations/lazy.h b/imperative/src/include/megbrain/imperative/transformations/lazy.h index 07ce55019..f4855b7f4 100644 --- a/imperative/src/include/megbrain/imperative/transformations/lazy.h +++ b/imperative/src/include/megbrain/imperative/transformations/lazy.h @@ -39,7 +39,8 @@ public: std::string name() const { return m_name; } }; -class LazyEvalValue final : public MixinValueImpl { +class LazyEvalValue final + : public MixinValueImpl { public: using MixinValueImpl::MixinValueImpl; diff --git a/imperative/src/include/megbrain/imperative/transformations/scalar.h b/imperative/src/include/megbrain/imperative/transformations/scalar.h index 5db4bfda6..142cafaba 100644 --- a/imperative/src/include/megbrain/imperative/transformations/scalar.h +++ b/imperative/src/include/megbrain/imperative/transformations/scalar.h @@ -17,7 +17,7 @@ namespace mgb::imperative { -class ScalarValue final : public ValueImpl { +class ScalarValue final : public ValueImpl { private: ValueRef m_value; diff --git a/imperative/src/include/megbrain/imperative/transformations/symbol.h b/imperative/src/include/megbrain/imperative/transformations/symbol.h index 56f76df2c..976df56b0 100644 --- a/imperative/src/include/megbrain/imperative/transformations/symbol.h +++ b/imperative/src/include/megbrain/imperative/transformations/symbol.h @@ -22,7 +22,7 @@ namespace mgb::imperative { -class SymbolValue final : public ValueImpl { +class SymbolValue final : public ValueImpl { private: VarNode* m_node = nullptr; diff --git a/imperative/src/include/megbrain/imperative/transformations/trace.h b/imperative/src/include/megbrain/imperative/transformations/trace.h index 134e01107..e33af27cd 100644 --- a/imperative/src/include/megbrain/imperative/transformations/trace.h +++ b/imperative/src/include/megbrain/imperative/transformations/trace.h @@ -111,7 +111,8 @@ public: size_t id() const { return m_id; } }; -class TracingValue final : public MixinValueImpl { +class TracingValue final + : public MixinValueImpl { public: using MixinValueImpl::MixinValueImpl; @@ -256,7 +257,8 @@ public: } }; - class TracedValue final : public MixinValueImpl { + class TracedValue final + : public MixinValueImpl { public: using MixinValueImpl::MixinValueImpl; diff --git a/imperative/src/include/megbrain/imperative/utils/stats.h b/imperative/src/include/megbrain/imperative/utils/stats.h new file mode 100644 index 000000000..9bab330af --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/stats.h @@ -0,0 +1,140 @@ +#pragma once + +#include +#include +#include +#include + +namespace mgb { +namespace imperative { +namespace stats { + +#define MGE_ENABLE_STATS 0 + +class Timer { +public: + using clock_t = std::chrono::system_clock; + +private: + clock_t::duration m_duration = clock_t::duration{0}; + size_t m_timing = 0; + const char* m_name = nullptr; + uint64_t m_count = 0; + size_t m_enabled = 1; + bool m_default_enabled = true; + + struct TimeScopeRecursive { + Timer& timer; + clock_t::time_point start; + bool released = false; + + TimeScopeRecursive(Timer& timer) : timer(timer) { + if (timer.m_enabled && !timer.m_timing++) { + start = clock_t::now(); + } + } + + ~TimeScopeRecursive() { release(); } + + void release() { + if (released) { + return; + } + if (timer.m_enabled) { + if (!--timer.m_timing) { + timer.m_duration += (clock_t::now() - start); + } + timer.m_count++; + } + released = true; + } + }; + + struct EnableScope { + Timer& timer; + bool released = false; + + EnableScope(Timer& timer) : timer(timer) { timer.m_enabled++; } + + ~EnableScope() { release(); } + + void release() { + if (released) { + return; + } + timer.m_enabled--; + released = true; + } + }; + + using TimeScope = TimeScopeRecursive; + +public: + Timer(const char* name, bool default_enabled); + + const char* name() { return m_name; } + auto time_scope() { return TimeScope(*this); } + auto time_scope_recursive() { return TimeScopeRecursive(*this); }; + auto enable_scope() { return EnableScope(*this); } + void reset() { + m_duration = clock_t::duration{0}; + m_count = 0; + m_enabled = m_default_enabled ? 1 : 0; + } + + clock_t::duration get() const { return m_duration; } + uint64_t count() const { return m_count; } +}; +} // namespace stats + +struct Stats { + static inline std::vector sm_timers; + + // register your timers here + // for example: + // + // static inline stats::Timer mytimer; + // + // then use MGE_TIMER_SCOPE(mytimer) to collect durations in your code + + static void print() { + std::vector unused_timers; + + for (auto* timer : sm_timers) { + if (timer->count() == 0) { + unused_timers.push_back(timer->name()); + } else { + printf("%s costs %ld ns, happens %ld times\n", timer->name(), + timer->get().count(), timer->count()); + } + } + + if (!unused_timers.empty()) { + printf("%zu timers unused\n", unused_timers.size()); + } + } + + static void reset() { + for (auto* timer : sm_timers) { + timer->reset(); + } + } +}; + +inline stats::Timer::Timer(const char* name, bool default_enabled) + : m_name(name), m_default_enabled(default_enabled) { + Stats::sm_timers.push_back(this); +} + +#if MGE_ENABLE_STATS +#define MGE_TIMER_SCOPE(name) auto name = Stats::name.time_scope() +#define MGE_TIMER_SCOPE_RELEASE(name) name.release() +#define MGE_TIMER_SCOPE_ENABLE(name) auto name = Stats::name.enable_scope() +#else +#define MGE_TIMER_SCOPE(name) (void)0 +#define MGE_TIMER_SCOPE_RELEASE(name) (void)0 +#define MGE_TIMER_SCOPE_ENABLE(name) (void)0 +#endif + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/value.h b/imperative/src/include/megbrain/imperative/value.h index 999d8bf10..cf4381df9 100644 --- a/imperative/src/include/megbrain/imperative/value.h +++ b/imperative/src/include/megbrain/imperative/value.h @@ -23,6 +23,7 @@ #include "megbrain/imperative/utils/debug.h" #include "megbrain/imperative/utils/local_ptr.h" #include "megbrain/imperative/utils/span.h" +#include "megbrain/imperative/utils/stats.h" namespace mgb { namespace imperative { @@ -58,6 +59,11 @@ public: inline size_t code() const { return m_code; } }; +enum class ValueKind { + Primitive, + Object, +}; + /** * \brief an smart reference of value * @@ -129,10 +135,10 @@ public: * \return TypedValueRef reference if success, otherwise empty reference */ template - inline TypedValueRef as_ref(Type type = {}) const; + inline const TypedValueRef& as_ref(Type type = {}) const; template - inline TypedValueRef cast_ref(Type type = {}) const; + inline const TypedValueRef& cast_ref(Type type = {}) const; template void on_cast_failure() const; @@ -161,14 +167,18 @@ public: static bool any_watching(); + static const ValueRef nil; + friend class ValueWeakRef; - template + template friend class TypedValueRef; - template + template friend class ValueImpl; friend ValueRefList apply(const Operator& op, Span inputs); }; +inline const ValueRef ValueRef::nil; + template <> struct ToStringTrait { public: @@ -241,7 +251,7 @@ public: friend class ValueRef; friend class ValueWeakRef; - template + template friend class ValueImpl; template friend class TypedValueRef; @@ -257,7 +267,7 @@ private: * * \tparam T type of value */ -template +template class ValueImpl : public Value { protected: ValueImpl() : Value(TYPE_CODE) {} @@ -267,6 +277,7 @@ public: using weak_ref_t = TypedValueWeakRef; static inline const size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); + static constexpr ValueKind KIND = Kind; /** * \brief helper function for construct a value @@ -288,8 +299,8 @@ public: * \tparam T type of value * \tparam TMixin type of mixin class */ -template -class MixinValueImpl : public ValueImpl, public TMixin { +template +class MixinValueImpl : public ValueImpl, public TMixin { public: using TMixin::TMixin; @@ -309,12 +320,14 @@ inline ValueRef::ValueRef(storage_t storage) { template inline const TValue* ValueRef::as(Type type) const { - static_assert(std::is_base_of_v, TValue>); + // auto _ = Stats::time_value_as.time_scope(); + static_assert(std::is_base_of_v); return static_cast(as(type.code())); } template inline const TValue& ValueRef::cast(Type type) const { + // auto _ = Stats::time_value_cast.time_scope(); auto* ptr = as(type); if (mgb_unlikely(!ptr)) { on_cast_failure(); @@ -324,26 +337,27 @@ inline const TValue& ValueRef::cast(Type type) const { template inline bool ValueRef::is(Type type) const { + // auto _ = Stats::time_value_is.time_scope(); return is(type.code()); } template -inline TypedValueRef ValueRef::as_ref(Type type) const { +inline const TypedValueRef& ValueRef::as_ref(Type type) const { if (!is(type)) { - return {}; + return TypedValueRef::nil; } - return TypedValueRef(*this); + return *reinterpret_cast*>(this); } template -inline TypedValueRef ValueRef::cast_ref(Type type) const { +inline const TypedValueRef& ValueRef::cast_ref(Type type) const { if (!m_storage) { - return {}; + return TypedValueRef::nil; } if (mgb_unlikely(!is(type))) { on_cast_failure(); } - return TypedValueRef(*this); + return *reinterpret_cast*>(this); } template @@ -363,12 +377,31 @@ void ValueRef::on_cast_failure() const { template class TypedValueRef : public ValueRef { private: - TypedValueRef(ValueRef value) : ValueRef(value) {} + TypedValueRef(ValueRef value) : ValueRef(std::move(value)) {} public: TypedValueRef() = default; - const T& operator*() const { return this->template cast(); } - const T* operator->() const { return this->template as(); } + const T& operator*() const { + if constexpr (T::KIND == ValueKind::Object) { + return this->template cast(); + } else if constexpr (T::KIND == ValueKind::Primitive) { + if (!m_storage) { + on_cast_failure(); + } + return static_cast(*m_storage); + } else { + static_assert(!std::is_same_v); + } + } + const T* operator->() const { + if constexpr (T::KIND == ValueKind::Object) { + return this->template as(); + } else if constexpr (T::KIND == ValueKind::Primitive) { + return static_cast(m_storage.get()); + } else { + static_assert(!std::is_same_v); + } + } /** * \brief reset underlying value to another value @@ -376,6 +409,7 @@ public: * \param successor new value */ inline void reset(ValueRef successor) { + static_assert(T::KIND == ValueKind::Object); mgb_assert(m_storage); mgb_assert(!m_storage->m_successor); if (m_storage->m_watching) { @@ -385,9 +419,11 @@ public: m_storage->m_successor = ValueRef(successor.storage()); } + static inline const TypedValueRef nil; + friend class ValueRef; - template + template friend class ValueImpl; }; @@ -423,7 +459,7 @@ public: ValueRefList() = default; ValueRefList(size_t nr_elems); ValueRefList(ValueRef item); - ValueRefList(std::initializer_list values); + // ValueRefList(std::initializer_list values); template ValueRefList(TIterator begin, TIterator end); ValueRefList(const ValueRefList& rhs); -- GitLab