提交 1e8337f1 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): remove redundant reshape in nchw->nchw4 pass

GitOrigin-RevId: 0f5c7c3e485b4da0cdfe9b0db3e23945ac43ee16
上级 946a340c
......@@ -435,13 +435,10 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0);
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
return y1.node();
};
reformat[LayoutType::NCHW4_TO_NCHW] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
......@@ -455,7 +452,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] = [](VarNode* inp) -> VarNode* {
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
......@@ -471,7 +469,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] = [](VarNode* inp) -> VarNode* {
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
......
......@@ -2452,6 +2452,8 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) {
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4,
find_opr<opr::ConvBias>(y_opt).param().format);
auto nr_reshape = find_opr_num<mgb::opr::Reshape>(y_opt);
ASSERT_EQ(2u, nr_reshape);
graph->compile({{y_opt, {}}})
->to_json()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册