diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index d3f2c9fffa3f0e24e584f211f7403ca10b327abd..5860a5c3685a913cd826ebc3093ec869c82865f8 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -309,7 +309,7 @@ inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) { for (size_t i = 0; i < nargs; ++i) { inputs.push_back(args[i]->shared_from_this()); } - auto apply_functor = [](std::shared_ptr op, SmallVector> inputs) { + auto apply_functor = [](std::shared_ptr op, SmallVector> inputs, size_t) { return apply(op, std::move(inputs)); }; return graph.apply(inputs, apply_functor, &make_const); @@ -317,7 +317,7 @@ inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) { template auto apply(Subgraph graph, T&& tensors) - -> std::enable_if_t, + -> std::enable_if_t, Tensor*>, apply_result_t> { size_t nargs = tensors.size(); Tensor* args[nargs]; diff --git a/imperative/src/impl/subgraph.cpp b/imperative/src/impl/subgraph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..087bd1fbcb0b3f2e10ab2aa41d77cf066738a077 --- /dev/null +++ b/imperative/src/impl/subgraph.cpp @@ -0,0 +1,105 @@ +/** + * \file imperative/src/impl/subgraph.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/imperative/subgraph.h" + +namespace mgb { +namespace imperative { + +void Subgraph::remove_unused_exprs() { + std::unordered_set required_vars = {outputs.begin(), outputs.end()}; + required_vars.erase(0); + for (auto iter = exprs.rbegin(); iter != exprs.rend(); ++iter) { + auto& expr = *iter; + bool required = false; + for (auto output : expr.outputs) { + if (required_vars.count(output)) { + required = true; + break; + } + } + if (required) { + required_vars.insert(expr.inputs.begin(), expr.inputs.end()); + } else { + expr.op = nullptr; + } + } + exprs.erase(std::remove_if(exprs.begin(), exprs.end(), + [](auto expr) { return expr.op == nullptr; }), + exprs.end()); +} + +SmallVector Subgraph::gen_input_mask() { + std::unordered_set unused_inputs = {inputs.begin(), inputs.end()}; + for (auto&& expr : exprs) { + for (auto&& input : expr.inputs) { + unused_inputs.erase(input); + } + } + for (auto&& output : outputs) { + unused_inputs.erase(output); + } + unused_inputs.insert(0); + SmallVector mask(inputs.size(), true); + for (size_t i = 0; i < inputs.size(); ++i) { + if (unused_inputs.count(inputs[i])) { + mask[i] = false; + } + } + return mask; +} + +SmallVector Subgraph::gen_output_mask() { + std::unordered_set invalid_outputs = {outputs.begin(), + outputs.end()}; + for (auto&& input : inputs) { + invalid_outputs.erase(input); + } + for (auto&& expr : exprs) { + for (auto&& output : expr.outputs) { + invalid_outputs.erase(output); + } + } + for (auto&& constant: constants) { + invalid_outputs.erase(constant.first); + } + invalid_outputs.insert(0); + SmallVector mask(outputs.size(), true); + for (size_t i = 0; i < outputs.size(); ++i) { + if (invalid_outputs.count(outputs[i])) { + mask[i] = false; + } + } + return mask; +} + +void Subgraph::replace_vars( + const std::unordered_map& replace_map) { + // FIXME: preprocess replace_map + auto replace_var = [&](var_t& var) { + // TODO: detect infinite loop + while (replace_map.count(var)) { + var = replace_map.at(var); + } + }; + for (auto& expr : exprs) { + for (auto& input : expr.inputs) { + replace_var(input); + } + } + for (auto& output : outputs) { + replace_var(output); + } +} + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/op_def.h b/imperative/src/include/megbrain/imperative/op_def.h index 8b8641a51b66a3f189b12d4b517ed058054883f4..7ab725070f777b84f2759a298d7e6de396f9e0c5 100644 --- a/imperative/src/include/megbrain/imperative/op_def.h +++ b/imperative/src/include/megbrain/imperative/op_def.h @@ -14,6 +14,7 @@ #include "megbrain/graph.h" #include "megbrain/imperative/physical_tensor.h" #include "megbrain/imperative/utils/to_string.h" +#include "megbrain/imperative/subgraph.h" namespace mgb { namespace imperative { @@ -28,54 +29,6 @@ enum DispatchMode { using SharedOp = std::shared_ptr; -template -struct Expr { - std::shared_ptr op; - SmallVector inputs; - SmallVector outputs; -}; - -struct Subgraph { - SmallVector inputs; - SmallVector> constants; - SmallVector outputs; - SmallVector> exprs; - - template - SmallVector apply(SmallVector input_vars, F&& f, C&& c) const { - std::unordered_map idx2var; - mgb_assert(inputs.size() == input_vars.size(), "input size mismatch"); - for (size_t i = 0; i < inputs.size(); ++i) { - idx2var[inputs[i]] = input_vars[i]; - } - for (auto&& [idx, val]: constants) { - idx2var[idx] = c(val); - } - for (auto& expr: exprs) { - SmallVector expr_inputs; - for (auto idx: expr.inputs) { - expr_inputs.push_back(idx2var[idx]); - } - SmallVector expr_outputs = f(expr.op, std::move(expr_inputs)); - mgb_assert(expr_outputs.size() == expr.outputs.size(), "output size mismatch"); - for (size_t i = 0; i < expr_outputs.size(); ++i) { - idx2var[expr.outputs[i]] = expr_outputs[i]; - } - } - SmallVector output_vars; - for (auto idx: outputs) { - output_vars.push_back(idx2var[idx]); - } - return output_vars; - } - - bool empty() const { - return outputs.size() == 0; - } - - std::string repr() const; -}; - struct BackwardGraphResult { Subgraph backward; SmallVector save_for_backward; diff --git a/imperative/src/include/megbrain/imperative/subgraph.h b/imperative/src/include/megbrain/imperative/subgraph.h new file mode 100644 index 0000000000000000000000000000000000000000..68221e4538d1db09d5c32430ce6a1629fd66bae1 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/subgraph.h @@ -0,0 +1,100 @@ +/** + * \file imperative/src/include/megbrain/imperative/subgraph.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include + +#include "megbrain/imperative/physical_tensor.h" +#include "megbrain/imperative/utils/to_string.h" +#include "megbrain/utils/small_vector.h" + +namespace mgb { +namespace imperative { + +class OpDef; + +template +struct Expr { + std::shared_ptr op; + SmallVector inputs; + SmallVector outputs; +}; + +template +struct ToStringTrait> { + std::string operator()(const Expr& expr) { + return ssprintf("%s = %s %s\n", to_string(expr.inputs).c_str(), to_string(expr.op.get()).c_str(), to_string(expr.outputs).c_str()); + } +}; + +struct Subgraph { + + template + class Builder; + + using var_t = size_t; + using vars_t = SmallVector; + using op_t = std::shared_ptr; + using expr_t = Expr; + + template + using builder_t = Builder; + + SmallVector inputs; + SmallVector> constants; + SmallVector outputs; + SmallVector exprs; + + template + SmallVector apply(SmallVector input_vars, F&& f, C&& c) const { + std::unordered_map idx2var; + mgb_assert(inputs.size() == input_vars.size(), "input size mismatch"); + for (size_t i = 0; i < inputs.size(); ++i) { + idx2var[inputs[i]] = input_vars[i]; + } + for (auto&& [idx, val] : constants) { + idx2var[idx] = c(val); + } + for (auto& expr : exprs) { + SmallVector expr_inputs; + for (auto idx : expr.inputs) { + expr_inputs.push_back(idx2var[idx]); + } + SmallVector expr_outputs = + f(expr.op, std::move(expr_inputs), expr.outputs.size()); + mgb_assert(expr_outputs.size() == expr.outputs.size(), + "output size mismatch"); + for (size_t i = 0; i < expr_outputs.size(); ++i) { + idx2var[expr.outputs[i]] = expr_outputs[i]; + } + } + SmallVector output_vars; + for (auto idx : outputs) { + output_vars.push_back(idx2var[idx]); + } + return output_vars; + } + + void remove_unused_exprs(); + SmallVector gen_input_mask(); + SmallVector gen_output_mask(); + bool empty() const { return outputs.size() == 0; } + void replace_vars(const std::unordered_map& replace_map); + std::string repr() const; + bool is_single() const; + std::shared_ptr as_single() const; + bool operator==(const Subgraph& rhs) const; +}; + +} // namespace imperative +} // namespace mgb \ No newline at end of file