diff --git a/src/gopt/impl/padding_channel.cpp b/src/gopt/impl/padding_channel.cpp index 52a4368be8cd41a6dfe080b68b5dce82e2bb238f..9b1a71eb2639c47869143b3d03597157516d425f 100644 --- a/src/gopt/impl/padding_channel.cpp +++ b/src/gopt/impl/padding_channel.cpp @@ -558,7 +558,7 @@ void PaddingChannelPass::add_condition_padding_oprs_replace_func(LayoutTrans) { if (reduce->input().size() > 1) { can_forward_padding = false; } else { - can_forward_padding = reduce->param().axis != 1; + can_forward_padding = axis != 1; } } else if (auto subtensor = opr->try_cast_final()) { auto indexs = subtensor->index_desc(); @@ -605,6 +605,7 @@ void PaddingChannelPass::add_nonpadding_oprs_replace_func(LayoutTrans) { return serialization::copy_opr_shallow(*opr, inps, opr->config()); }; m_opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs; + m_opr_replace_funcs[opr::AxisAddRemove::typeinfo()] = replace_nonpadding_oprs; m_opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs; m_opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs; m_opr_replace_funcs[opr::Dimshuffle::typeinfo()] = replace_nonpadding_oprs; diff --git a/src/gopt/test/padding_channel.cpp b/src/gopt/test/padding_channel.cpp index 67659ef004ac5d78b4013e3cf5da9aae87b5503a..697525b953c35982d32ae2de6c1bb0e54777325b 100644 --- a/src/gopt/test/padding_channel.cpp +++ b/src/gopt/test/padding_channel.cpp @@ -282,6 +282,85 @@ TEST(TestGoptInference, ChannelPaddingSubtensor) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-2); } +TEST(TestGoptInference, ChannelPaddingAxisAddRemove) { + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); + }; + + auto host_x = gen({1, 3, 8, 8}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + //! Hybrid nchw44 mode + opr::ConvBias::Param param_conv; + param_conv.pad_h = param_conv.pad_w = 1; + auto w1 = mkcvar("w1", {8, 3, 3, 3}), b1 = mkcvar("w1", {1, 8, 1, 1}), + conv1 = opr::ConvBias::make( + x, w1, b1, param_conv, {}, OperatorNodeConfig("conv1")); + + auto w2 = mkcvar("w2", {1, 8, 1, 1}), + conv2 = opr::Convolution::make(conv1, w2, {}, {}, OperatorNodeConfig("conv2")); + auto remove_axis_1 = opr::AxisAddRemove::make( + conv2, {opr::AxisAddRemove::AxisDesc::make_remove(1)}, + OperatorNodeConfig("remove_axis_1")); + auto add_axis_1 = opr::AxisAddRemove::make( + remove_axis_1, {opr::AxisAddRemove::AxisDesc::make_add(1)}); + auto w3 = mkcvar("w3", {3, 1, 1, 1}), + conv3 = opr::Convolution::make( + add_axis_1, w3, {}, {}, OperatorNodeConfig("conv3")); + auto remove_axis_0 = opr::AxisAddRemove::make( + conv3, {opr::AxisAddRemove::AxisDesc::make_remove(0)}, + OperatorNodeConfig("remove_axis_0")); + + SymbolVar y_pad; + unpack_vector( + gopt::GraphOptimizer{} + .add_pass(gopt::PaddingChannelPass::make( + cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW44, + true)) + .apply({{remove_axis_0}}) + .endpoint_vars(), + y_pad); + auto conv1_opt = find_opr(y_pad, "conv1"); + auto conv2_opt = find_opr(y_pad, "conv2"); + auto remove_axis_1_opt = find_opr(y_pad, "remove_axis_1"); + auto conv3_opt = find_opr(y_pad, "conv3"); + auto remove_axis_0_opt = find_opr(y_pad, "remove_axis_0"); + //! do not padding input tensor + ASSERT_EQ(conv1_opt->input(0)->shape()[1], 3); + //! output tensor padding input tensor + ASSERT_EQ(conv2_opt->input(1)->shape()[0], 4); + ASSERT_EQ(conv2_opt->output(0)->shape()[1], 4); + + //! AxisAddRemove always add subtensor + ASSERT_EQ(remove_axis_1_opt->input(0)->shape()[1], 1); + + ASSERT_EQ(conv3_opt->input(1)->shape()[0], 4); + ASSERT_EQ(conv3_opt->output(0)->shape()[1], 4); + + //! AxisAddRemove always add subtensor + ASSERT_EQ(remove_axis_0_opt->input(0)->shape()[1], 3); + + graph->compile({{y_pad, {}}}) + ->to_json() + ->writeto_fpath( + output_file("TestGoptInference.ChannelPaddingAxisAddRemove.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile( + {make_callback_copy(remove_axis_0, host_y), + make_callback_copy(y_pad, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-2); + + //! test change the input shape + *host_x = *gen({1, 3, 32, 32}, cn); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-2); +} + TEST(TestGoptInference, ChannelPaddingReduce) { HostTensorGenerator<> gen; auto cn = CompNode::load("cpu0");