diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index a5caeeca04b4ca7ff692382ec562aca1f7b71744..d6a1fefdf0dc447f6b332c1b1d6af479ccdd8e4d 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/internal/opr_header_prologue.h" @@ -314,8 +315,10 @@ public: /** * \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic) * \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw, - * 4*ic) \param[in] bias (1, oc, 1, 1) \param[in] z same as dst \param[out] - * dst (n, oc, oh, ow) or (n, oh, ow, oc) + * 4 * ic) + * \param[in] bias (1, oc, 1, 1) + * \param[in] z same as dst + * \param[out] dst (n, oc, oh, ow) or (n, oh, ow, oc) * * \note if the format is NCHW_WINOGRAD, the filter layout is (alphah, * alphaw, oc, ic) @@ -407,6 +410,26 @@ public: */ static WinogradParam parse_winograd_name(const std::string& algo_name); + /** + * @brief find if there is nchw_nchwxx conv kernel optimized for argment, + * nchw44 used for arm, nchw88 used for x86 + * + * @param src_dtype conv feature map data type + * @param filter_dtype conv filter or weight data type + * @param dst_dtype output data type + * @param fm filter meta param + * @param bias_mode bias mode, no_bias or broadcast or bias + * @param nonline_mode identity or relu or h_swish or sigmoid + * @return true, found a kernel + * @return false, can`t found any kernel + */ + static bool is_nchw_nchwxx_optimized( + const DTypeEnum src_dtype, const DTypeEnum filter_dtype, + const DTypeEnum dst_dtype, + const ConvolutionBase::CanonizedFilterMeta& fm, + const ConvBiasForward::BiasMode bias_mode, + const param::ConvBias::NonlineMode nonline_mode); + protected: CanonizedFilterMeta check_exec( const TensorLayout& src, const TensorLayout& filter, diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp index 2588857823d9844dc746b5dc0c6e68ec064ca92f..5f9418b38da3d94e94c47b12fa5acd05a8c1be00 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp @@ -16,10 +16,10 @@ #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/strategy.h" #include "src/arm_common/elemwise_op.h" +#include "src/common/nchw_nchwxx_valid.h" #include "src/common/opr_delegate.h" #include "midout.h" - using namespace megdnn; using namespace arm_common; using conv_fun = std::function= 4) && fm.group == 1; - bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && - (fh == 2 || fh == 3 || fh == 5 || fh == 7); - bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == fm.stride[1] && - (fm.stride[0] == 1 || fm.stride[0] == 2); - bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS; - bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; - return avaible; + return nchw_nchwxx_valid( + param.src_type.enumv(), param.filter_type.enumv(), + param.dst_type.enumv(), param.filter_meta, param.bias_mode, + param.nonlineMode); } size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace( diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp index 4890f3d939d31b50d559dde5805578ccdb687adb..ed01a68b4ed4360ee08b013194c37f23f2c2d62e 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp @@ -15,6 +15,7 @@ #include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/int8/strategy.h" #include "src/arm_common/elemwise_op.h" +#include "src/common/nchw_nchwxx_valid.h" #include "src/common/opr_delegate.h" #include "midout.h" @@ -214,26 +215,12 @@ static void do_conv_kern(const WorkspaceBundle& bundle, ow, op); } -bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const { - MEGDNN_MARK_USED_VAR(algo_selection_strategy); - auto&& fm = param.filter_meta; - auto FH = fm.spatial[0]; - auto OC = fm.ocpg; - bool avaible = //! src and filter are qint8, dst is qint8 - fm.icpg < 4 && // must be nchw input - ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && - param.filter_type.enumv() == DTypeEnum::QuantizedS8 && - (param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && - (fm.format == param::Convolution::Format::NCHW44) && - (OC % 4 == 0 && OC >= 4) && !fm.should_flip && fm.group == 1 && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && - (fm.stride[0] == 1 || fm.stride[0] == 2) && FH == fm.spatial[1] && - (FH == 2 || FH == 3 || FH == 5 || FH == 7) && fm.group == 1 && - param.bias_mode != BiasMode::BIAS; - return avaible; +bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy) const { + return nchw_nchwxx_valid( + param.src_type.enumv(), param.filter_type.enumv(), + param.dst_type.enumv(), param.filter_meta, param.bias_mode, + param.nonlineMode); } bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred( diff --git a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp index dfd29f77f29361190f8830d8712747eef2bf060d..4232581b0a79e0035a653f5cf4f98fbddd4c6837 100644 --- a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp @@ -16,6 +16,7 @@ #include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" #include "src/arm_common/elemwise_op.h" +#include "src/common/nchw_nchwxx_valid.h" #include "midout.h" @@ -174,23 +175,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle, bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy) const { - auto&& fm = param.filter_meta; - auto fh = fm.spatial[0]; - int oc = fm.ocpg; - int ic = fm.icpg; - bool ok_type = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && - param.filter_type.enumv() == DTypeEnum::QuantizedS8 && - (param.dst_type.enumv() == DTypeEnum::QuantizedS8))) && - (fm.format == param::Convolution::Format::NCHW44_DOT); - bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic < 4); - bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && - (fh == 2 || fh == 3 || fh == 5 || fh == 7); - bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == fm.stride[1] && - (fm.stride[0] == 1 || fm.stride[0] == 2); - bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS; - bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; - return avaible; + return nchw_nchwxx_valid( + param.src_type.enumv(), param.filter_type.enumv(), + param.dst_type.enumv(), param.filter_meta, param.bias_mode, + param.nonlineMode); } size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace( diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp index d23f13f94394aaf7b2d888db026de8e6f34d90bd..f8483f1d2882951058e9a4b8b3a9c1414994444c 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp @@ -16,6 +16,7 @@ #include "src/arm_common/conv_bias/int8x8x16/algos.h" #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" #include "src/arm_common/elemwise_op.h" +#include "src/common/nchw_nchwxx_valid.h" #include "src/common/opr_delegate.h" #include "midout.h" @@ -220,23 +221,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle, bool ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy) const { - auto&& fm = param.filter_meta; - auto fh = fm.spatial[0]; - int oc = fm.ocpg; - bool ok_type = ((param.src_type.enumv() == DTypeEnum::Int8 && - param.filter_type.enumv() == DTypeEnum::Int8 && - (param.dst_type.enumv() == DTypeEnum::Int16))) && - (fm.format == param::Convolution::Format::NCHW44); - bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1; - bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && - (fh == 2 || fh == 3 || fh == 5 || fh == 7); - bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == fm.stride[1] && - (fm.stride[0] == 2 || fm.stride[0] == 1); - bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS && - param.nonlineMode == param::ConvBias::NonlineMode::IDENTITY; - bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv; - return avaible; + return nchw_nchwxx_valid( + param.src_type.enumv(), param.filter_type.enumv(), + param.dst_type.enumv(), param.filter_meta, param.bias_mode, + param.nonlineMode); } size_t ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::get_workspace( diff --git a/dnn/src/common/nchw_nchwxx_valid.cpp b/dnn/src/common/nchw_nchwxx_valid.cpp new file mode 100644 index 0000000000000000000000000000000000000000..56113152e9b9a5efa0a15968800dc6e2dbf85178 --- /dev/null +++ b/dnn/src/common/nchw_nchwxx_valid.cpp @@ -0,0 +1,43 @@ +/** + * \file dnn/src/common/nchw_nchwxx_valid.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "megdnn/oprs/nn.h" +#include "src/common/nchw_nchwxx_valid.h" +using namespace megdnn; +namespace { +using NchwNchwxxFuncInterface = std::function::CanonizedFilterMeta& fm, + const ConvBiasForward::BiasMode bias_mode, + const param::ConvBias::NonlineMode nonline_mode)>; +static SmallVector g_func_vec{ + nchw_nchwxx_valid, + nchw_nchwxx_valid, + nchw_nchwxx_valid, + nchw_nchwxx_valid, + nchw_nchwxx_valid, +}; +} // namespace +bool ConvBiasForward::is_nchw_nchwxx_optimized( + const DTypeEnum src_dtype, const DTypeEnum filter_dtype, + const DTypeEnum dst_dtype, + const ConvolutionBase::CanonizedFilterMeta& fm, + const ConvBiasForward::BiasMode bias_mode, + const param::ConvBias::NonlineMode nonline_mode) { + for (auto& func : g_func_vec) { + if (func(src_dtype, filter_dtype, dst_dtype, fm, bias_mode, + nonline_mode)) { + return true; + } + } + return false; +} \ No newline at end of file diff --git a/dnn/src/common/nchw_nchwxx_valid.h b/dnn/src/common/nchw_nchwxx_valid.h new file mode 100644 index 0000000000000000000000000000000000000000..e1d0b6f09803ffe1256e018fb71c9e8d9662947e --- /dev/null +++ b/dnn/src/common/nchw_nchwxx_valid.h @@ -0,0 +1,161 @@ +/** + * \file dnn/src/common/nchw_nchwxx_valid.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megdnn/oprs.h" +#include "src/fallback/conv_bias/opr_impl.h" +namespace megdnn { +namespace { +enum NchwNchwxxType { + NCHW44_FP32, + NCHW44_INT8, + NCHW44_INT8_INT8_INT16, + NCHW44_INT8_DOT, + NCHW88, +}; +template +static inline bool nchw_nchwxx_valid( + const DTypeEnum src_dtype, const DTypeEnum filter_dtype, + const DTypeEnum dst_dtype, + const ConvolutionBase::CanonizedFilterMeta& fm, + const BiasMode bias_mode, + const param::ConvBias::NonlineMode nonline_mode); + +template <> +inline bool nchw_nchwxx_valid( + const DTypeEnum src_dtype, const DTypeEnum filter_dtype, + const DTypeEnum dst_dtype, + const ConvolutionBase::CanonizedFilterMeta& fm, + const BiasMode bias_mode, + const param::ConvBias::NonlineMode nonline_mode) { + bool ok_type = ((src_dtype == DTypeEnum::Float32 && + filter_dtype == DTypeEnum::Float32 && + (dst_dtype == DTypeEnum::Float32))) && + (fm.format == param::Convolution::Format::NCHW44); + bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY || + nonline_mode == param::ConvBias::NonlineMode::RELU || + nonline_mode == param::ConvBias::NonlineMode::H_SWISH; + bool ok_src_dst = + fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; + + bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && + (fm.spatial[0] == 2 || fm.spatial[0] == 3 || + fm.spatial[0] == 5 || fm.spatial[0] == 7); + bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == fm.stride[1] && + (fm.stride[0] == 1 || fm.stride[1] == 2); + bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; + bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && + ok_slide && ok_conv; + return avaible; +} +template <> +inline bool nchw_nchwxx_valid( + const DTypeEnum src_dtype, const DTypeEnum filter_dtype, + const DTypeEnum dst_dtype, + const ConvolutionBase::CanonizedFilterMeta& fm, + const BiasMode bias_mode, + const param::ConvBias::NonlineMode nonline_mode) { + bool ok_type = ((src_dtype == DTypeEnum::QuantizedS8 && + filter_dtype == DTypeEnum::QuantizedS8 && + (dst_dtype == DTypeEnum::QuantizedS8))) && + (fm.format == param::Convolution::Format::NCHW44); + bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY || + nonline_mode == param::ConvBias::NonlineMode::RELU || + nonline_mode == param::ConvBias::NonlineMode::H_SWISH; + bool ok_src_dst = + fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; + bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && + (fm.spatial[0] == 2 || fm.spatial[0] == 3 || + fm.spatial[0] == 5 || fm.spatial[0] == 7); + bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == fm.stride[1] && + (fm.stride[0] == 1 || fm.stride[1] == 2); + bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; + bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && + ok_slide && ok_conv; + return avaible; +} +template <> +inline bool nchw_nchwxx_valid( + const DTypeEnum src_dtype, const DTypeEnum filter_dtype, + const DTypeEnum dst_dtype, + const ConvolutionBase::CanonizedFilterMeta& fm, + const BiasMode bias_mode, + const param::ConvBias::NonlineMode nonline_mode) { + bool ok_type = + ((src_dtype == DTypeEnum::Int8 && filter_dtype == DTypeEnum::Int8 && + (dst_dtype == DTypeEnum::Int16))) && + (fm.format == param::Convolution::Format::NCHW44); + bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY; + bool ok_src_dst = + fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; + bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && + (fm.spatial[0] == 2 || fm.spatial[0] == 3 || + fm.spatial[0] == 5 || fm.spatial[0] == 7); + bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == fm.stride[1] && + (fm.stride[0] == 2 || fm.stride[0] == 1); + bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; + bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && + ok_slide && ok_conv; + return avaible; +} +template <> +inline bool nchw_nchwxx_valid( + const DTypeEnum src_dtype, const DTypeEnum filter_dtype, + const DTypeEnum dst_dtype, + const ConvolutionBase::CanonizedFilterMeta& fm, + const BiasMode bias_mode, + const param::ConvBias::NonlineMode nonline_mode) { + bool ok_type = ((src_dtype == DTypeEnum::QuantizedS8 && + filter_dtype == DTypeEnum::QuantizedS8 && + (dst_dtype == DTypeEnum::QuantizedS8))) && + (fm.format == param::Convolution::Format::NCHW44_DOT); + bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY || + nonline_mode == param::ConvBias::NonlineMode::RELU || + nonline_mode == param::ConvBias::NonlineMode::H_SWISH; + bool ok_src_dst = + fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; + bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && + (fm.spatial[0] == 2 || fm.spatial[0] == 3 || + fm.spatial[0] == 5 || fm.spatial[0] == 7); + bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == fm.stride[1] && + (fm.stride[0] == 1 || fm.stride[1] == 2); + bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; + bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && + ok_slide && ok_conv; + return avaible; +} + +template <> +inline bool nchw_nchwxx_valid( + const DTypeEnum src_dtype, const DTypeEnum filter_dtype, + const DTypeEnum dst_dtype, + const ConvolutionBase::CanonizedFilterMeta& fm, + const BiasMode bias_mode, + const param::ConvBias::NonlineMode nonline_mode) { + bool ok_type = ((src_dtype == DTypeEnum::Float32 && + filter_dtype == DTypeEnum::Float32 && + (dst_dtype == DTypeEnum::Float32))) && + (fm.format == param::Convolution::Format::NCHW88); + bool ok_src_dst = + fm.icpg < 8 && (fm.ocpg % 8 == 0 && fm.ocpg >= 8) && fm.group == 1; + bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; + bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1; + bool avaible = ok_type && ok_src_dst && ok_slide && ok_conv; + return avaible; +} + +} // namespace +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/x86/conv_bias/f32/algos.h b/dnn/src/x86/conv_bias/f32/algos.h index a231c5656b382daba5bf387e545d4bfa85ae25e4..66c751d76e3e88aa4a4b24c382aec92306ae2fb3 100644 --- a/dnn/src/x86/conv_bias/f32/algos.h +++ b/dnn/src/x86/conv_bias/f32/algos.h @@ -11,6 +11,7 @@ */ #pragma once +#include "src/common/nchw_nchwxx_valid.h" #include "src/x86/conv_bias/opr_impl.h" using namespace megdnn; @@ -29,6 +30,7 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase { const NCBKernParam& kern_param, const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); + public: bool is_reproducible() const override { return true; } const char* name() const override { @@ -61,6 +63,7 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase { const NCBKernParam& kern_param, const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); + public: bool is_reproducible() const override { return true; } const char* name() const override { @@ -163,13 +166,19 @@ public: AlgoSelectionStrategy) const override { auto&& fm = param.filter_meta; - bool ok = (fm.format == param::ConvBias::Format::NCHW88) && - fm.spatial_ndim == 2 && - param.src_type.enumv() == DTypeEnum::Float32 && - param.filter_type.enumv() == DTypeEnum::Float32 && - param.dst_type.enumv() == DTypeEnum::Float32 && - fm.dilation[0] == 1 && fm.dilation[1] == 1; - return ok; + bool nchw_nchw88_ok = nchw_nchwxx_valid( + param.src_type.enumv(), param.filter_type.enumv(), + param.dst_type.enumv(), param.filter_meta, param.bias_mode, + param.nonlineMode); + + bool normal_conv_ok = (fm.format == param::ConvBias::Format::NCHW88) && + fm.spatial_ndim == 2 && + param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_type.enumv() == DTypeEnum::Float32 && + param.dst_type.enumv() == DTypeEnum::Float32 && + fm.dilation[0] == 1 && fm.dilation[1] == 1; + + return nchw_nchw88_ok || normal_conv_ok; }; size_t get_workspace(const NCBKernSizeParam&) const override { return 0; } diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 505399faa9a7795790e6936992b62cab70319138..bae9498fc6c94df5bb98a96f05d6fab9dab3b7cf 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -1816,155 +1816,67 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, } template -static inline bool nchw_nchwxx_valid(const OprType& opr, - const VarNodeArray& new_inp, - const size_t pack_size, bool is_dense, - bool is_dot = false); -template <> -inline bool nchw_nchwxx_valid( - const opr::ConvolutionForward& opr, const VarNodeArray& new_inp, - const size_t pack_size, bool is_dense, bool is_dot) { - auto& filter_shape = new_inp[1]->shape(); - auto filter_dtype = new_inp[1]->dtype(); - bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 || - filter_dtype.enumv() == DTypeEnum::Int8; - const size_t oc = filter_shape[0]; - const size_t ic = filter_shape[1]; - bool is_like_nchw_nchwxx = - is_dense && oc % pack_size == 0 && ic < pack_size; - if (!is_like_nchw_nchwxx) { +static inline bool nchw_nchwxx_valid( + const OprType& opr, const VarNodeArray& new_inp, const size_t pack_size, + megdnn::param::ConvBias::NonlineMode nonline_mode = + megdnn::param::ConvBias::NonlineMode::IDENTITY, + bool is_dot = false) { + auto& src_node = new_inp[0]; + auto& filter_node = new_inp[1]; + auto dst_node = opr.output(0); + if (filter_node->shape().ndim != 4) { return false; } - SmallVector layouts; - - //! src - layouts.push_back( - {new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()}); - - //! weight - layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2], - filter_shape[3], filter_shape[1], pack_size}, - new_inp[1]->dtype(), - new_inp[1]->format()}); - - auto out0 = opr.output(0); - auto& out_shape = out0->shape(); - //! FIXME: return false if oc is invalid - layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2], - out_shape[3], pack_size}, - out0->dtype(), - out0->format()}); - - auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node()) - ->create_operator(); - megdnn_conv.get()->param() = opr.param(); - //! set by dtype - switch (pack_size) { - case 4: - if (is_dot && is_int8) { - megdnn_conv.get()->param().format = - megdnn::param::Convolution::Format::NCHW44_DOT; - } else { - megdnn_conv.get()->param().format = - megdnn::param::Convolution::Format::NCHW44; - } - break; - case 8: - megdnn_conv.get()->param().format = - megdnn::param::Convolution::Format::NCHW88; - break; - - default: - break; - } - - bool find_valid_algo = false; - auto algos = megdnn_conv.get()->get_all_algorithms(layouts[0], layouts[1], - layouts[2]); - for (auto i : algos) { - if (i->type() != nullptr) { - find_valid_algo = true; + megdnn::ConvolutionBase::CanonizedFilterMeta fm; + fm.format = megdnn::param::Convolution::Format::NCHW; + fm.should_flip = + opr.param().mode == megdnn::ConvBiasForward::Mode::CONVOLUTION; + fm.group = 1; + fm.spatial_ndim = 2; + fm.ocpg = filter_node->shape()[0]; + fm.icpg = filter_node->shape()[1]; + fm.spatial[0] = filter_node->shape()[2]; + fm.spatial[1] = filter_node->shape()[3]; + fm.stride[0] = opr.param().stride_h; + fm.stride[1] = opr.param().stride_w; + fm.padding[0] = opr.param().pad_h; + fm.padding[1] = opr.param().pad_w; + fm.dilation[0] = opr.param().dilate_h; + fm.dilation[1] = opr.param().dilate_w; + + megdnn::ConvBiasForward::BiasMode bias_mode = + megdnn::ConvBiasForward::BiasMode::NO_BIAS; + if (std::is_same::value) { + auto& bias_shape = new_inp[2]->shape(); + if (bias_shape.ndim == 0) { + bias_mode = megdnn::ConvBiasForward::BiasMode::NO_BIAS; + } else if (bias_shape.eq_shape(dst_node->shape())) { + bias_mode = megdnn::ConvBiasForward::BiasMode::BIAS; + } else { + //! just check the ndim, the detail shape check is in check_exec + mgb_assert(bias_shape.ndim == dst_node->shape().ndim); + bias_mode = + megdnn::ConvBiasForward::BiasMode::BROADCAST_CHANNEL_BIAS; } } - return find_valid_algo; -} -template <> -inline bool nchw_nchwxx_valid( - const opr::ConvBiasForward& opr, const VarNodeArray& new_inp, - const size_t pack_size, bool is_dense, bool is_dot) { - auto& filter_shape = new_inp[1]->shape(); - auto filter_dtype = new_inp[1]->dtype(); - bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 || - filter_dtype.enumv() == DTypeEnum::Int8; - const size_t oc = filter_shape[0]; - const size_t ic = filter_shape[1]; - bool is_like_nchw_nchwxx = - is_dense && oc % pack_size == 0 && ic < pack_size; - if (!is_like_nchw_nchwxx) { - return false; - } - SmallVector layouts; - - //! src - layouts.push_back( - {new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()}); - - //! weight - layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2], - filter_shape[3], filter_shape[1], pack_size}, - new_inp[1]->dtype(), - new_inp[1]->format()}); - - auto& bias_shape = new_inp[2]->shape(); - layouts.push_back({{bias_shape[0], bias_shape[1] / pack_size, bias_shape[2], - bias_shape[3], pack_size}, - new_inp[2]->dtype(), - new_inp[2]->format()}); - - auto out0 = opr.output(0); - auto& out_shape = out0->shape(); - //! FIXME: return false if oc is invalid - layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2], - out_shape[3], pack_size}, - out0->dtype(), - out0->format()}); - - // megdnn::ConvolutionForward - auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node()) - ->create_operator(); - megdnn_conv.get()->param() = opr.param(); - - //! FIXME: set by dtype - switch (pack_size) { - case 4: - if (is_dot && is_int8) { - megdnn_conv.get()->param().format = - megdnn::param::Convolution::Format::NCHW44_DOT; - } else { - megdnn_conv.get()->param().format = - megdnn::param::Convolution::Format::NCHW44; - } - break; - case 8: - megdnn_conv.get()->param().format = - megdnn::param::Convolution::Format::NCHW88; - break; - - default: - break; - } - bool find_valid_algo = false; - auto algos = megdnn_conv.get()->get_all_algorithms( - layouts[0], layouts[1], layouts[2], {}, layouts[3]); - for (auto i : algos) { - if (i->type() != nullptr) { - find_valid_algo = true; + if (pack_size == 4) { + if (is_dot && filter_node->dtype().enumv() == DTypeEnum::QuantizedS8) { + fm.format = megdnn::param::Convolution::Format::NCHW44_DOT; + } else { + fm.format = megdnn::param::Convolution::Format::NCHW44; } + } else if (pack_size == 8) { + fm.format = megdnn::param::Convolution::Format::NCHW88; + } else { + mgb_assert(0, "only support nchw44 nchw88"); } - return find_valid_algo; + return megdnn::ConvBiasForward::is_nchw_nchwxx_optimized( + src_node->dtype().enumv(), filter_node->dtype().enumv(), + dst_node->dtype().enumv(), fm, bias_mode, nonline_mode); } + void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { using RelayoutMode = RelayoutPlaceholder::LayoutType; using TestFilterResult = std::pair; @@ -1984,19 +1896,6 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { megdnn::param::Pooling::Format pooling_format = megdnn::param::Pooling::Format::NCHW88; std::string convter_pass_name = "conv_format_nchw88"; -#if MEGDNN_AARCH64 || MEGDNN_ARMv7 - if (pack_c_size == 8) { - mgb_log_error( - "runtime backend is ARM, but nchw88 only support X86, you may " - "have performance loss\n"); - } -#elif MEGDNN_X86 - if (pack_c_size == 4) { - mgb_log_error( - "runtime backend is X86, but nchw44 only support arm, you may " - "have performance loss\n"); - } -#endif if (pack_c_size == 4) { weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; @@ -2053,10 +1952,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { mgb_assert(conv_opr.param().format == megdnn::param::Convolution::Format::NCHW, "ConvertFormat Pass only support converting NCHW to NCHWXX"); - bool is_dense = conv_opr.param().sparse == - megdnn::param::Convolution::Sparse::DENSE; bool valid_nchw_nchw44 = - nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense); + nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size); auto is_trans = test_trans_nchwxx( conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, conv_opr.param().stride_w, valid_nchw_nchw44); @@ -2133,10 +2030,9 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { mgb_assert(conv_bias_opr.param().format == megdnn::param::ConvBias::Format::NCHW, "ConvertFormat Pass only support converting NCHW to NCHWXX"); - bool is_dense = conv_bias_opr.param().sparse == - megdnn::param::Convolution::Sparse::DENSE; - bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp, - pack_c_size, is_dense); + bool valid_nchw_nchw44 = + nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size, + conv_bias_opr.param().nonlineMode); auto is_trans = test_trans_nchwxx( conv_bias_opr.param().sparse, new_inp[1], conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w, @@ -2371,13 +2267,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { MIDOUT_B("EnableNchw44DotPass::make") auto ret = std::make_unique(); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); -//! First is whether the conv can trans to nchwxx, second is the filter -//! trans mode -#if MEGDNN_X86 - mgb_log_error( - "backend is X86, but nchw44_dot only support arm, you may have " - "performance loss\n"); -#endif + //! First is whether the conv can trans to nchwxx, second is the filter + //! trans mode using RelayoutMode = RelayoutPlaceholder::LayoutType; struct TestTransResult { @@ -2453,14 +2344,12 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { megdnn::param::Convolution::Format::NCHW, "ConvertFormat Pass only support converting NCHW to " "NCHW44_DOT"); - bool is_dense = conv_opr.param().sparse == - megdnn::param::Convolution::Sparse::DENSE; - bool valid_nchw_nchw44 = - nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense); + bool valid_nchw_nchw44 = nchw_nchwxx_valid( + conv_opr, new_inp, pack_c_size, + megdnn::param::ConvBias::NonlineMode::IDENTITY, true); auto is_trans = test_trans_nchw44_dot( conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, conv_opr.param().stride_w, valid_nchw_nchw44); - //! can not trans to nchwxx if (is_trans.trans_type == TransType::TRANS_NONE) { mgb_assert(new_inp[1]->shape().ndim == 4 || @@ -2533,10 +2422,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { mgb_assert(conv_bias_opr.param().format == megdnn::param::ConvBias::Format::NCHW, "ConvertFormat Pass only support converting NCHW to NCHWXX"); - bool is_dense = conv_bias_opr.param().sparse == - megdnn::param::Convolution::Sparse::DENSE; - bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp, - pack_c_size, is_dense); + bool valid_nchw_nchw44 = + nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size, + conv_bias_opr.param().nonlineMode, true); auto is_trans = test_trans_nchw44_dot( conv_bias_opr.param().sparse, new_inp[1], conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w, diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 093f704f0b23a6529eae4bcea48920a27a085c0b..ee8dc0ecd6df091bc54a44f903a958ca60148582 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -2913,7 +2913,8 @@ TEST(TestGoptInference, ConvertFormatNCHW88) { opr::Convolution::Param param_conv; param_conv.pad_h = param_conv.pad_w = 1; auto w1 = mkcvar("w1", {8, 3, 3, 3}), - conv1 = opr::Convolution::make(x, w1, param_conv); + conv1 = opr::Convolution::make(x, w1, param_conv, {}, + OperatorNodeConfig("conv1")); //! channel wise opr::ConvBias::Param param_conv_bias; param_conv_bias.pad_h = param_conv_bias.pad_w = 1; @@ -2954,7 +2955,8 @@ TEST(TestGoptInference, ConvertFormatNCHW88) { options.enable_nchw88(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); } - + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW88, + find_opr(y_opt, "conv1").param().format); ASSERT_EQ(opr::ConvBias::Param::Format::NCHW88, find_opr(y_opt).param().format); @@ -3084,13 +3086,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { options.enable_nchw44(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); -#if MEGDNN_AARCH64 || MEGDNN_ARMV7 ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, find_opr(y_opt, "conv1").param().format); -#else - ASSERT_EQ(opr::Convolution::Param::Format::NCHW, - find_opr(y_opt, "conv1").param().format); -#endif ASSERT_EQ(opr::Convolution::Param::Format::NCHW, find_opr(y_opt, "conv1_f1").param().format); ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, @@ -3325,17 +3322,10 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { options.enable_nchw44_dot(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); -#if MEGDNN_AARCH64 || MEGDNN_ARMV7 ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, find_opr(y_opt, "conv1").param().format); ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, find_opr(y_opt, "conv1_3_q").param().format); -#else - ASSERT_EQ(opr::Convolution::Param::Format::NCHW, - find_opr(y_opt, "conv1").param().format); - ASSERT_EQ(opr::Convolution::Param::Format::NCHW, - find_opr(y_opt, "conv1_3_q").param().format); -#endif ASSERT_EQ(opr::Convolution::Param::Format::NCHW, find_opr(y_opt, "conv1_f1").param().format); ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index 588464738c59d4c4330edd5f779086cb793c5b7c..1180695d4ced157fa7a24f9bed0b65940f7bc7cf 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -611,11 +611,11 @@ public: "%s: input shapes (%s %s, %s %s) -> (%s %s): algo=%s " "workspace=%.2fMiB reproducible=%d", mgb_opr->dyn_typeinfo()->name, - layouts[0].TensorShape::to_string().c_str(), + layouts[0].to_string().c_str(), layouts[0].dtype.name(), - layouts[1].TensorShape::to_string().c_str(), + layouts[1].to_string().c_str(), layouts[1].dtype.name(), - layouts[layouts.size() - 1].TensorShape::to_string().c_str(), + layouts[layouts.size() - 1].to_string().c_str(), layouts[layouts.size() - 1].dtype.name(), algo->name(), workspace / (1024 * 1024.0), algo->is_reproducible());