提交 88b3c842 编写于 作者: M Megvii Engine Team

refactor(subgraph): move to subgraph.h

GitOrigin-RevId: 2791f335d4ac7fa42939e4bc8b9c96981649ad9a
上级 43a9e6e3
...@@ -309,7 +309,7 @@ inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) { ...@@ -309,7 +309,7 @@ inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) {
for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
inputs.push_back(args[i]->shared_from_this()); inputs.push_back(args[i]->shared_from_this());
} }
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs) { auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs, size_t) {
return apply(op, std::move(inputs)); return apply(op, std::move(inputs));
}; };
return graph.apply(inputs, apply_functor, &make_const); return graph.apply(inputs, apply_functor, &make_const);
...@@ -317,7 +317,7 @@ inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) { ...@@ -317,7 +317,7 @@ inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) {
template <typename T> template <typename T>
auto apply(Subgraph graph, T&& tensors) auto apply(Subgraph graph, T&& tensors)
-> std::enable_if_t<std::is_same_v<decltype(tensors[0]), Tensor*>, -> std::enable_if_t<std::is_same_v<std::decay_t<decltype(tensors[0])>, Tensor*>,
apply_result_t> { apply_result_t> {
size_t nargs = tensors.size(); size_t nargs = tensors.size();
Tensor* args[nargs]; Tensor* args[nargs];
......
/**
* \file imperative/src/impl/subgraph.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/imperative/subgraph.h"
namespace mgb {
namespace imperative {
void Subgraph::remove_unused_exprs() {
std::unordered_set<size_t> required_vars = {outputs.begin(), outputs.end()};
required_vars.erase(0);
for (auto iter = exprs.rbegin(); iter != exprs.rend(); ++iter) {
auto& expr = *iter;
bool required = false;
for (auto output : expr.outputs) {
if (required_vars.count(output)) {
required = true;
break;
}
}
if (required) {
required_vars.insert(expr.inputs.begin(), expr.inputs.end());
} else {
expr.op = nullptr;
}
}
exprs.erase(std::remove_if(exprs.begin(), exprs.end(),
[](auto expr) { return expr.op == nullptr; }),
exprs.end());
}
SmallVector<bool> Subgraph::gen_input_mask() {
std::unordered_set<size_t> unused_inputs = {inputs.begin(), inputs.end()};
for (auto&& expr : exprs) {
for (auto&& input : expr.inputs) {
unused_inputs.erase(input);
}
}
for (auto&& output : outputs) {
unused_inputs.erase(output);
}
unused_inputs.insert(0);
SmallVector<bool> mask(inputs.size(), true);
for (size_t i = 0; i < inputs.size(); ++i) {
if (unused_inputs.count(inputs[i])) {
mask[i] = false;
}
}
return mask;
}
SmallVector<bool> Subgraph::gen_output_mask() {
std::unordered_set<size_t> invalid_outputs = {outputs.begin(),
outputs.end()};
for (auto&& input : inputs) {
invalid_outputs.erase(input);
}
for (auto&& expr : exprs) {
for (auto&& output : expr.outputs) {
invalid_outputs.erase(output);
}
}
for (auto&& constant: constants) {
invalid_outputs.erase(constant.first);
}
invalid_outputs.insert(0);
SmallVector<bool> mask(outputs.size(), true);
for (size_t i = 0; i < outputs.size(); ++i) {
if (invalid_outputs.count(outputs[i])) {
mask[i] = false;
}
}
return mask;
}
void Subgraph::replace_vars(
const std::unordered_map<size_t, size_t>& replace_map) {
// FIXME: preprocess replace_map
auto replace_var = [&](var_t& var) {
// TODO: detect infinite loop
while (replace_map.count(var)) {
var = replace_map.at(var);
}
};
for (auto& expr : exprs) {
for (auto& input : expr.inputs) {
replace_var(input);
}
}
for (auto& output : outputs) {
replace_var(output);
}
}
} // namespace imperative
} // namespace mgb
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "megbrain/graph.h" #include "megbrain/graph.h"
#include "megbrain/imperative/physical_tensor.h" #include "megbrain/imperative/physical_tensor.h"
#include "megbrain/imperative/utils/to_string.h" #include "megbrain/imperative/utils/to_string.h"
#include "megbrain/imperative/subgraph.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
...@@ -28,54 +29,6 @@ enum DispatchMode { ...@@ -28,54 +29,6 @@ enum DispatchMode {
using SharedOp = std::shared_ptr<OpDef>; using SharedOp = std::shared_ptr<OpDef>;
template <typename T>
struct Expr {
std::shared_ptr<OpDef> op;
SmallVector<T> inputs;
SmallVector<T> outputs;
};
struct Subgraph {
SmallVector<size_t> inputs;
SmallVector<std::pair<size_t, TensorPtr>> constants;
SmallVector<size_t> outputs;
SmallVector<Expr<size_t>> exprs;
template <typename T, typename F, typename C>
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const {
std::unordered_map<size_t, T> idx2var;
mgb_assert(inputs.size() == input_vars.size(), "input size mismatch");
for (size_t i = 0; i < inputs.size(); ++i) {
idx2var[inputs[i]] = input_vars[i];
}
for (auto&& [idx, val]: constants) {
idx2var[idx] = c(val);
}
for (auto& expr: exprs) {
SmallVector<T> expr_inputs;
for (auto idx: expr.inputs) {
expr_inputs.push_back(idx2var[idx]);
}
SmallVector<T> expr_outputs = f(expr.op, std::move(expr_inputs));
mgb_assert(expr_outputs.size() == expr.outputs.size(), "output size mismatch");
for (size_t i = 0; i < expr_outputs.size(); ++i) {
idx2var[expr.outputs[i]] = expr_outputs[i];
}
}
SmallVector<T> output_vars;
for (auto idx: outputs) {
output_vars.push_back(idx2var[idx]);
}
return output_vars;
}
bool empty() const {
return outputs.size() == 0;
}
std::string repr() const;
};
struct BackwardGraphResult { struct BackwardGraphResult {
Subgraph backward; Subgraph backward;
SmallVector<bool> save_for_backward; SmallVector<bool> save_for_backward;
......
/**
* \file imperative/src/include/megbrain/imperative/subgraph.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <list>
#include "megbrain/imperative/physical_tensor.h"
#include "megbrain/imperative/utils/to_string.h"
#include "megbrain/utils/small_vector.h"
namespace mgb {
namespace imperative {
class OpDef;
template <typename T>
struct Expr {
std::shared_ptr<OpDef> op;
SmallVector<T> inputs;
SmallVector<T> outputs;
};
template <typename T>
struct ToStringTrait<Expr<T>> {
std::string operator()(const Expr<T>& expr) {
return ssprintf("%s = %s %s\n", to_string(expr.inputs).c_str(), to_string(expr.op.get()).c_str(), to_string(expr.outputs).c_str());
}
};
struct Subgraph {
template <typename TDesc>
class Builder;
using var_t = size_t;
using vars_t = SmallVector<size_t>;
using op_t = std::shared_ptr<OpDef>;
using expr_t = Expr<var_t>;
template <typename TDesc>
using builder_t = Builder<TDesc>;
SmallVector<var_t> inputs;
SmallVector<std::pair<var_t, TensorPtr>> constants;
SmallVector<var_t> outputs;
SmallVector<expr_t> exprs;
template <typename T, typename F, typename C>
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const {
std::unordered_map<size_t, T> idx2var;
mgb_assert(inputs.size() == input_vars.size(), "input size mismatch");
for (size_t i = 0; i < inputs.size(); ++i) {
idx2var[inputs[i]] = input_vars[i];
}
for (auto&& [idx, val] : constants) {
idx2var[idx] = c(val);
}
for (auto& expr : exprs) {
SmallVector<T> expr_inputs;
for (auto idx : expr.inputs) {
expr_inputs.push_back(idx2var[idx]);
}
SmallVector<T> expr_outputs =
f(expr.op, std::move(expr_inputs), expr.outputs.size());
mgb_assert(expr_outputs.size() == expr.outputs.size(),
"output size mismatch");
for (size_t i = 0; i < expr_outputs.size(); ++i) {
idx2var[expr.outputs[i]] = expr_outputs[i];
}
}
SmallVector<T> output_vars;
for (auto idx : outputs) {
output_vars.push_back(idx2var[idx]);
}
return output_vars;
}
void remove_unused_exprs();
SmallVector<bool> gen_input_mask();
SmallVector<bool> gen_output_mask();
bool empty() const { return outputs.size() == 0; }
void replace_vars(const std::unordered_map<size_t, size_t>& replace_map);
std::string repr() const;
bool is_single() const;
std::shared_ptr<OpDef> as_single() const;
bool operator==(const Subgraph& rhs) const;
};
} // namespace imperative
} // namespace mgb
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册