#include "megbrain/imperative/transformations/lazy.h" #include "megbrain/imperative/opr_utility.h" #include "megbrain/imperative/ops/autogen.h" #include "megbrain/opr/utility.h" #include "../mgb_cg_impl.h" namespace mgb { namespace imperative { ValueRefList LazyEvalTransformation::apply_transformation( const Operator& op, Span inputs) { if (auto* op_val = op.as()) { static std::unordered_set mm_io_ops = { CollectiveComm::typeinfo(), RemoteSend::typeinfo(), RemoteRecv::typeinfo(), }; bool require_link = mm_io_ops.count(op_val->op().dyn_typeinfo()); VarNodeArray input_nodes; for (auto&& input : inputs) { if (auto* input_node = input.as(m_value_type)) { input_nodes.push_back(input_node->node()); } else { // ImmutableTensor has empty shape issues auto dev_val = input.dev_tensor()->as_nd(); auto dev_val_provider = [dev_val]() mutable { return std::move(dev_val); }; auto* node = opr::InputCallback::make( *m_graph, dev_val_provider, *input.device(), *input.dtype(), input.shape()->as_tensor_shape(), {}, true)[0] .node(); input_nodes.push_back(node); } } if (require_link && m_io_link.node()) { mgb_assert(!input_nodes.empty()); auto comp_node = m_io_link.node()->comp_node(); input_nodes[0] = opr::VirtualDep::make( {SymbolVar(input_nodes[0]), m_io_link}, comp_node) .node(); } VarNodeArray output_nodes = OpDef::apply_on_var_node(op_val->op(), input_nodes); if (require_link) { mgb_assert(!output_nodes.empty()); m_io_link = SymbolVar(output_nodes[0]); } ValueRefList outputs(output_nodes.size()); for (size_t i = 0; i < output_nodes.size(); ++i) { outputs[i] = record_var(output_nodes[i]); } return outputs; } else if (op.is()) { int64_t id = inputs[0].id(); return {IntegerValue::make(id)}; } else if (auto* create_tensor = op.as()) { auto&& args = create_tensor->parse(inputs); auto get_dev_val = [&] { if (!args.device) { mgb_assert(args.host); args.device.emplace(); args.device->copy_from(*args.host); // every h2d in imperative runtime should notify AsyncReleaser async_release(*args.host); } return *args.device; }; if (args.kind == CreateTensor::Const) { VarNode* node; if (args.host) { node = opr::ImmutableTensor::make(*m_graph, *args.host).node(); } else { node = opr::SharedDeviceTensor::make( *m_graph, std::make_shared(*args.device), true, {}) .node(); } if (m_no_exec) { // TODO: record args instead of value auto output = apply(op, inputs)[0]; auto name = output.name(); if (name) { return {record_var(node, output, *name)}; } else { return {record_var(node, output)}; } } else { return {record_var(node)}; } } else { // FIXME: reason for sync auto dev_val = get_dev_val(); auto callback = [dev_val]() mutable -> DeviceTensorND { return std::move(dev_val); }; auto* node = opr::InputCallback::make( *m_graph, callback, dev_val.comp_node(), dev_val.dtype(), dev_val.shape(), {}, true)[0] .node(); return {record_var(node)}; } } else if (auto* get_attr = op.as()) { if (auto* lazy_val = inputs.item().as(m_value_type)) { switch (get_attr->attr()) { case GetAttr::DType: return {DTypeValue::make(lazy_val->node()->dtype())}; case GetAttr::Device: return {CompNodeValue::make(lazy_val->node()->comp_node())}; case GetAttr::Shape: { if (!cg::is_static_var_shape(lazy_val->node())) { mgb_log_debug("LazyEval: get_shape_failed"); return {ValueRef()}; } auto shape = m_graph->static_infer_manager().infer_shape( lazy_val->node()); return {ShapeValue::make(ValueShape::from(shape))}; } case GetAttr::Value: { if (!cg::is_static_var_value(lazy_val->node())) { mgb_log_debug("LazyEval: get_value failed"); return {ValueRef()}; } auto inferred_value = m_graph->static_infer_manager().infer_value( lazy_val->node()); mgb_assert(inferred_value.comp_node() == CompNode::default_cpu()); HostTensorND host_value( lazy_val->node()->comp_node(), lazy_val->node()->dtype()); host_value.copy_from(inferred_value); // TODO: use proxy instead? return {HostValue::make(host_value)}; } case GetAttr::Data: { if (!cg::is_static_var_value(lazy_val->node())) { mgb_log_debug("LazyEval get_data failed"); return {ValueRef()}; } auto inferred_value = m_graph->static_infer_manager().infer_value( lazy_val->node()); mgb_assert(inferred_value.comp_node() == CompNode::default_cpu()); // TODO: use proxy instead? HostTensorND host_value( lazy_val->node()->comp_node(), lazy_val->node()->dtype()); host_value.copy_from(inferred_value); DeviceTensorND dev_value; dev_value.copy_from(host_value); async_release(host_value); return {DeviceValue::make(dev_value)}; } default: mgb_throw( MegBrainError, "LazyEval: malformed GetAttr: %s", op.to_string().c_str()); } } else { return imperative::apply(op, inputs); } } else if (auto* rename_value = op.as()) { if (auto* lazy_val = inputs.item().as(m_value_type)) { return {record_var( lazy_val->node(), lazy_val->bound_data(), rename_value->name())}; } else { return imperative::apply(op, inputs); } } else if (op.is()) { if (auto* lazy_val = inputs.item().as(m_value_type)) { auto name = lazy_val->name(); if (!name.empty()) { return {StringValue::make(lazy_val->name())}; } else { return {ValueRef()}; } } else { return imperative::apply(op, inputs); } } else if (op.is() || op.is()) { return {}; } else { return op.fallback(inputs); } } void LazyEvalTransformation::on_unregister() noexcept { std::vector lazy_vals; for (auto&& weak_var : m_weak_vars) { if (auto lazy_val = weak_var.lock()) { lazy_vals.push_back(lazy_val); } } CleanupGuard _{[this] { m_graph.reset(); m_weak_vars.clear(); }}; if (m_no_exec) { for (auto&& lazy_val : lazy_vals) { if (lazy_val->bound_data()) { auto value = lazy_val->bound_data(); lazy_val.reset(value); } else { lazy_val.reset(ErrorValue::make("no data bound")); } } return; } std::mutex mtx; std::vector> values; ComputingGraph::OutputSpec output_specs; for (auto&& lazy_val : lazy_vals) { auto* output = opr::OutputCallback::make( {[lazy_val, &mtx, &values](DeviceTensorND data) { MGB_LOCK_GUARD(mtx); values.push_back({lazy_val, data}); }}, lazy_val->node()) .node(); output_specs.push_back({output, {}}); } if (m_io_link.node()) { output_specs.push_back({m_io_link, {}}); } if (output_specs.empty()) { return; } { // set_priority_to_id auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { if (opr->node_prop().attribute().priority == 0) { opr->node_prop().attribute().priority = opr->id(); } }; mgb::cg::DepOprIter dep_iter{on_opr}; for (auto&& output_spec : output_specs) { dep_iter.add(output_spec.first); } } try { auto exectuble = m_graph->compile(output_specs); exectuble->execute(); exectuble->wait(); } catch (...) { m_graph_exc = std::current_exception(); } for (auto&& [var, data] : values) { var.reset(imperative::apply( CreateTensor(CreateTensor::Common, data.comp_node(), data.layout()), DeviceStorage::make(data.storage()))[0]); } for (auto&& lazy_val : lazy_vals) { if (lazy_val.is(m_value_type)) { std::string repr = ssprintf("lazy eval failed for %s", lazy_val->to_string().c_str()); mgb_log_debug("%s", repr.c_str()); lazy_val.reset(ErrorValue::make(repr.c_str())); } } } void LazyEvalTransformation::check_exception() { if (m_graph_exc) { std::rethrow_exception(m_graph_exc); } } } // namespace imperative } // namespace mgb