diff --git a/dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp b/dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp index 82fc8c7cba1038dab4c405023ef9c398c9a8a1bb..f89a756202aacf614fddb29bcbd61e8ccf7d0850 100644 --- a/dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp +++ b/dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp @@ -57,8 +57,6 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS4::is_available( // param buffer size is 4K, use 3K to store precomputed offset, fh * fw <= // (3*1024/4/2/2) - 1 available &= fh * fw <= 191; - // channels should be multiples of 64 - available &= ci % 64 == 0 && co % 64 == 0; return available; } @@ -75,11 +73,11 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec( auto ws = get_workspace_bundle(args.workspace.raw_ptr, args); auto ws_src = ws.get(0); auto ws_filter = ws.get(1); + auto ws_bias = args.bias_tensor->raw_ptr; auto ws_dst = ws.get(2); void* ws_z = nullptr; if (args.z_layout->ndim > 0) ws_z = ws.get(4); - // auto&& stream = cuda_stream(args.opr->handle()); auto nchw2nchw64 = [&args](const TensorND& src, TensorND&& dst) { if (dst.raw_ptr == nullptr) return; @@ -91,17 +89,40 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec( auto nchw642nchw = [&args](const TensorND& src, TensorND&& dst) { auto relayout = args.handle->create_operator(); relayout->param() = RelayoutFormat::Param::Mode::NCHW64_NCHW; + relayout->param().oc = dst.layout[1]; Workspace dummy; relayout->exec(src, dst, dummy); }; // reformat src nchw2nchw64(*(args.src_tensor), {ws_src, layouts[0]}); // reformat filter - nchw2nchw64(*(args.filter_tensor), {ws_filter, layouts[1]}); + size_t co = args.filter_layout->operator[](0); + if (co % 64 != 0) { + const auto& stream = cuda_stream(args.opr->handle()); + auto ws_filter_ = reinterpret_cast(ws_filter) + + layouts[1].span().dist_byte(); + size_t ci = args.filter_layout->operator[](1), + fh = args.filter_layout->operator[](2), + fw = args.filter_layout->operator[](3); + TensorLayout intermediate({round_up(co, 64_z), ci, fh, fw}, + args.filter_layout->dtype); + ws_bias = ws_filter_ + intermediate.span().dist_byte(); + cuda_check(cudaMemsetAsync(ws_filter_, 0, + intermediate.span().dist_byte(), stream)); + cuda_check(cudaMemcpyAsync(ws_filter_, args.filter_tensor->raw_ptr, + args.filter_layout->span().dist_byte(), + cudaMemcpyDeviceToDevice, stream)); + nchw2nchw64({ws_filter_, intermediate}, {ws_filter, layouts[1]}); + cuda_check(cudaMemcpyAsync(ws_bias, args.bias_tensor->raw_ptr, + co * args.bias_layout->dtype.size(), + cudaMemcpyDeviceToDevice, stream)); + } else { + nchw2nchw64(*(args.filter_tensor), {ws_filter, layouts[1]}); + } // reformat z nchw2nchw64(*(args.z_tensor), {ws_z, layouts[3]}); TensorND src_{ws_src, layouts[0]}, filter_{ws_filter, layouts[1]}, - bias_{args.bias_tensor->raw_ptr, layouts[2]}, z_{ws_z, layouts[3]}, + bias_{ws_bias, layouts[2]}, z_{ws_z, layouts[3]}, dst_{ws_dst, layouts[4]}; auto conv_op = args.opr->handle()->create_operator(); conv_op->param() = args.opr->param(); @@ -128,29 +149,43 @@ ConvBiasForwardImpl::AlgoFallbackNCHWQS4::make_underlying_tensor_layout( size_t co = dst[1], ho = dst[2], wo = dst[3]; size_t fh = filter[2], fw = filter[3]; SmallVector rst; - rst.emplace_back(TensorLayout{{n, ci / 64, hi, wi, 64}, src.dtype}); - rst.emplace_back(TensorLayout{{co, ci / 64, fh, fw, 64}, filter.dtype}); - rst.emplace_back(TensorLayout{{1, co / 64, 1, 1, 64}, bias.dtype}); + rst.emplace_back( + TensorLayout{{n, div_ceil(ci, 64_z), hi, wi, 64}, src.dtype}); + rst.emplace_back( + TensorLayout{{round_up(co, 64_z), div_ceil(ci, 64_z), fh, fw, 64}, + filter.dtype}); + rst.emplace_back( + TensorLayout{{1, div_ceil(co, 64_z), 1, 1, 64}, bias.dtype}); if (z.ndim > 0) { - rst.emplace_back(TensorLayout{{n, co / 64, ho, wo, 64}, z.dtype}); + rst.emplace_back( + TensorLayout{{n, div_ceil(co, 64_z), ho, wo, 64}, z.dtype}); } else { rst.emplace_back(TensorLayout{}); } - rst.emplace_back(TensorLayout{{n, co / 64, ho, wo, 64}, dst.dtype}); - for (auto& i : rst) { - i.init_contiguous_stride(); - } + rst.emplace_back( + TensorLayout{{n, div_ceil(co, 64_z), ho, wo, 64}, dst.dtype}); return rst; } WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_bundle( void* raw_ptr, const SizeArgs& args) const { - size_t ws_size_src = args.src_layout->span().dist_byte(); - size_t ws_size_filter = args.filter_layout->span().dist_byte(); - size_t ws_size_dst = args.dst_layout->span().dist_byte(); auto layouts = make_underlying_tensor_layout( *(args.src_layout), *(args.filter_layout), *(args.bias_layout), *(args.z_layout), *(args.dst_layout)); + size_t ws_size_src = layouts[0].span().dist_byte(); + size_t ws_size_filter = layouts[1].span().dist_byte(); + size_t ws_size_dst = layouts.back().span().dist_byte(); + size_t co = args.filter_layout->operator[](0); + if (co % 64 != 0) { + size_t ci = args.filter_layout->operator[](1), + fh = args.filter_layout->operator[](2), + fw = args.filter_layout->operator[](3); + ws_size_filter += TensorLayout({round_up(co, 64_z), ci, fh, fw}, + args.filter_layout->dtype) + .span() + .dist_byte(); + ws_size_filter += sizeof(int) * round_up(co, 64_z); + } auto conv_op = args.opr->handle()->create_operator(); conv_op->param() = args.opr->param(); using Format = param::ConvBias::Format; @@ -164,7 +199,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_bundle( size_t ws_size_underlying_algo = m_underlying_algo.get_workspace_in_bytes(args_); if (args.z_layout->ndim > 0) { - size_t ws_size_z = args.z_layout->span().dist_byte(); + size_t ws_size_z = layouts[3].span().dist_byte(); return WorkspaceBundle{raw_ptr, {ws_size_src, ws_size_filter, ws_size_dst, ws_size_underlying_algo, ws_size_z}}; diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index a0ece8fc85ce99dbc5d0827e6099a4d68c61320f..ef4500e8fd3454ba7ff155273d0fbd87153060ec 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -4535,7 +4535,7 @@ EnableNCHW64Pass::make_nchw64_converter() { ThinHashMap format_size; bool same_format = true; bool first_touch = false; - Format format; + Format format(Format::NCHW); for (const auto& i : opr->input()) { Format cur; auto iter = format_map.find(i->owner_opr()); @@ -4561,7 +4561,7 @@ EnableNCHW64Pass::make_nchw64_converter() { opr->config()); } - Format max_format; + Format max_format(Format::NCHW); size_t max_size = std::numeric_limits::min(); for (const auto& item : format_size) { if (item.second > max_size) {