提交 66b79160 编写于 作者: M Megvii Engine Team

fix(src/gopt): fix padding channel pass bug that hasn't insert a subtensor before AxisAddRemove

GitOrigin-RevId: 01bd8c70e96f72e6c5089d73ec7da607462837db
上级 801265e9
......@@ -558,7 +558,7 @@ void PaddingChannelPass::add_condition_padding_oprs_replace_func(LayoutTrans) {
if (reduce->input().size() > 1) {
can_forward_padding = false;
} else {
can_forward_padding = reduce->param().axis != 1;
can_forward_padding = axis != 1;
}
} else if (auto subtensor = opr->try_cast_final<opr::Subtensor>()) {
auto indexs = subtensor->index_desc();
......@@ -605,6 +605,7 @@ void PaddingChannelPass::add_nonpadding_oprs_replace_func(LayoutTrans) {
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
m_opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs;
m_opr_replace_funcs[opr::AxisAddRemove::typeinfo()] = replace_nonpadding_oprs;
m_opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs;
m_opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs;
m_opr_replace_funcs[opr::Dimshuffle::typeinfo()] = replace_nonpadding_oprs;
......
......@@ -282,6 +282,85 @@ TEST(TestGoptInference, ChannelPaddingSubtensor) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-2);
}
TEST(TestGoptInference, ChannelPaddingAxisAddRemove) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};
auto host_x = gen({1, 3, 8, 8}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
//! Hybrid nchw44 mode
opr::ConvBias::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3}), b1 = mkcvar("w1", {1, 8, 1, 1}),
conv1 = opr::ConvBias::make(
x, w1, b1, param_conv, {}, OperatorNodeConfig("conv1"));
auto w2 = mkcvar("w2", {1, 8, 1, 1}),
conv2 = opr::Convolution::make(conv1, w2, {}, {}, OperatorNodeConfig("conv2"));
auto remove_axis_1 = opr::AxisAddRemove::make(
conv2, {opr::AxisAddRemove::AxisDesc::make_remove(1)},
OperatorNodeConfig("remove_axis_1"));
auto add_axis_1 = opr::AxisAddRemove::make(
remove_axis_1, {opr::AxisAddRemove::AxisDesc::make_add(1)});
auto w3 = mkcvar("w3", {3, 1, 1, 1}),
conv3 = opr::Convolution::make(
add_axis_1, w3, {}, {}, OperatorNodeConfig("conv3"));
auto remove_axis_0 = opr::AxisAddRemove::make(
conv3, {opr::AxisAddRemove::AxisDesc::make_remove(0)},
OperatorNodeConfig("remove_axis_0"));
SymbolVar y_pad;
unpack_vector(
gopt::GraphOptimizer{}
.add_pass(gopt::PaddingChannelPass::make(
cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW44,
true))
.apply({{remove_axis_0}})
.endpoint_vars(),
y_pad);
auto conv1_opt = find_opr<opr::ConvBias>(y_pad, "conv1");
auto conv2_opt = find_opr<opr::Convolution>(y_pad, "conv2");
auto remove_axis_1_opt = find_opr<opr::AxisAddRemove>(y_pad, "remove_axis_1");
auto conv3_opt = find_opr<opr::Convolution>(y_pad, "conv3");
auto remove_axis_0_opt = find_opr<opr::AxisAddRemove>(y_pad, "remove_axis_0");
//! do not padding input tensor
ASSERT_EQ(conv1_opt->input(0)->shape()[1], 3);
//! output tensor padding input tensor
ASSERT_EQ(conv2_opt->input(1)->shape()[0], 4);
ASSERT_EQ(conv2_opt->output(0)->shape()[1], 4);
//! AxisAddRemove always add subtensor
ASSERT_EQ(remove_axis_1_opt->input(0)->shape()[1], 1);
ASSERT_EQ(conv3_opt->input(1)->shape()[0], 4);
ASSERT_EQ(conv3_opt->output(0)->shape()[1], 4);
//! AxisAddRemove always add subtensor
ASSERT_EQ(remove_axis_0_opt->input(0)->shape()[1], 3);
graph->compile({{y_pad, {}}})
->to_json()
->writeto_fpath(
output_file("TestGoptInference.ChannelPaddingAxisAddRemove.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile(
{make_callback_copy(remove_axis_0, host_y),
make_callback_copy(y_pad, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-2);
//! test change the input shape
*host_x = *gen({1, 3, 32, 32}, cn);
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-2);
}
TEST(TestGoptInference, ChannelPaddingReduce) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册