提交 cab46f2a 编写于 作者: M Megvii Engine Team

fix(dnn): fix relayout format when group conv group=1

GitOrigin-RevId: f2b53be77fc444743f72678878d6ec4637f9d740
上级 8b154fd7
......@@ -1066,7 +1066,27 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
}
};
auto replace_conv_opr = [&filter_mode](OperatorNodeBase* opr,
auto size_one_conv_to_dense_conv =
[](VarNode* origin_filter_input,
megdnn::param::Convolution::Sparse sparse) {
VarNode* reshaped_filter = origin_filter_input;
bool is_size_one_group_conv = false;
if (sparse == megdnn::param::Convolution::Sparse::GROUP &&
origin_filter_input->shape()[0] == 1) {
is_size_one_group_conv = true;
auto new_shape = origin_filter_input->shape();
new_shape.ndim = 4;
for (int i = 0; i < 4; i++) {
new_shape[i] = origin_filter_input->shape()[i + 1];
}
SymbolVar new_var(origin_filter_input);
reshaped_filter = new_var.reshape(new_shape).node();
}
return std::make_tuple(reshaped_filter, is_size_one_group_conv);
};
auto replace_conv_opr =
[&filter_mode, &size_one_conv_to_dense_conv](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
......@@ -1131,19 +1151,27 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
conv_src = new_inp[0];
}
VarNode* reshaped_filter;
bool is_size_one_group_conv;
std::tie(reshaped_filter, is_size_one_group_conv) =
size_one_conv_to_dense_conv(new_inp[1],
conv_opr.param().sparse);
auto new_conv_param = conv_opr.param();
if (is_size_one_group_conv) {
new_conv_param.sparse = megdnn::param::Convolution::Sparse::DENSE;
}
mgb_assert(new_inp[1]->format().type() !=
TensorFormat::Type::IMAGE2D_PACK4);
auto param = megdnn::param::RelayoutFormat();
param.mode = filter_mode(conv_opr.param().sparse, new_inp[1]);
auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param);
param.mode = filter_mode(new_conv_param.sparse, reshaped_filter);
auto relayout_weight = opr::RelayoutFormat::make(reshaped_filter, param);
conv_weights = relayout_weight.node();
auto new_param = conv_opr.param();
new_param.format = megdnn::param::Convolution::Format::NHWCD4;
new_conv_param.format = megdnn::param::Convolution::Format::NHWCD4;
mgb_assert(conv_src->shape().ndim == 5 &&
conv_src->format().type() ==
TensorFormat::Type::IMAGE2D_PACK4);
auto new_conv_opr = opr::Convolution::make(
conv_src, conv_weights, new_param, conv_opr.execution_policy(),
conv_src, conv_weights, new_conv_param, conv_opr.execution_policy(),
conv_opr.config());
OperatorNodeBase* ret = new_conv_opr.node()->owner_opr();
mgb_assert(new_conv_opr.shape().ndim == 5 &&
......@@ -1152,7 +1180,8 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return ret;
};
auto replace_conv_bias_opr = [&filter_mode](OperatorNodeBase* opr,
auto replace_conv_bias_opr =
[&filter_mode, &size_one_conv_to_dense_conv](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
......@@ -1221,9 +1250,18 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
mgb_assert(new_inp[1]->format().type() !=
TensorFormat::Type::IMAGE2D_PACK4);
VarNode* reshaped_filter;
bool is_size_one_group_conv;
std::tie(reshaped_filter, is_size_one_group_conv) =
size_one_conv_to_dense_conv(new_inp[1],
conv_bias_opr.param().sparse);
auto new_conv_param = conv_bias_opr.param();
if (is_size_one_group_conv) {
new_conv_param.sparse = megdnn::param::Convolution::Sparse::DENSE;
}
auto param = megdnn::param::RelayoutFormat();
param.mode = filter_mode(conv_bias_opr.param().sparse, new_inp[1]);
auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param);
param.mode = filter_mode(new_conv_param.sparse, reshaped_filter);
auto relayout_weight = opr::RelayoutFormat::make(reshaped_filter, param);
conv_bias_weights = relayout_weight.node();
mgb_assert(new_inp.size() < 4,
......@@ -1238,19 +1276,18 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
conv_bias_bias = new_inp[2];
}
auto new_param = conv_bias_opr.param();
new_param.format = megdnn::param::ConvBias::Format::NHWCD4;
new_conv_param.format = megdnn::param::ConvBias::Format::NHWCD4;
mgb_assert(conv_bias_src->shape().ndim == 5 &&
conv_bias_src->format().type() ==
TensorFormat::Type::IMAGE2D_PACK4);
SymbolVar new_conv_bias_opr;
if (has_bias) {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_weights, conv_bias_bias, new_param,
conv_bias_src, conv_bias_weights, conv_bias_bias, new_conv_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
} else {
new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_weights, new_param,
conv_bias_src, conv_bias_weights, new_conv_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
}
OperatorNodeBase* ret = new_conv_bias_opr.node()->owner_opr();
......
......@@ -3443,4 +3443,49 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
}
TEST(TestGoptInference, ConvertFormatCD4GroupOneConv) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name);
};
auto x = mkvar("x", {1, 3, 128, 128});
// ConvBias
opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
auto w1 = mkcvar("w1", {1, 16, 3, 3, 3}), b1 = mkcvar("b1", {1, 16, 1, 1});
auto conv1 = opr::ConvBias::make(x, w1, b1, param_conv_bias);
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
// Convolution
opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1;
param_conv.sparse = opr::Convolution::Param::Sparse::GROUP;
auto w3 = mkcvar("w3", {1, 16, 16, 3, 3});
auto y = opr::Convolution::make(conv1, w3, param_conv);
SymbolVar y_opt;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册