提交 0d169524 编写于 作者: M Megvii Engine Team

fix(mgb/cuda): fix conv error when the input tensor is too large

GitOrigin-RevId: 1b1d693795e665630a65a28997078694a78cb214
上级 ee634beb
...@@ -40,6 +40,12 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available(const SizeArgs& args) cons ...@@ -40,6 +40,12 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available(const SizeArgs& args) cons
return false; return false;
} }
// In conv_args.init_conv_desc will call cudnnSetTensor4dDescriptorEx(),which can't
// been supported when total_nr_elems() > 2 ^ 31
if (args.src_layout->total_nr_elems() > INT_MAX ||
args.dst_layout->total_nr_elems() > INT_MAX) {
return false;
}
auto dst_layout = *args.dst_layout; auto dst_layout = *args.dst_layout;
if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
dst_layout.dtype = DType(); dst_layout.dtype = DType();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册