diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 87e4d4f3e8df986fab39cadcb4d9f0388389086a..3a0bc60189a521eaa8f44807576e8876c02cbd8c 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -1753,6 +1753,24 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { } }; + auto replace_reduce_opr = [=](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& reduce_opr = opr->cast_final_safe(); + + VarNodeArray temp_inp = new_inp; + if (!opr->input(0)->shape().eq_shape(new_inp[0]->shape())) { + mgb_assert(opr->input(0)->shape().ndim == 4); + mgb_assert(new_inp[0]->shape().ndim == 5); + if (reduce_opr.param().axis != 2 && reduce_opr.param().axis != 3) { + auto new_var = + RelayoutPlaceholder::make(new_inp[0], src_to_nchw_mode); + temp_inp[0] = new_var.node(); + } + } + return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); + }; + //! When input change and all input can convert to nchwxx, this opr will run //! in nchwxx mode, else it will run in nchw mode, for example concat and //! elemwise opr @@ -1829,13 +1847,13 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { replace_func[opr::TypeCvt::typeinfo()] = replace_multi_inp_opr; replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_multi_inp_opr; replace_func[opr::PowC::typeinfo()] = replace_multi_inp_opr; + replace_func[opr::Reduce::typeinfo()] = replace_reduce_opr; //! not support yet replace_func[opr::ConvolutionBackwardData::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::WarpPerspectiveForward::typeinfo()] = diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 3b8f59cf07139de670ebf10be98ad90cc8f50c9b..b16c0df91ab5c284fe50cdc01349ee24e9adf9cb 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -3476,14 +3476,24 @@ TEST(TestGoptInference, ConvertFormatNCHW88) { //! group auto w3 = mkcvar("w3", {1, 8, 8, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}), conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias); - - auto shape_of = opr::GetVarShape::make(conv3); + //! reduce + opr::Reduce::Param param_reduce1; + param_reduce1.axis = 2; + param_reduce1.mode = opr::Reduce::Mode::SUM; + opr::Reduce::Param param_reduce2; + param_reduce2.axis = 0; + param_reduce2.mode = opr::Reduce::Mode::MAX; + auto reduce1 = conv3 + opr::Reduce::make(conv3, param_reduce1) + + opr::Reduce::make(conv3, param_reduce2); + + auto shape_of = opr::GetVarShape::make(reduce1); auto subtensor = opr::Subtensor::make( shape_of, {opr::Subtensor::AxisIndexer::make_interval( 0, x.make_scalar(2), None, x.make_scalar(1))}); opr::Resize::Param param_resize; param_resize.format = opr::Resize::Param::Format::NCHW; - auto resize = opr::ResizeForward::make(conv3, subtensor * 2, param_resize); + auto resize = + opr::ResizeForward::make(reduce1, subtensor * 2, param_resize); auto mat = mkcvar("mat", {2, 3, 3}), warp = opr::WarpPerspectiveForward::make( resize, mat, nullptr, cg::var_from_tensor_shape(x, {4, 4})); @@ -3586,14 +3596,24 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { //! group auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}), conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias); - - auto shape_of = opr::GetVarShape::make(conv3); + //! reduce + opr::Reduce::Param param_reduce1; + param_reduce1.axis = 1; + param_reduce1.mode = opr::Reduce::Mode::MIN; + opr::Reduce::Param param_reduce2; + param_reduce2.axis = 3; + param_reduce2.mode = opr::Reduce::Mode::SUM_SQR; + auto reduce1 = conv3 + opr::Reduce::make(conv3, param_reduce1) + + opr::Reduce::make(conv3, param_reduce2); + + auto shape_of = opr::GetVarShape::make(reduce1); auto subtensor = opr::Subtensor::make( shape_of, {opr::Subtensor::AxisIndexer::make_interval( 0, x.make_scalar(2), None, x.make_scalar(1))}); opr::Resize::Param param_resize; param_resize.format = opr::Resize::Param::Format::NCHW; - auto resize = opr::ResizeForward::make(conv3, subtensor * 2, param_resize); + auto resize = + opr::ResizeForward::make(reduce1, subtensor * 2, param_resize); auto mat = mkcvar("mat", {2, 3, 3}), warp = opr::WarpPerspectiveForward::make( resize, mat, nullptr, cg::var_from_tensor_shape(x, {4, 4}));