diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 1b82f409b028bc2aca0267db4cb68c8b37bae046..b7865997a6edffcaaaf5243956d6603f6c597498 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -435,13 +435,10 @@ void TensorReformatPass::translate_pass(OptState& opt) const { return opr::IndexAt::make(xshp, {{0, cv(idx)}}); }; auto tshp0 = opr::Concat::make( - {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0); + {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0); auto y0 = opr::Reshape::make(x, tshp0); auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); + return y1.node(); }; reformat[LayoutType::NCHW4_TO_NCHW] = [](VarNode* inp) -> VarNode* { auto x = SymbolVar(inp); @@ -455,7 +452,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const { auto y1 = opr::Reshape::make(y0, tshp0); return y1.node(); }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] = [](VarNode* inp) -> VarNode* { + reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] = + [](VarNode* inp) -> VarNode* { auto x = SymbolVar(inp); auto xshp = opr::GetVarShape::make(x); auto cv = [&x](int v) { return x.make_scalar(v); }; @@ -471,7 +469,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const { auto y2 = opr::Reshape::make(y1, tshp1); return y2.node(); }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] = [](VarNode* inp) -> VarNode* { + reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] = + [](VarNode* inp) -> VarNode* { auto x = SymbolVar(inp); auto xshp = opr::GetVarShape::make(x); auto cv = [&x](int v) { return x.make_scalar(v); }; diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 8321fd4de1fa5458d30875b1831c74483da1f6ba..86c7c0092ec581d18660be0053551e2f17afbcd6 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -2450,6 +2450,8 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, find_opr(y_opt).param().format); + auto nr_reshape = find_opr_num(y_opt); + ASSERT_EQ(2u, nr_reshape); graph->compile({{y_opt, {}}}) ->to_json()