提交 dd8fd273 编写于 作者: M Megvii Engine Team

fix(gopt): fix channel padding pass of channel wise conv

GitOrigin-RevId: e98568d0bd3fd2c51c2b033b691557c4f91f5ccf
上级 3b557b11
......@@ -316,6 +316,9 @@ OperatorNodeBase* PaddingChannelPass::padding_channel_wise_conv_policy(
size_t pad_channels_1 = new_in_channels - group;
if (pad_channels_1) {
inps[1] = pad_in_channels(new_inp[1], pad_channels_1);
if (inps.size() >= 3) {
inps[2] = pad_in_channels(new_inp[2], pad_channels_1);
}
m_padding_oprs.insert(opr);
}
}
......
......@@ -60,7 +60,11 @@ T* find_opr(SymbolVar endpoint, const std::string& node_name) {
}
} // namespace
TEST(TestGoptInference, ChannelPaddingNCHW44) {
template <opr::Convolution::Param::Format T_format>
void check_channel_padding_conv() {
size_t scalar = 1;
if (T_format == opr::Convolution::Param::Format::NCHW88)
scalar = 2;
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
......@@ -69,74 +73,103 @@ TEST(TestGoptInference, ChannelPaddingNCHW44) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};
auto host_x = gen({1, 3, 8, 8}, cn);
auto host_x = gen({1, 3 * scalar, 8, 8}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
//! Hybrid nchw44 mode
opr::ConvBias::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3}), b1 = mkcvar("w1", {1, 8, 1, 1}),
auto w1 = mkcvar("w1", {8 * scalar, 3 * scalar, 3, 3}),
b1 = mkcvar("b1", {1, 8 * scalar, 1, 1}),
conv1 = opr::ConvBias::make(
x, w1, b1, param_conv, {}, OperatorNodeConfig("conv1"));
auto w2 = mkcvar("w2", {6, 8, 3, 3}), b2 = mkcvar("b2", {1, 6, 1, 1}),
auto w2 = mkcvar("w2", {6 * scalar, 8 * scalar, 3, 3}),
b2 = mkcvar("b2", {1, 6 * scalar, 1, 1}),
conv2 = opr::ConvBias::make(
conv1, w2, b2, param_conv, {}, OperatorNodeConfig("conv2"));
auto w3 = mkcvar("w3", {3, 6, 3, 3}), b3 = mkcvar("b3", {1, 3, 1, 1}),
auto w3 = mkcvar("w3", {3 * scalar, 6 * scalar, 3, 3}),
b3 = mkcvar("b3", {1, 3 * scalar, 1, 1}),
conv3 = opr::ConvBias::make(
conv2, w3, b3, param_conv, {}, OperatorNodeConfig("conv3"));
//! channel wise conv bias
opr::ConvBias::Param param_channel_conv_bias;
param_channel_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
auto w4 = mkcvar("w4", {3 * scalar, 1, 1, 1, 1}),
b4 = mkcvar("b4", {1, 3 * scalar, 1, 1}),
conv4 = opr::ConvBias::make(
conv3, w4, b4, param_channel_conv_bias, {},
OperatorNodeConfig("conv4"));
opr::Convolution::Param param_convolution;
param_convolution.sparse = opr::Convolution::Param::Sparse::GROUP;
//! channel wise convolution
auto w4 = mkcvar("w4", {3, 1, 1, 1, 1}),
conv4 = opr::Convolution::make(
conv3, w4, param_convolution, {}, OperatorNodeConfig("conv4"));
param_convolution.sparse = opr::Convolution::Param::Sparse::DENSE;
auto w5 = mkcvar("w5", {6, 3, 1, 1}),
auto w5 = mkcvar("w5", {3 * scalar, 1, 1, 1, 1}),
conv5 = opr::Convolution::make(
conv4, w5, param_convolution, {}, OperatorNodeConfig("conv5"));
//! group convolution
param_convolution.sparse = opr::Convolution::Param::Sparse::GROUP;
auto w6 = mkcvar("w6", {2, 4, 3, 1, 1}),
param_convolution.sparse = opr::Convolution::Param::Sparse::DENSE;
auto w6 = mkcvar("w6", {6 * scalar, 3 * scalar, 1, 1}),
conv6 = opr::Convolution::make(
conv5, w6, param_convolution, {}, OperatorNodeConfig("conv6"));
//! group convolution
param_convolution.sparse = opr::Convolution::Param::Sparse::GROUP;
auto w7 = mkcvar("w7", {2 * scalar, 4, 3, 1, 1}),
conv7 = opr::Convolution::make(
conv6, w7, param_convolution, {}, OperatorNodeConfig("conv7"));
param_convolution.sparse = opr::Convolution::Param::Sparse::DENSE;
auto w7 = mkcvar("w7", {3, 8, 1, 1}),
auto w8 = mkcvar("w8", {3 * scalar, 8 * scalar, 1, 1}),
y = opr::Convolution::make(
conv6, w7, param_convolution, {}, OperatorNodeConfig("conv7"));
conv7, w8, param_convolution, {}, OperatorNodeConfig("conv8"));
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_conv_bias_nonlinearity();
if (T_format == opr::Convolution::Param::Format::NCHW44)
options.enable_nchw44();
else if (T_format == opr::Convolution::Param::Format::NCHW88)
options.enable_nchw88();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
auto conv1_opt = find_opr<opr::ConvBias>(y_opt, "conv1");
auto conv2_opt = find_opr<opr::ConvBias>(y_opt, "conv2");
auto conv3_opt = find_opr<opr::ConvBias>(y_opt, "conv3");
auto conv4_opt = find_opr<opr::Convolution>(y_opt, "conv4");
auto conv6_opt = find_opr<opr::Convolution>(y_opt, "conv6");
auto conv4_opt = find_opr<opr::ConvBias>(y_opt, "conv4");
auto conv5_opt = find_opr<opr::Convolution>(y_opt, "conv5");
auto conv7_opt = find_opr<opr::Convolution>(y_opt, "conv7");
//! do not padding input tensor
ASSERT_EQ(conv1_opt->input(0)->shape()[1], 3);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, conv1_opt->param().format);
ASSERT_EQ(conv1_opt->input(0)->shape()[1], 3 * scalar);
ASSERT_EQ(T_format, conv1_opt->param().format);
//! output tensor padding input tensor
ASSERT_EQ(conv2_opt->input(1)->shape()[0], 2);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, conv2_opt->param().format);
ASSERT_EQ(conv2_opt->input(2)->shape()[1], 2);
ASSERT_EQ(T_format, conv2_opt->param().format);
ASSERT_EQ(conv3_opt->input(1)->shape()[0], 1);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, conv3_opt->param().format);
ASSERT_EQ(conv3_opt->input(2)->shape()[1], 1);
ASSERT_EQ(T_format, conv3_opt->param().format);
ASSERT_EQ(conv4_opt->input(1)->shape()[0], 1);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, conv4_opt->param().format);
ASSERT_EQ(conv6_opt->input(0)->shape()[1], 6);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, conv6_opt->param().format);
ASSERT_EQ(conv4_opt->input(2)->shape()[1], 1);
ASSERT_EQ(T_format, conv4_opt->param().format);
ASSERT_EQ(conv5_opt->input(1)->shape()[0], 1);
ASSERT_EQ(T_format, conv5_opt->param().format);
ASSERT_EQ(conv7_opt->input(0)->shape()[1], 6 * scalar);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, conv7_opt->param().format);
//! the dst tensor channel must stay unchange
ASSERT_EQ(y_opt.node()->shape()[1], 3);
ASSERT_EQ(y_opt.node()->shape()[1], 3 * scalar);
if (T_format == opr::Convolution::Param::Format::NCHW44)
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file("TestGoptInference.ChannelPaddingNCHW44.json"));
->writeto_fpath(
output_file("TestGoptInference.ChannelPaddingNCHW44.json"));
else if (T_format == opr::Convolution::Param::Format::NCHW88)
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(
output_file("TestGoptInference.ChannelPaddingNCHW88.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile(
......@@ -145,11 +178,19 @@ TEST(TestGoptInference, ChannelPaddingNCHW44) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-2);
//! test change the input shape
*host_x = *gen({2, 3, 32, 32}, cn);
*host_x = *gen({2, 3 * scalar, 32, 32}, cn);
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-2);
}
TEST(TestGoptInference, ChannelPaddingNCHW44) {
check_channel_padding_conv<opr::Convolution::Param::Format::NCHW44>();
}
TEST(TestGoptInference, ChannelPaddingNCHW88) {
check_channel_padding_conv<opr::Convolution::Param::Format::NCHW88>();
}
TEST(TestGoptInference, ChannelPaddingSubtensor) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册