提交 e24fcd00 编写于 作者: M Megvii Engine Team

refactor(gopt): use graphcommonoptimizeoptions for graphopt

GitOrigin-RevId: dd8a93813ae7885bdc23e43f197a86c19e25ddc2
上级 e080dd3c
...@@ -83,7 +83,7 @@ R"__usage__( ...@@ -83,7 +83,7 @@ R"__usage__(
hard to profile host time. Use --profile-host to focus on host time hard to profile host time. Use --profile-host to focus on host time
profiling. profiling.
--input [ filepath | string] --input [ filepath | string]
Set up inputs for megbrain model. for example: --data image.ppm --data Set up inputs for megbrain model. for example: --data image.ppm --data
param.json --data bbox:bbox.npy@batchid:b.npy --data rect:[0,0,227,227]; param.json --data bbox:bbox.npy@batchid:b.npy --data rect:[0,0,227,227];
batchid:0,1,2,3. --io-dump or --bin-io-dump batchid:0,1,2,3. --io-dump or --bin-io-dump
should be enabled at the same time. should be enabled at the same time.
...@@ -974,7 +974,7 @@ Args Args::from_argv(int argc, char **argv) { ...@@ -974,7 +974,7 @@ Args Args::from_argv(int argc, char **argv) {
#endif #endif
if (!strcmp(argv[i], "--enable-chwn4")) { if (!strcmp(argv[i], "--enable-chwn4")) {
mgb_log_warn("enable chwn4 optimization"); mgb_log_warn("enable chwn4 optimization");
graph_opt.graph_opt.enable_chwn4 = true; graph_opt.graph_opt.enable_chwn4();
continue; continue;
} }
#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "megbrain/gopt/inference.h" #include "megbrain/gopt/inference.h"
#include "megbrain/gopt/basic_arith.h" #include "megbrain/gopt/basic_arith.h"
#include "megbrain/gopt/misc.h" #include "megbrain/gopt/misc.h"
#include "megbrain/graph/cg.h"
#include "megbrain/graph/event.h" #include "megbrain/graph/event.h"
#include "megbrain/graph/exc_extra_info.h" #include "megbrain/graph/exc_extra_info.h"
#include "megbrain/graph/helper.h" #include "megbrain/graph/helper.h"
...@@ -457,14 +458,17 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( ...@@ -457,14 +458,17 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
} }
#endif #endif
if (options().graph_opt.enable_chwn4) {
options().graph_opt.enable_chwn4 = false;
gopt::reformat_to_chwn4_transform_dest_vars_inplace(dest_vars);
}
if (options().graph_opt.winograd_transform) { if (options().graph_opt.winograd_transform) {
options().graph_opt.winograd_transform = false; options().graph_opt.winograd_transform = false;
gopt::transform_vars_inplace_with_winograd(dest_vars); gopt::transform_vars_inplace_with_winograd(dest_vars);
} }
if (options().graph_opt.transform_chwn4()) {
gopt::GraphOptimizer optimizer;
optimizer.apply_optimize_options(options().graph_opt);
options().graph_opt.layout_transform =
cg::GraphCommonOptimizeOptions::LayoutTransform::DEFAULT;
optimizer.apply_inplace(dest_vars);
}
#if MGB_JIT #if MGB_JIT
if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) {
......
...@@ -81,6 +81,59 @@ public: ...@@ -81,6 +81,59 @@ public:
virtual size_t static_alloc_version(ComputingGraph* graph) const; virtual size_t static_alloc_version(ComputingGraph* graph) const;
}; };
/**
* \brief common optimize options, it both can be used for optimize for
* inference in graph dump but also used in graph optimization in runtime.
*/
struct GraphCommonOptimizeOptions {
//! whether to enable IO in float16 compute in float32
bool f16_io_f32_comp = false;
//! whether to enable tranform to pure float16 model
bool f16_io_comp = false;
//! whether to enable conv bias nonlinearity fusion
bool fuse_conv_bias_nonlinearity = false;
enum LayoutTransform : uint32_t {
DEFAULT,
NHWCD4, ///< compute using NHWCD4 tensor format
NCHW88, ///< compute using NCHW88 tensor format
NCHW44, ///< compute using NCHW44 tensor format
NCHW32, ///< compute using NCHW32 tensor format, used for
///< tensorcore
CHWN4, ///< compute using CHWN4 tensor format, transformed mainly
///< used for cuda
};
LayoutTransform layout_transform = LayoutTransform::DEFAULT;
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b)
//! + z -> conv_bias(x, w, b, z)
bool fuse_conv_bias_with_z = false;
#define SET(n) \
GraphCommonOptimizeOptions& enable_##n() { \
n = true; \
return *this; \
}
SET(f16_io_f32_comp);
SET(f16_io_comp);
SET(fuse_conv_bias_nonlinearity);
SET(fuse_conv_bias_with_z);
#undef SET
#define SET(_trans, _trans_capital) \
GraphCommonOptimizeOptions& enable_##_trans() { \
layout_transform = LayoutTransform::_trans_capital; \
return *this; \
} \
bool transform_##_trans() const { \
return layout_transform == LayoutTransform::_trans_capital; \
}
SET(nhwcd4, NHWCD4);
SET(nchw88, NCHW88);
SET(nchw44, NCHW44);
SET(nchw32, NCHW32);
SET(chwn4, CHWN4);
#undef SET
};
/*! /*!
* \brief Computing graph. * \brief Computing graph.
* *
...@@ -232,7 +285,7 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, ...@@ -232,7 +285,7 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
} seq_opt; } seq_opt;
//! graph optimization options //! graph optimization options
struct GraphOpt { struct GraphOpt : GraphCommonOptimizeOptions {
//! whether to enable JIT; JIT would also be enabled at O3 //! whether to enable JIT; JIT would also be enabled at O3
//! this value indicates JIT level: 1 for basic elemwise opr; 2 //! this value indicates JIT level: 1 for basic elemwise opr; 2
//! for including reduce oprs //! for including reduce oprs
...@@ -241,8 +294,6 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, ...@@ -241,8 +294,6 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
bool tensorrt = false; bool tensorrt = false;
//! whether to enable fast-run profiled winograd opr replace //! whether to enable fast-run profiled winograd opr replace
bool winograd_transform = false; bool winograd_transform = false;
//! whether to enable nchw4->chwn4 opr replace
bool enable_chwn4 = false;
} graph_opt; } graph_opt;
//! get attribute for an operator //! get attribute for an operator
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "megbrain/gopt/basic_arith.h" #include "megbrain/gopt/basic_arith.h"
#include "megbrain/gopt/misc.h" #include "megbrain/gopt/misc.h"
#include "megbrain/gopt/gtrans.h" #include "megbrain/gopt/gtrans.h"
#include "megbrain/graph/cg.h"
#include "megbrain/graph/event.h" #include "megbrain/graph/event.h"
#include "megbrain/graph/exc_extra_info.h" #include "megbrain/graph/exc_extra_info.h"
#include "megbrain/serialization/serializer.h" #include "megbrain/serialization/serializer.h"
...@@ -672,7 +673,11 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( ...@@ -672,7 +673,11 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
} }
#endif #endif
apply_optimize_options(inference_opt); if (inference_opt) {
add_pass<ParamFusePass>();
apply_optimize_options(*inference_opt);
}
if (inference_opt) { if (inference_opt) {
// merge params to reduce loading time and graph overhead // merge params to reduce loading time and graph overhead
...@@ -699,32 +704,32 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) { ...@@ -699,32 +704,32 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) {
} }
} }
void GraphOptimizer::apply_optimize_options( const GraphOptimizer& GraphOptimizer::apply_optimize_options(
const OptimizeOptions* options) { const cg::GraphCommonOptimizeOptions& options) {
if (!options) return; if (options.f16_io_comp) {
if (options->f16_io_comp) {
add_pass(ConvertF32ToF16Pass::make(false)); add_pass(ConvertF32ToF16Pass::make(false));
} }
if (options->f16_io_f32_comp) { if (options.f16_io_f32_comp) {
add_pass(ConvertF32ToF16Pass::make(true)); add_pass(ConvertF32ToF16Pass::make(true));
} }
if (options->transform_nhwcd4()) { if (options.transform_nhwcd4()) {
add_pass(ConvertFormatPass::make_nhwcd4_converter()); add_pass(ConvertFormatPass::make_nhwcd4_converter());
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
} }
if (options->transform_nchw88()) { if (options.transform_nchw88()) {
add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); add_pass(EnableNchwxxPass::make_nchwxx_converter(8));
} }
if (options->transform_nchw44()) { if (options.transform_nchw44()) {
add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); add_pass(EnableNchwxxPass::make_nchwxx_converter(4));
} }
if (options->transform_nchw32()) { if (options.transform_nchw32()) {
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
add_pass(EnableTensorCorePass::make_tensorcore_converter()); add_pass(EnableTensorCorePass::make_tensorcore_converter());
add_pass<ShuffleShuffleRemovePass>(); add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>(); add_pass<RemoveRedundantTypeCvtPass>();
} }
if (options->transform_chwn4()) { if (options.transform_chwn4()) {
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>(); add_pass<FuseConvBiasZPass>();
add_pass(EnableCHWN4Pass::make_chwn4_converter()); add_pass(EnableCHWN4Pass::make_chwn4_converter());
...@@ -732,14 +737,15 @@ void GraphOptimizer::apply_optimize_options( ...@@ -732,14 +737,15 @@ void GraphOptimizer::apply_optimize_options(
add_pass<RemoveRedundantTypeCvtPass>(); add_pass<RemoveRedundantTypeCvtPass>();
} }
if (options->fuse_conv_bias_nonlinearity) { if (options.fuse_conv_bias_nonlinearity) {
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
} }
if (options->fuse_conv_bias_with_z) { if (options.fuse_conv_bias_with_z) {
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>(); add_pass<FuseConvBiasZPass>();
} }
add_pass<ParamFusePass>(); add_pass<ParamFusePass>();
return *this;
} }
/* ================ ConstVarPropogateBase ================ */ /* ================ ConstVarPropogateBase ================ */
......
...@@ -2215,16 +2215,4 @@ void ShuffleShuffleRemovePass::apply(OptState& opt) const { ...@@ -2215,16 +2215,4 @@ void ShuffleShuffleRemovePass::apply(OptState& opt) const {
Impl{opt}; Impl{opt};
} }
void gopt::reformat_to_chwn4_transform_dest_vars_inplace(
mgb::cg::VarNodeArray& dest_vars) {
gopt::GraphOptimizer optimizer;
optimizer.add_pass<FuseConvBiasNonlinPass>();
optimizer.add_pass<FuseConvBiasZPass>();
optimizer.add_pass(EnableCHWN4Pass::make_chwn4_converter());
optimizer.add_pass<ShuffleShuffleRemovePass>();
optimizer.add_pass<RemoveRedundantTypeCvtPass>();
optimizer.add_pass<ParamFusePass>();
optimizer.apply_inplace(dest_vars);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "megbrain/graph.h" #include "megbrain/graph.h"
#include "megbrain/gopt/gtrans.h" #include "megbrain/gopt/gtrans.h"
#include "megbrain/graph/cg.h"
namespace mgb { namespace mgb {
namespace gopt { namespace gopt {
...@@ -377,60 +378,6 @@ namespace gopt { ...@@ -377,60 +378,6 @@ namespace gopt {
RecursiveSubGraphRewriteHelper(OptState &state); RecursiveSubGraphRewriteHelper(OptState &state);
}; };
/**
* \brief common optimize options, it both can be used for optimize for
* inference in graph dump but also used in graph optimization in runtime.
*/
struct OptimizeOptions {
//! whether to enable IO in float16 compute in float32
bool f16_io_f32_comp = false;
//! whether to enable tranform to pure float16 model
bool f16_io_comp = false;
//! whether to enable conv bias nonlinearity fusion
bool fuse_conv_bias_nonlinearity = false;
enum LayoutTransform : uint32_t {
DEFAULT,
NHWCD4, ///< compute using NHWCD4 tensor format
NCHW88, ///< compute using NCHW88 tensor format
NCHW44, ///< compute using NCHW44 tensor format
NCHW32, ///< compute using NCHW32 tensor format, used for
///< tensorcore
CHWN4, ///< compute using CHWN4 tensor format, transformed mainly
///< used for cuda
};
LayoutTransform layout_transform = LayoutTransform::DEFAULT;
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b)
//! + z -> conv_bias(x, w, b, z)
bool fuse_conv_bias_with_z = false;
#define SET(n) \
OptimizeOptions& enable_##n() { \
n = true; \
return *this; \
}
SET(f16_io_f32_comp);
SET(f16_io_comp);
SET(fuse_conv_bias_nonlinearity);
SET(fuse_conv_bias_with_z);
#undef SET
#define SET(_trans, _trans_capital) \
OptimizeOptions& enable_##_trans() { \
layout_transform = LayoutTransform::_trans_capital; \
return *this; \
} \
bool transform_##_trans() const { \
return layout_transform == LayoutTransform::_trans_capital; \
}
SET(nhwcd4, NHWCD4);
SET(nchw88, NCHW88);
SET(nchw44, NCHW44);
SET(nchw32, NCHW32);
SET(chwn4, CHWN4);
#undef SET
};
/*! /*!
* \brief manage passes and their applying on graphs * \brief manage passes and their applying on graphs
* *
...@@ -523,7 +470,8 @@ namespace gopt { ...@@ -523,7 +470,8 @@ namespace gopt {
/** /**
* \brief apply optimize options * \brief apply optimize options
*/ */
void apply_optimize_options(const OptimizeOptions* options); const GraphOptimizer& apply_optimize_options(
const cg::GraphCommonOptimizeOptions& options);
}; };
/*! /*!
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#pragma once #pragma once
#include "megbrain/gopt/framework.h" #include "megbrain/gopt/framework.h"
#include "megbrain/graph/cg.h"
namespace mgb { namespace mgb {
namespace gopt { namespace gopt {
...@@ -256,7 +257,7 @@ namespace gopt { ...@@ -256,7 +257,7 @@ namespace gopt {
size_t pack_c_size); size_t pack_c_size);
}; };
struct OptimizeForInferenceOptions : OptimizeOptions {}; struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {};
/*! /*!
* \brief optimize a computing graph for inference * \brief optimize a computing graph for inference
...@@ -325,13 +326,6 @@ namespace gopt { ...@@ -325,13 +326,6 @@ namespace gopt {
void apply(OptState& opt) const override; void apply(OptState& opt) const override;
}; };
/*!
* \brief transform tensor format in a network to c/4hwn4 format, and
* accelerate the inference speed on Nvidia platform
*/
void reformat_to_chwn4_transform_dest_vars_inplace(
mgb::cg::VarNodeArray& dest_vars);
} // namespace gopt } // namespace gopt
} // namespace mgb } // namespace mgb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册