diff --git a/imperative/src/impl/op_def.cpp b/imperative/src/impl/op_def.cpp index 98ddaa39c93b12721b1f9086cff5a6b78ababc53..d4a96f32d7a6802206beb3dd187f59097f15e171 100644 --- a/imperative/src/impl/op_def.cpp +++ b/imperative/src/impl/op_def.cpp @@ -100,6 +100,19 @@ std::vector> OpDef::props( return def.trait()->props(def); } +EncodedSubraph OpDef::make_forward_graph( + const OpDef& def, + const SmallVector& inputs){ + using ForwardGraphCache = OpMethResultCache, SmallVector>; + thread_local ForwardGraphCache cache; + decltype(cache)::key_t cache_key{const_cast(def).shared_from_this(), inputs}; + auto iter = cache.find(cache_key); + if (iter == cache.end()) { + iter = cache.insert({cache_key, def.trait()->make_forward_graph(def, inputs)}).first; + } + return iter->second; +} + std::string OpDef::to_string() const { std::string builder = trait()->make_name(*this) + "{"; for (auto&& [name, value]: props(*this)) { diff --git a/imperative/src/impl/op_trait.cpp b/imperative/src/impl/op_trait.cpp index 59f1befa4ab165b1ecc66b2fab64e7be32d6032c..fff666ef6e15729dec7f8f190256a1774729dffc 100644 --- a/imperative/src/impl/op_trait.cpp +++ b/imperative/src/impl/op_trait.cpp @@ -16,6 +16,7 @@ #include "megbrain/imperative/op_def.h" #include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/proxy_graph_detail.h" +#include "megbrain/imperative/subgraph_detail.h" #include "megbrain/tensor.h" #include "./op_trait.h" @@ -38,24 +39,45 @@ StaticData& static_data() { return data; } -void OpMethFallback::impl(ApplyOnPhysicalTensor& func, +void OpMethFallbackByProxyGraph::impl(ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor) { func.Base::operator=(proxy_graph_detail::apply_on_physical_tensor); } -void OpMethFallback::impl(Execute& func, op_meth_tag::Execute) { +void OpMethFallbackByProxyGraph::impl(Execute& func, op_meth_tag::Execute) { func.Base::operator=(proxy_graph_detail::execute); } -void OpMethFallback::impl(InferOutputMemDesc& func, +void OpMethFallbackByProxyGraph::impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc) { func.Base::operator=(proxy_graph_detail::infer_output_mem_desc); } -void OpMethFallback::impl(InferOutputAttrsFallible& func, +void OpMethFallbackByProxyGraph::impl(InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible) { func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible); } -void OpMethFallback::impl(GradMaker& func, op_meth_tag::GradMaker) { +void OpMethFallbackByProxyGraph::impl(GradMaker& func, op_meth_tag::GradMaker) { func.Base::operator=(proxy_graph_detail::make_backward_graph); } + +void OpMethFallbackFromSubgraph::impl(ApplyOnPhysicalTensor& func, + op_meth_tag::ApplyOnPhysicalTensor) { + func.Base::operator=(subgraph_detail::apply_on_physical_tensor); +} +void OpMethFallbackFromSubgraph::impl(InferOutputMemDesc& func, + op_meth_tag::InferOutputMemDesc) { + func.Base::operator=(subgraph_detail::infer_output_mem_desc); +} +void OpMethFallbackFromSubgraph::impl(ApplyOnVarNode& func, + op_meth_tag::ApplyOnVarNode) { + func.Base::operator=(subgraph_detail::apply_on_var_node); +} +void OpMethFallbackFromSubgraph::impl(InferOutputAttrsFallible& func, + op_meth_tag::InferOutputAttrsFallible) { + func.Base::operator=(subgraph_detail::infer_output_attrs_fallible); +} +void OpMethFallbackFromSubgraph::impl(GradMaker& func, op_meth_tag::GradMaker) { + func.Base::operator=(subgraph_detail::make_backward_graph); +} + void OpMethFallback::impl(DecideDispatchMode& func, op_meth_tag::DecideDispatchMode) { static auto decide_dispatch_mode = @@ -99,16 +121,20 @@ void OpTrait::for_each_trait(thin_function visitor){ } OpTraitRegistry& OpTraitRegistry::fallback() { + using Mode = detail::OpMethFallbackMode; + uint64_t mode = Mode::None; + if (trait->make_forward_graph) { + mode |= Mode::FromSubgraph; + } if (trait->apply_on_var_node) { - // fallback to proxy graph impl - trait->apply_on_physical_tensor.allow_fallback = true; - trait->execute.allow_fallback = true; - trait->infer_output_mem_desc.allow_fallback = true; - trait->infer_output_attrs_fallible.allow_fallback = true; - trait->make_backward_graph.allow_fallback = true; + mode |= Mode::ByProxyGraph; } - trait->decide_dispatch_mode.allow_fallback = true; - trait->make_name.allow_fallback = true; + mode |= Mode::Default; +#define SET_FALLBACK_MODE(meth) \ + trait->meth.fallback_mode = mode; + FOR_EACH_OP_METH(SET_FALLBACK_MODE) +#undef SET_FALLBACK_MODE + return *this; } diff --git a/imperative/src/impl/op_trait.h b/imperative/src/impl/op_trait.h index 7c00ae0d59283e1a2b61b18a21265ace26152e6a..de134fb41e767ce978a88daa6f85409fc9393514 100644 --- a/imperative/src/impl/op_trait.h +++ b/imperative/src/impl/op_trait.h @@ -95,9 +95,18 @@ OpMethType(IsSame, OpMethType(MakeNameFunc, std::string(const OpDef&)); + +OpMethType(GraphMaker, + decltype(OpDef::make_forward_graph)); // clang-format on namespace detail { + +struct OpMethImplBase { + template + static void impl(thin_function& func, Tag) {} +}; + struct OpMethNotImpl { template static void impl(thin_function& func, Tag) { @@ -106,8 +115,15 @@ struct OpMethNotImpl { }; } }; -struct OpMethFallback : public OpMethNotImpl { - using OpMethNotImpl::impl; + +struct OpMethFallback: OpMethImplBase { + using OpMethImplBase::impl; + static void impl(DecideDispatchMode& func, op_meth_tag::DecideDispatchMode); + static void impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc); +}; + +struct OpMethFallbackByProxyGraph: OpMethImplBase { + using OpMethImplBase::impl; static void impl(ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor); static void impl(Execute& func, op_meth_tag::Execute); @@ -115,18 +131,48 @@ struct OpMethFallback : public OpMethNotImpl { static void impl(InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible); static void impl(GradMaker& func, op_meth_tag::GradMaker); - static void impl(DecideDispatchMode& func, op_meth_tag::DecideDispatchMode); - static void impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc); }; + +struct OpMethFallbackFromSubgraph: OpMethImplBase { + using OpMethImplBase::impl; + static void impl(ApplyOnPhysicalTensor& func, + op_meth_tag::ApplyOnPhysicalTensor); + static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc); + static void impl(ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode); + static void impl(InferOutputAttrsFallible& func, + op_meth_tag::InferOutputAttrsFallible); + static void impl(GradMaker& func, op_meth_tag::GradMaker); +}; + +struct OpMethFallbackMode { + static constexpr uint64_t None = 0; + static constexpr uint64_t Default = 1; + static constexpr uint64_t ByProxyGraph = 2; + static constexpr uint64_t FromSubgraph = 4; +}; + template struct OpMeth : public thin_function { using Base = thin_function; - OpMeth() : Base{}, allow_fallback(false){}; + OpMeth() : Base{}{}; explicit OpMeth(const Base& base) { this->Base::operator=(base); } using Base::operator bool; RType operator()(Args... args) const { - if (!this->Base::operator bool()) { - if (allow_fallback) { + uint64_t mode_mask = ~uint64_t(0); + auto match_mode = [&](uint64_t mode){ + if ((fallback_mode & mode_mask) & mode) { + mode_mask &= ~mode; + return true; + } + return false; + }; + while (!this->Base::operator bool()) { + using Mode = OpMethFallbackMode; + if (match_mode(Mode::FromSubgraph)) { + OpMethFallbackFromSubgraph::impl(*const_cast(this), Tag{}); + } else if (match_mode(Mode::ByProxyGraph)) { + OpMethFallbackByProxyGraph::impl(*const_cast(this), Tag{}); + } else if (match_mode(Mode::Default)) { OpMethFallback::impl(*const_cast(this), Tag{}); } else { OpMethNotImpl::impl(*const_cast(this), Tag{}); @@ -134,7 +180,7 @@ struct OpMeth : public thin_function { } return this->Base::operator()(std::forward(args)...); } - bool allow_fallback = false; + uint64_t fallback_mode = OpMethFallbackMode::None; }; } // namespace detail @@ -153,6 +199,7 @@ struct OpTrait { HashFunc hash; IsSame is_same_st; MakeNameFunc make_name; + GraphMaker make_forward_graph; OpTrait(const char* name); static OpTrait* find_by_name(const char* name); static OpTrait* find_by_typeinfo(Typeinfo* type); @@ -173,7 +220,9 @@ struct OpTrait { cb(props) \ cb(hash) \ cb(is_same_st) \ - cb(make_name) + cb(make_name) \ + cb(make_forward_graph) \ + // clang-format on struct OpTraitRegistry { diff --git a/imperative/src/impl/subgraph_detail.cpp b/imperative/src/impl/subgraph_detail.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ee4971645026acd5e7e4f229895006fd6d87e525 --- /dev/null +++ b/imperative/src/impl/subgraph_detail.cpp @@ -0,0 +1,169 @@ +/** + * \file imperative/src/impl/subgraph_detail.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/subgraph_detail.h" +#include "megbrain/imperative/graph_builder.h" + +#include "megbrain/opr/io.h" +#include "megbrain/imperative/ops/autogen.h" + +#include "./op_trait.h" + +namespace mgb { +namespace imperative { +namespace subgraph_detail { + +VarNodeArray apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + SmallVector input_descs; + for (auto&& input: inputs) { + input_descs.push_back({TensorLayout{input->dtype()}, input->comp_node()}); + } + auto apply_functor = [](const std::shared_ptr& op, const VarNodeArray& inputs, size_t nr_outputs){ + return OpDef::apply_on_var_node(*op, inputs); + }; + auto const_functor = [&](const TensorPtr& value) { + return opr::ImmutableTensor::make(*inputs[0]->owner_graph(), value->get_value()).node(); + }; + auto subgraph = def.trait()->make_forward_graph(def, input_descs); + auto outputs = subgraph.apply(inputs, apply_functor, const_functor); + return outputs; +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, + const SmallVector& inputs) { + auto subgraph = def.trait()->make_forward_graph(def, inputs); + bool all_validated = true; + auto apply_functor = [&](const std::shared_ptr& op, const SmallVector& inputs, size_t nr_outputs){ + auto [outputs, validated] = OpDef::infer_output_attrs_fallible(*op, inputs); + all_validated = all_validated && validated; + return outputs; + }; + auto const_functor = [&](const TensorPtr& value) { + return LogicalTensorDesc{value->layout(), value->comp_node(), value->get_value().proxy_to_default_cpu()}; + }; + auto outputs = subgraph.apply(inputs, apply_functor, const_functor); + return { outputs, all_validated }; +} + + +SmallVector apply_on_physical_tensor( + const OpDef& def, + SmallVector inputs) { + SmallVector input_descs; + for (auto&& input: inputs) { + input_descs.push_back({input->layout(), input->comp_node()}); + } + auto subgraph = def.trait()->make_forward_graph(def, input_descs); + auto apply_functor = [](const std::shared_ptr& op, const SmallVector& inputs, size_t nr_outputs){ + return OpDef::apply_on_physical_tensor(*op, inputs); + }; + auto const_functor = [&](const TensorPtr& value) { + return value; + }; + auto outputs = subgraph.apply(inputs, apply_functor, const_functor); + return outputs; +} + +static EncodedSubraph make_backward_graph_from_forward( + const SmallVector& inputs, + const SmallVector& input_requires_grad, + const SmallVector& output_has_grad, + EncodedSubraph forward_graph) { + using namespace std::placeholders; + using var_t = Subgraph::var_t; + using vars_t = Subgraph::vars_t; + Subgraph::Builder builder([](auto&& op, auto&& input_descs, size_t nr_outputs){ + auto [descs, _] = OpDef::infer_output_attrs_fallible(*op, input_descs); + return descs; + }); + auto accum_grad = [&](var_t lhs, var_t rhs) { + return builder.write_expr(Elemwise::make(Elemwise::Mode::ADD), {lhs, rhs}, 1)[0]; + }; + GradContext grad_context{accum_grad}; + auto input_vars = builder.write_inputs(inputs); + auto outputs = forward_graph.apply(input_vars, std::bind(&decltype(builder)::write_expr, &builder, _1, _2, _3), [&](TensorPtr constant){ + return builder.write_constant(constant, {constant->layout(), constant->comp_node()}); + }); + size_t nr_outputs = outputs.size(); + auto apply_mask = [](auto&& values, SmallVector mask) { + mgb_assert(mask.size() == values.size(), ""); + std::decay_t results; + for (size_t i = 0; i < mask.size(); ++i) { + if (mask[i]) { + results.push_back(values[i]); + } + } + return results; + }; + grad_context.mark_require_grads(apply_mask(input_vars, input_requires_grad)); + builder.iterate([&](std::list::iterator iter){ + grad_context.record_expr(iter->op, iter->inputs, iter->outputs); + }); + auto output_descs = builder.get_descs(outputs); + auto computed_outputs = builder.write_inputs(output_descs); + auto output_grads = builder.write_inputs(output_descs); + + grad_context.backward( + apply_mask(outputs, output_has_grad), + apply_mask(output_grads, output_has_grad), + [&](Subgraph::expr_t expr, vars_t output_grads) { + auto bg = OpDef::make_backward_graph( + *expr.op, builder.get_descs(expr.inputs), + grad_context.get_require_grads(expr.inputs), + grad_context.get_has_grads(expr.outputs)); + if (bg.graph.empty()) { + return vars_t(expr.inputs.size(), 0); + } + vars_t grad_inputs; + grad_inputs.insert(grad_inputs.end(), expr.inputs.begin(), + expr.inputs.end()); + grad_inputs.insert(grad_inputs.end(), expr.outputs.begin(), + expr.outputs.end()); + grad_inputs.insert(grad_inputs.end(), output_grads.begin(), + output_grads.end()); + auto apply_functor = std::bind(&decltype(builder)::write_expr, + &builder, _1, _2, _3); + auto const_functor = [&](TensorPtr constant) { + return builder.write_constant(constant, {constant->layout(), + constant->comp_node()}); + }; + return bg.apply(grad_inputs, apply_functor, const_functor); + }); + builder.add_outputs(grad_context.get_grads(input_vars)); + for (size_t i = 0; i < nr_outputs; ++i) { + builder.replace_var(outputs[i], computed_outputs[i]); + } + auto backward_graph = builder.encode(); + return backward_graph; +} + +EncodedSubraph make_backward_graph( + const OpDef& def, + const SmallVector& inputs, + const SmallVector& input_requires_grad, + const SmallVector& output_has_grad) { + auto forward_graph = OpDef::make_forward_graph(def, inputs); + return make_backward_graph_from_forward(inputs, input_requires_grad, output_has_grad, forward_graph); +} + +std::tuple, SmallVector> infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs_tensors, + const SmallVector& inputs_mems) { + return {{}, {}}; +} + +} +} +} diff --git a/imperative/src/include/megbrain/imperative/op_def.h b/imperative/src/include/megbrain/imperative/op_def.h index 591a2323ec61a17c4f8e37c1a5d55cc1218cf592..e244f00c9ff5c0fc2b1f76968056b4fe9d0ded26 100644 --- a/imperative/src/include/megbrain/imperative/op_def.h +++ b/imperative/src/include/megbrain/imperative/op_def.h @@ -13,6 +13,7 @@ #include "megbrain/graph.h" #include "megbrain/imperative/physical_tensor.h" +#include "megbrain/imperative/subgraph.h" #include "megbrain/imperative/utils/to_string.h" #include "megbrain/imperative/subgraph.h" @@ -94,6 +95,10 @@ public: static std::vector> props( const OpDef& def); + static EncodedSubraph make_forward_graph( + const OpDef& def, + const SmallVector& inputs); + const OpTrait* trait() const; std::string to_string() const; diff --git a/imperative/src/include/megbrain/imperative/subgraph_detail.h b/imperative/src/include/megbrain/imperative/subgraph_detail.h new file mode 100644 index 0000000000000000000000000000000000000000..d0cb6641cda6451d3985d37b3bb624b11d9ceff6 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/subgraph_detail.h @@ -0,0 +1,51 @@ +/** + * \file imperative/src/include/megbrain/imperative/subgraph_detail.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/imperative/op_def.h" + +namespace mgb { +namespace imperative { +namespace subgraph_detail { + +SmallVector +apply_on_physical_tensor(const OpDef& def, + SmallVector inputs); + +std::tuple, bool> infer_output_attrs_fallible(const OpDef& def, + const SmallVector& inputs); + +EncodedSubraph +make_backward_graph(const OpDef& def, + const SmallVector& inputs, + const SmallVector& input_requires_grad, + const SmallVector& output_has_grad); + +cg::VarNodeArray +apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs); + +EncodedSubraph make_backward_graph( + const OpDef& def, + const SmallVector& inputs, + const SmallVector& input_requires_grad, + const SmallVector& output_has_grad); + +std::tuple, SmallVector> infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs_tensors, + const SmallVector& inputs_mems); + +} +} +} \ No newline at end of file