From a437ec8e88b604e6092f94a724317d503c0a65c5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 7 Apr 2021 20:46:59 +0800 Subject: [PATCH] fix(src/gopt): add replace func of typecvt opr for nhwcd4 pass GitOrigin-RevId: 801eb1dab3ccbdcf71e8a153e0d3c7c9a7dbe6db --- src/gopt/impl/inference.cpp | 3 ++- src/gopt/test/inference.cpp | 49 +++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 58acd3d4..1defadb9 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1565,7 +1565,7 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { if (new_inp[i]->shape()[1] % 4 != 0) { can_exec_cd4 = false; } - //! cd4 elemwise with scaler is supported + //! cd4 elemwise with scaler is unsupported } else if (!new_inp[i]->shape().is_scalar()) { can_exec_cd4 = false; } @@ -1627,6 +1627,7 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw; replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw; replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw; + replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; replace_func[opr::WarpPerspectiveForward::typeinfo()] = replace_warp_perspective_opr; diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 5646ea3d..808cc174 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -1265,6 +1265,55 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); } +TEST(TestGoptInference, ConvertFormatNHWCD4TypeCvt) { + NaiveMegDNNHandleScope naive_megdnn_handle; + + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name); + }; + auto host_x = gen({8, 8, 8, 8}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + + opr::Convolution::Param param; + + param.pad_h = param.pad_w = 0; + auto w1 = mkcvar("w1", {8, 8, 3, 3}), + conv1 = opr::Convolution::make(x, w1, param), + tcvt1 = opr::TypeCvt::make(conv1, dtype::Float16()); + auto w2 = mkcvar("w2", {8, 8, 3, 3}), + conv2 = opr::Convolution::make(x, w2, param), + tcvt2 = opr::TypeCvt::make(conv2, dtype::Float16()); + auto y = opr::Elemwise::make({tcvt1, tcvt2}, opr::Elemwise::Param::Mode::ADD); + + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nhwcd4(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + + ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, + find_opr(y_opt).param().format); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath(output_file( + "TestGoptInference.ConvertFormatNHWCD4TypeCvt.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); + + *host_x = *gen({8, 8, 16, 16}, cn); + func->execute(); + MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); +} + TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) { // hwcd4 is only supported in naive handle NaiveMegDNNHandleScope naive_megdnn_handle; -- GitLab