提交 59084fa8 编写于 作者: M Megvii Engine Team

refactor(dispatch): implement lazy_eval

GitOrigin-RevId: 4e3f3a1c443fb91755bd2395adb83b2d6d3eda94
上级 d2b67c2a
/**
* \file imperative/src/impl/transformations/trace.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/transformations/lazy.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/utility.h"
#include "../async_releaser.h"
#include "../mgb_cg_impl.h"
namespace mgb {
namespace imperative {
std::vector<ValueRef> LazyEvalTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
if (auto* op_val = op.as<ApplyOp>()) {
static std::unordered_set<Typeinfo*> mm_io_ops = {
CollectiveComm::typeinfo(),
RemoteSend::typeinfo(),
RemoteRecv::typeinfo(),
};
bool require_link = mm_io_ops.count(op_val->op().dyn_typeinfo());
VarNodeArray input_nodes;
for (auto&& input : inputs) {
if (auto* input_node = input.as<LazyEvalValue>()) {
input_nodes.push_back(input_node->node());
} else {
// ImmutableTensor has empty shape issues
auto dev_val = input.dev_tensor()->as_nd();
auto dev_val_provider = [dev_val]() mutable {
return std::move(dev_val);
};
auto* node = opr::InputCallback::make(
*m_graph, dev_val_provider, *input.device(),
*input.dtype(), input.shape()->as_tensor_shape(),
{}, true)[0]
.node();
input_nodes.push_back(node);
}
}
if (require_link && m_io_link.node()) {
mgb_assert(!input_nodes.empty());
input_nodes[0] =
opr::VirtualDep::make({SymbolVar(input_nodes[0]), m_io_link})
.node();
}
VarNodeArray output_nodes = OpDef::apply_on_var_node(op_val->op(), input_nodes);
if (require_link) {
mgb_assert(!output_nodes.empty());
m_io_link = SymbolVar(output_nodes[0]);
}
std::vector<ValueRef> outputs;
for (auto&& output_node : output_nodes) {
outputs.push_back(record_var(output_node));
}
return outputs;
} else if (auto* create_tensor = op.as<CreateTensor>()) {
auto&& args = create_tensor->parse(inputs);
auto get_dev_val = [&] {
if (!args.device) {
mgb_assert(args.host);
args.device.emplace();
args.device->copy_from(*args.host);
// every h2d in imperative runtime should notify AsyncReleaser
AsyncReleaser::inst()->add(*args.host);
}
return *args.device;
};
if (args.kind == CreateTensor::Const) {
VarNode* node;
if (args.host) {
node = opr::ImmutableTensor::make(*m_graph, *args.host).node();
} else {
node = opr::SharedDeviceTensor::make(
*m_graph, std::make_shared<DeviceTensorND>(*args.device),
true, {})
.node();
}
if (m_no_exec) {
// TODO: record args instead of value
auto output = apply(op, inputs)[0];
auto name = output.name();
if (name) {
return {record_var(node, output, *name)};
} else {
return {record_var(node, output)};
}
} else {
return {record_var(node)};
}
} else {
// FIXME: reason for sync
auto dev_val = get_dev_val();
auto callback = [dev_val]() mutable -> DeviceTensorND {
return std::move(dev_val);
};
auto* node = opr::InputCallback::make(
*m_graph, callback, dev_val.comp_node(),
dev_val.dtype(), dev_val.shape(), {}, true)[0]
.node();
return {record_var(node)};
}
} else if (auto* get_attr = op.as<GetAttr>()) {
if (auto* lazy_val = inputs.item().as<LazyEvalValue>()) {
switch (get_attr->attr()) {
case GetAttr::DType:
return {DTypeValue::make(lazy_val->node()->dtype())};
case GetAttr::Device:
return {CompNodeValue::make(lazy_val->node()->comp_node())};
case GetAttr::Shape: {
if (!cg::is_static_var_shape(lazy_val->node())) {
mgb_log_debug("LazyEval: get_shape_failed");
return {ValueRef()};
}
auto shape = m_graph->static_infer_manager().infer_shape(
lazy_val->node());
return {ShapeValue::make(ValueShape::from(shape))};
}
case GetAttr::Value: {
if (!cg::is_static_var_value(lazy_val->node())) {
mgb_log_debug("LazyEval: get_value failed");
return {ValueRef()};
}
auto inferred_value = m_graph->static_infer_manager().infer_value(
lazy_val->node());
mgb_assert(inferred_value.comp_node() == CompNode::default_cpu());
HostTensorND host_value(
lazy_val->node()->comp_node(), lazy_val->node()->dtype());
host_value.copy_from(inferred_value);
// TODO: use proxy instead?
return {HostValue::make(host_value)};
}
case GetAttr::Data: {
if (!cg::is_static_var_value(lazy_val->node())) {
mgb_log_debug("LazyEval get_data failed");
return {ValueRef()};
}
auto inferred_value = m_graph->static_infer_manager().infer_value(
lazy_val->node());
mgb_assert(inferred_value.comp_node() == CompNode::default_cpu());
// TODO: use proxy instead?
HostTensorND host_value(
lazy_val->node()->comp_node(), lazy_val->node()->dtype());
host_value.copy_from(inferred_value);
DeviceTensorND dev_value;
dev_value.copy_from(host_value);
AsyncReleaser::inst()->add(host_value);
return {DeviceValue::make(dev_value)};
}
default:
mgb_throw(
MegBrainError, "LazyEval: malformed GetAttr: %s",
op.to_string().c_str());
}
} else {
return imperative::apply(op, inputs);
}
} else if (auto* rename_value = op.as<RenameValue>()) {
if (auto* lazy_val = inputs.item().as<LazyEvalValue>()) {
return {record_var(
lazy_val->node(), lazy_val->bound_data(), rename_value->name())};
} else {
return imperative::apply(op, inputs);
}
} else if (op.is<GetName>()) {
if (auto* lazy_val = inputs.item().as<LazyEvalValue>()) {
auto name = lazy_val->name();
if (!name.empty()) {
return {StringValue::make(lazy_val->name())};
} else {
return {ValueRef()};
}
} else {
return imperative::apply(op, inputs);
}
} else {
return op.fallback(inputs);
}
}
void LazyEvalTransformation::on_unregister() noexcept {
std::vector<LazyEvalValue::ref_t> lazy_vals;
for (auto&& weak_var : m_weak_vars) {
if (auto lazy_val = weak_var.lock()) {
lazy_vals.push_back(lazy_val);
}
}
CleanupGuard _{[this] {
m_graph.reset();
m_weak_vars.clear();
}};
if (m_no_exec) {
for (auto&& lazy_val : lazy_vals) {
if (lazy_val->bound_data()) {
auto value = lazy_val->bound_data();
lazy_val.reset(value);
} else {
lazy_val.reset(ErrorValue::make("no data bound"));
}
}
return;
}
std::mutex mtx;
std::vector<std::pair<LazyEvalValue::ref_t, DeviceTensorND>> values;
ComputingGraph::OutputSpec output_specs;
for (auto&& lazy_val : lazy_vals) {
auto* output = opr::OutputCallback::make(
{[lazy_val, &mtx, &values](DeviceTensorND data) {
MGB_LOCK_GUARD(mtx);
values.push_back({lazy_val, data});
}},
lazy_val->node())
.node();
output_specs.push_back({output, {}});
}
if (m_io_link.node()) {
output_specs.push_back({m_io_link, {}});
}
if (output_specs.empty()) {
return;
}
{
// set_priority_to_id
auto on_opr = [](mgb::cg::OperatorNodeBase* opr) {
if (opr->node_prop().attribute().priority == 0) {
opr->node_prop().attribute().priority = opr->id();
}
};
mgb::cg::DepOprIter dep_iter{on_opr};
for (auto&& output_spec : output_specs) {
dep_iter.add(output_spec.first);
}
}
try {
auto exectuble = m_graph->compile(output_specs);
exectuble->execute();
exectuble->wait();
} catch (...) {
m_graph_exc = std::current_exception();
}
for (auto&& [var, data] : values) {
var.reset(imperative::apply(
CreateTensor(CreateTensor::Common, data.comp_node(), data.layout()),
DeviceStorage::make(data.storage()))[0]);
}
for (auto&& lazy_val : lazy_vals) {
if (lazy_val.is<LazyEvalValue>()) {
std::string repr =
ssprintf("lazy eval failed for %s", lazy_val->to_string().c_str());
mgb_log_debug("%s", repr.c_str());
lazy_val.reset(ErrorValue::make(repr.c_str()));
}
}
}
void LazyEvalTransformation::check_exception() {
if (m_graph_exc) {
std::rethrow_exception(m_graph_exc);
}
}
} // namespace imperative
} // namespace mgb
/**
* \file imperative/src/include/megbrain/imperative/lazy.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <future>
#include <variant>
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/opr/io.h"
namespace mgb::imperative {
class LazyEvalInfo {
private:
VarNode* m_node = nullptr;
ValueRef m_bound_data;
std::string m_name;
public:
LazyEvalInfo() = default;
LazyEvalInfo(VarNode* node, ValueRef bound_data, std::string name)
: m_node(node), m_bound_data(bound_data), m_name(name) {}
VarNode* node() const { return m_node; }
ValueRef bound_data() const { return m_bound_data; }
std::string name() const { return m_name; }
};
class LazyEvalValue final : public MixinValueImpl<LazyEvalValue, LazyEvalInfo> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override {
return ssprintf(
"LazyEvalValue{node=%p, name=%s}", node(), node()->name().c_str());
}
};
/**
* \brief lazy evaluate on megbrain graph
*
* 1. Make a varnode for each external value (HoostToDeviceCopy/ImmutableTensor);
* 2. Invoke apply_on_var_node when handling ApplyOp, return LazyEvalValue(VarNode) as
* stub;
* 3. Try infer value/shape when handling GetAttr;
* 4. Compile and execute graph, get values and replace LazyEvalValues by concrete
* values.
*/
class LazyEvalTransformation final : public Transformation {
private:
bool m_no_exec;
std::shared_ptr<ComputingGraph> m_graph;
std::vector<LazyEvalValue::weak_ref_t> m_weak_vars;
SymbolVar m_io_link = nullptr;
std::exception_ptr m_graph_exc;
public:
LazyEvalTransformation(bool no_exec) : m_no_exec(no_exec) {
m_graph = ComputingGraph::make();
}
LazyEvalValue::ref_t record_var(
VarNode* node, ValueRef bound_data = {}, std::string name = {}) {
auto lazy_eval_val = LazyEvalValue::make(node, bound_data, name);
m_weak_vars.push_back(lazy_eval_val);
return lazy_eval_val;
}
ComputingGraph::Options& options() { return m_graph->options(); }
std::vector<ValueRef> apply_transformation(
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is<LazyEvalValue>());
return value;
}
std::string name() const override { return "LazyEvalTransformation"; }
void on_unregister() noexcept override;
void check_exception();
};
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册