提交 cab46f2a 编写于 作者: M Megvii Engine Team

fix(dnn): fix relayout format when group conv group=1

GitOrigin-RevId: f2b53be77fc444743f72678878d6ec4637f9d740
上级 8b154fd7
...@@ -1066,7 +1066,27 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1066,7 +1066,27 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
} }
}; };
auto replace_conv_opr = [&filter_mode](OperatorNodeBase* opr, auto size_one_conv_to_dense_conv =
[](VarNode* origin_filter_input,
megdnn::param::Convolution::Sparse sparse) {
VarNode* reshaped_filter = origin_filter_input;
bool is_size_one_group_conv = false;
if (sparse == megdnn::param::Convolution::Sparse::GROUP &&
origin_filter_input->shape()[0] == 1) {
is_size_one_group_conv = true;
auto new_shape = origin_filter_input->shape();
new_shape.ndim = 4;
for (int i = 0; i < 4; i++) {
new_shape[i] = origin_filter_input->shape()[i + 1];
}
SymbolVar new_var(origin_filter_input);
reshaped_filter = new_var.reshape(new_shape).node();
}
return std::make_tuple(reshaped_filter, is_size_one_group_conv);
};
auto replace_conv_opr =
[&filter_mode, &size_one_conv_to_dense_conv](OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
...@@ -1131,19 +1151,27 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1131,19 +1151,27 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2); mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2);
conv_src = new_inp[0]; conv_src = new_inp[0];
} }
VarNode* reshaped_filter;
bool is_size_one_group_conv;
std::tie(reshaped_filter, is_size_one_group_conv) =
size_one_conv_to_dense_conv(new_inp[1],
conv_opr.param().sparse);
auto new_conv_param = conv_opr.param();
if (is_size_one_group_conv) {
new_conv_param.sparse = megdnn::param::Convolution::Sparse::DENSE;
}
mgb_assert(new_inp[1]->format().type() != mgb_assert(new_inp[1]->format().type() !=
TensorFormat::Type::IMAGE2D_PACK4); TensorFormat::Type::IMAGE2D_PACK4);
auto param = megdnn::param::RelayoutFormat(); auto param = megdnn::param::RelayoutFormat();
param.mode = filter_mode(conv_opr.param().sparse, new_inp[1]); param.mode = filter_mode(new_conv_param.sparse, reshaped_filter);
auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param); auto relayout_weight = opr::RelayoutFormat::make(reshaped_filter, param);
conv_weights = relayout_weight.node(); conv_weights = relayout_weight.node();
auto new_param = conv_opr.param(); new_conv_param.format = megdnn::param::Convolution::Format::NHWCD4;
new_param.format = megdnn::param::Convolution::Format::NHWCD4;
mgb_assert(conv_src->shape().ndim == 5 && mgb_assert(conv_src->shape().ndim == 5 &&
conv_src->format().type() == conv_src->format().type() ==
TensorFormat::Type::IMAGE2D_PACK4); TensorFormat::Type::IMAGE2D_PACK4);
auto new_conv_opr = opr::Convolution::make( auto new_conv_opr = opr::Convolution::make(
conv_src, conv_weights, new_param, conv_opr.execution_policy(), conv_src, conv_weights, new_conv_param, conv_opr.execution_policy(),
conv_opr.config()); conv_opr.config());
OperatorNodeBase* ret = new_conv_opr.node()->owner_opr(); OperatorNodeBase* ret = new_conv_opr.node()->owner_opr();
mgb_assert(new_conv_opr.shape().ndim == 5 && mgb_assert(new_conv_opr.shape().ndim == 5 &&
...@@ -1152,7 +1180,8 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1152,7 +1180,8 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
return ret; return ret;
}; };
auto replace_conv_bias_opr = [&filter_mode](OperatorNodeBase* opr, auto replace_conv_bias_opr =
[&filter_mode, &size_one_conv_to_dense_conv](OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
...@@ -1221,9 +1250,18 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1221,9 +1250,18 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
mgb_assert(new_inp[1]->format().type() != mgb_assert(new_inp[1]->format().type() !=
TensorFormat::Type::IMAGE2D_PACK4); TensorFormat::Type::IMAGE2D_PACK4);
VarNode* reshaped_filter;
bool is_size_one_group_conv;
std::tie(reshaped_filter, is_size_one_group_conv) =
size_one_conv_to_dense_conv(new_inp[1],
conv_bias_opr.param().sparse);
auto new_conv_param = conv_bias_opr.param();
if (is_size_one_group_conv) {
new_conv_param.sparse = megdnn::param::Convolution::Sparse::DENSE;
}
auto param = megdnn::param::RelayoutFormat(); auto param = megdnn::param::RelayoutFormat();
param.mode = filter_mode(conv_bias_opr.param().sparse, new_inp[1]); param.mode = filter_mode(new_conv_param.sparse, reshaped_filter);
auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param); auto relayout_weight = opr::RelayoutFormat::make(reshaped_filter, param);
conv_bias_weights = relayout_weight.node(); conv_bias_weights = relayout_weight.node();
mgb_assert(new_inp.size() < 4, mgb_assert(new_inp.size() < 4,
...@@ -1238,19 +1276,18 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { ...@@ -1238,19 +1276,18 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
conv_bias_bias = new_inp[2]; conv_bias_bias = new_inp[2];
} }
auto new_param = conv_bias_opr.param(); new_conv_param.format = megdnn::param::ConvBias::Format::NHWCD4;
new_param.format = megdnn::param::ConvBias::Format::NHWCD4;
mgb_assert(conv_bias_src->shape().ndim == 5 && mgb_assert(conv_bias_src->shape().ndim == 5 &&
conv_bias_src->format().type() == conv_bias_src->format().type() ==
TensorFormat::Type::IMAGE2D_PACK4); TensorFormat::Type::IMAGE2D_PACK4);
SymbolVar new_conv_bias_opr; SymbolVar new_conv_bias_opr;
if (has_bias) { if (has_bias) {
new_conv_bias_opr = opr::ConvBias::make( new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_weights, conv_bias_bias, new_param, conv_bias_src, conv_bias_weights, conv_bias_bias, new_conv_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config()); conv_bias_opr.execution_policy(), conv_bias_opr.config());
} else { } else {
new_conv_bias_opr = opr::ConvBias::make( new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_weights, new_param, conv_bias_src, conv_bias_weights, new_conv_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config()); conv_bias_opr.execution_policy(), conv_bias_opr.config());
} }
OperatorNodeBase* ret = new_conv_bias_opr.node()->owner_opr(); OperatorNodeBase* ret = new_conv_bias_opr.node()->owner_opr();
......
...@@ -3443,4 +3443,49 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { ...@@ -3443,4 +3443,49 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
} }
TEST(TestGoptInference, ConvertFormatCD4GroupOneConv) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name);
};
auto x = mkvar("x", {1, 3, 128, 128});
// ConvBias
opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
auto w1 = mkcvar("w1", {1, 16, 3, 3, 3}), b1 = mkcvar("b1", {1, 16, 1, 1});
auto conv1 = opr::ConvBias::make(x, w1, b1, param_conv_bias);
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
// Convolution
opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1;
param_conv.sparse = opr::Convolution::Param::Sparse::GROUP;
auto w3 = mkcvar("w3", {1, 16, 16, 3, 3});
auto y = opr::Convolution::make(conv1, w3, param_conv);
SymbolVar y_opt;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册