diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index 9a54c91379bcd5f5d4421c03691ca954c2d60d6b..c31aff881a80a0253a888e6b6f28c681fb89d013 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -30,7 +30,10 @@ #include "megbrain/tensorrt/opr_replace.h" #endif -#include "megbrain/gopt/global_layout_transform.h" +#include "megbrain/gopt/layout_transform_context.h" +#include "megbrain/gopt/layout_transform_pass.h" +#include "megbrain/gopt/profiler.h" +#include "megbrain/gopt/solver.h" using namespace mgb; using namespace gopt; diff --git a/src/gopt/impl/dynamic_programming_solver.cpp b/src/gopt/impl/global_layout_transform/dynamic_programming_solver.cpp similarity index 97% rename from src/gopt/impl/dynamic_programming_solver.cpp rename to src/gopt/impl/global_layout_transform/dynamic_programming_solver.cpp index 01a3627d90abff53b1f8ae206feb60145338a381..0110fdde801ac020d8ba06deac9666c33e35ced5 100644 --- a/src/gopt/impl/dynamic_programming_solver.cpp +++ b/src/gopt/impl/global_layout_transform/dynamic_programming_solver.cpp @@ -12,7 +12,9 @@ #include #include "./utils.h" -#include "megbrain/gopt/global_layout_transform.h" +#include "megbrain/gopt/layout_transform_context.h" +#include "megbrain/gopt/profiler.h" +#include "megbrain/gopt/solver.h" using namespace mgb; using namespace gopt; @@ -85,11 +87,11 @@ private: const SmallVector& available_tensor_formats; }; /*! - * \brief get the tensor formats configuration for the operator with particular op format - * \param[out] var2fmts hashmap that maps varnode to actual tensor formats of the op format configuration - * \param[in] opr given operator - * \param[in] opr_fmt given op format, an enum type argument which indicates the op format configuration. - * \param[in] ctx context + * \brief get the tensor formats configuration for the operator with + * particular op format \param[out] var2fmts hashmap that maps varnode to + * actual tensor formats of the op format configuration \param[in] opr given + * operator \param[in] opr_fmt given op format, an enum type argument which + * indicates the op format configuration. \param[in] ctx context */ TensorFormats get_io_formats(ThinHashMap& var2fmts, const OperatorNodeBase* opr, OprFormat opr_fmt, diff --git a/src/gopt/impl/layout_transform_context.cpp b/src/gopt/impl/global_layout_transform/layout_transform_context.cpp similarity index 98% rename from src/gopt/impl/layout_transform_context.cpp rename to src/gopt/impl/global_layout_transform/layout_transform_context.cpp index 3bae7b89bc4e21c49f4d3a2d824fd6aeed8d6cc6..aa25ba634fe190de7b88bd177532afd0743b9a75 100644 --- a/src/gopt/impl/layout_transform_context.cpp +++ b/src/gopt/impl/global_layout_transform/layout_transform_context.cpp @@ -11,7 +11,7 @@ */ #include "./utils.h" -#include "megbrain/gopt/global_layout_transform.h" +#include "megbrain/gopt/layout_transform_context.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/imgproc.h" #include "megbrain/opr/nn_int.h" diff --git a/src/gopt/impl/layout_transform_pass.cpp b/src/gopt/impl/global_layout_transform/layout_transform_pass.cpp similarity index 77% rename from src/gopt/impl/layout_transform_pass.cpp rename to src/gopt/impl/global_layout_transform/layout_transform_pass.cpp index 5366bd0908a64111536dacc064150ce572512722..842f3c8a3e6a7eba28519d9b24de52b33caa37f1 100644 --- a/src/gopt/impl/layout_transform_pass.cpp +++ b/src/gopt/impl/global_layout_transform/layout_transform_pass.cpp @@ -10,9 +10,11 @@ * implied. */ +#include "megbrain/gopt/layout_transform_pass.h" #include "./opr_format_modifier.h" #include "./utils.h" -#include "megbrain/gopt/global_layout_transform.h" +#include "megbrain/gopt/profiler.h" +#include "megbrain/gopt/solver.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/imgproc.h" #include "megbrain/serialization/sereg.h" @@ -46,8 +48,7 @@ void LayoutTransformPass::apply(OptState& opt) const { auto&& opr_configs = m_ctx->opr_configs(); auto&& base_fmt = m_ctx->attribute().base_tensor_formats; - auto&& reformat_attribute = - ReformatManager::ReformatKey::Attribute::DEFAULT; + auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; ThinHashMap var2fmts; static ThinHashSet format_aware_oprs = { #define cb(_Opr) opr::_Opr::typeinfo(), @@ -55,8 +56,8 @@ void LayoutTransformPass::apply(OptState& opt) const { #undef cb }; auto rewriter = opt.graph().make_rewriter(); - auto on_opr = [this, &opr_configs, &base_fmt, &reformat_attribute, - &rewriter, &solution, &var2fmts, + auto on_opr = [&opr_configs, &base_fmt, &reformat_attribute, &rewriter, + &solution, &var2fmts, &endpoint_vars](OperatorNodeBase* opr) { auto it = solution.find(opr); if (it != solution.end()) { @@ -122,19 +123,6 @@ void LayoutTransformPass::apply(OptState& opt) const { opr->config()) ->output(0); } - if (endpoint_vars.count(opr->output(0)) && out_fmt != base_fmt) { - ReformatManager::ReformatKey key{ - out_fmt, base_fmt, reformat_attribute, - opr->output(0)->dtype().enumv(), - opr->output(0)->dtype().enumv()}; - auto reformat = ReformatManager::instance() - .auto_aligned_reformat_featrue( - opr->output(0), base_fmt, key); - new_out = reformat({new_out}); - var2fmts[new_out] = base_fmt; - } else { - var2fmts[new_out] = out_fmt; - } auto &&out0 = opr->output(), &&out1 = new_out->owner_opr()->output(); mgb_assert(opr->usable_output().size() == @@ -146,20 +134,29 @@ void LayoutTransformPass::apply(OptState& opt) const { new_out->owner_opr()->cname(), new_out->owner_opr()->dyn_typeinfo()->name, out0.size(), out1.size()); - for (size_t i = 0; i < out0.size(); ++i) { - if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { - mgb_assert(!out1[i]->contain_flag( - VarNode::Flag::VOLATILE_CONTENT)); - auto src = out0[i]; - auto dst = out1[i]; - rewriter.replace_var( - src, dst, - mgb_cstr_log(ssprintf("replace opr(%s) to new opr " - "format(%s)", - opr->cname(), - opr_format_to_string(opr_fmt)) - .c_str())); + size_t nr_outs = opr->usable_output().size(); + for (size_t i = 0; i < nr_outs; ++i) { + const auto& ovar = out0[i]; + auto new_ovar = out1[i]; + if (endpoint_vars.count(ovar) && out_fmt != base_fmt) { + ReformatManager::ReformatKey key{ + out_fmt, base_fmt, reformat_attribute, + ovar->dtype().enumv(), ovar->dtype().enumv()}; + auto reformat = ReformatManager::instance() + .auto_aligned_reformat_featrue( + ovar, base_fmt, key); + new_ovar = reformat({new_ovar}); + var2fmts[new_ovar] = base_fmt; + } else { + var2fmts[new_ovar] = out_fmt; } + rewriter.replace_var( + ovar, new_ovar, + mgb_cstr_log(ssprintf("replace opr(%s) to new opr " + "format(%s)", + opr->cname(), + opr_format_to_string(opr_fmt)) + .c_str())); } } else { auto new_opr = rewriter.auto_replace_outputs(opr); diff --git a/src/gopt/impl/opr_format_modifier.cpp b/src/gopt/impl/global_layout_transform/opr_format_modifier.cpp similarity index 100% rename from src/gopt/impl/opr_format_modifier.cpp rename to src/gopt/impl/global_layout_transform/opr_format_modifier.cpp diff --git a/src/gopt/impl/opr_format_modifier.h b/src/gopt/impl/global_layout_transform/opr_format_modifier.h similarity index 100% rename from src/gopt/impl/opr_format_modifier.h rename to src/gopt/impl/global_layout_transform/opr_format_modifier.h diff --git a/src/gopt/impl/opr_tensor_formats_config.cpp b/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp similarity index 99% rename from src/gopt/impl/opr_tensor_formats_config.cpp rename to src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp index 5a57054413e9a4e26c917adffe95aa5c8a29e419..23f4410ce96ab49460a724f48a346815d8299e46 100644 --- a/src/gopt/impl/opr_tensor_formats_config.cpp +++ b/src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp @@ -11,7 +11,7 @@ */ #include "./utils.h" -#include "megbrain/gopt/global_layout_transform.h" +#include "megbrain/gopt/layout_transform_context.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/imgproc.h" diff --git a/src/gopt/impl/profiler_impl.cpp b/src/gopt/impl/global_layout_transform/profiler_impl.cpp similarity index 99% rename from src/gopt/impl/profiler_impl.cpp rename to src/gopt/impl/global_layout_transform/profiler_impl.cpp index bc2a84eec325bb8bd221d24e45284e95c7bac909..2d3b1573d6d413f11dc8c84420696e682214c539 100644 --- a/src/gopt/impl/profiler_impl.cpp +++ b/src/gopt/impl/global_layout_transform/profiler_impl.cpp @@ -13,7 +13,7 @@ #include "./opr_format_modifier.h" #include "./utils.h" #include "megbrain/gopt/framework.h" -#include "megbrain/gopt/global_layout_transform.h" +#include "megbrain/gopt/profiler.h" #include "megbrain/graph/event.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/imgproc.h" diff --git a/src/gopt/impl/profiling_based_solver.cpp b/src/gopt/impl/global_layout_transform/profiling_based_solver.cpp similarity index 96% rename from src/gopt/impl/profiling_based_solver.cpp rename to src/gopt/impl/global_layout_transform/profiling_based_solver.cpp index 760d70b97df0cbcf7b56172ca30f220b6a4eade8..c327305edb3da4fa16b3452c41a6e02dbad2a495 100644 --- a/src/gopt/impl/profiling_based_solver.cpp +++ b/src/gopt/impl/global_layout_transform/profiling_based_solver.cpp @@ -10,7 +10,8 @@ * implied. */ -#include "megbrain/gopt/global_layout_transform.h" +#include "megbrain/gopt/profiler.h" +#include "megbrain/gopt/solver.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/imgproc.h" diff --git a/src/gopt/impl/reformat_emitter.cpp b/src/gopt/impl/global_layout_transform/reformat_emitter.cpp similarity index 100% rename from src/gopt/impl/reformat_emitter.cpp rename to src/gopt/impl/global_layout_transform/reformat_emitter.cpp diff --git a/src/gopt/impl/reformat_manager.cpp b/src/gopt/impl/global_layout_transform/reformat_manager.cpp similarity index 100% rename from src/gopt/impl/reformat_manager.cpp rename to src/gopt/impl/global_layout_transform/reformat_manager.cpp diff --git a/src/gopt/impl/subgraph_extractor.cpp b/src/gopt/impl/global_layout_transform/subgraph_extractor.cpp similarity index 100% rename from src/gopt/impl/subgraph_extractor.cpp rename to src/gopt/impl/global_layout_transform/subgraph_extractor.cpp diff --git a/src/gopt/impl/utils.h b/src/gopt/impl/global_layout_transform/utils.h similarity index 98% rename from src/gopt/impl/utils.h rename to src/gopt/impl/global_layout_transform/utils.h index 325f2c7ed8fe5d72eebf21bfe4ae0ba82b8b4e79..620688eb8fe5fbd0df1f45fc5413da8ad8869634 100644 --- a/src/gopt/impl/utils.h +++ b/src/gopt/impl/global_layout_transform/utils.h @@ -11,7 +11,7 @@ */ #pragma once -#include "megbrain/gopt/global_layout_transform.h" +#include "megbrain/gopt/layout_transform_context.h" namespace mgb { namespace gopt { diff --git a/src/gopt/include/megbrain/gopt/global_layout_transform.h b/src/gopt/include/megbrain/gopt/layout_transform_context.h similarity index 56% rename from src/gopt/include/megbrain/gopt/global_layout_transform.h rename to src/gopt/include/megbrain/gopt/layout_transform_context.h index 6e2a55864376c3ada91599148e9dfc10f3a1fefb..bde5b3ce3e765a88dad0c9cc79ba3e9c12cc29d3 100644 --- a/src/gopt/include/megbrain/gopt/global_layout_transform.h +++ b/src/gopt/include/megbrain/gopt/layout_transform_context.h @@ -1,5 +1,6 @@ /** - * \file src/gopt/include/megbrain/gopt/global_layout_transformation.h + * \file + * src/gopt/include/megbrain/gopt/layout_transform_context.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -12,11 +13,10 @@ #pragma once #include "megbrain/gopt/framework.h" +#include "megbrain/gopt/inference.h" #include "megbrain/gopt/reformat_manager.h" #include "megbrain/gopt/subgraph_extractor.h" -#include "megbrain/opr/dnn/convolution.h" #include "megbrain/plugin/opr_footprint.h" -#include "megbrain/gopt/inference.h" namespace mgb { namespace gopt { @@ -118,7 +118,7 @@ public: TensorFormats base_tensor_format = TensorFormats::NCHW); private: - OprList m_opr_list; /// supported operator list + OprList m_opr_list; /// supported operator list SmallVector m_available_tensor_formats; /// the available tensor formats, used /// for format agnostic operators (like @@ -180,164 +180,6 @@ private: const GraphPartition& m_graph_partition; /// the graph partition const LayoutTransformContext& m_ctx; }; - -/*! - * \brief A profiler that collects all the performance data to describe the - * global layout transform problem. - */ -class ProfilerBase { -public: - using OprFormat = Problem::OprFormat; - struct OperatorNodeRecord { - const cg::OperatorNodeBase* opr; ///< pointer to operator node - ThinHashMap - costs; ///< costs of operator node, i.e. the elapsed device - ///< time of the operator node on different opr format - ///< (layout configuration). - std::string to_string() const; - }; - struct VarNodeRecord { - struct KeyHash { - size_t operator()( - const std::pair& val) const { - size_t h1 = - std::hash()(static_cast(val.first)); - size_t h2 = std::hash()( - static_cast(val.second)); - return mgb::hash_pair_combine(h1, h2); - } - }; - const VarNode* var; ///< pointer to var node - std::unordered_map, float, - KeyHash> - costs; ///< costs of var node, i.e. the elapsed - ///< device time of the layout transform. - ///< Key of the hashmap indicates the - ///< source tensor format and the target - ///< tensor format. - std::string to_string() const; - }; - /*! - * \note the profiler assumes all the input and output var node are stored - * in contiguous layout in memory - */ - struct ProfilingResult { - /// A hashmap, that maps the operator node to the costs (device elapsed - /// time) of different layouts configuration - ThinHashMap opr_record; - /// A hashmap, that maps the var node to the costs of layout transform - ThinHashMap var_record; - }; - using OprFilter = thin_function; - using VarNodeFilter = - thin_function; - - ProfilerBase(float opr_threshold = 2.f, float var_node_threshold = 2.f); - ProfilerBase(OprFilter opr_filter, VarNodeFilter var_node_filter = {}) - : m_opr_filter{std::move(opr_filter)}, - m_var_node_filter{std::move(var_node_filter)} {} - virtual ~ProfilerBase() = default; - virtual ProfilingResult profile(const Problem& problem) const = 0; - static std::unique_ptr make_profiler(); - -protected: - OprFilter m_opr_filter; - VarNodeFilter m_var_node_filter; - float m_opr_threshold; - float m_var_node_threshold; - -private: - OprFootprint m_opr_footprint; -}; - -/*! - * \brief abstract solver - */ -class SolverBase { -public: - using OprFormat = Problem::OprFormat; - using Solution = ThinHashMap; - SolverBase() = default; - virtual ~SolverBase() = default; - /*! - * \brief solve the given problem - */ - virtual Solution solve(const Problem& problem) const = 0; - /*! - * \brief check whether the given problem can be solved by the - * algorithm(i.e. solver). - */ - virtual bool can_solve(const Problem& problem) const = 0; -}; - -/*! - * \brief solvers that will first collect the costs of operators in different op - * format and the costs of layout transform of varnode with a user provided - * profiler on the target device. This will lead to time consuming. - */ -class ProfilingBasedSolver : public SolverBase { -public: - using GraphPartitionFilter = - thin_function; - ProfilingBasedSolver(std::unique_ptr profiler); - /*! - * \note some graph partition (for example, graph partition without format - * aware operators like conv, deconv, warp, resize etc.) will be filtered by - * the GraphPartitionFilter, which can reduce the profiling time. */ - ProfilingBasedSolver(std::unique_ptr profiler, - GraphPartitionFilter graph_partition_filter) - : m_profiler{std::move(profiler)}, - m_graph_partition_filter{std::move(graph_partition_filter)} {} - virtual ~ProfilingBasedSolver() = default; - Solution solve(const Problem& problem) const override; - virtual Solution do_solve(const Problem& problem) const = 0; - -protected: - std::unique_ptr m_profiler; - -private: - GraphPartitionFilter m_graph_partition_filter; -}; - -/*! - * \brief A solver that solves the layout selection problem using dynamic - * programming algorithm (Markov decision process). - */ -class DynamicProgrammingSolver final : public ProfilingBasedSolver { -public: - DynamicProgrammingSolver(std::unique_ptr profiler) - : ProfilingBasedSolver(std::move(profiler)){}; - DynamicProgrammingSolver(std::unique_ptr profiler, - GraphPartitionFilter graph_partition_filter) - : ProfilingBasedSolver(std::move(profiler), - std::move(graph_partition_filter)){}; - ~DynamicProgrammingSolver() noexcept = default; - Solution do_solve(const Problem& problem) const override; - bool can_solve(const Problem& problem) const override; - -private: - class Impl; -}; - -/*! - * \brief A layout transform pass, which convert the operator's format to the - * optimal format using the results of the solver. - */ -class LayoutTransformPass final : public Pass { -public: - const char* name() const override { return "layout assignment pass"; } - void apply(OptState& opt) const override; - LayoutTransformPass(std::unique_ptr ctx, - std::unique_ptr solver) - : m_ctx{std::move(ctx)}, m_solver{std::move(solver)} {} - -private: - std::unique_ptr m_ctx; - std::unique_ptr m_solver; -}; - } // namespace gopt } // namespace mgb diff --git a/src/gopt/include/megbrain/gopt/layout_transform_pass.h b/src/gopt/include/megbrain/gopt/layout_transform_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..8dae10bad67f66f3c0e8eacd63e812d8e3027518 --- /dev/null +++ b/src/gopt/include/megbrain/gopt/layout_transform_pass.h @@ -0,0 +1,42 @@ +/** + * \file src/gopt/include/megbrain/gopt/global_layout_transformation.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megbrain/gopt/framework.h" + +namespace mgb { +namespace gopt { + +class LayoutTransformContext; +class SolverBase; + +/*! + * \brief A layout transform pass, which convert the operator's format to the + * optimal format using the results of the solver. + */ +class LayoutTransformPass final : public Pass { +public: + const char* name() const override { return "layout assignment pass"; } + void apply(OptState& opt) const override; + LayoutTransformPass(std::unique_ptr ctx, + std::unique_ptr solver) + : m_ctx{std::move(ctx)}, m_solver{std::move(solver)} {} + +private: + std::unique_ptr m_ctx; + std::unique_ptr m_solver; +}; + +} // namespace gopt +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/include/megbrain/gopt/profiler.h b/src/gopt/include/megbrain/gopt/profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..10df5e085995b401afc7c8cfae2498f32e69bb14 --- /dev/null +++ b/src/gopt/include/megbrain/gopt/profiler.h @@ -0,0 +1,101 @@ +/** + * \file src/gopt/include/megbrain/gopt/global_layout_transformation.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megbrain/gopt/framework.h" +#include "megbrain/gopt/inference.h" +#include "megbrain/gopt/layout_transform_context.h" +#include "megbrain/gopt/reformat_manager.h" +#include "megbrain/gopt/subgraph_extractor.h" +#include "megbrain/opr/dnn/convolution.h" +#include "megbrain/plugin/opr_footprint.h" + +namespace mgb { +namespace gopt { + +class Problem; + +/*! + * \brief A profiler that collects all the performance data to describe the + * global layout transform problem. + */ +class ProfilerBase { +public: + using OprFormat = Problem::OprFormat; + struct OperatorNodeRecord { + const cg::OperatorNodeBase* opr; ///< pointer to operator node + ThinHashMap + costs; ///< costs of operator node, i.e. the elapsed device + ///< time of the operator node on different opr format + ///< (layout configuration). + std::string to_string() const; + }; + struct VarNodeRecord { + struct KeyHash { + size_t operator()( + const std::pair& val) const { + size_t h1 = + std::hash()(static_cast(val.first)); + size_t h2 = std::hash()( + static_cast(val.second)); + return mgb::hash_pair_combine(h1, h2); + } + }; + const VarNode* var; ///< pointer to var node + std::unordered_map, float, + KeyHash> + costs; ///< costs of var node, i.e. the elapsed + ///< device time of the layout transform. + ///< Key of the hashmap indicates the + ///< source tensor format and the target + ///< tensor format. + std::string to_string() const; + }; + /*! + * \note the profiler assumes all the input and output var node are stored + * in contiguous layout in memory + */ + struct ProfilingResult { + /// A hashmap, that maps the operator node to the costs (device elapsed + /// time) of different layouts configuration + ThinHashMap opr_record; + /// A hashmap, that maps the var node to the costs of layout transform + ThinHashMap var_record; + }; + using OprFilter = thin_function; + using VarNodeFilter = + thin_function; + + ProfilerBase(float opr_threshold = 2.f, float var_node_threshold = 2.f); + ProfilerBase(OprFilter opr_filter, VarNodeFilter var_node_filter = {}) + : m_opr_filter{std::move(opr_filter)}, + m_var_node_filter{std::move(var_node_filter)} {} + virtual ~ProfilerBase() = default; + virtual ProfilingResult profile(const Problem& problem) const = 0; + static std::unique_ptr make_profiler(); + +protected: + OprFilter m_opr_filter; + VarNodeFilter m_var_node_filter; + float m_opr_threshold; + float m_var_node_threshold; + +private: + OprFootprint m_opr_footprint; +}; + +} // namespace gopt +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/include/megbrain/gopt/solver.h b/src/gopt/include/megbrain/gopt/solver.h new file mode 100644 index 0000000000000000000000000000000000000000..911a6305ee7a0473c54e5c731628f6bc930a895b --- /dev/null +++ b/src/gopt/include/megbrain/gopt/solver.h @@ -0,0 +1,97 @@ +/** + * \file src/gopt/include/megbrain/gopt/solver.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megbrain/gopt/framework.h" +#include "megbrain/gopt/layout_transform_context.h" +#include "megbrain/opr/dnn/convolution.h" +#include "megbrain/plugin/opr_footprint.h" +#include "megbrain/gopt/inference.h" + +namespace mgb { +namespace gopt { + +class ProfilerBase; + +/*! + * \brief abstract solver + */ +class SolverBase { +public: + using OprFormat = Problem::OprFormat; + using Solution = ThinHashMap; + SolverBase() = default; + virtual ~SolverBase() = default; + /*! + * \brief solve the given problem + */ + virtual Solution solve(const Problem& problem) const = 0; + /*! + * \brief check whether the given problem can be solved by the + * algorithm(i.e. solver). + */ + virtual bool can_solve(const Problem& problem) const = 0; +}; + +/*! + * \brief solvers that will first collect the costs of operators in different op + * format and the costs of layout transform of varnode with a user provided + * profiler on the target device. This will lead to time consuming. + */ +class ProfilingBasedSolver : public SolverBase { +public: + using GraphPartitionFilter = + thin_function; + ProfilingBasedSolver(std::unique_ptr profiler); + /*! + * \note some graph partition (for example, graph partition without format + * aware operators like conv, deconv, warp, resize etc.) will be filtered by + * the GraphPartitionFilter, which can reduce the profiling time. */ + ProfilingBasedSolver(std::unique_ptr profiler, + GraphPartitionFilter graph_partition_filter) + : m_profiler{std::move(profiler)}, + m_graph_partition_filter{std::move(graph_partition_filter)} {} + virtual ~ProfilingBasedSolver() = default; + Solution solve(const Problem& problem) const override; + virtual Solution do_solve(const Problem& problem) const = 0; + +protected: + std::unique_ptr m_profiler; + +private: + GraphPartitionFilter m_graph_partition_filter; +}; + +/*! + * \brief A solver that solves the layout selection problem using dynamic + * programming algorithm (Markov decision process). + */ +class DynamicProgrammingSolver final : public ProfilingBasedSolver { +public: + DynamicProgrammingSolver(std::unique_ptr profiler) + : ProfilingBasedSolver(std::move(profiler)){}; + DynamicProgrammingSolver(std::unique_ptr profiler, + GraphPartitionFilter graph_partition_filter) + : ProfilingBasedSolver(std::move(profiler), + std::move(graph_partition_filter)){}; + ~DynamicProgrammingSolver() noexcept = default; + Solution do_solve(const Problem& problem) const override; + bool can_solve(const Problem& problem) const override; + +private: + class Impl; +}; + +} // namespace gopt +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/test/layout_transform_pass.cpp b/src/gopt/test/layout_transform_pass.cpp index 40332feb083a94c9cf3db0963f1eac6957e35206..d43a35b988aad3155b8d032a5141ea6a22c54486 100644 --- a/src/gopt/test/layout_transform_pass.cpp +++ b/src/gopt/test/layout_transform_pass.cpp @@ -10,10 +10,13 @@ * implied. */ +#include "megbrain/gopt/layout_transform_pass.h" #include "./network.h" #include "megbrain/comp_node_env.h" -#include "megbrain/gopt/global_layout_transform.h" #include "megbrain/gopt/inference.h" +#include "megbrain/gopt/layout_transform_context.h" +#include "megbrain/gopt/profiler.h" +#include "megbrain/gopt/solver.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/imgproc.h" #include "megbrain/opr/nn_int.h" diff --git a/src/gopt/test/profiler.cpp b/src/gopt/test/profiler.cpp index ae153c16f6e1e641b43da12292e5141fc3ec1c7e..cb84fe1da2e0dd5db507a6915516d7da40cf6668 100644 --- a/src/gopt/test/profiler.cpp +++ b/src/gopt/test/profiler.cpp @@ -12,7 +12,7 @@ #include "megbrain/plugin/profiler.h" #include "./helper.h" -#include "megbrain/gopt/global_layout_transform.h" +#include "megbrain/gopt/profiler.h" #include "megbrain/gopt/inference.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/imgproc.h"