From d2b67c2a880a509d218319b2262819c701c9c013 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 14 Jan 2022 13:21:46 +0800 Subject: [PATCH] refactor(dispatch): implement trace GitOrigin-RevId: f8d3005732dad0f941d963e8e529f1c11d2d3ca5 --- imperative/src/impl/transformations/trace.cpp | 679 ++++++++++++++++++ .../imperative/transformations/trace.h | 348 +++++++++ 2 files changed, 1027 insertions(+) create mode 100644 imperative/src/impl/transformations/trace.cpp create mode 100644 imperative/src/include/megbrain/imperative/transformations/trace.h diff --git a/imperative/src/impl/transformations/trace.cpp b/imperative/src/impl/transformations/trace.cpp new file mode 100644 index 000000000..9a8074c83 --- /dev/null +++ b/imperative/src/impl/transformations/trace.cpp @@ -0,0 +1,679 @@ +/** + * \file imperative/src/impl/transformations/trace.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/transformations/trace.h" + +#include +#include + +#include "megbrain/gopt/inference.h" +#include "megbrain/graph/helper.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/io.h" +#include "megbrain/opr/utility.h" +#include "megbrain/serialization/serializer.h" + +#include "../event_pool.h" + +#define trace_assert(_cond, _msg...) \ + do { \ + if (mgb_unlikely(!(_cond))) { \ + auto exc = std::make_exception_ptr(TraceError(ssprintf(_msg))); \ + set_exception(exc); \ + std::rethrow_exception(exc); \ + } \ + } while (0) + +namespace mgb { +namespace imperative { + +VarNodeArray TraceResult::dump( + ComputingGraph& graph, + std::vector> inputs, + std::vector> outputs, bool prefer_input_names) { + // var -> VarNode + std::vector nodes(vars.size(), nullptr); + // make h2d node for each input + for (auto&& [input, name, shape] : inputs) { + auto& var = vars[input]; + auto& node = nodes[input]; + // TODO: cambricon CompNode + auto host = std::make_shared( + CompNode::load("xpux"), shape, var.dtype); + OperatorNodeConfig config; + // if prefer_input_names, prefer names from dump args + // else prefer names got from trace procedure + if (prefer_input_names && !name.empty()) { + config.name(name); + } else if (!var.name.empty()) { + config.name(var.name); + } else if (!name.empty()) { + config.name(name); + } + node = opr::Host2DeviceCopy::make(graph, host, {}, config).node(); + } + // make const node for each constant + for (size_t i = 0; i < vars.size(); ++i) { + auto& var = vars[i]; + auto& node = nodes[i]; + if (!node) { + if (var.kind != VarKind::Internal) { + if (!var.bound_data) { + continue; + } + if (!var.name.empty()) { + node = opr::ImmutableTensor::make( + graph, var.bound_data.numpy()->as_nd(), {var.name}) + .node(); + } else { + node = opr::ImmutableTensor::make( + graph, var.bound_data.numpy()->as_nd()) + .node(); + } + } + } + } + std::unordered_map> name2ops; + // iterate over opr_seq + for (auto&& item : seq) { + auto&& [op, inputs, outputs] = item; + VarNodeArray input_nodes; + for (auto&& input : inputs) { + auto& node = nodes[input]; + input_nodes.push_back(node); + } + VarNodeArray output_nodes; + if (op) { + if (auto* bn = op->try_cast_final()) { + mgb_assert( + bn->fwd_mode == BatchNorm::FwdMode::INFERENCE, + "can not dump BatchNorm in training mode, maybe you forget to " + "do model.eval()?"); + } + output_nodes = OpDef::apply_on_var_node(*op, input_nodes); + name2ops[output_nodes[0]->owner_opr()->name()].push_back( + output_nodes[0]->owner_opr()); + } else { + // no opr, just forward VarNode + mgb_assert( + inputs.size() == outputs.size(), + "output size not equals to input size when forwarding"); + output_nodes = input_nodes; + } + mgb_assert(output_nodes.size() == outputs.size(), "output size mismatch"); + for (size_t i = 0; i < outputs.size(); ++i) { + auto output = outputs[i]; + auto& var = vars[output]; + auto& node = nodes[output]; + mgb_assert(var.kind == VarKind::Internal, "output node should be internal"); + if (!node) { + node = output_nodes[i]; + } + if (!var.name.empty()) { + node->name(var.name); + } + } + } + for (auto&& [name, ops] : name2ops) { + if (ops.size() <= 1) { + continue; + } + // ops.size() > 1, need dedup (rename op) + for (size_t i = 0; i < ops.size(); ++i) { + auto& op = ops[i]; + auto new_name = ssprintf("%s[%zu]", name.c_str(), i); + for (auto&& output : op->output()) { + auto output_name = output->name(); + auto pos = output_name.find(name); + if (pos != std::string::npos) { + output_name.replace(pos, name.length(), new_name); + } + output->name(output_name); + } + op->name(new_name); + } + } + VarNodeArray output_nodes; + for (auto&& [output, name] : outputs) { + mgb_assert(output < vars.size(), "invalid output id %zu", output); + mgb_assert(nodes[output], "output node invalid"); + if (!name.empty()) { + nodes[output]->name(name); + } + output_nodes.push_back(nodes[output]); + } + return output_nodes; +} + +std::vector TracingTransformation::apply_transformation( + const Operator& op, Span inputs) { + if (auto* op_value = op.as()) { + SmallVector unwrapped_inputs; + SmallVector wrapped_inputs; + SmallVector input_ids; + for (auto input : inputs) { + auto tracing_value = input.as_ref(); + if (!tracing_value) { + tracing_value = + record_var(input, m_capture_as_const, VarKind::External); + } + unwrapped_inputs.push_back(tracing_value->value()); + wrapped_inputs.push_back(tracing_value); + input_ids.push_back(tracing_value->id()); + } + // TODO: remove OpDef::set_scope + auto scopes = Transformation::scopes(); + std::string scopes_join; + for (auto&& scope : scopes) { + if (!scopes_join.empty()) { + scopes_join.push_back('.'); + } + scopes_join.append(scope); + } + const_cast(op_value->op()).set_scope(scopes_join); + auto unwrapped_outputs = imperative::apply(op, unwrapped_inputs); + std::vector wrapped_outputs; + SmallVector output_ids; + for (auto&& output : unwrapped_outputs) { + auto wrapped_output = record_var(output, false, VarKind::Internal); + wrapped_outputs.push_back(wrapped_output); + output_ids.push_back(wrapped_output->id()); + } + m_seq.push_back({op_value->op().shared_from_this(), input_ids, output_ids}); + return wrapped_outputs; + } else if (auto* create_tensor = op.as()) { + auto outputs = imperative::apply(op, inputs); + if (create_tensor->kind() == CreateTensor::NoTrace) { + return outputs; + } + bool is_const = create_tensor->kind() == CreateTensor::Const; + auto wrapped_input = record_var( + outputs[0], is_const || m_capture_as_const, + is_const ? VarKind::Constant : VarKind::External); + auto wrapped_output = record_var(outputs[0], false, VarKind::Internal); + auto input_id = wrapped_input->id(); + auto output_id = wrapped_output->id(); + m_seq.push_back({{}, {input_id}, {output_id}}); + return {wrapped_output}; + } else if (auto* get_attr = op.as()) { + auto unwrapped_input = unwrap_var(inputs[0]); + auto outputs = imperative::apply(op, unwrapped_input); + if (auto* tracing_value = inputs[0].as()) { + auto& var_info = m_vars[tracing_value->id()]; + switch (get_attr->attr()) { + case GetAttr::Shape: + // TODO: reduce h2d when data or value is available + var_info.shape_required = true; + break; + case GetAttr::Data: + var_info.data_required = true; + break; + case GetAttr::Value: + var_info.value_required = true; + break; + default: + break; + } + } + return outputs; + } else if (auto* trace_mark_var = op.as()) { + mgb_assert(inputs.size() == 1, "TraceMarkVar expects exactly one input"); + auto input = inputs[0]; + auto tracing_var = input.as_ref(); + if (!tracing_var) { + bool is_input = trace_mark_var->mark().substr(0, 4) == "arg_" || + trace_mark_var->mark().substr(0, 6) == "kwarg_"; + if (is_input) { + tracing_var = record_var(input, false, VarKind::External); + } else { + tracing_var = record_var(input, m_capture_as_const, VarKind::External); + } + } else { + input = tracing_var->value(); + } + auto output = record_var(input, false, VarKind::Internal); + m_vars[output->id()].mark = trace_mark_var->mark(); + m_seq.push_back({{}, {tracing_var->id()}, {output->id()}}); + return {output}; + } else if (auto* trace_name_var = op.as()) { + mgb_assert(inputs.size() == 1, "RenameValue expects exactly one input"); + auto input = inputs[0]; + auto tracing_var = input.as_ref(); + if (!tracing_var) { + tracing_var = record_var(input, m_capture_as_const, VarKind::External); + } else { + input = tracing_var->value(); + } + auto output = record_var(input, false, VarKind::Internal); + m_vars[output->id()].name = trace_name_var->name(); + m_seq.push_back({{}, {tracing_var->id()}, {output->id()}}); + return {output}; + } else if (op.is()) { + mgb_assert(inputs.size() == 1, "GetName expects exactly one input"); + auto input = inputs[0]; + if (auto tracing_var = input.as_ref()) { + auto name = m_vars[tracing_var->id()].name; + if (!name.empty()) { + return {StringValue::make(name)}; + } else { + return {ValueRef()}; + } + } + return imperative::apply(op, inputs); + } else { + // TODO: handle DTRCommand and ... + return op.fallback(inputs); + } +} + +void TracingTransformation::on_unregister() noexcept { + for (auto&& weak_var : m_weak_vars) { + if (auto tracing_value = weak_var.lock()) { + auto& var_info = m_vars[tracing_value->id()]; + var_info.data_required = true; + tracing_value.reset(tracing_value->value()); + } + } + m_weak_vars.clear(); +} + +void CompiledTransformation::compile() { + // these ops require seq order, so we link them to an mm_io_link to ensure order + static std::unordered_set mm_io_ops = { + CollectiveComm::typeinfo(), RemoteSend::typeinfo(), RemoteRecv::typeinfo()}; + mgb_assert(!m_executable, "already compiled"); + // FIXME: mm_io_link and io_links should be merged + SymbolVarArray io_links; + SymbolVar mm_io_link; + auto make_input = [&](VarInfo* var_info) { + mgb_assert( + var_info->kind == VarKind::External, "input node should be external"); + VarAccessor accessor; + auto box = make_box(); + // TODO: attach ref count, release early + auto outputs = opr::InputCallback::make( + *m_graph, [box] { return box->take_value(); }, var_info->device, + var_info->dtype, var_info->shape, io_links, m_input_shape_static); + // attach input_callback to io_links + accessor.node = outputs[0].node(); + io_links = {outputs[1]}; + accessor.data_setter = [box](DeviceTensorND data) { box->try_set_value(data); }; + return accessor; + }; + auto make_output = [&](TraceResult::VarInfo* var_info, SymbolVar node) { + VarAccessor accessor; + accessor.node = node.node(); + if (var_info->shape_required) { + // TODO: use static infer manager for some vars? + auto box = make_box(); + auto callback = [box](DeviceTensorND data) { + box->try_set_value(data.shape()); + }; + SymbolVarArray inputs = io_links; + inputs.insert(inputs.begin(), node); + auto output = opr::OutputCallback::make({callback, true, false}, inputs); + io_links = {output}; + accessor.shape_getter = [box]() -> TensorShape { return box->get_value(); }; + } + if (var_info->data_required) { + auto box = make_box(); + auto callback = [box](DeviceTensorND data) { box->try_set_value(data); }; + SymbolVarArray inputs = io_links; + inputs.insert(inputs.begin(), node); + auto output = opr::OutputCallback::make({callback, false, false}, inputs); + io_links = {output}; + accessor.data_getter = [box]() -> DeviceTensorND { + return box->get_value(); + }; + } + if (var_info->value_required) { + struct ValueWithEvent { + HostTensorND value; + CompNode::Event* event = nullptr; + }; + auto box = make_box(); + auto event = EventPool::without_timer().alloc_shared(var_info->device); + auto callback = [box, event](DeviceTensorND data) { + HostTensorND host_val; + host_val.copy_from(data); + if (data.comp_node() != CompNode::default_cpu()) { + mgb_assert(data.comp_node() == event->comp_node()); + event->record(); + box->try_set_value({host_val, event.get()}); + } else { + box->try_set_value({host_val}); + } + }; + SymbolVarArray inputs = io_links; + inputs.insert(inputs.begin(), node); + auto output = opr::OutputCallback::make({callback, false, true}, inputs); + io_links = {output}; + accessor.value_getter = [box]() -> HostTensorND { + auto&& [value, event] = box->get_value(); + if (event) { + event->host_wait(); + } + return value; + }; + } + return accessor; + }; + auto make_const = [&](TraceResult::VarInfo* var_info) { + VarAccessor accessor; + mgb_assert( + var_info->kind == VarKind::Constant, "const node should be constant"); + HostTensorND host_val = var_info->bound_data.numpy()->as_nd(); + accessor.node = opr::ImmutableTensor::make(*m_graph, host_val).node(); + return accessor; + }; + std::vector var_accessors(m_vars.size()); + for (auto&& item : m_seq) { + bool require_link = bool(item.op) && mm_io_ops.count(item.op->dyn_typeinfo()); + VarNodeArray input_vars; + for (auto&& input : item.inputs) { + auto& var = m_vars[input]; + if (!var_accessors[input].node) { + switch (var.kind) { + case VarKind::External: + var_accessors[input] = make_input(&var); + break; + case VarKind::Constant: + var_accessors[input] = make_const(&var); + break; + default: + mgb_throw( + AssertionError, + "internal node should be valid when used as input"); + } + } + input_vars.push_back(var_accessors[input].node); + } + if (require_link && mm_io_link.node()) { + mgb_assert( + !input_vars.empty(), + "io-mm operator should have at least one input"); + input_vars[0] = + opr::VirtualDep::make({SymbolVar(input_vars[0]), mm_io_link}) + .node(); + } + VarNodeArray output_vars; + if (item.op) { + output_vars = OpDef::apply_on_var_node(*item.op, input_vars); + } else { + // forward inputs to outputs + mgb_assert( + item.inputs.size() == item.outputs.size(), + "output size not equals to input size when forwarding"); + for (auto&& input_var : input_vars) { + output_vars.push_back(input_var); + } + } + if (require_link) { + mgb_assert( + !item.outputs.empty(), + "io-mm operator should have at least one output"); + mm_io_link = SymbolVar(output_vars[0]); + } + // init output accessors + for (size_t i = 0; i < output_vars.size(); ++i) { + auto output = item.outputs[i]; + auto& node = output_vars[i]; + auto& var = m_vars[output]; + var_accessors[output] = make_output(&var, node); + } + } + ComputingGraph::OutputSpec output_specs; + // avoid input/output/callback from being optimized + for (auto&& io_link : io_links) { + output_specs.push_back({io_link, {}}); + } + // avoid remote io ops from being optimized + if (mm_io_link.node()) { + output_specs.push_back({mm_io_link, {}}); + } + { + // set_priority_to_id + // workaround for having mm_io_link and io_links separated + auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { + if (opr->node_prop().attribute().priority == 0) { + opr->node_prop().attribute().priority = opr->id(); + } + }; + mgb::cg::DepOprIter dep_iter{on_opr}; + for (const auto& output_spec : output_specs) { + dep_iter.add(output_spec.first); + } + } + m_executable = m_graph->compile(output_specs); + m_var_accessors = var_accessors; + m_output_spec = output_specs; +} + +void CompiledTransformation::recompile() { + mgb_assert(m_executable); + m_executable = m_graph->compile(m_output_spec); +} + +void CompiledTransformation::assert_tensor_equal(ValueRef lhs, ValueRef rhs) { + trace_assert(m_value_comparator(lhs, rhs), "tensors not equals"); +} + +void CompiledTransformation::trace_input(size_t id, ValueRef value) { + try { + auto& var = m_vars[id]; + auto& var_accessor = m_var_accessors[id]; + switch (var.kind) { + case VarKind::External: { + trace_assert( + !value.is(), "expect external node, got internal"); + if (var.bound_data) { + assert_tensor_equal(var.bound_data, value); + } else { + DType dtype = *value.dtype(); + CompNode device = *value.device(); + trace_assert( + var.dtype == dtype, "dtype mismatch: %s vs %s", + var.dtype.name(), dtype.name()); + trace_assert( + var.device == device, "comp_node mismatch: %s vs %s", + var.device.to_string().c_str(), device.to_string().c_str()); + } + var_accessor.data_setter(value.dev_tensor()->as_nd()); + break; + } + case VarKind::Constant: { + mgb_assert(var.bound_data, "const var without data bound"); + assert_tensor_equal(var.bound_data, value); + break; + } + case VarKind::Internal: { + trace_assert( + value.is(), "expect internal node, got external"); + auto& traced_value = value.cast(); + trace_assert(traced_value.id() == id, "input id mismatch"); + break; + } + } + } catch (TraceError&) { + throw; + } catch (...) { + mgb_assert(false, "unexpected error"); + } +} + +TracedValue::ref_t CompiledTransformation::trace_output(size_t id) { + auto traced_value = TracedValue::make(id); + m_weak_values.push_back(traced_value); + return traced_value; +} + +TraceResult::SeqItem& CompiledTransformation::next_instruction() { + trace_assert(m_pc < m_seq.size(), "too many instructions"); + return m_seq[m_pc++]; +} + +std::vector CompiledTransformation::apply_transformation( + const Operator& op, Span inputs) { + if (auto* op_value = op.as()) { + auto& item = next_instruction(); + SmallVector unwrapped_inputs; + SmallVector wrapped_inputs; + trace_assert(inputs.size() == item.inputs.size(), "input size mismatch"); + trace_assert(op_value->op().is_same(*item.op), "operator mismatch"); + for (size_t i = 0; i < inputs.size(); ++i) { + trace_input(item.inputs[i], inputs[i]); + } + std::vector outputs; + for (auto&& output_id : item.outputs) { + outputs.push_back(trace_output(output_id)); + } + return outputs; + } else if (auto* create_tensor = op.as()) { + if (create_tensor->kind() == CreateTensor::NoTrace) { + return imperative::apply(op, inputs); + } + auto& item = next_instruction(); + trace_assert(item.op == nullptr, "operator mismatch"); + auto input_id = item.inputs[0]; + auto output_id = item.outputs[0]; + auto tensor = imperative::apply(op, inputs)[0]; + trace_input(input_id, tensor); + return {trace_output(output_id)}; + } else if (auto* get_attr = op.as()) { + if (auto* traced_value = inputs[0].as()) { + ValueRef output; + auto& var = m_vars[traced_value->id()]; + auto& var_accessor = m_var_accessors[traced_value->id()]; + switch (get_attr->attr()) { + case GetAttr::Shape: + trace_assert(var_accessor.shape_getter, "shape unreadable"); + output = ShapeValue::make( + ValueShape::from(var_accessor.shape_getter())); + break; + case GetAttr::Data: + trace_assert(var_accessor.data_getter, "data unreadable"); + output = DeviceValue::make(var_accessor.data_getter()); + break; + case GetAttr::Value: + trace_assert(var_accessor.value_getter, "value unreadable"); + output = HostValue::make(var_accessor.value_getter()); + break; + case GetAttr::DType: + output = DTypeValue::make(var.dtype); + break; + case GetAttr::Device: + output = CompNodeValue::make(var.device); + default: + break; + } + return {output}; + } else { + return imperative::apply(op, inputs); + } + } else if (auto* trace_mark_var = op.as()) { + auto& item = next_instruction(); + trace_assert(item.op == nullptr, "operator mismatch"); + trace_assert(item.inputs.size() == 1, "inputs size mismatch"); + trace_assert(item.outputs.size() == 1, "inputs output mismatch"); + trace_input(item.inputs[0], inputs[0]); + trace_assert( + trace_mark_var->mark() == m_vars[item.outputs[0]].mark, + "mark mismatch"); + return {trace_output(item.outputs[0])}; + } else if (auto* trace_name_var = op.as()) { + auto& item = next_instruction(); + trace_assert(item.op == nullptr, "operator mismatch"); + trace_assert(item.inputs.size() == 1, "inputs size mismatch"); + trace_assert(item.outputs.size() == 1, "outputs size mismatch"); + trace_input(item.inputs[0], inputs[0]); + trace_assert( + trace_name_var->name() == m_vars[item.outputs[0]].name, + "name mismatch"); + return {trace_output(item.outputs[0])}; + } else { + return op.fallback(inputs); + } +} + +void CompiledTransformation::on_unregister() noexcept { + // resolve pending values + for (auto&& weak_value : m_weak_values) { + if (auto traced_value = weak_value.lock()) { + auto& var_accessor = m_var_accessors[traced_value->id()]; + auto value = ([&]() -> ValueRef { + try { + trace_assert(var_accessor.data_getter, "data unreadable"); + auto dev_value = DeviceValue::make(var_accessor.data_getter()); + return imperative::apply( + CreateTensor( + CreateTensor::Common, dev_value->device(), + dev_value->dtype(), dev_value->shape()), + DeviceStorage::make(dev_value->storage()))[0]; + } catch (...) { + set_exception(std::current_exception()); + return ErrorValue::make("trace exit failed"); + } + })(); + traced_value.reset(value); + } + } + m_weak_values.clear(); +} + +void CompiledTransformation::execute() { + mgb_assert(m_executable != nullptr); + m_graph_executor = std::thread([&] { + try { + m_executable->execute(); + m_executable->wait(); + } catch (...) { + auto exc = std::current_exception(); + set_exception(exc); + } + }); +} + +void CompiledTransformation::wait() { + try { + trace_assert(m_pc == m_seq.size(), "mismature end"); + } catch (...) { + } + mgb_assert(m_executable != nullptr); + m_graph_executor.join(); + m_graph_executor = {}; + for (auto&& box : m_boxes) { + box->reset(); + } + m_pc = 0; + std::exception_ptr graph_exc; + std::swap(m_graph_exc, graph_exc); + if (graph_exc) { + // graph with exception cannot be reused + recompile(); + std::rethrow_exception(graph_exc); + } +} + +std::exception_ptr CompiledTransformation::set_exception( + std::exception_ptr exc) noexcept { + MGB_LOCK_GUARD(m_mutex); + if (m_graph_exc) { + return m_graph_exc; + } + for (auto&& box : m_boxes) { + box->try_set_exception(exc); + } + m_graph_exc = exc; + return m_graph_exc; +} + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/transformations/trace.h b/imperative/src/include/megbrain/imperative/transformations/trace.h new file mode 100644 index 000000000..fceb3ec2f --- /dev/null +++ b/imperative/src/include/megbrain/imperative/transformations/trace.h @@ -0,0 +1,348 @@ +/** + * \file imperative/src/include/megbrain/imperative/trace.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include + +#include "megbrain/gopt/inference.h" +#include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/interpreter.h" +#include "megbrain/imperative/opr_utility.h" +#include "megbrain/imperative/utils/box.h" +#include "megbrain/imperative/utils/helper.h" +#include "megbrain/opr/io.h" +#include "megbrain/serialization/serializer.h" + +namespace mgb::imperative { + +struct TraceResult { + struct SeqItem { + std::shared_ptr op; + SmallVector inputs; + SmallVector outputs; + }; + + struct VarInfo { + enum Kind { + External, // End point of traced graph, its value is received from + // environment + Constant, // Also end point, but its value is constant in all executions, + // so we don't need to get from env every time, just capture it + Internal, // Not end point, produced by some op (or just forwarded) from + // op_seq + }; + + size_t id; + DType dtype; + CompNode device; + + // if exists, assert equal when meet + ValueRef bound_data; + std::string mark; + std::string name; + + Kind kind; + bool value_required = false; + bool data_required = false; + bool shape_required = false; + + TensorShape shape; + }; + + using VarKind = VarInfo::Kind; + + std::vector seq; + std::vector vars; + + /** + * \brief dump to mgb computing graph + * + * \param graph mgb computing graph + * \param inputs (input_id, input_name, input_shape) + * \param outputs (output_id, outupt_name) + * \param prefer_input_names + * \return VarNodeArray output nodes + */ + VarNodeArray dump( + ComputingGraph& graph, + std::vector> inputs, + std::vector> outputs, + bool prefer_input_names); +}; + +/** + * \brief mark an var as arg/kwarg/output + * + */ +class TraceMarkVar : public OperatorImpl { +private: + std::string m_mark; + +public: + TraceMarkVar(std::string mark) : m_mark(mark) {} + + std::string mark() const { return m_mark; } + + std::string to_string() const override { + return ssprintf("TraceMarkVar{mark=%s}", imperative::quoted(m_mark).c_str()); + } +}; + +class TracingInfo { +private: + ValueRef m_value = {}; + size_t m_id = 0; + +public: + TracingInfo() = default; + TracingInfo(ValueRef value, size_t id) : m_value(value), m_id(id) {} + ValueRef value() const { return m_value; } + size_t id() const { return m_id; } +}; + +class TracingValue final : public MixinValueImpl { +public: + using MixinValueImpl::MixinValueImpl; + + std::string to_string() const override { + return ssprintf( + "TracingValue{\"id\"=%zu, \"value\"=%s}", id(), + value().to_string().c_str()); + } + + void on_watch() override { value().watch(); } + + void on_unwatch() override { value().unwatch(); } +}; + +class TracedInfo { +private: + size_t m_id = 0; + +public: + TracedInfo() = default; + TracedInfo(size_t id) : m_id(id) {} + size_t id() const { return m_id; } +}; + +class TracedValue final : public MixinValueImpl { +public: + using MixinValueImpl::MixinValueImpl; + + std::string to_string() const override { + return ssprintf("TracedValue{\"id\"=%zu}", id()); + } +}; + +/** + * \brief trace operation sequence to TraceResult + * + * TracingTransformation records and forwards all operations to next layer, + * as if it's transparent. When execution ends, it exports an operation sequence, + * which is usually used to build CompiledTransformation. + */ +class TracingTransformation final : public Transformation { +public: + using VarInfo = TraceResult::VarInfo; + using VarKind = VarInfo::Kind; + +private: + std::vector m_seq; + std::vector m_vars; + std::vector m_weak_vars; + bool m_capture_as_const = false; + bool m_record_input_shapes = false; + +public: + TracingTransformation(bool capture_as_const, bool record_input_shapes) + : m_capture_as_const(capture_as_const), + m_record_input_shapes(record_input_shapes) {} + + /** + * \brief record values for trace + * + * \param value value to be traced + * \param capture whether capture value or not + * \param kind External, Constant or Internal + * \return TypedValueRef traced value + */ + TypedValueRef record_var(ValueRef value, bool capture, VarKind kind) { + size_t id = m_vars.size(); + auto wrapped_value = TracingValue::make(value, id); + m_vars.push_back({id, *value.dtype(), *value.device()}); + auto& var = m_vars.back(); + if (capture) { + var.bound_data = value; + } + var.kind = kind; + if (m_record_input_shapes && kind != VarKind::Internal) { + var.shape = value.shape()->as_tensor_shape(); + } + if (auto name = value.name()) { + var.name = *name; + } + m_weak_vars.push_back(wrapped_value); + return wrapped_value; + } + ValueRef unwrap_var(ValueRef value) { + if (auto* tracing_value = value.as()) { + return tracing_value->value(); + } + return value; + } + + std::vector apply_transformation( + const Operator& op, Span inputs) override; + + ValueRef unwrap(ValueRef value) override { + if (auto* tracing_value = value.as()) { + return tracing_value->value(); + } + return value; + } + + std::string name() const override { return "TracingTransformation"; } + + void on_unregister() noexcept override; + + TraceResult get_result() { return {m_seq, m_vars}; } +}; + +class TraceError : public std::exception { +private: + std::string m_message; + +public: + TraceError(std::string reason) { + m_message = ssprintf("trace error because %s", reason.c_str()); + } + const char* what() const noexcept override { return m_message.c_str(); } +}; + +/** + * \brief boost with traced result from TracingTransformation + * + * CompiledTransformation is built with an operation sequence. It compiles a megbrain + * graph with the sequence and handle operation requests with this graph. Besides that, + * it also checks that if current operation is same as previous one in seq. + */ +class CompiledTransformation final : public Transformation { +public: + using VarInfo = TraceResult::VarInfo; + using VarKind = VarInfo::Kind; + + struct VarAccessor { + VarNode* node; + std::function shape_getter; + std::function data_getter; + std::function value_getter; + std::function data_setter; + }; + +private: + std::vector m_seq; + std::vector m_vars; + std::vector m_var_accessors; + size_t m_pc = 0; + std::shared_ptr m_graph; + std::unique_ptr m_executable; + std::vector m_weak_values; + std::thread m_graph_executor; + std::function m_value_comparator; + bool m_input_shape_static; + std::mutex m_mutex; + std::exception_ptr m_graph_exc; + std::vector> m_boxes; + ComputingGraph::OutputSpec m_output_spec; + +public: + CompiledTransformation(TraceResult result, bool input_shape_static) + : m_seq(result.seq), + m_vars(result.vars), + m_input_shape_static(input_shape_static) { + m_graph = ComputingGraph::make(); + options().no_force_inplace = true; + options().async_exec_level = 0b100; + } + + ComputingGraph& graph() { return *m_graph; } + + ComputingGraph::Options& options() { return m_graph->options(); } + + /** + * \brief Set the value comparator object (usually from python) + * + * \param comparator + */ + void set_value_comparator(std::function comparator) { + m_value_comparator = comparator; + } + + void compile(); + + void recompile(); + + void assert_tensor_equal(ValueRef lhs, ValueRef rhs); + + /** + * \brief handle input for trace + * + * 1. For external, set input value to data_setter; + * 2. For const, do nothing; + * 3. For internal, assert var id; + * *. Always assert data equals if there are data bound. + * + * \param id + * \param value + */ + void trace_input(size_t id, ValueRef value); + + /** + * \brief make a placeholder for output. + * + * \param id trace_id + * \return TracedValue::ref_t output placeholder, would be reset to real value when + * trace exits + */ + TracedValue::ref_t trace_output(size_t id); + + TraceResult::SeqItem& next_instruction(); + + std::vector apply_transformation( + const Operator& op, Span inputs) override; + + void on_unregister() noexcept override; + + ValueRef unwrap(ValueRef value) override { + mgb_assert(!value.is()); + return value; + } + + std::string name() const override { return "CompiledTransformation"; } + + void execute(); + + void wait(); + + std::exception_ptr set_exception(std::exception_ptr exc) noexcept; + + template + std::shared_ptr> make_box() { + auto box = Box::make(); + m_boxes.push_back(box); + return box; + } +}; + +} // namespace mgb::imperative -- GitLab