From 81d8c73a4136765ae58bfb1477a9bddca685c152 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 7 Feb 2022 22:48:57 +0800 Subject: [PATCH] perf(dispatch/trace): serval tricks to speed up trace GitOrigin-RevId: 2bdd70cde2d19f43055218804abd65f4e5a54b89 --- imperative/src/impl/transformations/trace.cpp | 41 +++++++++++-------- .../imperative/transformations/trace.h | 6 +-- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/imperative/src/impl/transformations/trace.cpp b/imperative/src/impl/transformations/trace.cpp index d7c0cd993..f13ec73f6 100644 --- a/imperative/src/impl/transformations/trace.cpp +++ b/imperative/src/impl/transformations/trace.cpp @@ -47,7 +47,7 @@ VarNodeArray TraceResult::dump( auto& node = nodes[input]; // TODO: cambricon CompNode auto host = std::make_shared( - CompNode::load("xpux"), shape, var.dtype); + CompNode::load("xpux"), shape, *var.dtype); OperatorNodeConfig config; // if prefer_input_names, prefer names from dump args // else prefer names got from trace procedure @@ -211,7 +211,6 @@ ValueRefList TracingTransformation::apply_transformation( auto& var_info = m_vars[tracing_value->id()]; switch (get_attr->attr()) { case GetAttr::Shape: - // TODO: reduce h2d when data or value is available var_info.shape_required = true; break; case GetAttr::Data: @@ -301,8 +300,8 @@ void CompiledTransformation::compile() { auto box = make_box(); // TODO: attach ref count, release early auto outputs = opr::InputCallback::make( - *m_graph, [box] { return box->take_value(); }, var_info->device, - var_info->dtype, var_info->shape, io_links, m_input_shape_static); + *m_graph, [box] { return box->take_value(); }, *var_info->device, + *var_info->dtype, var_info->shape, io_links, m_input_shape_static); // attach input_callback to io_links accessor.node = outputs[0].node(); io_links = {outputs[1]}; @@ -312,6 +311,11 @@ void CompiledTransformation::compile() { auto make_output = [&](TraceResult::VarInfo* var_info, SymbolVar node) { VarAccessor accessor; accessor.node = node.node(); + if (var_info->data_required) { + // reduce d2h when data is available + // FIXME: compile should not change var_info in-place + var_info->shape_required = false; + } if (var_info->shape_required) { // TODO: use static infer manager for some vars? auto box = make_box(); @@ -334,6 +338,12 @@ void CompiledTransformation::compile() { accessor.data_getter = [box]() -> DeviceTensorND { return box->get_value(); }; + if (!accessor.shape_getter) { + // also implement shape_getter + accessor.shape_getter = [box]() -> TensorShape { + return box->get_value().shape(); + }; + } } if (var_info->value_required) { struct ValueWithEvent { @@ -341,7 +351,7 @@ void CompiledTransformation::compile() { CompNode::Event* event = nullptr; }; auto box = make_box(); - auto event = EventPool::without_timer().alloc_shared(var_info->device); + auto event = EventPool::without_timer().alloc_shared(*var_info->device); auto callback = [box, event](DeviceTensorND data) { HostTensorND host_val; host_val.copy_from(data); @@ -355,7 +365,7 @@ void CompiledTransformation::compile() { }; SymbolVarArray inputs = io_links; inputs.insert(inputs.begin(), node); - auto output = opr::OutputCallback::make({callback, false, true}, inputs); + auto output = opr::OutputCallback::make({callback, true, true}, inputs); io_links = {output}; accessor.value_getter = [box]() -> HostTensorND { auto&& [value, event] = box->get_value(); @@ -486,11 +496,12 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { DType dtype = *value.dtype(); CompNode device = *value.device(); trace_assert( - var.dtype == dtype, "dtype mismatch: %s vs %s", - var.dtype.name(), dtype.name()); + *var.dtype == dtype, "dtype mismatch: %s vs %s", + var.dtype->name(), dtype.name()); trace_assert( - var.device == device, "comp_node mismatch: %s vs %s", - var.device.to_string().c_str(), device.to_string().c_str()); + *var.device == device, "comp_node mismatch: %s vs %s", + var.device->to_string().c_str(), + device.to_string().c_str()); } var_accessor.data_setter(value.dev_tensor()->as_nd()); break; @@ -535,17 +546,11 @@ ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const { } DTypeValue::ref_t CompiledTransformation::TracedInfo::dtype() const { - if (!m_dtype) { - m_dtype = DTypeValue::make(m_var->dtype); - } - return m_dtype; + return m_var->dtype; } CompNodeValue::ref_t CompiledTransformation::TracedInfo::comp_node() const { - if (!m_comp_node) { - m_comp_node = CompNodeValue::make(m_var->device); - } - return m_comp_node; + return m_var->device; } auto CompiledTransformation::TracedInfo::accessor() const -> const VarAccessor& { return *m_accessor; diff --git a/imperative/src/include/megbrain/imperative/transformations/trace.h b/imperative/src/include/megbrain/imperative/transformations/trace.h index e33af27cd..c73b4ce47 100644 --- a/imperative/src/include/megbrain/imperative/transformations/trace.h +++ b/imperative/src/include/megbrain/imperative/transformations/trace.h @@ -44,8 +44,8 @@ struct TraceResult { }; size_t id; - DType dtype; - CompNode device; + DTypeValue::ref_t dtype; + CompNodeValue::ref_t device; // if exists, assert equal when meet ValueRef bound_data; @@ -162,7 +162,7 @@ public: TypedValueRef record_var(ValueRef value, bool capture, VarKind kind) { size_t id = m_vars.size(); auto wrapped_value = TracingValue::make(value, id); - m_vars.push_back({id, *value.dtype(), *value.device()}); + m_vars.push_back({id, value.dtype(), value.device()}); auto& var = m_vars.back(); if (capture) { var.bound_data = value; -- GitLab