diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index f80fac1700abeeec5b292974c0e8fe96142916c6..7d99c87341d37cbbb7fe694e14bde29254354fdb 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1635,6 +1635,7 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_chw; replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw; replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw; + replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw; replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; replace_func[opr::WarpPerspectiveForward::typeinfo()] = replace_warp_perspective_opr; diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index f1f5f952091bc72bf8b0d412e326692cef8ff2ce..229cd2c128e0a008314d9be9a2bd1d69b87bddde 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -1171,16 +1171,22 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) { opr::Elemwise::Param::Mode::RELU); param.pad_h = param.pad_w = 1; auto w2 = mkcvar("w2", {4, 4, 3, 3}), - y = opr::Convolution::make(elem, w2, param); + y = opr::Convolution::make(elem, w2, param), + z = opr::AxisAddRemove::make(y, {opr::AxisAddRemove::AxisDesc::make_add(0)}); - SymbolVar y_opt; + SymbolVar y_opt, z_opt; auto options = gopt::OptimizeForInferenceOptions{}; options.enable_nhwcd4(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + unpack_vector(gopt::optimize_for_inference({z}, options), z_opt); ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, find_opr(y_opt).param().format); + ASSERT_EQ(TensorFormat::Type::DEFAULT, + find_opr(z_opt).input(0)->format().type()); + ASSERT_EQ(4, find_opr(z_opt).input(0)->shape().ndim); + graph->compile({{y_opt, {}}}) ->to_json() ->writeto_fpath(