提交 e080dd3c 编写于 作者: M Megvii Engine Team

refactor(gopt): rename nchw2xxx to xxx

GitOrigin-RevId: fcb08c09e0e48afbad99230cb67ceb7fa1c9119f
上级 a3560fa1
...@@ -588,11 +588,11 @@ def optimize_for_inference( ...@@ -588,11 +588,11 @@ def optimize_for_inference(
layout_tranform = None layout_tranform = None
for k, v in { for k, v in {
"use_nhwcd4": "nchw2nhwcd4", "use_nhwcd4": "nhwcd4",
"use_nchw32": "nchw2nchw32", "use_nchw32": "nchw32",
"use_nchw88": "nchw2nchw88", "use_nchw88": "nchw88",
"use_nchw44": "nchw2nchw44", "use_nchw44": "nchw44",
"use_chwn4": "nchw42chwn4", "use_chwn4": "chwn4",
}.items(): }.items():
if settings[k]: if settings[k]:
assert ( assert (
......
...@@ -80,11 +80,11 @@ struct _OptimizeForInferenceOptions { ...@@ -80,11 +80,11 @@ struct _OptimizeForInferenceOptions {
#define SET(_trans, _trans_capital) \ #define SET(_trans, _trans_capital) \
void enable_##_trans(); \ void enable_##_trans(); \
SET(nchw2nhwcd4, NCHW2NHWCD4); SET(nhwcd4, NHWCD4);
SET(nchw2nchw88, NCHW2NCHW88); SET(nchw88, NCHW88);
SET(nchw2nchw44, NCHW2NCHW44); SET(nchw44, NCHW44);
SET(nchw2nchw32, NCHW2NCHW32); SET(nchw32, NCHW32);
SET(nchw42chwn4, NCHW42CHWN4); SET(chwn4, CHWN4);
#undef SET #undef SET
}; };
......
...@@ -708,23 +708,23 @@ void GraphOptimizer::apply_optimize_options( ...@@ -708,23 +708,23 @@ void GraphOptimizer::apply_optimize_options(
if (options->f16_io_f32_comp) { if (options->f16_io_f32_comp) {
add_pass(ConvertF32ToF16Pass::make(true)); add_pass(ConvertF32ToF16Pass::make(true));
} }
if (options->transform_nchw2nhwcd4()) { if (options->transform_nhwcd4()) {
add_pass(ConvertFormatPass::make_nhwcd4_converter()); add_pass(ConvertFormatPass::make_nhwcd4_converter());
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
} }
if (options->transform_nchw2nchw88()) { if (options->transform_nchw88()) {
add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); add_pass(EnableNchwxxPass::make_nchwxx_converter(8));
} }
if (options->transform_nchw2nchw44()) { if (options->transform_nchw44()) {
add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); add_pass(EnableNchwxxPass::make_nchwxx_converter(4));
} }
if (options->transform_nchw2nchw32()) { if (options->transform_nchw32()) {
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
add_pass(EnableTensorCorePass::make_tensorcore_converter()); add_pass(EnableTensorCorePass::make_tensorcore_converter());
add_pass<ShuffleShuffleRemovePass>(); add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>(); add_pass<RemoveRedundantTypeCvtPass>();
} }
if (options->transform_nchw42chwn4()) { if (options->transform_chwn4()) {
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>(); add_pass<FuseConvBiasZPass>();
add_pass(EnableCHWN4Pass::make_chwn4_converter()); add_pass(EnableCHWN4Pass::make_chwn4_converter());
......
...@@ -390,13 +390,13 @@ namespace gopt { ...@@ -390,13 +390,13 @@ namespace gopt {
bool fuse_conv_bias_nonlinearity = false; bool fuse_conv_bias_nonlinearity = false;
enum LayoutTransform : uint32_t { enum LayoutTransform : uint32_t {
DEFAULT, DEFAULT,
NCHW2NHWCD4, ///< compute using NHWCD4 tensor format NHWCD4, ///< compute using NHWCD4 tensor format
NCHW2NCHW88, ///< compute using NCHW88 tensor format NCHW88, ///< compute using NCHW88 tensor format
NCHW2NCHW44, ///< compute using NCHW44 tensor format NCHW44, ///< compute using NCHW44 tensor format
NCHW2NCHW32, ///< compute using NCHW32 tensor format, used for NCHW32, ///< compute using NCHW32 tensor format, used for
///< tensorcore ///< tensorcore
NCHW42CHWN4, ///< compute using CHWN4 tensor format, transformed CHWN4, ///< compute using CHWN4 tensor format, transformed mainly
///< from NCHW4, mainly used for cuda ///< used for cuda
}; };
LayoutTransform layout_transform = LayoutTransform::DEFAULT; LayoutTransform layout_transform = LayoutTransform::DEFAULT;
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b)
...@@ -422,11 +422,11 @@ namespace gopt { ...@@ -422,11 +422,11 @@ namespace gopt {
return layout_transform == LayoutTransform::_trans_capital; \ return layout_transform == LayoutTransform::_trans_capital; \
} }
SET(nchw2nhwcd4, NCHW2NHWCD4); SET(nhwcd4, NHWCD4);
SET(nchw2nchw88, NCHW2NCHW88); SET(nchw88, NCHW88);
SET(nchw2nchw44, NCHW2NCHW44); SET(nchw44, NCHW44);
SET(nchw2nchw32, NCHW2NCHW32); SET(nchw32, NCHW32);
SET(nchw42chwn4, NCHW42CHWN4); SET(chwn4, CHWN4);
#undef SET #undef SET
}; };
......
...@@ -992,7 +992,7 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) { ...@@ -992,7 +992,7 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) {
SymbolVar y_opt; SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw2nhwcd4(); options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4,
...@@ -1051,7 +1051,7 @@ TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) { ...@@ -1051,7 +1051,7 @@ TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) {
SymbolVar y_opt; SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw2nhwcd4(); options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4,
...@@ -1102,7 +1102,7 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Deconv) { ...@@ -1102,7 +1102,7 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Deconv) {
SymbolVar y_opt; SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw2nhwcd4(); options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
...@@ -1147,7 +1147,7 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Qint8) { ...@@ -1147,7 +1147,7 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Qint8) {
SymbolVar y_opt; SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw2nhwcd4(); options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(opr::ConvBias::Param::Format::NHWCD4, ASSERT_EQ(opr::ConvBias::Param::Format::NHWCD4,
...@@ -1199,7 +1199,7 @@ TEST(TestGoptInference, ConvertFormatPadIC) { ...@@ -1199,7 +1199,7 @@ TEST(TestGoptInference, ConvertFormatPadIC) {
auto y = opr::Convolution::make(concat, w1, param); auto y = opr::Convolution::make(concat, w1, param);
SymbolVar y_opt; SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw2nhwcd4(); options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
HostTensorND host_y_opt, host_y; HostTensorND host_y_opt, host_y;
...@@ -1285,7 +1285,7 @@ TEST(TestGoptInference, ConvBiasNonlinearityFusePass) { ...@@ -1285,7 +1285,7 @@ TEST(TestGoptInference, ConvBiasNonlinearityFusePass) {
y_y = opr::Convolution::make(y_expand, w3, param), y = y_y + y_tmp; y_y = opr::Convolution::make(y_expand, w3, param), y = y_y + y_tmp;
SymbolVar y_opt; SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw2nhwcd4().enable_fuse_conv_bias_nonlinearity(); options.enable_nhwcd4().enable_fuse_conv_bias_nonlinearity();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(3u, find_opr<opr::ConvBias>(y_opt).input().size()); ASSERT_EQ(3u, find_opr<opr::ConvBias>(y_opt).input().size());
graph->compile({{y_opt, {}}}) graph->compile({{y_opt, {}}})
...@@ -1516,7 +1516,7 @@ TEST(TestEnableTensorCore, SmallInputShape) { ...@@ -1516,7 +1516,7 @@ TEST(TestEnableTensorCore, SmallInputShape) {
SymbolVar y_no_tc; SymbolVar y_no_tc;
{ {
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw2nchw32().enable_fuse_conv_bias_nonlinearity(); options.enable_nchw32().enable_fuse_conv_bias_nonlinearity();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
} }
{ {
...@@ -1581,7 +1581,7 @@ TEST(TestEnableTensorCore, ConvBiasWithZ) { ...@@ -1581,7 +1581,7 @@ TEST(TestEnableTensorCore, ConvBiasWithZ) {
SymbolVar y_no_tc; SymbolVar y_no_tc;
{ {
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_conv_bias_nonlinearity().enable_nchw2nchw32(); options.enable_fuse_conv_bias_nonlinearity().enable_nchw32();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
} }
{ {
...@@ -1649,12 +1649,12 @@ TEST(TestGoptInference, EnableTensorCore) { ...@@ -1649,12 +1649,12 @@ TEST(TestGoptInference, EnableTensorCore) {
SymbolVar y_no_tc; SymbolVar y_no_tc;
{ {
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_conv_bias_nonlinearity().enable_nchw2nchw32(); options.enable_fuse_conv_bias_nonlinearity().enable_nchw32();
unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt);
} }
{ {
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_conv_bias_nonlinearity().enable_nchw2nchw32(); options.enable_fuse_conv_bias_nonlinearity().enable_nchw32();
unpack_vector(gopt::optimize_for_inference({y4}, options), y_no_tc); unpack_vector(gopt::optimize_for_inference({y4}, options), y_no_tc);
} }
auto nr_dimshuffle = find_opr_num<mgb::opr::Dimshuffle>(y_opt); auto nr_dimshuffle = find_opr_num<mgb::opr::Dimshuffle>(y_opt);
...@@ -1855,7 +1855,7 @@ TEST(TestEnableTensorCore, ShuffleMerge) { ...@@ -1855,7 +1855,7 @@ TEST(TestEnableTensorCore, ShuffleMerge) {
SymbolVar y_no_tc; SymbolVar y_no_tc;
{ {
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_conv_bias_nonlinearity().enable_nchw2nchw32(); options.enable_fuse_conv_bias_nonlinearity().enable_nchw32();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
} }
{ {
...@@ -1923,7 +1923,7 @@ TEST(FuseConvBiasZPass, Basic) { ...@@ -1923,7 +1923,7 @@ TEST(FuseConvBiasZPass, Basic) {
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_conv_bias_nonlinearity() options.enable_fuse_conv_bias_nonlinearity()
.enable_fuse_conv_bias_with_z() .enable_fuse_conv_bias_with_z()
.enable_nchw2nchw32(); .enable_nchw32();
unpack_vector(gopt::optimize_for_inference({y1}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y1}, options), y_opt);
} }
auto nr_elemwisemultitype = find_opr_num<opr::ElemwiseMultiType>(y_opt); auto nr_elemwisemultitype = find_opr_num<opr::ElemwiseMultiType>(y_opt);
...@@ -1940,7 +1940,7 @@ TEST(FuseConvBiasZPass, Basic) { ...@@ -1940,7 +1940,7 @@ TEST(FuseConvBiasZPass, Basic) {
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_conv_bias_nonlinearity() options.enable_fuse_conv_bias_nonlinearity()
.enable_fuse_conv_bias_with_z() .enable_fuse_conv_bias_with_z()
.enable_nchw2nchw32(); .enable_nchw32();
unpack_vector(gopt::optimize_for_inference({y2}, options), unpack_vector(gopt::optimize_for_inference({y2}, options),
y_opt); y_opt);
} }
...@@ -2013,7 +2013,7 @@ TEST(TestGoptInference, EnableCHWN4) { ...@@ -2013,7 +2013,7 @@ TEST(TestGoptInference, EnableCHWN4) {
SymbolVar y_cudnn; SymbolVar y_cudnn;
{ {
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw42chwn4(); options.enable_chwn4();
unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt);
} }
unpack_vector(gopt::GraphOptimizer{} unpack_vector(gopt::GraphOptimizer{}
...@@ -2099,7 +2099,7 @@ TEST(TestGoptInference, EnableCHWN4WarpPespective) { ...@@ -2099,7 +2099,7 @@ TEST(TestGoptInference, EnableCHWN4WarpPespective) {
SymbolVar y_cudnn; SymbolVar y_cudnn;
{ {
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw42chwn4(); options.enable_chwn4();
unpack_vector(gopt::optimize_for_inference({y2}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y2}, options), y_opt);
} }
unpack_vector(gopt::GraphOptimizer{} unpack_vector(gopt::GraphOptimizer{}
...@@ -2386,7 +2386,7 @@ TEST(TestGoptInference, ConvertFormatNCHW88) { ...@@ -2386,7 +2386,7 @@ TEST(TestGoptInference, ConvertFormatNCHW88) {
SymbolVar y_opt; SymbolVar y_opt;
{ {
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw2nchw88(); options.enable_nchw88();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
} }
...@@ -2467,7 +2467,7 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { ...@@ -2467,7 +2467,7 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
SymbolVar y_opt; SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw2nchw44(); options.enable_nchw44();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44,
......
...@@ -501,7 +501,7 @@ TEST(TestOprIO, MultipleDeviceTensorWithFormatHolderCpu) { ...@@ -501,7 +501,7 @@ TEST(TestOprIO, MultipleDeviceTensorWithFormatHolderCpu) {
auto y = opr::Elemwise::make({conv2}, opr::Elemwise::Param::Mode::RELU); auto y = opr::Elemwise::make({conv2}, opr::Elemwise::Param::Mode::RELU);
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw2nhwcd4(); options.enable_nhwcd4();
SymbolVar y_opt = SymbolVar y_opt =
gopt::optimize_for_inference({y}, options)[0].rename("out"); gopt::optimize_for_inference({y}, options)[0].rename("out");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册