提交 19d7412a 编写于 作者: M Megvii Engine Team

refactor(mgb/gopt): reorganize code of global layout transform

GitOrigin-RevId: 4973820e0269162c4c13d98bf13566e24a98144b
上级 8ef12bdf
......@@ -30,7 +30,10 @@
#include "megbrain/tensorrt/opr_replace.h"
#endif
#include "megbrain/gopt/global_layout_transform.h"
#include "megbrain/gopt/layout_transform_context.h"
#include "megbrain/gopt/layout_transform_pass.h"
#include "megbrain/gopt/profiler.h"
#include "megbrain/gopt/solver.h"
using namespace mgb;
using namespace gopt;
......
......@@ -12,7 +12,9 @@
#include <queue>
#include "./utils.h"
#include "megbrain/gopt/global_layout_transform.h"
#include "megbrain/gopt/layout_transform_context.h"
#include "megbrain/gopt/profiler.h"
#include "megbrain/gopt/solver.h"
using namespace mgb;
using namespace gopt;
......@@ -85,11 +87,11 @@ private:
const SmallVector<TensorFormats>& available_tensor_formats;
};
/*!
* \brief get the tensor formats configuration for the operator with particular op format
* \param[out] var2fmts hashmap that maps varnode to actual tensor formats of the op format configuration
* \param[in] opr given operator
* \param[in] opr_fmt given op format, an enum type argument which indicates the op format configuration.
* \param[in] ctx context
* \brief get the tensor formats configuration for the operator with
* particular op format \param[out] var2fmts hashmap that maps varnode to
* actual tensor formats of the op format configuration \param[in] opr given
* operator \param[in] opr_fmt given op format, an enum type argument which
* indicates the op format configuration. \param[in] ctx context
*/
TensorFormats get_io_formats(ThinHashMap<VarNode*, TensorFormats>& var2fmts,
const OperatorNodeBase* opr, OprFormat opr_fmt,
......
......@@ -11,7 +11,7 @@
*/
#include "./utils.h"
#include "megbrain/gopt/global_layout_transform.h"
#include "megbrain/gopt/layout_transform_context.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h"
......
......@@ -10,9 +10,11 @@
* implied.
*/
#include "megbrain/gopt/layout_transform_pass.h"
#include "./opr_format_modifier.h"
#include "./utils.h"
#include "megbrain/gopt/global_layout_transform.h"
#include "megbrain/gopt/profiler.h"
#include "megbrain/gopt/solver.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/serialization/sereg.h"
......@@ -46,8 +48,7 @@ void LayoutTransformPass::apply(OptState& opt) const {
auto&& opr_configs = m_ctx->opr_configs();
auto&& base_fmt = m_ctx->attribute().base_tensor_formats;
auto&& reformat_attribute =
ReformatManager::ReformatKey::Attribute::DEFAULT;
auto&& reformat_attribute = m_ctx->attribute().reformat_attribute;
ThinHashMap<VarNode*, TensorFormats> var2fmts;
static ThinHashSet<Typeinfo*> format_aware_oprs = {
#define cb(_Opr) opr::_Opr::typeinfo(),
......@@ -55,8 +56,8 @@ void LayoutTransformPass::apply(OptState& opt) const {
#undef cb
};
auto rewriter = opt.graph().make_rewriter();
auto on_opr = [this, &opr_configs, &base_fmt, &reformat_attribute,
&rewriter, &solution, &var2fmts,
auto on_opr = [&opr_configs, &base_fmt, &reformat_attribute, &rewriter,
&solution, &var2fmts,
&endpoint_vars](OperatorNodeBase* opr) {
auto it = solution.find(opr);
if (it != solution.end()) {
......@@ -122,19 +123,6 @@ void LayoutTransformPass::apply(OptState& opt) const {
opr->config())
->output(0);
}
if (endpoint_vars.count(opr->output(0)) && out_fmt != base_fmt) {
ReformatManager::ReformatKey key{
out_fmt, base_fmt, reformat_attribute,
opr->output(0)->dtype().enumv(),
opr->output(0)->dtype().enumv()};
auto reformat = ReformatManager::instance()
.auto_aligned_reformat_featrue(
opr->output(0), base_fmt, key);
new_out = reformat({new_out});
var2fmts[new_out] = base_fmt;
} else {
var2fmts[new_out] = out_fmt;
}
auto &&out0 = opr->output(),
&&out1 = new_out->owner_opr()->output();
mgb_assert(opr->usable_output().size() ==
......@@ -146,20 +134,29 @@ void LayoutTransformPass::apply(OptState& opt) const {
new_out->owner_opr()->cname(),
new_out->owner_opr()->dyn_typeinfo()->name, out0.size(),
out1.size());
for (size_t i = 0; i < out0.size(); ++i) {
if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
mgb_assert(!out1[i]->contain_flag(
VarNode::Flag::VOLATILE_CONTENT));
auto src = out0[i];
auto dst = out1[i];
rewriter.replace_var(
src, dst,
mgb_cstr_log(ssprintf("replace opr(%s) to new opr "
"format(%s)",
opr->cname(),
opr_format_to_string(opr_fmt))
.c_str()));
size_t nr_outs = opr->usable_output().size();
for (size_t i = 0; i < nr_outs; ++i) {
const auto& ovar = out0[i];
auto new_ovar = out1[i];
if (endpoint_vars.count(ovar) && out_fmt != base_fmt) {
ReformatManager::ReformatKey key{
out_fmt, base_fmt, reformat_attribute,
ovar->dtype().enumv(), ovar->dtype().enumv()};
auto reformat = ReformatManager::instance()
.auto_aligned_reformat_featrue(
ovar, base_fmt, key);
new_ovar = reformat({new_ovar});
var2fmts[new_ovar] = base_fmt;
} else {
var2fmts[new_ovar] = out_fmt;
}
rewriter.replace_var(
ovar, new_ovar,
mgb_cstr_log(ssprintf("replace opr(%s) to new opr "
"format(%s)",
opr->cname(),
opr_format_to_string(opr_fmt))
.c_str()));
}
} else {
auto new_opr = rewriter.auto_replace_outputs(opr);
......
......@@ -11,7 +11,7 @@
*/
#include "./utils.h"
#include "megbrain/gopt/global_layout_transform.h"
#include "megbrain/gopt/layout_transform_context.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
......
......@@ -13,7 +13,7 @@
#include "./opr_format_modifier.h"
#include "./utils.h"
#include "megbrain/gopt/framework.h"
#include "megbrain/gopt/global_layout_transform.h"
#include "megbrain/gopt/profiler.h"
#include "megbrain/graph/event.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
......
......@@ -10,7 +10,8 @@
* implied.
*/
#include "megbrain/gopt/global_layout_transform.h"
#include "megbrain/gopt/profiler.h"
#include "megbrain/gopt/solver.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
......
......@@ -11,7 +11,7 @@
*/
#pragma once
#include "megbrain/gopt/global_layout_transform.h"
#include "megbrain/gopt/layout_transform_context.h"
namespace mgb {
namespace gopt {
......
/**
* \file src/gopt/include/megbrain/gopt/global_layout_transformation.h
* \file
* src/gopt/include/megbrain/gopt/layout_transform_context.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -12,11 +13,10 @@
#pragma once
#include "megbrain/gopt/framework.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/reformat_manager.h"
#include "megbrain/gopt/subgraph_extractor.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/plugin/opr_footprint.h"
#include "megbrain/gopt/inference.h"
namespace mgb {
namespace gopt {
......@@ -118,7 +118,7 @@ public:
TensorFormats base_tensor_format = TensorFormats::NCHW);
private:
OprList m_opr_list; /// supported operator list
OprList m_opr_list; /// supported operator list
SmallVector<TensorFormats>
m_available_tensor_formats; /// the available tensor formats, used
/// for format agnostic operators (like
......@@ -180,164 +180,6 @@ private:
const GraphPartition& m_graph_partition; /// the graph partition
const LayoutTransformContext& m_ctx;
};
/*!
* \brief A profiler that collects all the performance data to describe the
* global layout transform problem.
*/
class ProfilerBase {
public:
using OprFormat = Problem::OprFormat;
struct OperatorNodeRecord {
const cg::OperatorNodeBase* opr; ///< pointer to operator node
ThinHashMap<OprFormat, float>
costs; ///< costs of operator node, i.e. the elapsed device
///< time of the operator node on different opr format
///< (layout configuration).
std::string to_string() const;
};
struct VarNodeRecord {
struct KeyHash {
size_t operator()(
const std::pair<TensorFormats, TensorFormats>& val) const {
size_t h1 =
std::hash<uint32_t>()(static_cast<uint32_t>(val.first));
size_t h2 = std::hash<uint32_t>()(
static_cast<uint32_t>(val.second));
return mgb::hash_pair_combine(h1, h2);
}
};
const VarNode* var; ///< pointer to var node
std::unordered_map<std::pair<TensorFormats, TensorFormats>, float,
KeyHash>
costs; ///< costs of var node, i.e. the elapsed
///< device time of the layout transform.
///< Key of the hashmap indicates the
///< source tensor format and the target
///< tensor format.
std::string to_string() const;
};
/*!
* \note the profiler assumes all the input and output var node are stored
* in contiguous layout in memory
*/
struct ProfilingResult {
/// A hashmap, that maps the operator node to the costs (device elapsed
/// time) of different layouts configuration
ThinHashMap<cg::OperatorNodeBase*, OperatorNodeRecord> opr_record;
/// A hashmap, that maps the var node to the costs of layout transform
ThinHashMap<VarNode*, VarNodeRecord> var_record;
};
using OprFilter = thin_function<bool(const cg::OperatorNodeBase*,
cg::OperatorNodeBase*)>;
using VarNodeFilter =
thin_function<bool(const VarNode*, TensorShape, TensorShape,
ReformatManager::ReformatKey)>;
ProfilerBase(float opr_threshold = 2.f, float var_node_threshold = 2.f);
ProfilerBase(OprFilter opr_filter, VarNodeFilter var_node_filter = {})
: m_opr_filter{std::move(opr_filter)},
m_var_node_filter{std::move(var_node_filter)} {}
virtual ~ProfilerBase() = default;
virtual ProfilingResult profile(const Problem& problem) const = 0;
static std::unique_ptr<ProfilerBase> make_profiler();
protected:
OprFilter m_opr_filter;
VarNodeFilter m_var_node_filter;
float m_opr_threshold;
float m_var_node_threshold;
private:
OprFootprint m_opr_footprint;
};
/*!
* \brief abstract solver
*/
class SolverBase {
public:
using OprFormat = Problem::OprFormat;
using Solution = ThinHashMap<cg::OperatorNodeBase*, OprFormat>;
SolverBase() = default;
virtual ~SolverBase() = default;
/*!
* \brief solve the given problem
*/
virtual Solution solve(const Problem& problem) const = 0;
/*!
* \brief check whether the given problem can be solved by the
* algorithm(i.e. solver).
*/
virtual bool can_solve(const Problem& problem) const = 0;
};
/*!
* \brief solvers that will first collect the costs of operators in different op
* format and the costs of layout transform of varnode with a user provided
* profiler on the target device. This will lead to time consuming.
*/
class ProfilingBasedSolver : public SolverBase {
public:
using GraphPartitionFilter =
thin_function<bool(const GraphPartition& graph_partition)>;
ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler);
/*!
* \note some graph partition (for example, graph partition without format
* aware operators like conv, deconv, warp, resize etc.) will be filtered by
* the GraphPartitionFilter, which can reduce the profiling time. */
ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler,
GraphPartitionFilter graph_partition_filter)
: m_profiler{std::move(profiler)},
m_graph_partition_filter{std::move(graph_partition_filter)} {}
virtual ~ProfilingBasedSolver() = default;
Solution solve(const Problem& problem) const override;
virtual Solution do_solve(const Problem& problem) const = 0;
protected:
std::unique_ptr<ProfilerBase> m_profiler;
private:
GraphPartitionFilter m_graph_partition_filter;
};
/*!
* \brief A solver that solves the layout selection problem using dynamic
* programming algorithm (Markov decision process).
*/
class DynamicProgrammingSolver final : public ProfilingBasedSolver {
public:
DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler)
: ProfilingBasedSolver(std::move(profiler)){};
DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler,
GraphPartitionFilter graph_partition_filter)
: ProfilingBasedSolver(std::move(profiler),
std::move(graph_partition_filter)){};
~DynamicProgrammingSolver() noexcept = default;
Solution do_solve(const Problem& problem) const override;
bool can_solve(const Problem& problem) const override;
private:
class Impl;
};
/*!
* \brief A layout transform pass, which convert the operator's format to the
* optimal format using the results of the solver.
*/
class LayoutTransformPass final : public Pass {
public:
const char* name() const override { return "layout assignment pass"; }
void apply(OptState& opt) const override;
LayoutTransformPass(std::unique_ptr<LayoutTransformContext> ctx,
std::unique_ptr<SolverBase> solver)
: m_ctx{std::move(ctx)}, m_solver{std::move(solver)} {}
private:
std::unique_ptr<LayoutTransformContext> m_ctx;
std::unique_ptr<SolverBase> m_solver;
};
} // namespace gopt
} // namespace mgb
......
/**
* \file src/gopt/include/megbrain/gopt/global_layout_transformation.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megbrain/gopt/framework.h"
namespace mgb {
namespace gopt {
class LayoutTransformContext;
class SolverBase;
/*!
* \brief A layout transform pass, which convert the operator's format to the
* optimal format using the results of the solver.
*/
class LayoutTransformPass final : public Pass {
public:
const char* name() const override { return "layout assignment pass"; }
void apply(OptState& opt) const override;
LayoutTransformPass(std::unique_ptr<LayoutTransformContext> ctx,
std::unique_ptr<SolverBase> solver)
: m_ctx{std::move(ctx)}, m_solver{std::move(solver)} {}
private:
std::unique_ptr<LayoutTransformContext> m_ctx;
std::unique_ptr<SolverBase> m_solver;
};
} // namespace gopt
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file src/gopt/include/megbrain/gopt/global_layout_transformation.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megbrain/gopt/framework.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/layout_transform_context.h"
#include "megbrain/gopt/reformat_manager.h"
#include "megbrain/gopt/subgraph_extractor.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/plugin/opr_footprint.h"
namespace mgb {
namespace gopt {
class Problem;
/*!
* \brief A profiler that collects all the performance data to describe the
* global layout transform problem.
*/
class ProfilerBase {
public:
using OprFormat = Problem::OprFormat;
struct OperatorNodeRecord {
const cg::OperatorNodeBase* opr; ///< pointer to operator node
ThinHashMap<OprFormat, float>
costs; ///< costs of operator node, i.e. the elapsed device
///< time of the operator node on different opr format
///< (layout configuration).
std::string to_string() const;
};
struct VarNodeRecord {
struct KeyHash {
size_t operator()(
const std::pair<TensorFormats, TensorFormats>& val) const {
size_t h1 =
std::hash<uint32_t>()(static_cast<uint32_t>(val.first));
size_t h2 = std::hash<uint32_t>()(
static_cast<uint32_t>(val.second));
return mgb::hash_pair_combine(h1, h2);
}
};
const VarNode* var; ///< pointer to var node
std::unordered_map<std::pair<TensorFormats, TensorFormats>, float,
KeyHash>
costs; ///< costs of var node, i.e. the elapsed
///< device time of the layout transform.
///< Key of the hashmap indicates the
///< source tensor format and the target
///< tensor format.
std::string to_string() const;
};
/*!
* \note the profiler assumes all the input and output var node are stored
* in contiguous layout in memory
*/
struct ProfilingResult {
/// A hashmap, that maps the operator node to the costs (device elapsed
/// time) of different layouts configuration
ThinHashMap<cg::OperatorNodeBase*, OperatorNodeRecord> opr_record;
/// A hashmap, that maps the var node to the costs of layout transform
ThinHashMap<VarNode*, VarNodeRecord> var_record;
};
using OprFilter = thin_function<bool(const cg::OperatorNodeBase*,
cg::OperatorNodeBase*)>;
using VarNodeFilter =
thin_function<bool(const VarNode*, TensorShape, TensorShape,
ReformatManager::ReformatKey)>;
ProfilerBase(float opr_threshold = 2.f, float var_node_threshold = 2.f);
ProfilerBase(OprFilter opr_filter, VarNodeFilter var_node_filter = {})
: m_opr_filter{std::move(opr_filter)},
m_var_node_filter{std::move(var_node_filter)} {}
virtual ~ProfilerBase() = default;
virtual ProfilingResult profile(const Problem& problem) const = 0;
static std::unique_ptr<ProfilerBase> make_profiler();
protected:
OprFilter m_opr_filter;
VarNodeFilter m_var_node_filter;
float m_opr_threshold;
float m_var_node_threshold;
private:
OprFootprint m_opr_footprint;
};
} // namespace gopt
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file src/gopt/include/megbrain/gopt/solver.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megbrain/gopt/framework.h"
#include "megbrain/gopt/layout_transform_context.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/plugin/opr_footprint.h"
#include "megbrain/gopt/inference.h"
namespace mgb {
namespace gopt {
class ProfilerBase;
/*!
* \brief abstract solver
*/
class SolverBase {
public:
using OprFormat = Problem::OprFormat;
using Solution = ThinHashMap<cg::OperatorNodeBase*, OprFormat>;
SolverBase() = default;
virtual ~SolverBase() = default;
/*!
* \brief solve the given problem
*/
virtual Solution solve(const Problem& problem) const = 0;
/*!
* \brief check whether the given problem can be solved by the
* algorithm(i.e. solver).
*/
virtual bool can_solve(const Problem& problem) const = 0;
};
/*!
* \brief solvers that will first collect the costs of operators in different op
* format and the costs of layout transform of varnode with a user provided
* profiler on the target device. This will lead to time consuming.
*/
class ProfilingBasedSolver : public SolverBase {
public:
using GraphPartitionFilter =
thin_function<bool(const GraphPartition& graph_partition)>;
ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler);
/*!
* \note some graph partition (for example, graph partition without format
* aware operators like conv, deconv, warp, resize etc.) will be filtered by
* the GraphPartitionFilter, which can reduce the profiling time. */
ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profiler,
GraphPartitionFilter graph_partition_filter)
: m_profiler{std::move(profiler)},
m_graph_partition_filter{std::move(graph_partition_filter)} {}
virtual ~ProfilingBasedSolver() = default;
Solution solve(const Problem& problem) const override;
virtual Solution do_solve(const Problem& problem) const = 0;
protected:
std::unique_ptr<ProfilerBase> m_profiler;
private:
GraphPartitionFilter m_graph_partition_filter;
};
/*!
* \brief A solver that solves the layout selection problem using dynamic
* programming algorithm (Markov decision process).
*/
class DynamicProgrammingSolver final : public ProfilingBasedSolver {
public:
DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler)
: ProfilingBasedSolver(std::move(profiler)){};
DynamicProgrammingSolver(std::unique_ptr<ProfilerBase> profiler,
GraphPartitionFilter graph_partition_filter)
: ProfilingBasedSolver(std::move(profiler),
std::move(graph_partition_filter)){};
~DynamicProgrammingSolver() noexcept = default;
Solution do_solve(const Problem& problem) const override;
bool can_solve(const Problem& problem) const override;
private:
class Impl;
};
} // namespace gopt
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -10,10 +10,13 @@
* implied.
*/
#include "megbrain/gopt/layout_transform_pass.h"
#include "./network.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/gopt/global_layout_transform.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/layout_transform_context.h"
#include "megbrain/gopt/profiler.h"
#include "megbrain/gopt/solver.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h"
......
......@@ -12,7 +12,7 @@
#include "megbrain/plugin/profiler.h"
#include "./helper.h"
#include "megbrain/gopt/global_layout_transform.h"
#include "megbrain/gopt/profiler.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册