提交 a0231a79 编写于 作者: M Megvii Engine Team

fix(dnn/cuda): fix algo matmul for conv bwd filter

fix fastrun workspace size not available exception and device OOM error caused by the incorrect workspace size calculation of algo matmul of conv bwd filter

GitOrigin-RevId: de96b4fe117ed9691d4a5555261e9f10fa8d2ae4
上级 f3ed59d3
...@@ -21,11 +21,11 @@ using namespace cuda; ...@@ -21,11 +21,11 @@ using namespace cuda;
namespace { namespace {
std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config( std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config(
const ConvolutionBackwardDataImpl::CanonizedFilterMeta& fm, const ConvolutionBackwardFilterImpl::CanonizedFilterMeta& fm,
const TensorLayout& src_layout, const TensorLayout& diff_layout, const TensorLayout& src_layout, const TensorLayout& diff_layout,
const TensorLayout& grad_layout, const TensorLayout& grad_layout,
const ConvolutionBackwardFilterImpl* opr) { const ConvolutionBackwardFilterImpl* opr) {
size_t N = grad_layout.shape[0], IC = fm.icpg, size_t N = src_layout.shape[0], IC = fm.icpg,
OC = fm.ocpg, OH = diff_layout.shape[2], OC = fm.ocpg, OH = diff_layout.shape[2],
OW = diff_layout.shape[3], FH = fm.spatial[0], OW = diff_layout.shape[3], FH = fm.spatial[0],
FW = fm.spatial[1]; FW = fm.spatial[1];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册