From 0213dbe55659bad5dba28602bc1efd2dc624f93b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 2 Aug 2021 14:49:57 +0800 Subject: [PATCH] feat(subgraph): add graph builder GitOrigin-RevId: f32cfc39e071da45068b4abf9aa585e6d1474150 --- .../megbrain/imperative/graph_builder.h | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 imperative/src/include/megbrain/imperative/graph_builder.h diff --git a/imperative/src/include/megbrain/imperative/graph_builder.h b/imperative/src/include/megbrain/imperative/graph_builder.h new file mode 100644 index 000000000..05185c60a --- /dev/null +++ b/imperative/src/include/megbrain/imperative/graph_builder.h @@ -0,0 +1,134 @@ +/** + * \file imperative/src/include/megbrain/imperative/graph_builder.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "megbrain/imperative/subgraph.h" + +namespace mgb { +namespace imperative { + +template +class Subgraph::Builder { + using graph_t = Subgraph; + using var_t = graph_t::var_t; + using vars_t = graph_t::vars_t; + using op_t = graph_t::op_t; + using expr_t = graph_t::expr_t; + using exprs_t = std::list; + using expr_iter_t = std::list::iterator; + using desc_t = TDesc; + using descs_t = SmallVector; + using infer_fn_t = std::function; + using encoded_graph_t = EncodedSubraph; + using var_map_t = std::unordered_map; + vars_t m_inputs; + SmallVector> m_constants; + vars_t m_outputs; + exprs_t m_exprs; + var_t m_last_var = 0; + std::unordered_map m_var2desc; + infer_fn_t m_infer_fn; + var_map_t m_var_replace_map; + +private: + var_t next_var() { return ++m_last_var; } + +public: + explicit Builder(std::function infer_function) + : m_infer_fn{infer_function} {} + vars_t write_expr(op_t op, vars_t inputs, size_t nr_outputs) { + return write_expr_before(m_exprs.end(), std::move(op), + std::move(inputs), std::move(nr_outputs)); + } + vars_t write_expr_before(expr_iter_t iter, op_t op, vars_t inputs, + size_t nr_outputs) { + vars_t outputs; + for (size_t i = 0; i < nr_outputs; ++i) { + outputs.push_back(next_var()); + } + m_exprs.insert(iter, {op, inputs, outputs}); + descs_t input_descs = get_descs(inputs); + descs_t output_descs = m_infer_fn(op, input_descs, nr_outputs); + mgb_assert(output_descs.size() == nr_outputs, + "bad infer_function: output descs size mismatch"); + for (size_t i = 0; i < nr_outputs; ++i) { + m_var2desc[outputs[i]] = output_descs[i]; + } + return outputs; + } + var_t write_constant(TensorPtr constant, desc_t desc) { + var_t constant_var = next_var(); + m_constants.emplace_back(constant_var, constant); + m_var2desc[constant_var] = std::move(desc); + return constant_var; + } + var_t write_input(desc_t input_desc) { + var_t input = next_var(); + m_var2desc[input] = input_desc; + m_inputs.push_back(input); + return input; + } + vars_t write_inputs(descs_t input_descs) { + vars_t inputs; + for (auto&& input_desc: input_descs) { + inputs.push_back(write_input(input_desc)); + } + return inputs; + } + void add_output(var_t var) { m_outputs.push_back(var); } + void add_outputs(vars_t vars) { + m_outputs.insert(m_outputs.begin(), vars.begin(), vars.end()); + } + desc_t get_desc(var_t var) { return m_var2desc.at(var); } + descs_t get_descs(vars_t vars) { + descs_t descs; + for (auto&& var : vars) { + descs.push_back(get_desc(var)); + } + return descs; + } + encoded_graph_t encode() const { + graph_t graph{m_inputs, + m_constants, + m_outputs, + {m_exprs.begin(), m_exprs.end()}}; + graph.replace_vars(m_var_replace_map); + graph.remove_unused_exprs(); + return encoded_graph_t::make(std::move(graph)); + } + void replace_var(var_t old_var, var_t new_var) { + mgb_assert(!m_var_replace_map.count(old_var), + "var cannot be replaced twice"); + m_var_replace_map[old_var] = new_var; + } + template + void iterate(TFunctor&& functor) { + for (expr_iter_t iter = m_exprs.begin(); iter != m_exprs.end(); + ++iter) { + functor(iter); + } + } + template + void reverse_iterate(TFunctor&& functor) { + for (expr_iter_t iter = --m_exprs.end();; --iter) { + functor(iter); + if (iter == m_exprs.begin()) { + break; + } + } + } + expr_iter_t begin() { return m_exprs.begin(); } + expr_iter_t end() { return m_exprs.end(); } +}; +} // namespace imperative +} // namespace mgb \ No newline at end of file -- GitLab