diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 2db1f1da1cdc6e67e75015ee1904bdbe6c0d4731..4a337351654762e89145a80897105cb3fda0a0c3 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1066,7 +1066,27 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { } }; - auto replace_conv_opr = [&filter_mode](OperatorNodeBase* opr, + auto size_one_conv_to_dense_conv = + [](VarNode* origin_filter_input, + megdnn::param::Convolution::Sparse sparse) { + VarNode* reshaped_filter = origin_filter_input; + bool is_size_one_group_conv = false; + if (sparse == megdnn::param::Convolution::Sparse::GROUP && + origin_filter_input->shape()[0] == 1) { + is_size_one_group_conv = true; + auto new_shape = origin_filter_input->shape(); + new_shape.ndim = 4; + for (int i = 0; i < 4; i++) { + new_shape[i] = origin_filter_input->shape()[i + 1]; + } + SymbolVar new_var(origin_filter_input); + reshaped_filter = new_var.reshape(new_shape).node(); + } + return std::make_tuple(reshaped_filter, is_size_one_group_conv); + }; + + auto replace_conv_opr = + [&filter_mode, &size_one_conv_to_dense_conv](OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); auto& conv_opr = opr->cast_final_safe(); @@ -1131,19 +1151,27 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { mgb_assert(new_inp[0]->shape().ndim == 5 && fmt.align_axis() == 2); conv_src = new_inp[0]; } + VarNode* reshaped_filter; + bool is_size_one_group_conv; + std::tie(reshaped_filter, is_size_one_group_conv) = + size_one_conv_to_dense_conv(new_inp[1], + conv_opr.param().sparse); + auto new_conv_param = conv_opr.param(); + if (is_size_one_group_conv) { + new_conv_param.sparse = megdnn::param::Convolution::Sparse::DENSE; + } mgb_assert(new_inp[1]->format().type() != TensorFormat::Type::IMAGE2D_PACK4); auto param = megdnn::param::RelayoutFormat(); - param.mode = filter_mode(conv_opr.param().sparse, new_inp[1]); - auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param); + param.mode = filter_mode(new_conv_param.sparse, reshaped_filter); + auto relayout_weight = opr::RelayoutFormat::make(reshaped_filter, param); conv_weights = relayout_weight.node(); - auto new_param = conv_opr.param(); - new_param.format = megdnn::param::Convolution::Format::NHWCD4; + new_conv_param.format = megdnn::param::Convolution::Format::NHWCD4; mgb_assert(conv_src->shape().ndim == 5 && conv_src->format().type() == TensorFormat::Type::IMAGE2D_PACK4); auto new_conv_opr = opr::Convolution::make( - conv_src, conv_weights, new_param, conv_opr.execution_policy(), + conv_src, conv_weights, new_conv_param, conv_opr.execution_policy(), conv_opr.config()); OperatorNodeBase* ret = new_conv_opr.node()->owner_opr(); mgb_assert(new_conv_opr.shape().ndim == 5 && @@ -1152,7 +1180,8 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { return ret; }; - auto replace_conv_bias_opr = [&filter_mode](OperatorNodeBase* opr, + auto replace_conv_bias_opr = + [&filter_mode, &size_one_conv_to_dense_conv](OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); auto& conv_bias_opr = opr->cast_final_safe(); @@ -1221,9 +1250,18 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { mgb_assert(new_inp[1]->format().type() != TensorFormat::Type::IMAGE2D_PACK4); + VarNode* reshaped_filter; + bool is_size_one_group_conv; + std::tie(reshaped_filter, is_size_one_group_conv) = + size_one_conv_to_dense_conv(new_inp[1], + conv_bias_opr.param().sparse); + auto new_conv_param = conv_bias_opr.param(); + if (is_size_one_group_conv) { + new_conv_param.sparse = megdnn::param::Convolution::Sparse::DENSE; + } auto param = megdnn::param::RelayoutFormat(); - param.mode = filter_mode(conv_bias_opr.param().sparse, new_inp[1]); - auto relayout_weight = opr::RelayoutFormat::make(new_inp[1], param); + param.mode = filter_mode(new_conv_param.sparse, reshaped_filter); + auto relayout_weight = opr::RelayoutFormat::make(reshaped_filter, param); conv_bias_weights = relayout_weight.node(); mgb_assert(new_inp.size() < 4, @@ -1238,19 +1276,18 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { conv_bias_bias = new_inp[2]; } - auto new_param = conv_bias_opr.param(); - new_param.format = megdnn::param::ConvBias::Format::NHWCD4; + new_conv_param.format = megdnn::param::ConvBias::Format::NHWCD4; mgb_assert(conv_bias_src->shape().ndim == 5 && conv_bias_src->format().type() == TensorFormat::Type::IMAGE2D_PACK4); SymbolVar new_conv_bias_opr; if (has_bias) { new_conv_bias_opr = opr::ConvBias::make( - conv_bias_src, conv_bias_weights, conv_bias_bias, new_param, + conv_bias_src, conv_bias_weights, conv_bias_bias, new_conv_param, conv_bias_opr.execution_policy(), conv_bias_opr.config()); } else { new_conv_bias_opr = opr::ConvBias::make( - conv_bias_src, conv_bias_weights, new_param, + conv_bias_src, conv_bias_weights, new_conv_param, conv_bias_opr.execution_policy(), conv_bias_opr.config()); } OperatorNodeBase* ret = new_conv_bias_opr.node()->owner_opr(); diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index f2a4203577dfe7eee40b0cfa94785be04e8bd10a..78a61fc4404fbaa26a530cf7c323bcc18da727cb 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -3443,4 +3443,49 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); } +TEST(TestGoptInference, ConvertFormatCD4GroupOneConv) { + // hwcd4 is only supported in naive handle + NaiveMegDNNHandleScope naive_megdnn_handle; + + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name); + }; + + auto x = mkvar("x", {1, 3, 128, 128}); + // ConvBias + opr::ConvBias::Param param_conv_bias; + param_conv_bias.pad_h = param_conv_bias.pad_w = 1; + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; + auto w1 = mkcvar("w1", {1, 16, 3, 3, 3}), b1 = mkcvar("b1", {1, 16, 1, 1}); + auto conv1 = opr::ConvBias::make(x, w1, b1, param_conv_bias); + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; + // Convolution + opr::Convolution::Param param_conv; + param_conv.pad_h = param_conv.pad_w = 1; + param_conv.sparse = opr::Convolution::Param::Sparse::GROUP; + auto w3 = mkcvar("w3", {1, 16, 16, 3, 3}); + auto y = opr::Convolution::make(conv1, w3, param_conv); + + SymbolVar y_opt; + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nhwcd4(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + } + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}