From a3caa5d3b73c20bdaac2e6d3e59cb8766d083827 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 10 Jan 2021 00:02:30 +0800 Subject: [PATCH] fix(mgb(dnn)): fix convbias cudnnConvBiasActivation GitOrigin-RevId: c0e44feffbe91f299227175e25b1e5c88b9b4724 --- dnn/src/cuda/conv_bias/algo.cpp | 5 +++-- dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp | 5 ++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/dnn/src/cuda/conv_bias/algo.cpp b/dnn/src/cuda/conv_bias/algo.cpp index 8614b7902..0d7739306 100644 --- a/dnn/src/cuda/conv_bias/algo.cpp +++ b/dnn/src/cuda/conv_bias/algo.cpp @@ -166,11 +166,12 @@ std::string ConvBiasForwardImpl::AlgoBase::SizeArgs::to_string() const { megdnn_throw("invalid conv bias nonlinear mode"); } return megdnn_mangle(ssprintf( - "src=%s, filter=%u{%u,%u,%u,%u}, dst=%s, " + "src=%s, filter=%u{%u,%u,%u,%u}, bias=%s, z=%s, dst=%s, " "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s, " "nonlinear_mode=%s", src_layout->to_string().c_str(), fm.group, fm.ocpg, fm.icpg, - fm.spatial[0], fm.spatial[1], dst_layout->to_string().c_str(), + fm.spatial[0], fm.spatial[1], bias_layout->to_string().c_str(), + z_layout->to_string().c_str(), dst_layout->to_string().c_str(), fm.padding[0], fm.padding[1], fm.stride[0], fm.stride[1], fm.dilation[0], fm.dilation[1], !fm.should_flip, src_layout->dtype.name(), dst_layout->dtype.name(), diff --git a/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp b/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp index 19c82aa8c..a0980c1eb 100644 --- a/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp +++ b/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp @@ -27,9 +27,12 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( args.src_layout->dtype == dtype::BFloat16()) { return false; } + if (args.bias_layout->ndim == 0 || - args.bias_layout->eq_shape(*args.dst_layout)) + !conv_bias::check_bias_share_in_channel(*(args.bias_layout), + args.opr->param().format)) { return false; + } auto&& param = args.opr->param(); if (param.format == param::ConvBias::Format::NCHW4_NCHW32 || param.format == param::ConvBias::Format::NCHW32_NCHW4) -- GitLab