From 0ccb965c8e52e8d2cfd1fe969a5a031091d616d3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 24 Jun 2020 19:29:54 +0800 Subject: [PATCH] fix(mgb/gopt): fix convert format nchw->nchw4 pass GitOrigin-RevId: 1813753b144fa70d53f4f97f1a2a509963440d04 --- src/gopt/impl/tensor_reformat.cpp | 90 ++++++++++++++++++++++++++-- src/gopt/test/inference.cpp | 98 +++++++++++++++++++++++++++++-- 2 files changed, 178 insertions(+), 10 deletions(-) diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 1b82f409b..21053c0b7 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -1599,18 +1599,103 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter(){ } return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); }; + auto replace_pooling_opr = [](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + using Param = opr::PoolingForward::Param; + using Format = Param::Format; + mgb_assert(opr->input().size() == new_inp.size()); + auto& pooling = opr->cast_final_safe(); + mgb_assert(pooling.param().format == Format::NCHW, + "ConvertFormat Pass only support converting NCHW to NCHW4."); + if (new_inp[0]->shape().ndim == 5) { + mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); + auto new_param = pooling.param(); + new_param.format = Format::NCHW4; + auto new_pooling = + opr::PoolingForward::make(new_inp[0], new_param, opr->config()); + mgb_assert(new_pooling.shape().ndim == 5, + "out var of Pooling opr after transform must be 5 (got: " + "%zu).", + new_pooling.shape().ndim); + return new_pooling.node()->owner_opr(); + } + auto new_opr = + serialization::copy_opr_shallow(*opr, new_inp, opr->config()); + return new_opr; + }; + auto replace_resize_opr = [](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + using Param = opr::ResizeForward::Param; + using Format = Param::Format; + mgb_assert(opr->input().size() == new_inp.size()); + auto& resize = opr->cast_final_safe(); + mgb_assert(resize.param().format == Format::NCHW, + "ConvertFormat Pass only support converting NCHW to NCHW4."); + if (new_inp[0]->shape().ndim == 5) { + mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); + auto new_param = resize.param(); + new_param.format = Format::NCHW4; + auto new_resize = opr::ResizeForward::make( + new_inp[0], new_inp[1], new_param, opr->config()); + mgb_assert(new_resize.shape().ndim == 5, + "out var of Resize opr after transform must be 5 (got: " + "%zu).", + new_resize.shape().ndim); + return new_resize.node()->owner_opr(); + } + auto new_opr = + serialization::copy_opr_shallow(*opr, new_inp, opr->config()); + return new_opr; + }; + auto replace_warp_perspective_opr = [](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + using Param = opr::WarpPerspective::Param; + using Format = Param::Format; + mgb_assert(opr->input().size() == new_inp.size()); + auto& warp = opr->cast_final_safe(); + mgb_assert(warp.param().format == Format::NCHW, + "ConvertFormat Pass only support converting NCHW to NCHW4."); + if (new_inp[0]->shape().ndim == 5) { + mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); + auto new_param = warp.param(); + new_param.format = Format::NCHW4; + SymbolVar new_warp; + if (new_inp.size() == 3) { + new_warp = opr::WarpPerspectiveForward::make( + new_inp[0], new_inp[1], nullptr, new_inp[2], new_param, + opr->config()); + } else { + mgb_assert(new_inp.size() == 4); + new_warp = opr::WarpPerspectiveForward::make( + new_inp[0], new_inp[1], new_inp[2], new_inp[3], + new_param, opr->config()); + } + mgb_assert(new_warp.shape().ndim == 5, + "out var of WarpPerspective opr after transform must be " + "5 (got: " + "%zu).", + new_warp.shape().ndim); + return new_warp.node()->owner_opr(); + } + auto new_opr = + serialization::copy_opr_shallow(*opr, new_inp, opr->config()); + return new_opr; + }; auto&& replace_func = ret->m_opr_replace_func; //! supportted nchw4 replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; replace_func[opr::BatchConvBias::typeinfo()] = replace_batch_conv_bias_opr; + replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; + replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; + replace_func[opr::WarpPerspectiveForward::typeinfo()] = + replace_warp_perspective_opr; replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; //! not supported nchw4 - replace_func[opr::PoolingForward::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::ConvolutionBackwardData::typeinfo()] = relayout_inp_to_nchw; @@ -1620,9 +1705,6 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter(){ replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::WarpPerspectiveForward::typeinfo()] = - relayout_inp_to_nchw; replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; return ret; } diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 8321fd4de..38748e0d6 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -2430,14 +2430,16 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); auto conv1 = opr::ConvBiasForward::make( - x, w1, b1, param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); + x, w1, b1, param_conv_bias, {}, + OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); // group // icpg != 1 && ocpg != 1 param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)), b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); - auto conv2 = opr::ConvBiasForward::make(conv1, w2, b2, - param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); + auto conv2 = opr::ConvBiasForward::make( + conv1, w2, b2, param_conv_bias, {}, + OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); auto y = opr::TypeCvt::make(conv2, dtype::Float32()); @@ -2453,8 +2455,8 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { graph->compile({{y_opt, {}}}) ->to_json() - ->writeto_fpath( - output_file("TestGoptInference.ConvertFormatNCHW4GPU.json")); + ->writeto_fpath(output_file( + "TestGoptInference.ConvertFormatNCHW4GPU.json")); HostTensorND host_y, host_y_opt; auto func = graph->compile({make_callback_copy(y, host_y), @@ -2465,6 +2467,90 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { #endif +TEST(TestGoptInference, ConvertFormatNCHW4NonConvOpr) { + auto cn = CompNode::load("xpu0"); + HostTensorGenerator gen; + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), + dtype); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + auto mkcvarf32 = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name); + }; + + auto x = mkvar("x", {2, 4, 16, 16}, dtype::QuantizedS8(2.5f)); + opr::ConvBias::Param param_conv_bias; + param_conv_bias.format = opr::ConvBias::Param::Format::NCHW; + param_conv_bias.stride_h = param_conv_bias.stride_w = 1; + param_conv_bias.pad_h = param_conv_bias.pad_w = 1; + param_conv_bias.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; + // dense + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; + auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), + b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); + auto conv1 = opr::ConvBiasForward::make( + x, w1, b1, param_conv_bias, {}, + OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); + // test Resize + auto shape_of = opr::GetVarShape::make(x); + auto subtensor = opr::Subtensor::make( + shape_of, {opr::Subtensor::AxisIndexer::make_interval( + 0, x.make_scalar(2), None, x.make_scalar(1))}); + opr::Resize::Param param_resize; + param_resize.format = opr::Resize::Param::Format::NCHW; + auto resize = opr::ResizeForward::make(conv1, subtensor * 2, param_resize); + // test WarpPerspective + auto mat = mkcvarf32("mat", {2, 3, 3}), + warp = opr::WarpPerspectiveForward::make( + resize, mat, nullptr, cg::var_from_tensor_shape(x, {32, 32})); + opr::Pooling::Param pool_param; + pool_param.format = opr::Pooling::Param::Format::NCHW; + // test Pooling + auto pool = opr::Pooling::make(warp, pool_param); + // group + // icpg != 1 && ocpg != 1 + param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; + auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)), + b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); + auto conv2 = opr::ConvBiasForward::make( + pool, w2, b2, param_conv_bias, {}, + OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); + + auto add = opr::ElemwiseMultiType::make( + {conv1, conv2}, {opr::ElemwiseMultiType::Param::Mode::QADD}, + OperatorNodeConfig{dtype::QuantizedS8{1.2f}}); + auto y = opr::TypeCvt::make(add, dtype::Float32()); + + SymbolVar y_opt; + { + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nchw4(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + } + auto nr_dimshuffle = find_opr_num(y_opt); + ASSERT_EQ(2u, nr_dimshuffle); + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, + find_opr(y_opt).param().format); + ASSERT_EQ(opr::ResizeForward::Param::Format::NCHW4, + find_opr(y_opt).param().format); + ASSERT_EQ(opr::WarpPerspectiveForward::Param::Format::NCHW4, + find_opr(y_opt).param().format); + ASSERT_EQ(opr::PoolingForward::Param::Format::NCHW4, + find_opr(y_opt).param().format); +} + TEST(TestGoptInference, ConvertFormatNCHW4) { HostTensorGenerator<> gen; auto cn = CompNode::load("cpu0"); @@ -2479,7 +2565,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4) { }; auto x = mkvar("x", {2, 4, 16, 16}); - // ConvBias + // ConvBias test dense opr::ConvBias::Param param_conv_bias; param_conv_bias.pad_h = param_conv_bias.pad_w = 1; param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; -- GitLab