From f1c86606cbac3856e6e0be850d62fc2858800ded Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 4 Jun 2020 16:06:29 +0800 Subject: [PATCH] fix(dnn/cuda): fix FuseConvBiasWithZ pass for HSwish activation GitOrigin-RevId: b290469cb1993214c01cab823864476778963f70 --- src/gopt/impl/inference.cpp | 5 +- src/gopt/test/inference.cpp | 158 ++++++++++++++++++++---------------- 2 files changed, 91 insertions(+), 72 deletions(-) diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index cb1ae70c2..b569b2fd4 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1875,6 +1875,8 @@ void FuseConvBiasZPass::apply(OptState& state) const { auto elem = try_cast_as_op(opr); if (elem->param().mode == MultiMode::QFUSE_ADD_RELU) return NonlineMode::RELU; + else if (elem->param().mode == MultiMode::QFUSE_ADD_H_SWISH) + return NonlineMode::H_SWISH; } return NonlineMode::IDENTITY; }; @@ -1941,7 +1943,8 @@ void FuseConvBiasZPass::apply(OptState& state) const { if (elem->input().size() != 2) return false; if (elem->param().mode != MultiMode::QADD && - elem->param().mode != MultiMode::QFUSE_ADD_RELU) + elem->param().mode != MultiMode::QFUSE_ADD_RELU && + elem->param().mode != MultiMode::QFUSE_ADD_H_SWISH) return false; return try_replace_var_node(opr); }; diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index adc2aba60..ce58a0ed0 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -1701,80 +1701,96 @@ TEST(FuseConvBiasZPass, BlockFuse) { dtype); }; - auto x = mkvar("x", {32, 16, 16, 16, 4}, dtype::QuantizedS8(2.5f)), - w1 = mkcvar("w1", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), - b1 = mkcvar("b1", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)), - w2 = mkcvar("w2", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), - b2 = mkcvar("b2", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)), - w3 = mkcvar("w3", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), - b3 = mkcvar("b3", {1, 16, 1, 1, 4}, dtype::QuantizedS32(3.0f)); + using ElemMultiMode = opr::ElemwiseMultiType::Param::Mode; + using NonlineMode = opr::ConvBias::Param::NonlineMode; + for (auto mode : + {ElemMultiMode::QFUSE_ADD_RELU, ElemMultiMode::QFUSE_ADD_H_SWISH}) { + auto x = mkvar("x", {32, 16, 16, 16, 4}, dtype::QuantizedS8(2.5f)), + w1 = mkcvar("w1", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), + b1 = mkcvar("b1", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)), + w2 = mkcvar("w2", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), + b2 = mkcvar("b2", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)), + w3 = mkcvar("w3", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)), + b3 = mkcvar("b3", {1, 16, 1, 1, 4}, dtype::QuantizedS32(3.0f)); + NonlineMode nonline_mode = NonlineMode::RELU; + if (mode == ElemMultiMode::QFUSE_ADD_H_SWISH) { + nonline_mode = NonlineMode::H_SWISH; + } - opr::ConvBias::Param param; - param.format = opr::Convolution::Param::Format::NCHW4; - param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; - param.stride_h = param.stride_w = 1; - param.pad_h = param.pad_w = 1; + opr::ConvBias::Param param; + param.format = opr::Convolution::Param::Format::NCHW4; + param.nonlineMode = nonline_mode; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = 1; - auto y1 = opr::ConvBias::make(x, w1, b1, param, {}, - OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); - param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; - auto y2 = opr::ConvBias::make(y1, w2, b2, param, {}, - OperatorNodeConfig{dtype::QuantizedS8(2.5f)}), - y3 = opr::ElemwiseMultiType::make( - {y1, y2}, - {opr::ElemwiseMultiType::Param::Mode::QFUSE_ADD_RELU}, - OperatorNodeConfig{dtype::QuantizedS8(1.2f)}); - param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; - auto y4 = opr::ConvBias::make(y3, w3, b3, param, {}, - OperatorNodeConfig{dtype::QuantizedS8(2.5f)}), - z = opr::ElemwiseMultiType::make( - {y3, y4}, - {opr::ElemwiseMultiType::Param::Mode::QADD}, - OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); - z = opr::TypeCvt::make(z, dtype::Float32()); - - //! fuse z mannually - auto z0 = opr::ConvBias::make(x, w1, b1, param, {}, - OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); - auto z1 = opr::ConvBias::make(z0, w2, b2, z0, param, {}, - OperatorNodeConfig{dtype::QuantizedS8(1.2f)}), - z2 = opr::ConvBias::make(z1, w3, b3, param, {}, - OperatorNodeConfig{dtype::QuantizedS8(2.5f)}), - z4 = opr::ElemwiseMultiType::make( - {z1, z2}, {opr::ElemwiseMultiType::Mode::QADD}, - OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); - z4 = opr::TypeCvt::make(z4, dtype::Float32()); - - SymbolVar z_fuse; - SymbolVar z_nonfuse; - { - auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_fuse_conv_bias_nonlinearity() - .enable_fuse_conv_bias_with_z(); - unpack_vector(gopt::optimize_for_inference({z}, options), z_fuse); - } - { - auto options = gopt::OptimizeForInferenceOptions{}; - options.enable_fuse_conv_bias_nonlinearity(); - unpack_vector(gopt::optimize_for_inference({z4}, options), z_nonfuse); - } - auto nr_elem_multi_type = find_opr_num(z_fuse); - MGB_MARK_USED_VAR(nr_elem_multi_type); - ASSERT_EQ(1u, nr_elem_multi_type); - graph->compile({{z_fuse, {}}}) - ->to_json() - ->writeto_fpath( - output_file("FuseConvBiasZPass.BlockFuse_fuse.json")); - graph->compile({{z_nonfuse, {}}}) - ->to_json() - ->writeto_fpath( - output_file("FuseConvBiasZPass.BlockFuse_nonfuse.json")); - - HostTensorND host_z_fuse, host_z_nonfuse; - auto func = graph->compile({make_callback_copy(z_nonfuse, host_z_nonfuse), + auto y1 = opr::ConvBias::make( + x, w1, b1, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; + auto y2 = opr::ConvBias::make( + y1, w2, b2, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}), + y3 = opr::ElemwiseMultiType::make( + {y1, y2}, {mode}, + OperatorNodeConfig{dtype::QuantizedS8(1.2f)}); + param.nonlineMode = nonline_mode; + auto y4 = opr::ConvBias::make( + y3, w3, b3, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}), + z = opr::ElemwiseMultiType::make( + {y3, y4}, {opr::ElemwiseMultiType::Param::Mode::QADD}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + z = opr::TypeCvt::make(z, dtype::Float32()); + + //! fuse z mannually + auto z0 = opr::ConvBias::make( + x, w1, b1, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + auto z1 = opr::ConvBias::make( + z0, w2, b2, z0, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(1.2f)}), + z2 = opr::ConvBias::make( + z1, w3, b3, param, {}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}), + z4 = opr::ElemwiseMultiType::make( + {z1, z2}, {opr::ElemwiseMultiType::Mode::QADD}, + OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); + z4 = opr::TypeCvt::make(z4, dtype::Float32()); + + SymbolVar z_fuse; + SymbolVar z_nonfuse; + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_fuse_conv_bias_nonlinearity() + .enable_fuse_conv_bias_with_z(); + unpack_vector(gopt::optimize_for_inference({z}, options), z_fuse); + } + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_fuse_conv_bias_nonlinearity(); + unpack_vector(gopt::optimize_for_inference({z4}, options), + z_nonfuse); + } + auto nr_elem_multi_type = + find_opr_num(z_fuse); + MGB_MARK_USED_VAR(nr_elem_multi_type); + ASSERT_EQ(1u, nr_elem_multi_type); + graph->compile({{z_fuse, {}}}) + ->to_json() + ->writeto_fpath( + output_file("FuseConvBiasZPass.BlockFuse_fuse.json")); + graph->compile({{z_nonfuse, {}}}) + ->to_json() + ->writeto_fpath(output_file( + "FuseConvBiasZPass.BlockFuse_nonfuse.json")); + + HostTensorND host_z_fuse, host_z_nonfuse; + auto func = + graph->compile({make_callback_copy(z_nonfuse, host_z_nonfuse), make_callback_copy(z_fuse, host_z_fuse)}); - func->execute(); - MGB_ASSERT_TENSOR_EQ(host_z_fuse, host_z_nonfuse); + func->execute(); + MGB_ASSERT_TENSOR_EQ(host_z_fuse, host_z_nonfuse); + } } TEST(TestEnableTensorCore, ShuffleMerge) { -- GitLab