diff --git a/src/gopt/impl/subgraph_extractor.cpp b/src/gopt/impl/subgraph_extractor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ffced9ea46ee3331a4e1753b5592343d682f4142 --- /dev/null +++ b/src/gopt/impl/subgraph_extractor.cpp @@ -0,0 +1,101 @@ +/** + * \file src/gopt/impl/subgraph_extractor.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/gopt/subgraph_extractor.h" + +using namespace mgb; +using namespace cg; +using namespace gopt; + +/* ================== SubGraphExtractor =================*/ +std::vector SubGraphExtractor::extract( + const SymbolVarArray& endpoint_vars) const { + ThinHashMap> parent; + thin_function union_find; + auto union_find = [&parent, &union_find](OperatorNodeBase* o) { + if (parent[o].first == o) + return o; + else { + auto p = union_find(parent[o].first); + parent[o].first = p; + return p; + } + }; + auto union_merge = [&parent, &union_find](OperatorNodeBase* x, + OperatorNodeBase* y) { + auto root_x = union_find(x), root_y = union_find(y); + if (root_x != root_y) { + OperatorNodeBase *large, small; + if (parent[root_x].second < parent[root_y].second) { + small = root_x, large = root_y; + } else { + small = root_y, large = root_x; + } + parent[small].first = large; + if (parent[large].second == parent[small].second) { + parend[large].second += 1; + } + } + }; + + std::vector topo; + auto cb = [&topo](OperatorNodeBase* opr) { + topo.push_back(opr); + if (opr_list.count(opr->dyn_typeinfo()) == 0) + return; + auto find = parent.find(opr); + if (find == parent.end()) { + auto insert = + parent.insert(std::make_pair(opr, std::make_pair(opr, 0))); + find = insert.first; + } + for (auto&& i : opr->input()) { + auto&& o = i->owner_opr(); + if (opr_list.count(o->dyn_typeinfo()) == 0) + continue; + union_merge(opr, o); + } + }; + cg::DepOprIter iter{cb}; + for (const auto& v : endpoint_vars) + iter.add(v.node()->owner_opr()); + + std::vector partitions; + ThinHashMap roots; + for (const auto& opr : reverse_adaptor(topo)) { + auto root = union_find(opr); + auto find = roots.find(root); + InternalGraph* internal_graph = nullptr; + if (find == roots.end()) { + partitions.emplace_back(InternalGraph{}); + auto insert = + roots.insert(std::make_pair(root, &partitions.back())); + internal_graph = insert.first->second; + internal_graph->m_outputs.insert(opr->output(0)); + } else { + internal_graph = find->second; + auto erase = internal_graph->m_inputs.erase(opr->output(0)); + if (erase > 0) { + internal_graph->m_internals.insert(opr->output(0)); + } else { + internal_graph->m_outputs.insert(opr->output(0)); + } + } + for (const auto& i : opr->input()) + internal_graph->m_inputs.insert(i); + } + return partitions; +} + +/* ============= SubGraphExtractor =================*/ + +// vim: syntax=cpp.doxygen diff --git a/src/gopt/include/megbrain/gopt/subgraph_extractor.h b/src/gopt/include/megbrain/gopt/subgraph_extractor.h new file mode 100644 index 0000000000000000000000000000000000000000..e443c253df008cf2993d0d60ac50aecc099f6f56 --- /dev/null +++ b/src/gopt/include/megbrain/gopt/subgraph_extractor.h @@ -0,0 +1,40 @@ +/** + * \file src/gopt/include/megbrain/gopt/subgraph_extractor.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megbrain/graph.h" + +namespace mgb { +namespace gopt { + +struct InternalGraph { + ThinHashSet m_internals; + ThinHashSet m_inputs; + ThinHashSet m_outputs; +}; + +class SubGraphExtractor { +public: + using OprList = ThinHashSet; + SubGraphExtractor(OprList opr_list) : m_opr_list{opr_list} {}; + std::vector extract( + const SymbolVarArray& endpoint_vars) const; + +private: + class Impl; + OprList m_opr_list; +}; + +} // namespace gopt +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}