From 554ce352c8f6d5974cdfd321ad3a98bd8161a0bc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 17 Apr 2020 17:27:28 +0800 Subject: [PATCH] feat(mgb/gopt): add nchw44 optpass GitOrigin-RevId: dc38724558b0c6635ea9a3137e1c0d0acc665a0f --- python_module/megengine/_internal/__init__.py | 8 +- python_module/src/swig/misc.i | 1 + sdk/load-and-run/dump_with_testcase_mge.py | 7 + src/gopt/impl/framework.cpp | 3 + src/gopt/impl/tensor_reformat.cpp | 263 +++++++++++++++++- src/gopt/include/megbrain/gopt/inference.h | 9 +- src/gopt/test/inference.cpp | 81 ++++++ src/plugin/impl/opr_footprint.cpp | 12 +- 8 files changed, 364 insertions(+), 20 deletions(-) diff --git a/python_module/megengine/_internal/__init__.py b/python_module/megengine/_internal/__init__.py index c691f2ddd..5106ffd96 100644 --- a/python_module/megengine/_internal/__init__.py +++ b/python_module/megengine/_internal/__init__.py @@ -541,7 +541,8 @@ def optimize_for_inference( fuse_conv_bias_nonlinearity=False, use_tensor_core=False, fuse_conv_bias_with_z=False, - use_nchw88=False + use_nchw88=False, + use_nchw44=False ): """optimize computing graph for inference @@ -559,7 +560,9 @@ def optimize_for_inference( OpenCL devices :param fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty into one opr. This is supported only in NHWCD4 format. - :param use_nchw88: whether to use NCHW4 tensor format. This maybe faster some + :param use_nchw88: whether to use NCHW88 tensor format. This maybe faster some + times. + :param use_nchw44: whether to use NCHW44 tensor format. This maybe faster some times. @@ -577,6 +580,7 @@ def optimize_for_inference( "use_tensor_core", "fuse_conv_bias_with_z", "use_nchw88", + "use_nchw44", ]: if settings[i]: getattr(opt, "enable_{}".format(i))() diff --git a/python_module/src/swig/misc.i b/python_module/src/swig/misc.i index 9e1421efc..0b96f763f 100644 --- a/python_module/src/swig/misc.i +++ b/python_module/src/swig/misc.i @@ -79,6 +79,7 @@ struct _OptimizeForInferenceOptions { SET(use_tensor_core); SET(fuse_conv_bias_with_z); SET(use_nchw88); + SET(use_nchw44); #undef SET }; diff --git a/sdk/load-and-run/dump_with_testcase_mge.py b/sdk/load-and-run/dump_with_testcase_mge.py index 42a99f4ac..a87ade529 100755 --- a/sdk/load-and-run/dump_with_testcase_mge.py +++ b/sdk/load-and-run/dump_with_testcase_mge.py @@ -253,6 +253,7 @@ def optimize_for_inference(args, outputs): 'enable_ioc16': 'f16_io_comp', 'enable_hwcd4': 'use_nhwcd4', 'enable_nchw88': 'use_nchw88', + 'enable_nchw44': 'use_nchw44', 'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity', 'enable_tensorcore': 'use_tensor_core', 'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z', @@ -385,6 +386,12 @@ def main(): help='transform the model format from NCHW to NCHW88 ' 'for inference' ) + parser.add_argument( + '--enable-nchw44', + action='store_true', + help='transform the model format from NCHW to NCHW44 ' + 'for inference' + ) parser.add_argument( '--enable-tensorcore', action='store_true', diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index c7656d01b..570e05f7d 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -700,6 +700,9 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( if (inference_opt->use_nchw88) { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); } + if (inference_opt->use_nchw44) { + add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); + } if (inference_opt->use_tensor_core) { mgb_assert(inference_opt->fuse_conv_bias_nonlinearity, "enable tensor core should fuse conv bias activation " diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index e8c475e9d..0579eaa7e 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "megbrain/gopt/inference.h" @@ -63,7 +64,10 @@ public: NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout + NCHW_TO_NCHW44, //!< from nchw layout to nchw44 layout NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout + NCHW44_TO_NCHW, //!< from nchw44 layout to nchw layout + WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 //!< layout WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to @@ -73,6 +77,16 @@ public: //!< the weight layout of input is nchw output is nchw88, special for //!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8} WEIGHT_HYBIRD_NCHW_NCHW88, + + WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44 + //!< layout + WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to + //!< nchw44 layout + WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout + //!< to nchw44 layout + //!< the weight layout of input is nchw output is nchw44, special for + //!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4} + WEIGHT_HYBIRD_NCHW_NCHW44, }; RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type); @@ -203,10 +217,8 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { dst[3] = inp_shape[3]; dst[4] = inp_shape[4]; dst[5] = 8; - } else { - mgb_assert( - layout_type() == - RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88); + } else if (layout_type() == + RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88) { mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0); dst.ndim = 5; dst[0] = inp_shape[0] / 8; @@ -214,6 +226,68 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { dst[2] = inp_shape[3]; dst[3] = inp_shape[1]; dst[4] = 8; + } else if (layout_type() == + RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW44) { + mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); + dst.ndim = 5; + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1] / 4; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + dst[4] = 4; + } else if (layout_type() == + RelayoutPlaceholder::LayoutType::NCHW44_TO_NCHW) { + mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); + dst.ndim = 4; + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1] * 4; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + } else if (layout_type() == RelayoutPlaceholder::LayoutType:: + WEIGHT_NCHW_TO_NCHW44_DENSE) { + mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0 && + inp_shape[1] % 4 == 0); + dst.ndim = 6; + dst[0] = inp_shape[0] / 4; + dst[1] = inp_shape[1] / 4; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + dst[4] = 4; + dst[5] = 4; + } else if (layout_type() == RelayoutPlaceholder::LayoutType:: + WEIGHT_NCHW_TO_NCHW44_GROUP) { + mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 4 == 0 && + inp_shape[2] % 4 == 0); + dst.ndim = 7; + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1] / 4; + dst[2] = inp_shape[2] / 4; + dst[3] = inp_shape[3]; + dst[4] = inp_shape[4]; + dst[5] = 4; + dst[6] = 4; + } else if (layout_type() == RelayoutPlaceholder::LayoutType:: + WEIGHT_NCHW_TO_NCHW44_CHAN) { + mgb_assert(inp_shape.ndim == 5 && inp_shape[1] == 1 && + inp_shape[2] == 1 && inp_shape[0] % 4 == 0); + dst.ndim = 6; + dst[0] = inp_shape[0] / 4; + dst[1] = inp_shape[1]; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + dst[4] = inp_shape[4]; + dst[5] = 4; + } else { + mgb_assert( + layout_type() == + RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW44); + mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0); + dst.ndim = 5; + dst[0] = inp_shape[0] / 4; + dst[1] = inp_shape[2]; + dst[2] = inp_shape[3]; + dst[3] = inp_shape[1]; + dst[4] = 4; } return true; }; @@ -418,6 +492,104 @@ void TensorReformatPass::translate_pass(OptState& opt) const { auto y2 = opr::Reshape::make(y1, tshp1); return y2.node(); }; + reformat[LayoutType::NCHW_TO_NCHW44] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::NCHW44_TO_NCHW] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); + auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); + auto y1 = opr::Reshape::make(y0, tshp0); + return y1.node(); + }; + reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DENSE] = + [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0) / 4, cv(4), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), + tshp1 = opr::Concat::make( + {sub(0) / 4, sub(1) / 4, sub(2), sub(3), cv(4), cv(4)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 3, 1}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_GROUP] = + [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4, + cv(4), sub(3), sub(4)}, + 0), + tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3), + sub(4), cv(4), cv(4)}, + 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 4, 2}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_CHAN] = + [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0) / 4, cv(4), sub(1), sub(2), sub(3), sub(4)}, 0), + tshp1 = opr::Concat::make( + {sub(0) / 4, sub(1), sub(2), sub(3), sub(4), cv(4)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 2, 3, 4, 5, 1}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::WEIGHT_HYBIRD_NCHW_NCHW44] = + [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0) / 4, cv(4), sub(1), sub(2), sub(3)}, 0), + tshp1 = opr::Concat::make( + {sub(0) / 4, sub(2), sub(3), sub(1), cv(4)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 3, 4, 2, 1}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; auto rewriter = opt.graph().make_rewriter(); auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) { @@ -1071,16 +1243,24 @@ std::unique_ptr EnableCHWN4Pass::make_chwn4_converter() { VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, VarNode* orig_var) const { if (!orig_var->shape().eq_shape(new_var->shape())) { - return RelayoutPlaceholder::make( - new_var, RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW) - .node(); + if (m_pack_c_size == 8) { + return RelayoutPlaceholder::make( + new_var, + RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW) + .node(); + } else if (m_pack_c_size == 4) { + return RelayoutPlaceholder::make( + new_var, + RelayoutPlaceholder::LayoutType::NCHW44_TO_NCHW) + .node(); + } } return new_var; } std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( size_t pack_c_size) { - auto ret = std::make_unique(); + auto ret = std::make_unique(pack_c_size); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); //! First is whether the conv can trans to nchwxx, second is the filter //! trans mode @@ -1102,8 +1282,18 @@ std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( megdnn::param::Pooling::Format pooling_format = megdnn::param::Pooling::Format::NCHW88; std::string convter_pass_name = "conv_format_nchw88"; - mgb_assert(pack_c_size == static_cast(8), - "The ConvertFormatPass to nchwxx only support NCHW88 now !"); + if (pack_c_size == 4) { + weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; + weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP; + weight_to_nchwxx_mode_chan = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN; + hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; + src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW44; + src_to_nchw_mode = RelayoutMode::NCHW44_TO_NCHW; + conv_bias_format = megdnn::param::ConvBias::Format::NCHW44; + conv_format = megdnn::param::ConvolutionV0::Format::NCHW44; + pooling_format = megdnn::param::Pooling::Format::NCHW44; + convter_pass_name = "conv_format_nchw44"; + } auto test_trans_nchwxx = [pack_c_size, weight_to_nchwxx_mode_dense, weight_to_nchwxx_mode_group, weight_to_nchwxx_mode_chan, @@ -1297,7 +1487,7 @@ std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( auto new_param = conv_bias_opr.param(); new_param.format = conv_bias_format; auto new_conv_bias_opr = opr::ConvBias::make( - conv_bias_src, conv_bias_filter, new_param, + conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, conv_bias_opr.execution_policy(), conv_bias_opr.config()); OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); mgb_assert(new_conv_bias_opr.shape().ndim == 5, @@ -1330,6 +1520,51 @@ std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( } }; + auto replace_concat_opr = [=](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + bool has_inp_changed = false; + bool can_exec_ncwxx = true; + for (size_t i = 0; i < opr->input().size(); i++) { + if (new_inp[i]->shape().ndim == 5) { + has_inp_changed = true; + break; + } else if (new_inp[i]->shape().ndim == 4) { + if (new_inp[i]->shape()[1] % pack_c_size != 0) { + can_exec_ncwxx = false; + } + } + } + if (has_inp_changed) { + auto temp_inp = new_inp; + if (can_exec_ncwxx) { + for (size_t i = 0; i < opr->input().size(); i++) { + if (new_inp[i]->shape().ndim == 4) { + auto new_var = RelayoutPlaceholder::make( + new_inp[i], src_to_nchwxx_mode); + temp_inp[i] = new_var.node(); + } else { + mgb_assert((new_inp[i]->shape().ndim == 5) || + new_inp[i]->shape().is_scalar()); + } + } + } else { + for (size_t i = 0; i < opr->input().size(); i++) { + if (new_inp[i]->shape().ndim == 5) { + auto new_var = RelayoutPlaceholder::make( + new_inp[i], src_to_nchw_mode); + temp_inp[i] = new_var.node(); + } + } + } + return serialization::copy_opr_shallow(*opr, temp_inp, + opr->config()); + } else { + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + }; + auto replace_elemwise_opr = [=](OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); @@ -1382,6 +1617,7 @@ std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; + replace_func[opr::Concat::typeinfo()] = replace_concat_opr; replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; @@ -1390,13 +1626,10 @@ std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( replace_func[opr::ConvolutionBackwardData::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::WarpPerspectiveForward::typeinfo()] = diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index ce353d501..af86d61bf 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -234,16 +234,18 @@ namespace gopt { */ class EnableNchwxxPass final : public TensorReformatPass { std::string m_name = "tensor_format_nchwxx"; + size_t m_pack_c_size; VarNode* on_graph_endpoint_var(VarNode* new_var, VarNode* orig_var) const override; //! the flag for conv to transform to nchwxx enum class TransType { - TRANS_PURE_NCHWXX, //!< weight and src all trans to nchw88 - TRANS_HYBIRD_NCHWXX, //!< input is nchw, output is nchw88 + TRANS_PURE_NCHWXX, //!< weight and src all trans to nchwxx + TRANS_HYBIRD_NCHWXX, //!< input is nchw, output is nchwxx TRANS_NONE, //!< no need trans }; public: + EnableNchwxxPass(size_t pack_c_size) : m_pack_c_size(pack_c_size) {} const char* name() const override { return mgb_cstr_log(m_name.c_str()); } @@ -265,6 +267,8 @@ namespace gopt { bool use_nhwcd4 = false; //! whether to compute using NCHW88 tensor format bool use_nchw88 = false; + //! whether to compute using NCHW44 tensor format + bool use_nchw44 = false; //! whether to enable tensor core bool use_tensor_core = false; //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) @@ -283,6 +287,7 @@ namespace gopt { SET(use_tensor_core); SET(fuse_conv_bias_with_z); SET(use_nchw88); + SET(use_nchw44); #undef SET }; diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 43cc6d4e9..9ee21c1ce 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -2325,5 +2325,86 @@ TEST(TestGoptInference, ConvertFormatNCHW88) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); } +TEST(TestGoptInference, ConvertFormatNCHW44) { + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name); + }; + + auto host_x = gen({2, 3, 16, 16}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + //!Hybrid nchw88 mode + opr::Convolution::Param param_conv; + param_conv.pad_h = param_conv.pad_w = 1; + auto w1 = mkcvar("w1", {8, 3, 3, 3}), + conv1 = opr::Convolution::make(x, w1, param_conv); + //!channel wise + opr::ConvBias::Param param_conv_bias; + param_conv_bias.pad_h = param_conv_bias.pad_w = 1; + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; + auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}), + conv2 = opr::ConvBias::make(conv1, w2, b2, param_conv_bias); + //! group + auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}), + conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias); + + auto shape_of = opr::GetVarShape::make(conv3); + auto subtensor = opr::Subtensor::make( + shape_of, {opr::Subtensor::AxisIndexer::make_interval( + 0, x.make_scalar(2), None, x.make_scalar(1))}); + opr::Resize::Param param_resize; + param_resize.format = opr::Resize::Param::Format::NCHW; + auto resize = opr::ResizeForward::make(conv3, subtensor * 2, param_resize); + auto mat = mkcvar("mat", {2, 3, 3}), + warp = opr::WarpPerspectiveForward::make( + resize, mat, nullptr, cg::var_from_tensor_shape(x, {4, 4})); + + auto b = mkvar("b", {1, 8, 1, 1}), + elem = opr::Elemwise::make({warp + b}, + opr::Elemwise::Param::Mode::RELU); + //! Dense + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; + param_conv_bias.pad_h = param_conv_bias.pad_w = 1; + auto w4 = mkcvar("w4", {4, 8, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), + conv4 = opr::ConvBias::make(elem, w4, b4, param_conv_bias); + auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}), + conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias); + auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), + y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); + + SymbolVar y_opt; + unpack_vector( + gopt::optimize_for_inference( + {y}, + gopt::OptimizeForInferenceOptions{}.enable_use_nchw44()), + y_opt); + + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, + find_opr(y_opt).param().format); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath( + output_file("TestGoptInference.ConvertFormatNCHW44.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + //! meybe go to winograd in x86-32, so set error 1e-1 + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); + + *host_x = *gen({2, 3, 32, 32}, cn); + func->execute(); + //! meybe go to winograd in x86-32, so set error 1e-1 + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); +} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/plugin/impl/opr_footprint.cpp b/src/plugin/impl/opr_footprint.cpp index a33efbef3..ae01f39d5 100644 --- a/src/plugin/impl/opr_footprint.cpp +++ b/src/plugin/impl/opr_footprint.cpp @@ -99,7 +99,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, group = filter_shape[0]; } if (param.format == Param::Format::NCHW88) { - //! if channel wise weight layout is {group/8, 1, 1, FH, FW, 8} + //! if channel wise weight layout is {group/8, FH, FW, 1, 1, 8} if (filter_shape[1] == 1 && filter_shape[2] == 1) { group *= 8; } @@ -107,6 +107,15 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, src_shape[1] / group * 2; return hybird_nchwx ? computation : computation * 8; } + if (param.format == Param::Format::NCHW44) { + //! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4} + if (filter_shape[1] == 1 && filter_shape[2] == 1) { + group *= 4; + } + size_t computation = dst_shape.total_nr_elems() * fh * fw * + src_shape[1] / group * 2; + return hybird_nchwx ? computation : computation * 4; + } if (param.format == Param::Format::NCHW32) { return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * 32 / group * 2; @@ -135,6 +144,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, }; if (param.format == Param::Format::NCHW4 || param.format == Param::Format::NCHW88 || + param.format == Param::Format::NCHW44 || param.format == Param::Format::NCHW32) { return eval_conv_computation_nchwx(); } -- GitLab