From 45e2beead6dac6069a6254062f43ee3b9588c5df Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 18 May 2020 13:33:00 +0800 Subject: [PATCH] feat(mgb/gopt): add nchw4 optpass GitOrigin-RevId: 551b6b828d33916b8e0a8bec73e6d3c6abd65536 --- python_module/megengine/_internal/__init__.py | 3 + python_module/megengine/jit/__init__.py | 1 + python_module/src/swig/misc.i | 1 + sdk/load-and-run/dump_with_testcase_mge.py | 7 + sdk/load-and-run/src/mgblar.cpp | 1 + src/core/include/megbrain/graph/cg.h | 2 + src/gopt/impl/framework.cpp | 7 + src/gopt/impl/tensor_reformat.cpp | 443 ++++++++++++++++-- src/gopt/include/megbrain/gopt/inference.h | 13 + src/gopt/test/inference.cpp | 126 +++++ 10 files changed, 554 insertions(+), 50 deletions(-) diff --git a/python_module/megengine/_internal/__init__.py b/python_module/megengine/_internal/__init__.py index 58b8f1515..13372392e 100644 --- a/python_module/megengine/_internal/__init__.py +++ b/python_module/megengine/_internal/__init__.py @@ -541,6 +541,7 @@ def optimize_for_inference( fuse_conv_bias_nonlinearity=False, use_nchw32=False, fuse_conv_bias_with_z=False, + use_nchw4=False, use_nchw88=False, use_nchw44=False, use_chwn4=False @@ -561,6 +562,7 @@ def optimize_for_inference( OpenCL devices :param fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty into one opr. This is supported only in NHWCD4 format. + :param use_nchw4: whether to use NCHW4 tensor format. :param use_nchw88: whether to use NCHW88 tensor format. This maybe faster some times. :param use_nchw44: whether to use NCHW44 tensor format. This maybe faster some @@ -588,6 +590,7 @@ def optimize_for_inference( layout_tranform = None for k, v in { + "use_nchw4": "nchw4", "use_nhwcd4": "nhwcd4", "use_nchw32": "nchw32", "use_nchw88": "nchw88", diff --git a/python_module/megengine/jit/__init__.py b/python_module/megengine/jit/__init__.py index 89b134473..37a307594 100644 --- a/python_module/megengine/jit/__init__.py +++ b/python_module/megengine/jit/__init__.py @@ -463,6 +463,7 @@ class trace: "enable_io16xc32": "f16_io_f32_comp", "enable_ioc16": "f16_io_comp", "enable_hwcd4": "use_nhwcd4", + "enable_nchw4": "use_nchw4", "enable_nchw88": "use_nchw88", "enable_nchw32": "use_nchw32", "enable_nchw44": "use_nchw44", diff --git a/python_module/src/swig/misc.i b/python_module/src/swig/misc.i index d554bf5f3..cb7542634 100644 --- a/python_module/src/swig/misc.i +++ b/python_module/src/swig/misc.i @@ -80,6 +80,7 @@ struct _OptimizeForInferenceOptions { #define SET(_trans, _trans_capital) \ void enable_##_trans(); \ + SET(nchw4, NCHW4); SET(nhwcd4, NHWCD4); SET(nchw88, NCHW88); SET(nchw44, NCHW44); diff --git a/sdk/load-and-run/dump_with_testcase_mge.py b/sdk/load-and-run/dump_with_testcase_mge.py index cd62283e5..c2c5d0d4c 100755 --- a/sdk/load-and-run/dump_with_testcase_mge.py +++ b/sdk/load-and-run/dump_with_testcase_mge.py @@ -252,6 +252,7 @@ def optimize_for_inference(args, outputs): 'enable_io16xc32': 'f16_io_f32_comp', 'enable_ioc16': 'f16_io_comp', 'enable_hwcd4': 'use_nhwcd4', + 'enable_nchw4': 'use_nchw4', 'enable_nchw88': 'use_nchw88', 'enable_nchw44': 'use_nchw44', 'enable_nchw32': 'use_nchw32', @@ -381,6 +382,12 @@ def main(): 'for inference; you may need to disable CUDA and set ' 'MGB_USE_MEGDNN_DBG=2' ) + parser.add_argument( + '--enable-nchw4', + action='store_true', + help='transform the model format from NCHW to NCHW4 ' + 'for inference' + ) parser.add_argument( '--enable-nchw88', action='store_true', diff --git a/sdk/load-and-run/src/mgblar.cpp b/sdk/load-and-run/src/mgblar.cpp index 6789e7e70..15103ab4f 100644 --- a/sdk/load-and-run/src/mgblar.cpp +++ b/sdk/load-and-run/src/mgblar.cpp @@ -980,6 +980,7 @@ Args Args::from_argv(int argc, char **argv) { continue; \ } + cb(nchw4); cb(chwn4); cb(nchw44); cb(nchw88); diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index 0f5bddfd5..e569a6f05 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -97,6 +97,7 @@ struct GraphCommonOptimizeOptions { bool fuse_conv_bias_with_z = false; enum LayoutTransform : uint32_t { DEFAULT, + NCHW4, ///< compute using NCHW4 tensor format NHWCD4, ///< compute using NHWCD4 tensor format NCHW88, ///< compute using NCHW88 tensor format NCHW44, ///< compute using NCHW44 tensor format @@ -137,6 +138,7 @@ struct GraphCommonOptimizeOptions { return layout_transform == LayoutTransform::_trans_capital; \ } + SET(nchw4, NCHW4); SET(nhwcd4, NHWCD4); SET(nchw88, NCHW88); SET(nchw44, NCHW44); diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index 8552c3be1..f9e5a4c26 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -725,6 +725,13 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); }); cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); }); + cb(nchw4, { + add_pass(); + add_pass(); + add_pass(EnableNCHW4Pass::make_nchw4_converter()); + add_pass(); + add_pass(); + }); cb(nhwcd4, { add_pass(); add_pass(ConvertFormatPass::make_nhwcd4_converter()); diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 8f1507a56..c605117a2 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -63,10 +63,15 @@ public: NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout + NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout + NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout - NCHW_TO_NCHW44, //!< from nchw layout to nchw44 layout NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout - NCHW44_TO_NCHW, //!< from nchw44 layout to nchw layout + + WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 + //!< layout + WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to + //!< nchw4 layout WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 //!< layout @@ -167,6 +172,42 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { dst[3] = inp_shape[2]; dst[4] = inp_shape[4]; } else if (layout_type() == + RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4){ + mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); + dst.ndim = 5; + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1] / 4; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + dst[4] = 4; + } else if (layout_type() == + RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW){ + mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); + dst.ndim = 4; + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1] * 4; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + } else if (layout_type() == RelayoutPlaceholder::LayoutType:: + WEIGHT_NCHW_TO_NCHW4_DENSE) { + mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); + dst.ndim = 5; + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1] / 4; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + dst[4] = 4; + } else if (layout_type() == RelayoutPlaceholder::LayoutType:: + WEIGHT_NCHW_TO_NCHW4_GROUP) { + mgb_assert(inp_shape.ndim == 5 && inp_shape[2] % 4 == 0); + dst.ndim = 6; + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1]; + dst[2] = inp_shape[2] / 4; + dst[3] = inp_shape[3]; + dst[4] = inp_shape[4]; + dst[5] = 4; + }else if (layout_type() == RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW88) { mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 8 == 0); dst.ndim = 5; @@ -226,23 +267,6 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { dst[2] = inp_shape[3]; dst[3] = inp_shape[1]; dst[4] = 8; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW44) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 4; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 4; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW44_TO_NCHW) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); - dst.ndim = 4; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] * 4; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; } else if (layout_type() == RelayoutPlaceholder::LayoutType:: WEIGHT_NCHW_TO_NCHW44_DENSE) { mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0 && @@ -394,6 +418,66 @@ void TensorReformatPass::translate_pass(OptState& opt) const { auto y2 = opr::Reshape::make(y1, tshp1); return y2.node(); }; + reformat[LayoutType::NCHW_TO_NCHW4] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::NCHW4_TO_NCHW] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); + auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); + auto y1 = opr::Reshape::make(y0, tshp0); + return y1.node(); + }; + reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1), sub(2) / 4, cv(4), sub(3), sub(4)}, 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1), sub(2) / 4, sub(3), sub(4), cv(4)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 2, 4, 5, 3}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; reformat[LayoutType::NCHW_TO_NCHW88] = [](VarNode* inp) -> VarNode* { auto x = SymbolVar(inp); auto xshp = opr::GetVarShape::make(x); @@ -492,34 +576,6 @@ void TensorReformatPass::translate_pass(OptState& opt) const { auto y2 = opr::Reshape::make(y1, tshp1); return y2.node(); }; - reformat[LayoutType::NCHW_TO_NCHW44] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::NCHW44_TO_NCHW] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); - auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); - auto y1 = opr::Reshape::make(y0, tshp0); - return y1.node(); - }; reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DENSE] = [](VarNode* inp) -> VarNode* { auto x = SymbolVar(inp); @@ -1239,6 +1295,293 @@ std::unique_ptr EnableCHWN4Pass::make_chwn4_converter() { return ret; } +/* ================ EnableNCHW4Pass ================ */ +VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var, + VarNode* orig_var) const { + if (!orig_var->shape().eq_shape(new_var->shape())) { + return RelayoutPlaceholder::make( + new_var, RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW) + .node(); + } + return new_var; +} + +std::unique_ptr EnableNCHW4Pass::make_nchw4_converter(){ + auto ret = std::make_unique(); + ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); + using RelayoutMode = RelayoutPlaceholder::LayoutType; + megdnn::param::Convolution::Format conv_format = + megdnn::param::Convolution::Format::NCHW4; + megdnn::param::ConvBias::Format conv_bias_format = + megdnn::param::ConvBias::Format::NCHW4; + megdnn::param::BatchConvBias::Format batch_conv_bias_format = + megdnn::param::BatchConvBias::Format::NCHW4; + RelayoutMode src_to_nchw4_mode = RelayoutMode::NCHW_TO_NCHW4; + RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW; + RelayoutMode weight_to_nchw4_mode_dense = + RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE; + RelayoutMode weight_to_nchw4_mode_group = + RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP; + auto trans_nchw4 = [weight_to_nchw4_mode_dense, + weight_to_nchw4_mode_group]( + const megdnn::param::Convolution::Sparse conv_mode, + const VarNode* filter) -> RelayoutMode { + if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { + mgb_assert(filter->shape().ndim == 4, + "The origin filter is not NCHW mode"); + size_t IC = filter->shape()[1]; + mgb_assert(IC % 4 == 0, + "The input channel should be divisible by 4"); + return weight_to_nchw4_mode_dense; + } else { + mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); + mgb_assert(filter->shape().ndim == 5, + "The origin filter if not NCHW mode"); + size_t IC = filter->shape()[2]; + mgb_assert(IC % 4 == 0, + "The input channel should be divisible by 4"); + return weight_to_nchw4_mode_group; + } + }; + auto replace_conv_opr = [trans_nchw4, conv_format, src_to_nchw4_mode]( + OperatorNodeBase* opr, const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& conv_opr = opr->cast_final_safe(); + mgb_assert(conv_opr.param().format == + megdnn::param::Convolution::Format::NCHW, + "ConvertFormat Pass only support converting NCHW to NCHW4"); + VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; + // src: NCHW --> NCWH4 + if (new_inp[0]->shape().ndim != 5) { + mgb_assert(new_inp[0]->shape().ndim == 4); + auto new_src = RelayoutPlaceholder::make(new_inp[0], + src_to_nchw4_mode); + conv_src = new_src.node(); + } + // weight: NCHW --> NCHW4 + auto weight_mode = + trans_nchw4(conv_opr.param().sparse, new_inp[1]); + auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode); + conv_filter = new_filter.node(); + // format: NCHW --> NCHW4 + auto new_param = conv_opr.param(); + new_param.format = conv_format; + // dst + auto new_conv_opr = opr::Convolution::make( + conv_src, conv_filter, new_param, + conv_opr.execution_policy(), conv_opr.config()); + OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr(); + mgb_assert(new_conv_opr.shape().ndim == 5, + "The conv dst dim is not trans to nchw4"); + return new_opr; + }; + + auto replace_batch_conv_bias_opr = [batch_conv_bias_format, + src_to_nchw4_mode]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& batch_conv_bias_opr = + opr->cast_final_safe(); + mgb_assert(batch_conv_bias_opr.param().format == + megdnn::param::BatchConvBias::Format::NCHW, + "ConvertFormat Pass only support converting NCHW to NCHW4"); + // what should be converted: src, weight + VarNode *src = new_inp[0], *filter = new_inp[1]; + // src: NCHW --> NCHW4 + if (new_inp[0]->shape().ndim !=5) { + mgb_assert(new_inp[0]->shape().ndim == 4); + auto new_src = RelayoutPlaceholder::make(new_inp[0], + src_to_nchw4_mode); + src = new_src.node(); + } + // weight: BNCHW --> BNCHW4 + // only support dense mode, which is similar with conv->group. + auto weight_mode = + RelayoutPlaceholder::LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP; + auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode); + filter = new_filter.node(); + // format: NCHW --> NCHW4 + auto new_param = batch_conv_bias_opr.param(); + new_param.format = batch_conv_bias_format; + if (new_inp.size() == 2) { + auto dst = opr::BatchConvBias::make(src, filter, new_param, + batch_conv_bias_opr.execution_policy(), + batch_conv_bias_opr.config()); + OperatorNodeBase* new_opr = dst.node()->owner_opr(); + mgb_assert(dst.shape().ndim == 5, + "The conv_bias dst dim is not trans to nchw4"); + return new_opr; + } + // bias: NCHW --> NCHW4 + VarNode* bias = new_inp[2]; + if (new_inp[2]->shape().ndim == 4) { + auto new_bias = RelayoutPlaceholder::make(new_inp[2], + src_to_nchw4_mode); + bias = new_bias.node(); + } + if (new_inp.size() == 3) { + auto dst = opr::BatchConvBias::make(src, filter, bias, new_param, + batch_conv_bias_opr.execution_policy(), + batch_conv_bias_opr.config()); + OperatorNodeBase* new_opr = dst.node()->owner_opr(); + mgb_assert(dst.shape().ndim == 5, + "The conv_bias dst dim is not trans to nchw4"); + return new_opr; + } + // z_inp: NCHW --> NCHW4 + VarNode* z_inp = new_inp[3]; + if (new_inp[3]->shape().ndim == 4) { + auto new_z = RelayoutPlaceholder::make(new_inp[3], + src_to_nchw4_mode); + z_inp = new_z.node(); + } + auto dst = opr::BatchConvBias::make(src, filter, bias, z_inp, + new_param,batch_conv_bias_opr.execution_policy(), + batch_conv_bias_opr.config()); + OperatorNodeBase* new_opr = dst.node()->owner_opr(); + mgb_assert(dst.shape().ndim == 5, + "The conv_bias dst dim is not trans to nchw4"); + return new_opr; + }; + auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format, + src_to_nchw4_mode]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& conv_bias_opr = opr->cast_final_safe(); + mgb_assert(conv_bias_opr.param().format == + megdnn::param::ConvBias::Format::NCHW, + "ConvertFormat Pass only support converting NCHW to NCHW4"); + // what should be converted: src, weight + VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1]; + // src: NCHW --> NCHW4 + if (new_inp[0]->shape().ndim !=5) { + mgb_assert(new_inp[0]->shape().ndim == 4); + auto new_src = RelayoutPlaceholder::make(new_inp[0], + src_to_nchw4_mode); + conv_bias_src = new_src.node(); + } + // weight: NCHW --> NCHW4 or GNCHW --> GNCHW4 + auto weight_mode = + trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]); + auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode); + conv_bias_filter = new_filter.node(); + // format: NCHW --> NCHW4 + auto new_param = conv_bias_opr.param(); + new_param.format = conv_bias_format; + if (new_inp.size() == 2) { + auto new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_filter, new_param, + conv_bias_opr.execution_policy(), conv_bias_opr.config()); + OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); + mgb_assert(new_conv_bias_opr.shape().ndim == 5, + "The conv_bias dst dim is not trans to nchw4"); + return new_opr; + } + // bias: NCHW --> NCHW4 + VarNode* conv_bias_bias = new_inp[2]; + if (new_inp[2]->shape().ndim == 4) { + auto new_bias = RelayoutPlaceholder::make(new_inp[2], + src_to_nchw4_mode); + conv_bias_bias = new_bias.node(); + } + if (new_inp.size() == 3) { + auto new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, + conv_bias_opr.execution_policy(), conv_bias_opr.config()); + OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); + mgb_assert(new_conv_bias_opr.shape().ndim == 5, + "The conv_bias dst dim is not trans to nchw4"); + return new_opr; + } + // z_inp: NCHW --> NCHW4 + VarNode* z_inp = new_inp[3]; + if (new_inp[3]->shape().ndim == 4) { + auto new_z = RelayoutPlaceholder::make(new_inp[3], + src_to_nchw4_mode); + z_inp = new_z.node(); + } + auto new_conv_bias_opr = opr::ConvBias::make(conv_bias_src, + conv_bias_filter, conv_bias_bias, z_inp, new_param, + conv_bias_opr.execution_policy(), conv_bias_opr.config()); + OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); + mgb_assert(new_conv_bias_opr.shape().ndim == 5, + "The conv_bias dst dim is not trans to nchw4"); + return new_opr; + }; + auto replace_elemwise_opr = [=](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + bool has_inp_changed = false; + for (size_t i = 0; i < opr->input().size(); i++) { + if (new_inp[i]->shape().ndim == 5) { + has_inp_changed = true; + break; + } + } + if (has_inp_changed) { + auto temp_inp = new_inp; + for (size_t i = 0; i < opr->input().size(); i++) { + if (new_inp[i]->shape().ndim == 4) { + auto new_var = RelayoutPlaceholder::make( + new_inp[i], src_to_nchw4_mode); + temp_inp[i] = new_var.node(); + } else { + mgb_assert((new_inp[i]->shape().ndim == 5) || + new_inp[i]->shape().is_scalar()); + } + } + return serialization::copy_opr_shallow(*opr, temp_inp, + opr->config()); + } else { + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + }; + auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + VarNodeArray temp_inp = new_inp; + for (size_t i = 0; i < opr->input().size(); i++) { + if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { + mgb_assert(opr->input(i)->shape().ndim == 4); + mgb_assert(new_inp[i]->shape().ndim == 5); + auto new_var = + RelayoutPlaceholder::make(new_inp[i], src_to_nchw_mode); + temp_inp[i] = new_var.node(); + } + } + return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); + }; + auto&& replace_func = ret->m_opr_replace_func; + //! supportted nchw4 + replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; + replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; + replace_func[opr::BatchConvBias::typeinfo()] = + replace_batch_conv_bias_opr; + replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; + replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; + replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; + replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; + //! not supported nchw4 + replace_func[opr::PoolingForward::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::ConvolutionBackwardData::typeinfo()] = + relayout_inp_to_nchw; + replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::WarpPerspectiveForward::typeinfo()] = + relayout_inp_to_nchw; + replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; + return ret; +} + /* ================ EnableNchwxxPass =============== */ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, VarNode* orig_var) const { @@ -1251,7 +1594,7 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, } else if (m_pack_c_size == 4) { return RelayoutPlaceholder::make( new_var, - RelayoutPlaceholder::LayoutType::NCHW44_TO_NCHW) + RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW) .node(); } } @@ -1287,8 +1630,8 @@ std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP; weight_to_nchwxx_mode_chan = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN; hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; - src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW44; - src_to_nchw_mode = RelayoutMode::NCHW44_TO_NCHW; + src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW4; + src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW; conv_bias_format = megdnn::param::ConvBias::Format::NCHW44; conv_format = megdnn::param::ConvolutionV0::Format::NCHW44; pooling_format = megdnn::param::Pooling::Format::NCHW44; diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index 9e653c581..351a3a675 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -229,6 +229,19 @@ namespace gopt { static std::unique_ptr make_chwn4_converter(); }; + /*! + * \brief convert tensor format to nchw4 to speed up inference on CUDA + */ + class EnableNCHW4Pass final : public TensorReformatPass { + VarNode* on_graph_endpoint_var(VarNode* new_var, + VarNode* orig_var) const override; + public: + const char* name() const override { return mgb_cstr_log("tensor_format_nchw4"); } + + //! make nchw -> nchw4 converter opt pass + static std::unique_ptr make_nchw4_converter(); + }; + /*! * \brief convert tensor format to nchwxx to speed up inference on certain * devices diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 40636d434..535db7f0e 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -2327,8 +2327,134 @@ TEST(TestGoptInference, EnableCHWN4ShuffleRemove) { MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); } +TEST(TestGoptInference, ConvertFormatNCHW4GPU) { + REQUIRE_GPU(1); + auto cn = CompNode::load("gpu0"); + cn.activate(); + auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; + auto sm_ver = prop.major * 10 + prop.minor; + if (sm_ver < 61) { + printf("This testcast ignored due to insufficient cuda cap(got: %d, " + "expected: %d)\n", + sm_ver, 61); + return; + } + + HostTensorGenerator gen; + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), + dtype); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + + auto x = mkvar("x", {2, 4, 16, 16}, dtype::QuantizedS8(2.5f)); + opr::ConvBias::Param param_conv_bias; + param_conv_bias.format = opr::ConvBias::Param::Format::NCHW; + param_conv_bias.stride_h = param_conv_bias.stride_w = 1; + param_conv_bias.pad_h = param_conv_bias.pad_w = 1; + param_conv_bias.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; + // dense + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; + auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), + b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); + auto conv1 = opr::ConvBiasForward::make( + x, w1, b1, param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); + // group + // icpg != 1 && ocpg != 1 + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; + auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)), + b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); + auto conv2 = opr::ConvBiasForward::make(conv1, w2, b2, + param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); + + auto y = opr::TypeCvt::make(conv2, dtype::Float32()); + + SymbolVar y_opt; + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nchw4(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + } + + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, + find_opr(y_opt).param().format); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath( + output_file("TestGoptInference.ConvertFormatNCHW4GPU.json")); + + HostTensorND host_y, host_y_opt; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); +} + #endif +TEST(TestGoptInference, ConvertFormatNCHW4) { + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name); + }; + + auto x = mkvar("x", {2, 4, 16, 16}); + // ConvBias + opr::ConvBias::Param param_conv_bias; + param_conv_bias.pad_h = param_conv_bias.pad_w = 1; + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; + auto w1 = mkcvar("w1", {8, 4, 3, 3}), b1 = mkcvar("b1", {1, 8, 1, 1}); + auto conv1 = opr::ConvBias::make(x, w1, b1, param_conv_bias); + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; + auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}); + auto conv2 = opr::ConvBias::make(conv1, w2, b2, param_conv_bias); + // Convolution + opr::Convolution::Param param_conv; + param_conv.pad_h = param_conv.pad_w = 1; + param_conv.sparse = opr::Convolution::Param::Sparse::DENSE; + auto w3 = mkcvar("w3", {8, 8, 3, 3}); + auto y = opr::Convolution::make(conv2, w3, param_conv); + + SymbolVar y_opt; + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nchw4(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + } + + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, + find_opr(y_opt).param().format); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath( + output_file("TestGoptInference.ConvertFormatNCHW4.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); +} + TEST(TestGoptInference, ConvertFormatNCHW88) { HostTensorGenerator<> gen; auto cn = CompNode::load("cpu0"); -- GitLab