diff --git a/dnn/src/common/conv_bias.cpp b/dnn/src/common/conv_bias.cpp index 2fb2395c522e93c34130867dbe5f73f2301ff7f8..c747c7e99c0f67877c016aed7c50de383af6a2c4 100644 --- a/dnn/src/common/conv_bias.cpp +++ b/dnn/src/common/conv_bias.cpp @@ -65,8 +65,8 @@ void do_check_exec_common( bias.to_string().c_str(), dst.to_string().c_str()); megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[3] == 1); - } else if (param().format == param::ConvBias::Format::NHWC || - param().format == param::ConvBias::Format::NCHW4_NHWC) { + } else if (opr->param().format == param::ConvBias::Format::NHWC || + opr->param().format == param::ConvBias::Format::NCHW4_NHWC) { megdnn_assert(bias.shape[0] == 1); megdnn_assert(bias.shape[1] == 1); megdnn_assert(bias.shape[2] == 1); diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 7b35d7fd6ec5cd4a3a99ed2ba96ae7481927807e..cfbf72ba57cfe3278d9d685168b8cac93a96c4dd 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -420,7 +420,8 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { dst[4] = 32; } else if (layout_type() == RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW64) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 64 == 0, "%s", inp_shape.to_string().c_str()); + mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 64 == 0, "%s", + inp_shape.to_string().c_str()); dst.ndim = 5; dst[0] = inp_shape[0]; dst[1] = inp_shape[1] / 64; @@ -438,8 +439,6 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { dst[4] = 32; } else if (layout_type() == RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64) { - mgb_assert(layout_type() == - RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64); mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 16 == 0); dst.ndim = 5; dst[0] = inp_shape[0]; @@ -499,18 +498,17 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { } else if (layout_type() == RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW32) { mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 32 == 0); - dst.ndim = 4; + dst.ndim = 5; dst[0] = inp_shape[0]; dst[1] = inp_shape[3] / 32; dst[2] = inp_shape[1]; dst[3] = inp_shape[2]; dst[4] = 32; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64) { + } else { mgb_assert(layout_type() == RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64); mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 64 == 0); - dst.ndim = 4; + dst.ndim = 5; dst[0] = inp_shape[0]; dst[1] = inp_shape[3] / 64; dst[2] = inp_shape[1]; @@ -3729,21 +3727,6 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { return y1.node(); }; - auto nhwc2nchw64 = [](VarNode* inp) -> VarNode* { - mgb_assert(inp->shape().ndim == 4); - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp = opr::Concat::make( - {sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0); - auto y0 = opr::Reshape::make(x, tshp); - auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); - return y1.node(); - }; - auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers, &nchw42nchw]( OperatorNodeBase* opr) { @@ -3915,31 +3898,29 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { return true; }; - auto try_conv_reformat_nchw42nchw64 = [&rewriter, &nchw42nhwc, &nhwc2nchw64, - &readers](OperatorNodeBase* opr) { + auto try_conv_reformat_nchw42nhwc = [&rewriter, &nchw42nhwc, + &readers](OperatorNodeBase* opr) { ThinHashSet opr_set; ThinHashSet reader_set; // check reshape - auto reshape1 = - try_cast_as_op(opr); - if (reshape1 == nullptr) + auto reshape = try_cast_as_op(opr); + if (reshape == nullptr) return false; opr_set.insert(opr); // check dimshuffle auto shuffle = try_cast_as_op( - reshape1->input(0)->owner_opr()); + reshape->input(0)->owner_opr()); if (shuffle == nullptr) return false; auto&& param = shuffle->param(); - if (param.pattern_len != 6) + if (param.pattern_len != 5) return false; - bool is_nchw42nchw64 = param.pattern[0] == 0 && param.pattern[1] == 1 && - param.pattern[2] == 3 && param.pattern[3] == 4 && - param.pattern[4] == 2 && param.pattern[5] == 5 && - shuffle->output(0)->shape()[5] == 4 && - shuffle->output(0)->shape()[4] == 16; - if (!is_nchw42nchw64) + bool is_nchw42nhwc = param.pattern[0] == 0 && param.pattern[1] == 2 && + param.pattern[2] == 3 && param.pattern[3] == 1 && + param.pattern[4] == 4 && + shuffle->output(0)->shape()[4] == 4; + if (!is_nchw42nhwc) return false; opr_set.insert(shuffle); for (auto&& i : readers[shuffle]) { @@ -3948,20 +3929,8 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { } } - // check reshape - auto reshape2 = - try_cast_as_op(shuffle->input(0)->owner_opr()); - if (reshape2 == nullptr) - return false; - opr_set.insert(reshape2); - for (auto&& i : readers[reshape2]) { - if (i.second & DepType::DEV_VALUE) { - reader_set.insert(i.first); - } - } - auto typecvt = - try_cast_as_op(reshape2->input(0)->owner_opr()); + try_cast_as_op(shuffle->input(0)->owner_opr()); if (typecvt == nullptr) return false; auto in_dtype = typecvt->input(0)->dtype(), @@ -3972,6 +3941,11 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { if (!is_s82s4) return false; opr_set.insert(typecvt); + for (auto&& i : readers[typecvt]) { + if (i.second & DepType::DEV_VALUE) { + reader_set.insert(i.first); + } + } // check conv bias auto conv_bias = @@ -4006,11 +3980,10 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { auto conv_bias_shuffle = opr::ConvBias::make( src, filter, new_bias, new_param, conv_bias->execution_policy(), OperatorNodeConfig{out_dtype}); - auto new_var = nhwc2nchw64(conv_bias_shuffle.node()); rewriter.replace_var( - opr->output(0), new_var, + opr->output(0), conv_bias_shuffle.node(), mgb_cstr_log("replace conv_bias + " - "reformat to conv_bias(NCHW4_NCHW64)")); + "reformat to conv_bias(NCHW4_NHWC)")); return true; }; @@ -4098,14 +4071,14 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, &try_conv_reformat_nchw42nchw32, - &try_conv_reformat_nchw42nchw64, + &try_conv_reformat_nchw42nhwc, #if CUDA_VERSION >= 10020 &try_conv_reformat_nchw322nchw4, #endif &rewriter](OperatorNodeBase* opr) { if (!try_conv_dimshuffle_reshape_typecvt(opr) && - !try_conv_reformat_nchw42nchw32(opr) && - !try_conv_reformat_nchw42nchw64(opr) + !try_conv_reformat_nchw42nchw32(opr) && + !try_conv_reformat_nchw42nhwc(opr) #if CUDA_VERSION >= 10020 && !try_conv_reformat_nchw322nchw4(opr) #endif @@ -4546,6 +4519,9 @@ VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var, case Format::NCHW64: type = LayoutType::NCHW64_TO_NCHW; break; + case Format::NHWC: + type = LayoutType::NHWC_TO_NCHW; + break; default: mgb_throw(AssertionError, "format(%d) is not supported, related var " @@ -4980,7 +4956,7 @@ EnableNCHW64Pass::make_nchw64_converter() { case Format::NHWC: inps[1] = RelayoutPlaceholder::make( inps[1], RelayoutPlaceholder::LayoutType:: - NCHW_TO_NHWC) + NHWC_TO_NCHW4) .node(); break; case Format::NCHW32: diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 97b45ecb957b860e643d107cab6314d98047cb75..d988cb0f93e5022ce5a74598849a49eb8316f2a6 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -4404,10 +4404,10 @@ TEST(TestGoptInference, FoldingConvDimshuffleNCHW4NHWC) { }; auto x = mkvar("x", {32, 4, 23, 40}, dtype::QuantizedS8(2.5f)), - w = mkcvar("w", {64, 4, 3, 3}, dtype::QuantizedS8(2.5f)), - b = mkcvar("b", {1, 64, 1, 1}, dtype::QuantizedS32(6.25f)), - w1 = mkcvar("w1", {64, 64, 3, 3}, dtype::QuantizedS4(1.234f)), - b1 = mkcvar("b1", {1, 64, 1, 1}, dtype::QuantizedS32(12.34567f*1.234f)); + w = mkcvar("w", {32, 4, 3, 3}, dtype::QuantizedS8(2.5f)), + b = mkcvar("b", {1, 32, 1, 1}, dtype::QuantizedS32(6.25f)), + w1 = mkcvar("w1", {32, 32, 3, 3}, dtype::QuantizedS4(1.234f)), + b1 = mkcvar("b1", {1, 32, 1, 1}, dtype::QuantizedS32(12.34567f*1.234f)); opr::ConvBias::Param param; param.format = opr::ConvBias::Param::Format::NCHW; param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; @@ -4438,7 +4438,7 @@ TEST(TestGoptInference, FoldingConvDimshuffleNCHW4NHWC) { ->writeto_fpath(output_file( "TestGoptInference.FoldingConvDimshuffleNCHW4NHWC.json")); size_t nr_dimshuffle = find_opr_num(y_fuse); - ASSERT_EQ(3u, find_opr_num(y_fuse)); + ASSERT_EQ(2u, nr_dimshuffle); bool found = false; cg::DepOprIter{[&found](cg::OperatorNodeBase* opr) { if (!found && opr->same_type()) { @@ -4735,101 +4735,6 @@ TEST(TestGoptInference, PaddingChannelsWithWarpPerspective) { MGB_ASSERT_TENSOR_EQ(t1, t2); } -TEST(TestGoptInference, PaddingChannelsB4) { - REQUIRE_GPU(1); - auto cn = CompNode::load("gpu0"); - cn.activate(); - REQUIRE_CUDA_COMPUTE_CAPABILITY(7, 5); - - HostTensorGenerator gen; - auto graph = ComputingGraph::make(); - graph->options().graph_opt_level = 0; - auto mkvar = [&](const char* name, const TensorShape& shp, - const DType& dtype) { - return opr::TypeCvt::make( - opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), - dtype); - }; - auto mkcvar = [&](const char* name, const TensorShape& shp, - const DType& dtype) { - return opr::TypeCvt::make( - opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) - .rename(name), - dtype); - }; - - auto x = mkvar("x", {16, 3, 14, 14}, dtype::QuantizedS8(2.5f)), - w = mkcvar("w", {16, 3, 3, 3}, dtype::QuantizedS8(2.5f)), - b = mkcvar("b", {1, 16, 1, 1}, dtype::QuantizedS32(6.25f)); - opr::ConvBias::Param param; - param.format = opr::ConvBias::Param::Format::NCHW; - param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; - param.stride_h = param.stride_w = 1; - param.pad_h = param.pad_w = 1; - - auto y = opr::ConvBias::make(x, w, b, param, {}, - OperatorNodeConfig{dtype::QuantizedS8(2.5f)}); - y = opr::TypeCvt::make(y, dtype::Quantized4Asymm{20.f, 8}); - opr::Pooling::Param pool; - pool.format = opr::Pooling::Param::Format::NCHW; - y = opr::Pooling::make(y, pool); - auto w1 = mkcvar("w1", {48, 16, 3, 3}, dtype::QuantizedS4(1.234f)), - b1 = mkcvar("b1", {1, 48, 1, 1}, dtype::QuantizedS32(20.f*1.234f)); - auto y1 = opr::ConvBias::make(y, w1, b1, param, {}, - OperatorNodeConfig{dtype::Quantized4Asymm(20.f, 8)}); - auto w2 = mkcvar("w2", {48, 48, 3, 3}, dtype::QuantizedS4(1.234f)), - b2 = mkcvar("b2", {1, 48, 1, 1}, dtype::QuantizedS32(20.f*1.234f)); - auto y2 = opr::ConvBias::make( - y1, w2, b2, param, {}, - OperatorNodeConfig{dtype::Quantized4Asymm(20.f, 8)}); - auto w3 = mkcvar("w2", {16, 48, 3, 3}, dtype::QuantizedS4(1.234f)), - b3 = mkcvar("b2", {1, 16, 1, 1}, dtype::QuantizedS32(20.f*1.234f)); - auto y3 = opr::ConvBias::make( - y2, w3, b3, param, {}, - OperatorNodeConfig{dtype::Quantized4Asymm(20.f, 8)}); - using ElemMultiMode = opr::ElemwiseMultiType::Param::Mode; - auto y4 = opr::ElemwiseMultiType::make( - {y, y3}, {ElemMultiMode::QFUSE_ADD_RELU}, - OperatorNodeConfig{dtype::Quantized4Asymm{20.f, 7}}); - y4 = opr::TypeCvt::make(y4, dtype::Float32()); - SymbolVar y4_pad; - unpack_vector(gopt::GraphOptimizer{} - .add_pass() - .add_pass() - .apply({{y4}}) - .endpoint_vars(), - y4_pad); - ASSERT_EQ(y4_pad.node()->shape()[1], y4.node()->shape()[1]); - SmallVector oprs; - auto cb1 = [&oprs](cg::OperatorNodeBase* opr) { - if (opr->same_type()) { - oprs.push_back(opr); - } - }; - cg::DepOprIter{cb1}.add(y4_pad.node()->owner_opr()); - ASSERT_EQ(oprs.size(), 4); - ASSERT_EQ(oprs[0]->output(0)->shape()[1], 16); - ASSERT_EQ(oprs[1]->output(0)->shape()[1], 64); - ASSERT_EQ(oprs[2]->output(0)->shape()[1], 64); - ASSERT_EQ(oprs[3]->output(0)->shape()[1], 16); - size_t nr_concat = find_opr_num(y4_pad); - ASSERT_EQ(nr_concat, 1); - cg::OperatorNodeBase* concat = nullptr; - auto cb2 = [&concat](cg::OperatorNodeBase* opr) { - if (opr->same_type()) { - concat = opr; - } - }; - cg::DepOprIter{cb2}.add(y4_pad.node()->owner_opr()); - ASSERT_EQ(oprs[0]->input(0)->owner_opr(), concat); - HostTensorND t1, t2; - auto func1 = graph->compile({make_callback_copy(y4, t1)}); - func1->execute(); - auto func2 = graph->compile({make_callback_copy(y4_pad, t2)}); - func2->execute(); - MGB_ASSERT_TENSOR_EQ(t1, t2); -} - #endif