From 8b0315b3b17c5023fc4fb5536f4e31d0156bffae Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 9 Apr 2020 23:44:35 +0800 Subject: [PATCH] fix(mgb): fix nhwcd4 optpass GitOrigin-RevId: 9295abec77af2763d301ba372116e4b2281f442a --- dnn/src/naive/handle.cpp | 5 ++ dnn/src/naive/handle.h | 18 +++++++ dnn/src/naive/local_share/algorithms.h | 41 +++++++++++++++ dnn/src/naive/local_share/opr_impl.cpp | 73 ++++++++++++++++++++++++++ dnn/src/naive/local_share/opr_impl.h | 24 +++------ src/gopt/impl/inference.cpp | 47 ++++++++++++----- src/gopt/test/inference.cpp | 64 ++++++++++++++++++++++ 7 files changed, 241 insertions(+), 31 deletions(-) create mode 100644 dnn/src/naive/local_share/algorithms.h diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index 79f28d07d..c6a583302 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -94,6 +94,11 @@ DefaultConvolution3DBackwardFilterAlgorithm HandleImpl::m_default_conv3d_bwd_filter_algo; DefaultBatchConvBiasForwardAlgorithm HandleImpl::m_default_batch_conv_bias_fwd_algo; +DefaultLocalShareForwardAlgorithm HandleImpl::m_default_local_share_fwd_algo; +DefaultLocalShareBackwardDataAlgorithm + HandleImpl::m_default_local_share_bwd_data_algo; +DefaultLocalShareBackwardFilterAlgorithm + HandleImpl::m_default_local_share_bwd_filter_algo; HandleImpl::HandleImpl(megcoreComputingHandle_t computing_handle, HandleType type) diff --git a/dnn/src/naive/handle.h b/dnn/src/naive/handle.h index b11bf7a60..db301b191 100644 --- a/dnn/src/naive/handle.h +++ b/dnn/src/naive/handle.h @@ -13,6 +13,7 @@ #include "src/common/handle_impl.h" #include "src/naive/convolution/algorithms.h" +#include "src/naive/local_share/algorithms.h" #include "src/naive/convolution3d/algorithms.h" #include @@ -39,6 +40,11 @@ class HandleImpl : public HandleImplHelper { m_default_conv3d_bwd_filter_algo; static DefaultBatchConvBiasForwardAlgorithm m_default_batch_conv_bias_fwd_algo; + static DefaultLocalShareForwardAlgorithm m_default_local_share_fwd_algo; + static DefaultLocalShareBackwardDataAlgorithm + m_default_local_share_bwd_data_algo; + static DefaultLocalShareBackwardFilterAlgorithm + m_default_local_share_bwd_filter_algo; //! move KernFunc to alloc_kern()->func, destruct func, and call dispatch template @@ -91,6 +97,18 @@ public: return &m_default_batch_conv_bias_fwd_algo; } + LocalShareForward::Algorithm* default_local_share_fwd_algo() { + return &m_default_local_share_fwd_algo; + } + + LocalShareBackwardData::Algorithm* default_local_share_bwd_data_algo() { + return &m_default_local_share_bwd_data_algo; + } + + LocalShareBackwardFilter::Algorithm* default_local_share_bwd_filter_algo() { + return &m_default_local_share_bwd_filter_algo; + } + Relayout* relayout_opr() override { return get_helper_opr(this); } diff --git a/dnn/src/naive/local_share/algorithms.h b/dnn/src/naive/local_share/algorithms.h new file mode 100644 index 000000000..5b5c88651 --- /dev/null +++ b/dnn/src/naive/local_share/algorithms.h @@ -0,0 +1,41 @@ +/** + * \file dnn/src/naive/local_share/algorithms.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class DefaultLocalShareForwardAlgorithm final: + public megdnn::LocalShareForward::Algorithm { + bool is_reproducible() const override + { return true; } + const char* name() const override + { return "DEFAULT"; } +}; +class DefaultLocalShareBackwardDataAlgorithm final: + public megdnn::LocalShareBackwardData::Algorithm { + bool is_reproducible() const override + { return true; } + const char* name() const override + { return "DEFAULT"; } +}; +class DefaultLocalShareBackwardFilterAlgorithm final: + public megdnn::LocalShareBackwardFilter::Algorithm { + bool is_reproducible() const override + { return true; } + const char* name() const override + { return "DEFAULT"; } +}; +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/local_share/opr_impl.cpp b/dnn/src/naive/local_share/opr_impl.cpp index 85668b392..4be1554e0 100644 --- a/dnn/src/naive/local_share/opr_impl.cpp +++ b/dnn/src/naive/local_share/opr_impl.cpp @@ -152,4 +152,77 @@ void LocalShareBackwardFilterImpl::exec(_megdnn_tensor_in src, StrategyBwdFlt>(src, grad, diff, param()));); } +std::vector +LocalShareForwardImpl::get_all_algorithms(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) { + return {static_cast(handle())->default_local_share_fwd_algo()}; +} + +LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( + const TensorLayout& /* src */, const TensorLayout& /* diff */, + const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, + bool reproducible) { + auto algo = + static_cast(handle())->default_local_share_fwd_algo(); + if (reproducible) { + megdnn_assert(algo->is_reproducible(), + "require reproducible algorithm, but heuristic " + "algorithm(%s) is not " + "reproducible", + algo->name()); + } + return algo; +} + +std::vector +LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) { + return {static_cast(handle()) + ->default_local_share_bwd_data_algo()}; +} + +LocalShareBackwardData::Algorithm* +LocalShareBackwardDataImpl::get_algorithm_heuristic( + const TensorLayout& /* filter */, const TensorLayout& /* diff */, + const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, + bool reproducible) { + auto algo = static_cast(handle()) + ->default_local_share_bwd_data_algo(); + if (reproducible) { + megdnn_assert(algo->is_reproducible(), + "require reproducible algorithm, but heuristic " + "algorithm(%s) is not " + "reproducible", + algo->name()); + } + return algo; +} + +std::vector +LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) { + return {static_cast(handle()) + ->default_local_share_bwd_filter_algo()}; +} + +LocalShareBackwardFilter::Algorithm* +LocalShareBackwardFilterImpl::get_algorithm_heuristic( + const TensorLayout& /* src */, const TensorLayout& /* diff */, + const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */, + bool reproducible) { + auto algo = static_cast(handle()) + ->default_local_share_bwd_filter_algo(); + if (reproducible) { + megdnn_assert(algo->is_reproducible(), + "require reproducible algorithm, but heuristic " + "algorithm(%s) is not " + "reproducible", + algo->name()); + } + return algo; +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/local_share/opr_impl.h b/dnn/src/naive/local_share/opr_impl.h index 11e98e3ff..6d8b26efc 100644 --- a/dnn/src/naive/local_share/opr_impl.h +++ b/dnn/src/naive/local_share/opr_impl.h @@ -27,17 +27,13 @@ public: std::vector get_all_algorithms( const TensorLayout& /*src*/, const TensorLayout& /*filter*/, - const TensorLayout& /*dst*/) override { - return {}; - } + const TensorLayout& /*dst*/) override; Algorithm* get_algorithm_heuristic(const TensorLayout& /*src*/, const TensorLayout& /*filter*/, const TensorLayout& /*dst*/, size_t /*workspace_limit_in_bytes*/, - bool /*reproducible*/) override { - return nullptr; - } + bool /*reproducible*/) override; const char* get_algorithm_set_name() const override { return "DEFAULT"; } }; @@ -55,17 +51,13 @@ public: std::vector get_all_algorithms( const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, - const TensorLayout& /*grad*/) override { - return {}; - } + const TensorLayout& /*grad*/) override; Algorithm* get_algorithm_heuristic(const TensorLayout& /*filter*/, const TensorLayout& /*diff*/, const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, - bool /*reproducible*/) override { - return nullptr; - } + bool /*reproducible*/) override; const char* get_algorithm_set_name() const override { return "DEFAULT"; } }; @@ -83,17 +75,13 @@ public: std::vector get_all_algorithms( const TensorLayout& /*src*/, const TensorLayout& /*diff*/, - const TensorLayout& /*grad*/) override { - return {}; - } + const TensorLayout& /*grad*/) override; Algorithm* get_algorithm_heuristic(const TensorLayout& /*src*/, const TensorLayout& /*diff*/, const TensorLayout& /*grad*/, size_t /*workspace_limit_in_bytes*/, - bool /*reproducible*/) override { - return nullptr; - } + bool /*reproducible*/) override; const char* get_algorithm_set_name() const override { return "DEFAULT"; } }; diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 16207f03b..dd9390e5f 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -14,6 +14,7 @@ #include "megbrain/gopt/basic_arith.h" #include "megbrain/graph/event.h" #include "megbrain/opr/dnn/batch_norm.h" +#include "megbrain/opr/dnn/local.h" #include "megbrain/utils/shared_set.h" #include "megbrain/serialization/opr_shallow_copy.h" #include "megbrain/opr/basic_arith.h" @@ -1358,23 +1359,28 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { return new_pooling_opr.node()->owner_opr(); }; - auto relayout_inp_to_chw = [](OperatorNodeBase* opr, + auto var_to_chw = [](VarNode* inp, VarNode* new_inp) { + if (!inp->shape().eq_shape(new_inp->shape())) { + mgb_assert(inp->shape().ndim == 4 && + inp->format().type() != + TensorFormat::Type::IMAGE2D_PACK4); + mgb_assert(new_inp->shape().ndim == 5 && + new_inp->format().type() == + TensorFormat::Type::IMAGE2D_PACK4); + auto param = megdnn::param::RelayoutFormat(); + param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW; + auto rf = opr::RelayoutFormat::make(new_inp, param); + return rf.node(); + } + return new_inp; + }; + + auto relayout_inp_to_chw = [var_to_chw](OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); VarNodeArray t_inp = new_inp; for (size_t i = 0; i < opr->input().size(); i++) { - if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { - mgb_assert(opr->input(i)->shape().ndim == 4 && - opr->input(i)->format().type() != - TensorFormat::Type::IMAGE2D_PACK4); - mgb_assert(new_inp[i]->shape().ndim == 5 && - new_inp[i]->format().type() == - TensorFormat::Type::IMAGE2D_PACK4); - auto param = megdnn::param::RelayoutFormat(); - param.mode = megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW; - auto rf = opr::RelayoutFormat::make(new_inp[i], param); - t_inp[i] = rf.node(); - } + t_inp[i] = var_to_chw(opr->input(i), new_inp[i]); } auto new_opr = serialization::copy_opr_shallow(*opr, t_inp, opr->config()); @@ -1415,6 +1421,18 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { } }; + /* This helper function converts the first input to the NCHW format to + * handle operations that do not support NHWCD4 format + */ + auto relayout_first_inp_to_chw = + [var_to_chw](OperatorNodeBase* opr, + const VarNodeArray& new_inp) -> OperatorNodeBase* { + mgb_assert(opr->input().size() == new_inp.size()); + VarNodeArray t_inp = new_inp; + t_inp[0] = var_to_chw(opr->input(0), new_inp[0]); + return serialization::copy_opr_shallow(*opr, t_inp, opr->config()); + }; + auto ret = std::make_unique(); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); auto&& replace_func = ret->m_opr_replace_func; @@ -1436,6 +1454,9 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { replace_func[opr::WarpPerspectiveForward::typeinfo()] = replace_warp_perspective_opr; replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr; + replace_func[opr::LocalForward::typeinfo()] = relayout_first_inp_to_chw; + replace_func[opr::GroupLocalForward::typeinfo()] = + relayout_first_inp_to_chw; return ret; } diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 5fe6fa64f..43cc6d4e9 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -9,6 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "megbrain/opr/dnn/local.h" #include "megbrain/test/helper.h" #include "megbrain/gopt/inference.h" @@ -919,6 +920,69 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); } +TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) { + // hwcd4 is only supported in naive handle + NaiveMegDNNHandleScope naive_megdnn_handle; + + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name); + }; + + auto host_x = gen({2, 8, 8, 16}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + + opr::Convolution::Param param; + param.pad_h = param.pad_w = 1; + auto w1 = mkcvar("w1", {4, 8, 3, 3}), + conv1 = opr::Convolution::make(x, w1, param); + + auto w2 = mkcvar("w2", {8, 16, 4, 3, 3, 4}), + local = opr::Local::make(conv1, w2, param); + + auto w3 = mkcvar("w3", {4, 4, 3, 3}), + conv2 = opr::Convolution::make(local, w3, param); + + opr::GroupLocal::Param param_group_local; + param_group_local.pad_h = param_group_local.pad_w = 1; + auto w4 = mkcvar("w4", {2, 8, 16, 2, 3, 3, 2}), + group_local = opr::GroupLocal::make(conv2, w4, param_group_local); + + auto w5 = mkcvar("w5", {4, 4, 3, 3}), + y = opr::Convolution::make(group_local, w5, param); + + SymbolVar y_opt; + unpack_vector( + gopt::optimize_for_inference( + {y}, + gopt::OptimizeForInferenceOptions{}.enable_use_nhwcd4()), + y_opt); + + ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, + find_opr(y_opt).param().format); + + ASSERT_EQ(opr::Local::Param::Format::NCHW, + find_opr(y_opt).param().format); + + ASSERT_EQ(opr::GroupLocal::Param::Format::NCHW, + find_opr(y_opt).param().format); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath(output_file( + "TestGoptInference.ConvertFormatNHWCD4LOCAL.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); +} + TEST(TestGoptInference, ConvertFormatNHWCD4Deconv) { // hwcd4 is only supported in naive handle NaiveMegDNNHandleScope naive_megdnn_handle; -- GitLab