From 59084fa857010a5d1f6ff7b448ac2875eff9151b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 14 Jan 2022 13:22:11 +0800 Subject: [PATCH] refactor(dispatch): implement lazy_eval GitOrigin-RevId: 4e3f3a1c443fb91755bd2395adb83b2d6d3eda94 --- imperative/src/impl/transformations/lazy.cpp | 273 ++++++++++++++++++ .../imperative/transformations/lazy.h | 99 +++++++ 2 files changed, 372 insertions(+) create mode 100644 imperative/src/impl/transformations/lazy.cpp create mode 100644 imperative/src/include/megbrain/imperative/transformations/lazy.h diff --git a/imperative/src/impl/transformations/lazy.cpp b/imperative/src/impl/transformations/lazy.cpp new file mode 100644 index 000000000..332645da3 --- /dev/null +++ b/imperative/src/impl/transformations/lazy.cpp @@ -0,0 +1,273 @@ +/** + * \file imperative/src/impl/transformations/trace.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/transformations/lazy.h" +#include "megbrain/imperative/opr_utility.h" +#include "megbrain/imperative/ops/autogen.h" + +#include "megbrain/opr/utility.h" + +#include "../async_releaser.h" +#include "../mgb_cg_impl.h" + +namespace mgb { +namespace imperative { + +std::vector LazyEvalTransformation::apply_transformation( + const Operator& op, Span inputs) { + if (auto* op_val = op.as()) { + static std::unordered_set mm_io_ops = { + CollectiveComm::typeinfo(), + RemoteSend::typeinfo(), + RemoteRecv::typeinfo(), + }; + bool require_link = mm_io_ops.count(op_val->op().dyn_typeinfo()); + VarNodeArray input_nodes; + for (auto&& input : inputs) { + if (auto* input_node = input.as()) { + input_nodes.push_back(input_node->node()); + } else { + // ImmutableTensor has empty shape issues + auto dev_val = input.dev_tensor()->as_nd(); + auto dev_val_provider = [dev_val]() mutable { + return std::move(dev_val); + }; + auto* node = opr::InputCallback::make( + *m_graph, dev_val_provider, *input.device(), + *input.dtype(), input.shape()->as_tensor_shape(), + {}, true)[0] + .node(); + input_nodes.push_back(node); + } + } + if (require_link && m_io_link.node()) { + mgb_assert(!input_nodes.empty()); + input_nodes[0] = + opr::VirtualDep::make({SymbolVar(input_nodes[0]), m_io_link}) + .node(); + } + VarNodeArray output_nodes = OpDef::apply_on_var_node(op_val->op(), input_nodes); + if (require_link) { + mgb_assert(!output_nodes.empty()); + m_io_link = SymbolVar(output_nodes[0]); + } + std::vector outputs; + for (auto&& output_node : output_nodes) { + outputs.push_back(record_var(output_node)); + } + return outputs; + } else if (auto* create_tensor = op.as()) { + auto&& args = create_tensor->parse(inputs); + auto get_dev_val = [&] { + if (!args.device) { + mgb_assert(args.host); + args.device.emplace(); + args.device->copy_from(*args.host); + // every h2d in imperative runtime should notify AsyncReleaser + AsyncReleaser::inst()->add(*args.host); + } + return *args.device; + }; + if (args.kind == CreateTensor::Const) { + VarNode* node; + if (args.host) { + node = opr::ImmutableTensor::make(*m_graph, *args.host).node(); + } else { + node = opr::SharedDeviceTensor::make( + *m_graph, std::make_shared(*args.device), + true, {}) + .node(); + } + if (m_no_exec) { + // TODO: record args instead of value + auto output = apply(op, inputs)[0]; + auto name = output.name(); + if (name) { + return {record_var(node, output, *name)}; + } else { + return {record_var(node, output)}; + } + } else { + return {record_var(node)}; + } + } else { + // FIXME: reason for sync + auto dev_val = get_dev_val(); + auto callback = [dev_val]() mutable -> DeviceTensorND { + return std::move(dev_val); + }; + auto* node = opr::InputCallback::make( + *m_graph, callback, dev_val.comp_node(), + dev_val.dtype(), dev_val.shape(), {}, true)[0] + .node(); + return {record_var(node)}; + } + } else if (auto* get_attr = op.as()) { + if (auto* lazy_val = inputs.item().as()) { + switch (get_attr->attr()) { + case GetAttr::DType: + return {DTypeValue::make(lazy_val->node()->dtype())}; + case GetAttr::Device: + return {CompNodeValue::make(lazy_val->node()->comp_node())}; + case GetAttr::Shape: { + if (!cg::is_static_var_shape(lazy_val->node())) { + mgb_log_debug("LazyEval: get_shape_failed"); + return {ValueRef()}; + } + auto shape = m_graph->static_infer_manager().infer_shape( + lazy_val->node()); + return {ShapeValue::make(ValueShape::from(shape))}; + } + case GetAttr::Value: { + if (!cg::is_static_var_value(lazy_val->node())) { + mgb_log_debug("LazyEval: get_value failed"); + return {ValueRef()}; + } + auto inferred_value = m_graph->static_infer_manager().infer_value( + lazy_val->node()); + mgb_assert(inferred_value.comp_node() == CompNode::default_cpu()); + HostTensorND host_value( + lazy_val->node()->comp_node(), lazy_val->node()->dtype()); + host_value.copy_from(inferred_value); + // TODO: use proxy instead? + return {HostValue::make(host_value)}; + } + case GetAttr::Data: { + if (!cg::is_static_var_value(lazy_val->node())) { + mgb_log_debug("LazyEval get_data failed"); + return {ValueRef()}; + } + auto inferred_value = m_graph->static_infer_manager().infer_value( + lazy_val->node()); + mgb_assert(inferred_value.comp_node() == CompNode::default_cpu()); + // TODO: use proxy instead? + HostTensorND host_value( + lazy_val->node()->comp_node(), lazy_val->node()->dtype()); + host_value.copy_from(inferred_value); + DeviceTensorND dev_value; + dev_value.copy_from(host_value); + AsyncReleaser::inst()->add(host_value); + return {DeviceValue::make(dev_value)}; + } + default: + mgb_throw( + MegBrainError, "LazyEval: malformed GetAttr: %s", + op.to_string().c_str()); + } + } else { + return imperative::apply(op, inputs); + } + } else if (auto* rename_value = op.as()) { + if (auto* lazy_val = inputs.item().as()) { + return {record_var( + lazy_val->node(), lazy_val->bound_data(), rename_value->name())}; + } else { + return imperative::apply(op, inputs); + } + } else if (op.is()) { + if (auto* lazy_val = inputs.item().as()) { + auto name = lazy_val->name(); + if (!name.empty()) { + return {StringValue::make(lazy_val->name())}; + } else { + return {ValueRef()}; + } + } else { + return imperative::apply(op, inputs); + } + } else { + return op.fallback(inputs); + } +} + +void LazyEvalTransformation::on_unregister() noexcept { + std::vector lazy_vals; + for (auto&& weak_var : m_weak_vars) { + if (auto lazy_val = weak_var.lock()) { + lazy_vals.push_back(lazy_val); + } + } + CleanupGuard _{[this] { + m_graph.reset(); + m_weak_vars.clear(); + }}; + if (m_no_exec) { + for (auto&& lazy_val : lazy_vals) { + if (lazy_val->bound_data()) { + auto value = lazy_val->bound_data(); + lazy_val.reset(value); + } else { + lazy_val.reset(ErrorValue::make("no data bound")); + } + } + return; + } + std::mutex mtx; + std::vector> values; + ComputingGraph::OutputSpec output_specs; + for (auto&& lazy_val : lazy_vals) { + auto* output = opr::OutputCallback::make( + {[lazy_val, &mtx, &values](DeviceTensorND data) { + MGB_LOCK_GUARD(mtx); + values.push_back({lazy_val, data}); + }}, + lazy_val->node()) + .node(); + output_specs.push_back({output, {}}); + } + if (m_io_link.node()) { + output_specs.push_back({m_io_link, {}}); + } + if (output_specs.empty()) { + return; + } + { + // set_priority_to_id + auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { + if (opr->node_prop().attribute().priority == 0) { + opr->node_prop().attribute().priority = opr->id(); + } + }; + mgb::cg::DepOprIter dep_iter{on_opr}; + for (auto&& output_spec : output_specs) { + dep_iter.add(output_spec.first); + } + } + try { + auto exectuble = m_graph->compile(output_specs); + exectuble->execute(); + exectuble->wait(); + } catch (...) { + m_graph_exc = std::current_exception(); + } + for (auto&& [var, data] : values) { + var.reset(imperative::apply( + CreateTensor(CreateTensor::Common, data.comp_node(), data.layout()), + DeviceStorage::make(data.storage()))[0]); + } + for (auto&& lazy_val : lazy_vals) { + if (lazy_val.is()) { + std::string repr = + ssprintf("lazy eval failed for %s", lazy_val->to_string().c_str()); + mgb_log_debug("%s", repr.c_str()); + lazy_val.reset(ErrorValue::make(repr.c_str())); + } + } +} + +void LazyEvalTransformation::check_exception() { + if (m_graph_exc) { + std::rethrow_exception(m_graph_exc); + } +} + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/transformations/lazy.h b/imperative/src/include/megbrain/imperative/transformations/lazy.h new file mode 100644 index 000000000..6a29f41fc --- /dev/null +++ b/imperative/src/include/megbrain/imperative/transformations/lazy.h @@ -0,0 +1,99 @@ +/** + * \file imperative/src/include/megbrain/imperative/lazy.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include + +#include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/interpreter.h" +#include "megbrain/imperative/opr_utility.h" +#include "megbrain/imperative/utils/helper.h" +#include "megbrain/opr/io.h" + +namespace mgb::imperative { + +class LazyEvalInfo { +private: + VarNode* m_node = nullptr; + ValueRef m_bound_data; + std::string m_name; + +public: + LazyEvalInfo() = default; + LazyEvalInfo(VarNode* node, ValueRef bound_data, std::string name) + : m_node(node), m_bound_data(bound_data), m_name(name) {} + VarNode* node() const { return m_node; } + + ValueRef bound_data() const { return m_bound_data; } + + std::string name() const { return m_name; } +}; + +class LazyEvalValue final : public MixinValueImpl { +public: + using MixinValueImpl::MixinValueImpl; + + std::string to_string() const override { + return ssprintf( + "LazyEvalValue{node=%p, name=%s}", node(), node()->name().c_str()); + } +}; + +/** + * \brief lazy evaluate on megbrain graph + * + * 1. Make a varnode for each external value (HoostToDeviceCopy/ImmutableTensor); + * 2. Invoke apply_on_var_node when handling ApplyOp, return LazyEvalValue(VarNode) as + * stub; + * 3. Try infer value/shape when handling GetAttr; + * 4. Compile and execute graph, get values and replace LazyEvalValues by concrete + * values. + */ +class LazyEvalTransformation final : public Transformation { +private: + bool m_no_exec; + std::shared_ptr m_graph; + std::vector m_weak_vars; + SymbolVar m_io_link = nullptr; + std::exception_ptr m_graph_exc; + +public: + LazyEvalTransformation(bool no_exec) : m_no_exec(no_exec) { + m_graph = ComputingGraph::make(); + } + + LazyEvalValue::ref_t record_var( + VarNode* node, ValueRef bound_data = {}, std::string name = {}) { + auto lazy_eval_val = LazyEvalValue::make(node, bound_data, name); + m_weak_vars.push_back(lazy_eval_val); + return lazy_eval_val; + } + + ComputingGraph::Options& options() { return m_graph->options(); } + + std::vector apply_transformation( + const Operator& op, Span inputs) override; + + ValueRef unwrap(ValueRef value) override { + mgb_assert(!value.is()); + return value; + } + + std::string name() const override { return "LazyEvalTransformation"; } + + void on_unregister() noexcept override; + + void check_exception(); +}; + +} // namespace mgb::imperative -- GitLab