提交 e8a16929 编写于 作者: M Megvii Engine Team 提交者: 王彪

feat(dnn/cuda): add heuristic rule for implicit batched gemm large kernel dwconv2d kernels

GitOrigin-RevId: 2d2c213bfdf91e85b2513cafb1dda0f6940199e5
上级 38067472
...@@ -145,9 +145,20 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( ...@@ -145,9 +145,20 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
const bool prefer_dnn_chanwise = slow_cudnn_chanwise_impl || const bool prefer_dnn_chanwise = slow_cudnn_chanwise_impl ||
args.filter_meta.stride[0] != 1 || args.filter_meta.stride[0] != 1 ||
args.filter_meta.stride[1] != 1 || hw_size < 512; args.filter_meta.stride[1] != 1 || hw_size < 512;
//! choose for large kernel cases
size_t fh = args.filter_meta.spatial[2], fw = args.filter_meta.spatial[3];
size_t hi = src[2], wi = src[3];
const bool prefer_dnn_lk_implbmm = hi <= 2 * fh && wi <= 2 * fw;
//! avoid bad case in cudnn, check dnn chanwise impl first //! avoid bad case in cudnn, check dnn chanwise impl first
if (is_chanwise) { if (is_chanwise) {
if (prefer_dnn_chanwise) { if (prefer_dnn_lk_implbmm) {
if (sm_algo_pack.f16_implicit_bmm[0].is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes))
return &sm_algo_pack.f16_implicit_bmm[0];
if (sm_algo_pack.f32_implicit_bmm[0].is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes))
return &sm_algo_pack.f32_implicit_bmm[0];
} else if (prefer_dnn_chanwise) {
if (sm_algo_pack.chanwise.is_available_attribute( if (sm_algo_pack.chanwise.is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes)) args, positive_attr, negative_attr, workspace_limit_in_bytes))
return &sm_algo_pack.chanwise; return &sm_algo_pack.chanwise;
......
...@@ -115,6 +115,19 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl:: ...@@ -115,6 +115,19 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
const AlgoAttribute& negative_attr) { const AlgoAttribute& negative_attr) {
AlgoBase::SizeArgs args(this, filter, diff, grad); AlgoBase::SizeArgs args(this, filter, diff, grad);
//! choose for large kernel cases
size_t fh = args.filter_meta.spatial[2], fw = args.filter_meta.spatial[3];
size_t ho = diff[2], wo = diff[3];
const bool prefer_dnn_lk_implbmm = args.filter_meta.format == Param::Format::NCHW &&
ho <= 2 * fh && wo <= 2 * fw;
if (prefer_dnn_lk_implbmm) {
if (sm_algo_pack.implbmm_nchw_hmma.is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes))
return &sm_algo_pack.implbmm_nchw_hmma[0];
if (sm_algo_pack.implbmm_nchw_fma.is_available_attribute(args, positive_attr, negative_attr, workspace_limit_in_bytes))
return &sm_algo_pack.implbmm_nchw_fma[0];
}
if (args.filter_meta.group > 1 && if (args.filter_meta.group > 1 &&
sm_algo_pack.chanwise.is_available_attribute( sm_algo_pack.chanwise.is_available_attribute(
args, positive_attr, negative_attr, workspace_limit_in_bytes)) { args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册