diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 7f3c04ac73303f51e3e128d652238e041232bc9a..17b8318cd0acea586e7b4e7f4b24319c12400020 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1771,7 +1771,6 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { bool can_be_fused = true; can_be_fused &= (elem->input().size() == 1); can_be_fused &= (elem->param().mode == Mode::RELU) || - (elem->param().mode == Mode::TANH) || (elem->param().mode == Mode::SIGMOID); return can_be_fused; @@ -1911,13 +1910,14 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { } } else if (try_fuse_nonlinearity(elem)) { auto inp = rewriter.get_var(elem->input(0)); + auto elem_noline = get_nonlinearity_mode(elem); { auto conv = try_cast_as_op(inp->owner_opr()); if (conv && check_conv(conv) && m_deps[elem->input(0)].size() == 1) { opr::ConvBiasForward::Param param = convert_to_conv_bias_param(conv->param()); - param.nonlineMode = get_nonlinearity_mode(elem); + param.nonlineMode = elem_noline; auto new_var = opr::ConvBiasForward::make( conv->input(0), conv->input(1), param, conv->execution_policy(), conv->config()) @@ -1941,9 +1941,16 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { ; }; if (conv && check_conv_bias(conv) && - m_deps[elem->input(0)].size() == 1) { + m_deps[elem->input(0)].size() == 1 && + conv->input().size() > 2) { auto param = conv->param(); - param.nonlineMode = get_nonlinearity_mode(elem); + bool noline_ok = param.nonlineMode == NonlineMode::IDENTITY || + (param.nonlineMode == NonlineMode::RELU && + elem_noline == NonlineMode::RELU); + if (!noline_ok) { + return; + } + param.nonlineMode = elem_noline; auto new_var = opr::ConvBiasForward::make( conv->input(0), conv->input(1), conv->input(2), diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 2f783530824e5decbee8921341ac259145e46548..ff0e6f8b91350f3fd0af7d3808ff5368f6502c34 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -1731,6 +1731,52 @@ TEST(TestGoptInference, ConvBiasNonlinearityFusePass) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-4); } +TEST(TestGoptInference, ConvBiasNonlinearityFusePass2) { + // hwcd4 is only supported in naive handle + NaiveMegDNNHandleScope naive_megdnn_handle; + + auto cn = CompNode::load("cpu0"); + + HostTensorGenerator<> gen; + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); + }; + opr::Convolution::Param param; + auto x = mkvar("x", {5, 8, 16, 24}), w1 = mkcvar("w1", {4, 8, 1, 1}), + w2 = mkcvar("w2", {4, 8, 1, 1}); + + auto b1 = mkcvar("b1", {1, 4, 1, 1}); + auto y_cut = opr::Convolution::make(x, w1, param); + auto y = opr::Elemwise::make({y_cut + b1}, opr::Elemwise::Param::Mode::SIGMOID); + y = opr::Elemwise::make({y}, opr::Elemwise::Param::Mode::RELU); + auto y_cut2 = opr::Convolution::make(x, w2, param); + y_cut2 = opr::Elemwise::make({y_cut2}, opr::Elemwise::Param::Mode::SIGMOID); + y_cut2 = opr::Elemwise::make({y_cut2}, opr::Elemwise::Param::Mode::RELU); + y = y + y_cut2; + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nhwcd4().enable_fuse_conv_bias_nonlinearity(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + ASSERT_EQ( + opr::ConvBias::Param::NonlineMode::SIGMOID, + find_opr(y_opt).param().nonlineMode); + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath( + output_file("TestGoptInference.FuseConvBiasNonlinPass2.json")); + + HostTensorND host_y, host_y_opt; + auto func = graph->compile( + {make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-4); +} + TEST(TestGoptInference, ConvBiasNonlinearityFusePass_FullBias) { NaiveMegDNNHandleScope naive_megdnn_handle;