refactor(dispatch): implement symbol

GitOrigin-RevId: c7bd86f5c1a412205c801226fdd8ba01ce02d4c1
上级 e32929df
/**
* \file imperative/src/include/megbrain/imperative/symbol.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <future>
#include <variant>
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/opr/io.h"
namespace mgb::imperative {
class SymbolValue final : public ValueImpl<SymbolValue> {
private:
VarNode* m_node = nullptr;
public:
SymbolValue(VarNode* node) : m_node(node) {}
VarNode* node() const { return m_node; }
std::string to_string() const override { return ssprintf("VarNode{%p}", m_node); }
void clear() override { m_node = nullptr; }
};
/**
* \brief this transformation is used to handle VarNode.
*
* Unlike other transformations, this transformation is not used in Tensor evaluation.
* when user calls py_apply(SymbolVar), we'll switch current transformation context to a
* special symbol context. The advantage is that we can handle scalar by
* ScalarTransformation.
*/
class SymbolTransformation final : public Transformation {
private:
ComputingGraph* m_graph = nullptr;
public:
SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {}
std::vector<ValueRef> apply_transformation(
const Operator& op, Span<ValueRef> inputs) override {
if (auto* apply_op = op.as<ApplyOp>()) {
SmallVector<VarNode*> input_nodes;
for (auto&& input : inputs) {
input_nodes.push_back(input.cast<SymbolValue>().node());
}
auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes);
std::vector<ValueRef> outputs;
for (auto&& output_node : output_nodes) {
outputs.push_back(SymbolValue::make(output_node));
}
return outputs;
} else if (auto* create_tensor = op.as<CreateTensor>()) {
auto&& args = create_tensor->parse(inputs);
mgb_assert(
args.kind == CreateTensor::Const,
"only const value is allowed here");
auto* node = opr::ImmutableTensor::make(*m_graph, *args.host, {}).node();
return {SymbolValue::make(node)};
} else if (auto* get_attr = op.as<GetAttr>()) {
auto* node = inputs.as_array<1>()[0].cast<SymbolValue>().node();
switch (get_attr->attr()) {
case GetAttr::DType:
return {DTypeValue::make(node->dtype())};
case GetAttr::Device:
return {CompNodeValue::make(node->comp_node())};
case GetAttr::Shape: {
if (!cg::is_static_var_shape(node)) {
mgb_log_debug(
"shape inference invalid for %s", node->name().c_str());
return {ValueRef()};
}
auto shape = m_graph->static_infer_manager().infer_shape(node);
return {ShapeValue::make(ValueShape::from(shape))};
}
case GetAttr::Value: {
if (!cg::is_static_var_value(node)) {
mgb_log_debug(
"value inference invalid for %s", node->name().c_str());
return {ValueRef()};
}
auto inferred_value =
m_graph->static_infer_manager().infer_value(node);
HostTensorND host_value(node->comp_node(), node->dtype());
host_value.copy_from(inferred_value);
return {HostValue::make(host_value)};
}
case GetAttr::Data: {
if (!cg::is_static_var_value(node)) {
mgb_log_debug(
"value inference invalid for %s", node->name().c_str());
return {ValueRef()};
}
auto inferred_value =
m_graph->static_infer_manager().infer_value(node);
DeviceTensorND dev_value(node->comp_node(), node->dtype());
dev_value.copy_from(inferred_value);
return {DeviceValue::make(dev_value)};
}
default:
mgb_throw(
MegBrainError, "Symbol: malformed GetAttr: %s",
op.to_string().c_str());
}
} else {
return op.fallback(inputs);
}
}
ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is<SymbolValue>(), "SymbolValue doesn't support unwrap");
return value;
}
std::string name() const override { return "SymbolTransformation"; }
};
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部