提交 a3caa5d3 编写于 作者: M Megvii Engine Team

fix(mgb(dnn)): fix convbias cudnnConvBiasActivation

GitOrigin-RevId: c0e44feffbe91f299227175e25b1e5c88b9b4724
上级 c418d3cd
......@@ -166,11 +166,12 @@ std::string ConvBiasForwardImpl::AlgoBase::SizeArgs::to_string() const {
megdnn_throw("invalid conv bias nonlinear mode");
}
return megdnn_mangle(ssprintf(
"src=%s, filter=%u{%u,%u,%u,%u}, dst=%s, "
"src=%s, filter=%u{%u,%u,%u,%u}, bias=%s, z=%s, dst=%s, "
"pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s, "
"nonlinear_mode=%s",
src_layout->to_string().c_str(), fm.group, fm.ocpg, fm.icpg,
fm.spatial[0], fm.spatial[1], dst_layout->to_string().c_str(),
fm.spatial[0], fm.spatial[1], bias_layout->to_string().c_str(),
z_layout->to_string().c_str(), dst_layout->to_string().c_str(),
fm.padding[0], fm.padding[1], fm.stride[0], fm.stride[1],
fm.dilation[0], fm.dilation[1], !fm.should_flip,
src_layout->dtype.name(), dst_layout->dtype.name(),
......
......@@ -27,9 +27,12 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.bias_layout->ndim == 0 ||
args.bias_layout->eq_shape(*args.dst_layout))
!conv_bias::check_bias_share_in_channel(*(args.bias_layout),
args.opr->param().format)) {
return false;
}
auto&& param = args.opr->param();
if (param.format == param::ConvBias::Format::NCHW4_NCHW32 ||
param.format == param::ConvBias::Format::NCHW32_NCHW4)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册