/** * \file dnn/src/aarch64/conv_bias/fp32/algos.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/aarch64/conv_bias/fp32/algos.h" #include "src/aarch64/conv_bias/fp32/stride2_kern.h" #include "src/arm_common/conv_bias/direct/multi_thread_common.h" #include "src/arm_common/conv_bias/postprocess_helper.h" #include "src/fallback/conv_bias/common.h" #include "midout.h" using namespace megdnn; using namespace aarch64; MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) bool ConvBiasImpl::AlgoF32DirectStride2::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; bool aviliable = param.filter_meta.format == param::ConvBias::Format::NCHW && param.src_type.enumv() == DTypeEnum::Float32 && param.filter_type.enumv() == DTypeEnum::Float32 && param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); if (algo_selection_strategy == AlgoSelectionStrategy::HEURISTIC) { bool large_group = param.filter_meta.group >= param.nr_threads; aviliable &= (large_group == m_large_group); } return aviliable; } MIDOUT_END(); return false; } size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 1) { auto wbundle = arm_common::MultithreadDirectConvCommon< float, float>::get_bundle_stride(param, m_large_group); return wbundle.total_size_in_bytes(); } MIDOUT_END(); return 0; } SmallVector ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { return get_kimpls(param); } MIDOUT_END(); return {}; } SmallVector ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( const NCBKernSizeParam& param) const { auto fm = param.filter_meta; auto FH = fm.spatial[0]; size_t N = param.n; size_t IC = param.filter_meta.icpg; size_t OC = param.filter_meta.ocpg; size_t group = fm.group; using Func = std::function; Func conv = nullptr; if (FH == 2) { conv = fp32::conv_stride2::do_conv_2x2_stride2; } else if (FH == 3) { conv = fp32::conv_stride2::do_conv_3x3_stride2; } else if (FH == 5) { conv = fp32::conv_stride2::do_conv_5x5_stride2; } else if (FH == 7) { conv = fp32::conv_stride2::do_conv_7x7_stride2; } WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon< float, float>::get_bundle_stride(param, m_large_group); SmallVector ret_kerns; //! Dense conv and small group if (m_large_group) { //! Channel wise conv and big groups auto exec_one_group = [bundle, conv]( const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { auto fm = kern_param.filter_meta; size_t IC = fm.icpg; size_t OC = fm.ocpg; bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { arm_common::MultithreadDirectConvCommon:: copy_padding_kern_stride(bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { arm_common::MultithreadDirectConvCommon< float, float>::do_conv_kern_stride(bundle, kern_param, ncb_index, conv, {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); } else { auto copy_padding = [bundle](const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); arm_common::MultithreadDirectConvCommon:: copy_padding_kern_stride(bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); auto do_conv = [bundle, conv](const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); arm_common::MultithreadDirectConvCommon< float, float>::do_conv_kern_stride(bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id); }; ret_kerns.push_back({do_conv, {group, N, OC}}); } return ret_kerns; }