From c609c031f19260ce53f3808c1df014d0965ecb57 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 14 Jan 2022 13:23:14 +0800 Subject: [PATCH] refactor(dispatch): implement symbol GitOrigin-RevId: c7bd86f5c1a412205c801226fdd8ba01ce02d4c1 --- .../imperative/transformations/symbol.h | 131 ++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 imperative/src/include/megbrain/imperative/transformations/symbol.h diff --git a/imperative/src/include/megbrain/imperative/transformations/symbol.h b/imperative/src/include/megbrain/imperative/transformations/symbol.h new file mode 100644 index 000000000..2032ef42a --- /dev/null +++ b/imperative/src/include/megbrain/imperative/transformations/symbol.h @@ -0,0 +1,131 @@ +/** + * \file imperative/src/include/megbrain/imperative/symbol.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include + +#include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/interpreter.h" +#include "megbrain/imperative/opr_utility.h" +#include "megbrain/imperative/utils/helper.h" +#include "megbrain/opr/io.h" + +namespace mgb::imperative { + +class SymbolValue final : public ValueImpl { +private: + VarNode* m_node = nullptr; + +public: + SymbolValue(VarNode* node) : m_node(node) {} + + VarNode* node() const { return m_node; } + + std::string to_string() const override { return ssprintf("VarNode{%p}", m_node); } + + void clear() override { m_node = nullptr; } +}; + +/** + * \brief this transformation is used to handle VarNode. + * + * Unlike other transformations, this transformation is not used in Tensor evaluation. + * when user calls py_apply(SymbolVar), we'll switch current transformation context to a + * special symbol context. The advantage is that we can handle scalar by + * ScalarTransformation. + */ +class SymbolTransformation final : public Transformation { +private: + ComputingGraph* m_graph = nullptr; + +public: + SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {} + std::vector apply_transformation( + const Operator& op, Span inputs) override { + if (auto* apply_op = op.as()) { + SmallVector input_nodes; + for (auto&& input : inputs) { + input_nodes.push_back(input.cast().node()); + } + auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes); + std::vector outputs; + for (auto&& output_node : output_nodes) { + outputs.push_back(SymbolValue::make(output_node)); + } + return outputs; + } else if (auto* create_tensor = op.as()) { + auto&& args = create_tensor->parse(inputs); + mgb_assert( + args.kind == CreateTensor::Const, + "only const value is allowed here"); + auto* node = opr::ImmutableTensor::make(*m_graph, *args.host, {}).node(); + return {SymbolValue::make(node)}; + } else if (auto* get_attr = op.as()) { + auto* node = inputs.as_array<1>()[0].cast().node(); + switch (get_attr->attr()) { + case GetAttr::DType: + return {DTypeValue::make(node->dtype())}; + case GetAttr::Device: + return {CompNodeValue::make(node->comp_node())}; + case GetAttr::Shape: { + if (!cg::is_static_var_shape(node)) { + mgb_log_debug( + "shape inference invalid for %s", node->name().c_str()); + return {ValueRef()}; + } + auto shape = m_graph->static_infer_manager().infer_shape(node); + return {ShapeValue::make(ValueShape::from(shape))}; + } + case GetAttr::Value: { + if (!cg::is_static_var_value(node)) { + mgb_log_debug( + "value inference invalid for %s", node->name().c_str()); + return {ValueRef()}; + } + auto inferred_value = + m_graph->static_infer_manager().infer_value(node); + HostTensorND host_value(node->comp_node(), node->dtype()); + host_value.copy_from(inferred_value); + return {HostValue::make(host_value)}; + } + case GetAttr::Data: { + if (!cg::is_static_var_value(node)) { + mgb_log_debug( + "value inference invalid for %s", node->name().c_str()); + return {ValueRef()}; + } + auto inferred_value = + m_graph->static_infer_manager().infer_value(node); + DeviceTensorND dev_value(node->comp_node(), node->dtype()); + dev_value.copy_from(inferred_value); + return {DeviceValue::make(dev_value)}; + } + default: + mgb_throw( + MegBrainError, "Symbol: malformed GetAttr: %s", + op.to_string().c_str()); + } + } else { + return op.fallback(inputs); + } + } + + ValueRef unwrap(ValueRef value) override { + mgb_assert(!value.is(), "SymbolValue doesn't support unwrap"); + return value; + } + + std::string name() const override { return "SymbolTransformation"; } +}; + +} // namespace mgb::imperative -- GitLab