diff --git a/dnn/src/naive/convolution/helper.h b/dnn/src/naive/convolution/helper.h index 3060375a56cf31df3186029ffd8e79203ddb77c3..ee5e8aacc36b1e2c8fe589838af32c7788c3e976 100644 --- a/dnn/src/naive/convolution/helper.h +++ b/dnn/src/naive/convolution/helper.h @@ -287,8 +287,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, } } else if (filter_meta.format == Format::NCHW44 || filter_meta.format == Format::NCHW44_DOT) { - if (filter_meta.format == Format::NCHW44 && !is_output && - src.layout.ndim == 4) { + if (!is_output && src.layout.ndim == 4) { return n * layout.stride[0] + c * layout.stride[1] + h * layout.stride[2] + w * layout.stride[3]; } else { diff --git a/python_module/megengine/_internal/__init__.py b/python_module/megengine/_internal/__init__.py index 00e314e259c76dd75a84305fd86038a80c7077e5..e95067473529bb8d05cd78c756fbf5db02ca1be1 100644 --- a/python_module/megengine/_internal/__init__.py +++ b/python_module/megengine/_internal/__init__.py @@ -554,6 +554,7 @@ def optimize_for_inference( use_nchw4=False, use_nchw88=False, use_nchw44=False, + use_nchw44_dot=False, use_chwn4=False ): """optimize computing graph for inference @@ -577,6 +578,8 @@ def optimize_for_inference( times. :param use_nchw44: whether to use NCHW44 tensor format. This maybe faster some times. + :param use_nchw44_dot: whether to use NCHW44_DOT tensor format. This format is + optimized for inference in armv8.2 :param use_nchw32: whether to use NCHW32 tensor format. Mainly used for nvidia tensorcore. :param use_chwn4: whether to use CHWN4 tensor format. Mainly used for @@ -605,6 +608,7 @@ def optimize_for_inference( "use_nchw32": "nchw32", "use_nchw88": "nchw88", "use_nchw44": "nchw44", + "use_nchw44_dot": "nchw44_dot", "use_chwn4": "chwn4", }.items(): if settings[k]: diff --git a/python_module/src/swig/misc.i b/python_module/src/swig/misc.i index 8b343058ff51a3d4993f8fe80814efbb27289f0a..d3f18946646de306ef130e31fc9864f3fa86423a 100644 --- a/python_module/src/swig/misc.i +++ b/python_module/src/swig/misc.i @@ -84,6 +84,7 @@ struct _OptimizeForInferenceOptions { SET(nhwcd4, NHWCD4); SET(nchw88, NCHW88); SET(nchw44, NCHW44); + SET(nchw44_dot, NCHW44_DOT); SET(nchw32, NCHW32); SET(chwn4, CHWN4); #undef SET diff --git a/sdk/load-and-run/dump_with_testcase_mge.py b/sdk/load-and-run/dump_with_testcase_mge.py index c2c5d0d4c8bc404f22845013c93b82cac18ee5f0..5af630db6eee238914956141bce27c02713a90a5 100755 --- a/sdk/load-and-run/dump_with_testcase_mge.py +++ b/sdk/load-and-run/dump_with_testcase_mge.py @@ -255,6 +255,7 @@ def optimize_for_inference(args, outputs): 'enable_nchw4': 'use_nchw4', 'enable_nchw88': 'use_nchw88', 'enable_nchw44': 'use_nchw44', + 'enable_nchw44_dot': 'use_nchw44_dot', 'enable_nchw32': 'use_nchw32', 'enable_chwn4': 'use_chwn4', 'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity', @@ -400,6 +401,12 @@ def main(): help='transform the model format from NCHW to NCHW44 ' 'for inference' ) + parser.add_argument( + '--enable-nchw44-dot', + action='store_true', + help='transform the model format from NCHW to NCHW44_DOT ' + 'for optimizing armv8.2 dot in inference' + ) parser.add_argument( '--enable-nchw32', action='store_true', diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index e569a6f05604c31a64a7cab4c3ca6cfc1b8ff22f..3b33175f58a3559ff43a244249e38fbd81219017 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -97,14 +97,15 @@ struct GraphCommonOptimizeOptions { bool fuse_conv_bias_with_z = false; enum LayoutTransform : uint32_t { DEFAULT, - NCHW4, ///< compute using NCHW4 tensor format - NHWCD4, ///< compute using NHWCD4 tensor format - NCHW88, ///< compute using NCHW88 tensor format - NCHW44, ///< compute using NCHW44 tensor format - NCHW32, ///< compute using NCHW32 tensor format, used for - ///< tensorcore - CHWN4, ///< compute using CHWN4 tensor format, transformed mainly - ///< used for cuda + NCHW4, ///< compute using NCHW4 tensor format + NHWCD4, ///< compute using NHWCD4 tensor format + NCHW88, ///< compute using NCHW88 tensor format + NCHW44, ///< compute using NCHW44 tensor format + NCHW44_DOT, ///< compute using NCHW44_DOT tensor format + NCHW32, ///< compute using NCHW32 tensor format, used for + ///< tensorcore + CHWN4, ///< compute using CHWN4 tensor format, transformed mainly + ///< used for cuda }; LayoutTransform layout_transform = LayoutTransform::DEFAULT; @@ -142,6 +143,7 @@ struct GraphCommonOptimizeOptions { SET(nhwcd4, NHWCD4); SET(nchw88, NCHW88); SET(nchw44, NCHW44); + SET(nchw44_dot, NCHW44_DOT); SET(nchw32, NCHW32); SET(chwn4, CHWN4); #undef SET diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index f9e5a4c26b41b10ed584c250acbc0059262b7c81..64315ab0aabd0f0b3e0352ec5c14a2a4bad1acda 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -738,6 +738,8 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( }); cb(nchw88, { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); }); cb(nchw44, { add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); }); + cb(nchw44_dot, + { add_pass(EnableNchw44DotPass::make_nchw44_dot_converter()); }); cb(nchw32, { add_pass(); add_pass(); diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index c605117a2f27b1b6ff39b755081d59352dfb42fe..0b27dbcc2561e198e949156b8cad5668c7f4101a 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -28,6 +28,7 @@ #include "megbrain/opr/imgproc.h" #include "megbrain/opr/nn_int.h" +#include "megdnn/opr_param_defs.h" #include "megdnn/tensor_format.h" #if MGB_ENABLE_TENSOR_RT @@ -59,19 +60,19 @@ MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder, public: //! relayout type of this opr enum class LayoutType { - NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout - NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout - NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout - CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout - NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout - NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout - NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout - NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout - - WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 - //!< layout - WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to - //!< nchw4 layout + NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout + NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout + NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout + CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout + NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout + NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout + NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout + NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout + + WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 + //!< layout + WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to + //!< nchw4 layout WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 //!< layout @@ -92,6 +93,10 @@ public: //!< the weight layout of input is nchw output is nchw44, special for //!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4} WEIGHT_HYBIRD_NCHW_NCHW44, + WEIGHT_NCHW_TO_NCHW44_DOT_DENSE, //!< weight from NCHW44 layout to + //!< NCHW44_DOT layout dense + WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to + //!< NCHW44_DOT layout group }; RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type); @@ -268,7 +273,9 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { dst[3] = inp_shape[1]; dst[4] = 8; } else if (layout_type() == RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW44_DENSE) { + WEIGHT_NCHW_TO_NCHW44_DENSE || + layout_type() == RelayoutPlaceholder::LayoutType:: + WEIGHT_NCHW_TO_NCHW44_DOT_DENSE) { mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0 && inp_shape[1] % 4 == 0); dst.ndim = 6; @@ -279,7 +286,9 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { dst[4] = 4; dst[5] = 4; } else if (layout_type() == RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW44_GROUP) { + WEIGHT_NCHW_TO_NCHW44_GROUP || + layout_type() == RelayoutPlaceholder::LayoutType:: + WEIGHT_NCHW_TO_NCHW44_DOT_GROUP) { mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 4 == 0 && inp_shape[2] % 4 == 0); dst.ndim = 7; @@ -646,6 +655,42 @@ void TensorReformatPass::translate_pass(OptState& opt) const { auto y2 = opr::Reshape::make(y1, tshp1); return y2.node(); }; + reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE] = + [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0) / 4, cv(4), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), + tshp1 = opr::Concat::make( + {sub(0) / 4, sub(1) / 4, sub(2), sub(3), cv(4), cv(4)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 1, 3}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP] = + [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4, + cv(4), sub(3), sub(4)}, + 0), + tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3), + sub(4), cv(4), cv(4)}, + 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 2, 4}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; auto rewriter = opt.graph().make_rewriter(); auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) { @@ -1601,12 +1646,7 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, return new_var; } -std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( - size_t pack_c_size) { - auto ret = std::make_unique(pack_c_size); - ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); - //! First is whether the conv can trans to nchwxx, second is the filter - //! trans mode +void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ using RelayoutMode = RelayoutPlaceholder::LayoutType; using TestFilterResult = std::pair; RelayoutMode weight_to_nchwxx_mode_dense = @@ -1954,8 +1994,7 @@ std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); }; - ret->set_name(convter_pass_name); - auto&& replace_func = ret->m_opr_replace_func; + auto&& replace_func = m_opr_replace_func; //! supportted nchwxx replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; @@ -1978,6 +2017,246 @@ std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( replace_func[opr::WarpPerspectiveForward::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; +} + +std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( + size_t pack_c_size) { + auto ret = std::make_unique(pack_c_size); + ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); + std::string convter_pass_name = "conv_format_nchw88"; + if (pack_c_size == 4) { + convter_pass_name = "conv_format_nchw44"; + } + ret->fill_opr_convert_fun(pack_c_size); + ret->set_name(convter_pass_name); + return ret; +} + +/* ================ EnableNchw44DotPass =============== */ +VarNode* EnableNchw44DotPass::on_graph_endpoint_var(VarNode* new_var, + VarNode* orig_var) const { + if (!orig_var->shape().eq_shape(new_var->shape())) { + return RelayoutPlaceholder::make( + new_var, RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW) + .node(); + } + return new_var; +} + +std::unique_ptr +EnableNchw44DotPass::make_nchw44_dot_converter() { + auto ret = std::make_unique(); + ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); + //! First is whether the conv can trans to nchwxx, second is the filter + //! trans mode + using RelayoutMode = RelayoutPlaceholder::LayoutType; + using TestTransResult = std::pair; + megdnn::param::ConvolutionV0::Format conv_dot_format = + megdnn::param::ConvBias::Format::NCHW44_DOT; + constexpr size_t pack_c_size = 4_z; + auto test_trans_nchw44_dot = + [](const megdnn::param::Convolution::Sparse conv_mode, + const VarNode* filter) -> TestTransResult { + TestTransResult ret{TransType::TRANS_NONE, {}}; + if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { + size_t IC = filter->shape()[1]; + size_t OC = filter->shape()[0]; + if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { + ret.first = TransType::TRANS_PURE_NCHWXX; + ret.second = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; + } else if (IC < pack_c_size && OC % pack_c_size == 0) { + ret.first = TransType::TRANS_HYBIRD_NCHWXX; + ret.second = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; + } + } else { + mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); + size_t group = filter->shape()[0]; + size_t ocpg = filter->shape()[1]; + size_t icpg = filter->shape()[2]; + if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) { + ret.first = TransType::TRANS_NONE; + } else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) { + ret.first = TransType::TRANS_PURE_NCHWXX; + ret.second = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; + } + } + return ret; + }; + auto replace_conv_opr = [test_trans_nchw44_dot, conv_dot_format]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& conv_opr = opr->cast_final_safe(); + mgb_assert(conv_opr.param().format == + megdnn::param::Convolution::Format::NCHW, + "ConvertFormat Pass only support converting NCHW to " + "NCHW44_DOT"); + auto is_trans = + test_trans_nchw44_dot(conv_opr.param().sparse, new_inp[1]); + //! can not trans to nchwxx + if (is_trans.first == TransType::TRANS_NONE) { + mgb_assert(new_inp[1]->shape().ndim == 4 || + new_inp[1]->shape().ndim == 5, + "The origin filter is not NCHW mode"); + VarNodeArray temp_inp = new_inp; + //! if src is nchwxx, should RelayoutPlaceholder to nchw + if (temp_inp[0]->shape().ndim == 5) { + auto new_src = RelayoutPlaceholder::make( + new_inp[0], RelayoutMode::NCHW4_TO_NCHW); + temp_inp[0] = new_src.node(); + } + auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, + opr->config()); + return new_opr; + } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { + //! filter trans to nchwxx mode + mgb_assert(new_inp[1]->shape().ndim == 4 || + new_inp[1]->shape().ndim == 5, + "The origin filter is not NCHW mode"); + VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; + auto new_filter = + RelayoutPlaceholder::make(new_inp[1], is_trans.second); + conv_filter = new_filter.node(); + //! src trans to nchwxx mode + if (new_inp[0]->shape().ndim != 5) { + mgb_assert(new_inp[0]->shape().ndim == 4); + auto new_src = RelayoutPlaceholder::make( + new_inp[0], RelayoutMode::NCHW_TO_NCHW4); + conv_src = new_src.node(); + } + auto new_param = conv_opr.param(); + new_param.format = conv_dot_format; + mgb_assert(conv_src->shape().ndim == 5 && + conv_filter->shape().ndim >= 6, + "The conv src dim is not trans to nchwxx"); + auto new_conv_opr = opr::Convolution::make( + conv_src, conv_filter, new_param, + conv_opr.execution_policy(), conv_opr.config()); + OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr(); + mgb_assert(new_conv_opr.shape().ndim == 5, + "The conv dst dim is not trans to nchwxx"); + return new_opr; + } else { + mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); + VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; + auto new_filter = + RelayoutPlaceholder::make(new_inp[1], is_trans.second); + conv_filter = new_filter.node(); + mgb_assert(conv_src->shape().ndim == 4 && + conv_filter->shape().ndim == 5, + "The src and filter is OK"); + auto new_param = conv_opr.param(); + new_param.format = conv_dot_format; + auto new_conv_opr = opr::Convolution::make( + conv_src, conv_filter, new_param, + conv_opr.execution_policy(), conv_opr.config()); + OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr(); + mgb_assert(new_conv_opr.shape().ndim == 5, + "The conv dst dim is not trans to nchwxx"); + return new_opr; + } + }; + + auto replace_conv_bias_opr = [test_trans_nchw44_dot, conv_dot_format]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& conv_bias_opr = opr->cast_final_safe(); + mgb_assert(conv_bias_opr.param().format == + megdnn::param::ConvBias::Format::NCHW, + "ConvertFormat Pass only support converting NCHW to NCHWXX"); + auto is_trans = + test_trans_nchw44_dot(conv_bias_opr.param().sparse, new_inp[1]); + //! can not trans to nchwxx + if (is_trans.first == TransType::TRANS_NONE) { + mgb_assert(new_inp[1]->shape().ndim == 4 || + new_inp[1]->shape().ndim == 5, + "The origin filter is not NCHW mode"); + VarNodeArray temp_inp = new_inp; + //! if src is nchwxx, should RelayoutPlaceholder to nchw + if (temp_inp[0]->shape().ndim == 5) { + auto new_src = RelayoutPlaceholder::make( + new_inp[0], RelayoutMode::NCHW4_TO_NCHW); + temp_inp[0] = new_src.node(); + } + //! the bias is nchwxx + if (temp_inp[2]->shape().ndim == 5) { + auto new_bias = RelayoutPlaceholder::make( + new_inp[2], RelayoutMode::NCHW4_TO_NCHW); + temp_inp[2] = new_bias.node(); + } + auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, + opr->config()); + return new_opr; + } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { + VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], + *conv_bias_bias = new_inp[2]; + //! filter trans to nchwxx mode + mgb_assert(new_inp[1]->shape().ndim == 4 || + new_inp[1]->shape().ndim == 5, + "The origin filter is not NCHW mode"); + auto new_filter = + RelayoutPlaceholder::make(new_inp[1], is_trans.second); + conv_bias_filter = new_filter.node(); + //! src trans to nchwxx mode + if (new_inp[0]->shape().ndim != 5) { + mgb_assert(new_inp[0]->shape().ndim == 4); + auto new_src = RelayoutPlaceholder::make( + new_inp[0], RelayoutMode::NCHW_TO_NCHW4); + conv_bias_src = new_src.node(); + } + //! bias trans to nchwxx mode, bias may be scale + if (new_inp[2]->shape().ndim == 4) { + auto new_bias = RelayoutPlaceholder::make( + new_inp[2], RelayoutMode::NCHW_TO_NCHW4); + conv_bias_bias = new_bias.node(); + } + + auto new_param = conv_bias_opr.param(); + new_param.format = conv_dot_format; + mgb_assert(conv_bias_src->shape().ndim == 5 && + conv_bias_filter->shape().ndim >= 6, + "The conv_bias src dim is not trans to nchwxx"); + auto new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, + conv_bias_opr.execution_policy(), conv_bias_opr.config()); + OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); + mgb_assert(new_conv_bias_opr.shape().ndim == 5, + "The conv_bias dst dim is not trans to nchwxx"); + return new_opr; + } else { + mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); + VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], + *conv_bias_bias = new_inp[2]; + auto new_filter = + RelayoutPlaceholder::make(new_inp[1], is_trans.second); + conv_bias_filter = new_filter.node(); + //! bias trans to nchwxx mode, bias may be scale + if (new_inp[2]->shape().ndim == 4) { + auto new_bias = RelayoutPlaceholder::make( + new_inp[2], RelayoutMode::NCHW_TO_NCHW4); + conv_bias_bias = new_bias.node(); + } + mgb_assert(conv_bias_src->shape().ndim == 4 && + conv_bias_filter->shape().ndim == 5); + mgb_assert((conv_bias_bias->shape().ndim == 5) || + conv_bias_bias->shape().is_scalar()); + auto new_param = conv_bias_opr.param(); + new_param.format = conv_dot_format; + auto new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, + conv_bias_opr.execution_policy(), conv_bias_opr.config()); + OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); + mgb_assert(new_conv_bias_opr.shape().ndim == 5, + "The conv dst dim is not trans to nchwxx"); + return new_opr; + } + }; + ret->fill_opr_convert_fun(4); + auto&& replace_func = ret->m_opr_replace_func; + //! supportted nchwxx + replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; + replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; return ret; } diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index 351a3a6750af5b2967a7477675df099180a5818f..d2556e15d0706da3c312a144bbc5d21b212a6bef 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -236,8 +236,10 @@ namespace gopt { VarNode* on_graph_endpoint_var(VarNode* new_var, VarNode* orig_var) const override; public: - const char* name() const override { return mgb_cstr_log("tensor_format_nchw4"); } - + const char* name() const override { + return mgb_cstr_log("tensor_format_nchw4"); + } + //! make nchw -> nchw4 converter opt pass static std::unique_ptr make_nchw4_converter(); }; @@ -246,30 +248,48 @@ namespace gopt { * \brief convert tensor format to nchwxx to speed up inference on certain * devices */ - class EnableNchwxxPass final : public TensorReformatPass { + class EnableNchwxxPass : public TensorReformatPass { std::string m_name = "tensor_format_nchwxx"; size_t m_pack_c_size; VarNode* on_graph_endpoint_var(VarNode* new_var, VarNode* orig_var) const override; + public: + EnableNchwxxPass(size_t pack_c_size) : m_pack_c_size(pack_c_size) {} + //! the flag for conv to transform to nchwxx enum class TransType { TRANS_PURE_NCHWXX, //!< weight and src all trans to nchwxx TRANS_HYBIRD_NCHWXX, //!< input is nchw, output is nchwxx TRANS_NONE, //!< no need trans }; - - public: - EnableNchwxxPass(size_t pack_c_size) : m_pack_c_size(pack_c_size) {} const char* name() const override { return mgb_cstr_log(m_name.c_str()); } void set_name(std::string in_name) { m_name = in_name; } + + void fill_opr_convert_fun(size_t pack_c_size); + //! make nchw -> nchwxx converter opt pass, pack_c_size is the x, like //! 4,8,16 static std::unique_ptr make_nchwxx_converter( size_t pack_c_size); }; + /*! + * \brief convert tensor format from nchw44 to nchw44_dot to speed up + * inference on armv8.2 + */ + class EnableNchw44DotPass final : public EnableNchwxxPass { + std::string m_name = "tensor_format_nchw44_dot"; + VarNode* on_graph_endpoint_var(VarNode* new_var, + VarNode* orig_var) const override; + + public: + EnableNchw44DotPass() : EnableNchwxxPass(4) {} + //! make nchw44 -> nchw44_dot converter opt pass + static std::unique_ptr make_nchw44_dot_converter(); + }; + struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {}; /*! diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 535db7f0efa971405f8e355be5c7e937706ae614..adc2aba60cfef3d4dd2f0f69d3981e4ef71e4cc3 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -2356,7 +2356,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { .rename(name), dtype); }; - + auto x = mkvar("x", {2, 4, 16, 16}, dtype::QuantizedS8(2.5f)); opr::ConvBias::Param param_conv_bias; param_conv_bias.format = opr::ConvBias::Param::Format::NCHW; @@ -2376,7 +2376,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); auto conv2 = opr::ConvBiasForward::make(conv1, w2, b2, param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); - + auto y = opr::TypeCvt::make(conv2, dtype::Float32()); SymbolVar y_opt; @@ -2617,4 +2617,86 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); } +TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name); + }; + + auto host_x = gen({2, 3, 16, 16}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + //!Hybrid nchw88 mode + opr::Convolution::Param param_conv; + param_conv.pad_h = param_conv.pad_w = 1; + auto w1 = mkcvar("w1", {8, 3, 3, 3}), + conv1 = opr::Convolution::make(x, w1, param_conv); + //!channel wise + opr::ConvBias::Param param_conv_bias; + param_conv_bias.pad_h = param_conv_bias.pad_w = 1; + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; + auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}), + conv2 = opr::ConvBias::make(conv1, w2, b2, param_conv_bias); + //! group + auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}), + conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias); + + auto shape_of = opr::GetVarShape::make(conv3); + auto subtensor = opr::Subtensor::make( + shape_of, {opr::Subtensor::AxisIndexer::make_interval( + 0, x.make_scalar(2), None, x.make_scalar(1))}); + opr::Resize::Param param_resize; + param_resize.format = opr::Resize::Param::Format::NCHW; + auto resize = opr::ResizeForward::make(conv3, subtensor * 2, param_resize); + auto mat = mkcvar("mat", {2, 3, 3}), + warp = opr::WarpPerspectiveForward::make( + resize, mat, nullptr, cg::var_from_tensor_shape(x, {4, 4})); + + auto b = mkvar("b", {1, 8, 1, 1}), + elem = opr::Elemwise::make({warp + b}, + opr::Elemwise::Param::Mode::RELU); + //! Dense + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; + param_conv_bias.pad_h = param_conv_bias.pad_w = 1; + auto w4 = mkcvar("w4", {4, 8, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), + conv4 = opr::ConvBias::make(elem, w4, b4, param_conv_bias); + auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}), + conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias); + auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), + y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); + + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nchw44_dot(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, + find_opr(y_opt).param().format); + ASSERT_EQ(opr::Convolution::Param::Format::NCHW, + find_opr(y_opt).param().format); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath( + output_file("TestGoptInference.ConvertFormatNCHW44.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + //! meybe go to winograd in x86-32, so set error 1e-1 + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); + + *host_x = *gen({2, 3, 32, 32}, cn); + func->execute(); + //! meybe go to winograd in x86-32, so set error 1e-1 + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}