提交 c218d4b0 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): fallback conv qs4 support channel not aligend to 64

GitOrigin-RevId: f0d080f35cc3194f6d45dba5fc58986b4f6c06ab
上级 4fe68ac9
......@@ -57,8 +57,6 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS4::is_available(
// param buffer size is 4K, use 3K to store precomputed offset, fh * fw <=
// (3*1024/4/2/2) - 1
available &= fh * fw <= 191;
// channels should be multiples of 64
available &= ci % 64 == 0 && co % 64 == 0;
return available;
}
......@@ -75,11 +73,11 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec(
auto ws = get_workspace_bundle(args.workspace.raw_ptr, args);
auto ws_src = ws.get(0);
auto ws_filter = ws.get(1);
auto ws_bias = args.bias_tensor->raw_ptr;
auto ws_dst = ws.get(2);
void* ws_z = nullptr;
if (args.z_layout->ndim > 0)
ws_z = ws.get(4);
// auto&& stream = cuda_stream(args.opr->handle());
auto nchw2nchw64 = [&args](const TensorND& src, TensorND&& dst) {
if (dst.raw_ptr == nullptr)
return;
......@@ -91,17 +89,40 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec(
auto nchw642nchw = [&args](const TensorND& src, TensorND&& dst) {
auto relayout = args.handle->create_operator<RelayoutFormat>();
relayout->param() = RelayoutFormat::Param::Mode::NCHW64_NCHW;
relayout->param().oc = dst.layout[1];
Workspace dummy;
relayout->exec(src, dst, dummy);
};
// reformat src
nchw2nchw64(*(args.src_tensor), {ws_src, layouts[0]});
// reformat filter
nchw2nchw64(*(args.filter_tensor), {ws_filter, layouts[1]});
size_t co = args.filter_layout->operator[](0);
if (co % 64 != 0) {
const auto& stream = cuda_stream(args.opr->handle());
auto ws_filter_ = reinterpret_cast<dt_byte*>(ws_filter) +
layouts[1].span().dist_byte();
size_t ci = args.filter_layout->operator[](1),
fh = args.filter_layout->operator[](2),
fw = args.filter_layout->operator[](3);
TensorLayout intermediate({round_up(co, 64_z), ci, fh, fw},
args.filter_layout->dtype);
ws_bias = ws_filter_ + intermediate.span().dist_byte();
cuda_check(cudaMemsetAsync(ws_filter_, 0,
intermediate.span().dist_byte(), stream));
cuda_check(cudaMemcpyAsync(ws_filter_, args.filter_tensor->raw_ptr,
args.filter_layout->span().dist_byte(),
cudaMemcpyDeviceToDevice, stream));
nchw2nchw64({ws_filter_, intermediate}, {ws_filter, layouts[1]});
cuda_check(cudaMemcpyAsync(ws_bias, args.bias_tensor->raw_ptr,
co * args.bias_layout->dtype.size(),
cudaMemcpyDeviceToDevice, stream));
} else {
nchw2nchw64(*(args.filter_tensor), {ws_filter, layouts[1]});
}
// reformat z
nchw2nchw64(*(args.z_tensor), {ws_z, layouts[3]});
TensorND src_{ws_src, layouts[0]}, filter_{ws_filter, layouts[1]},
bias_{args.bias_tensor->raw_ptr, layouts[2]}, z_{ws_z, layouts[3]},
bias_{ws_bias, layouts[2]}, z_{ws_z, layouts[3]},
dst_{ws_dst, layouts[4]};
auto conv_op = args.opr->handle()->create_operator<ConvBiasForward>();
conv_op->param() = args.opr->param();
......@@ -128,29 +149,43 @@ ConvBiasForwardImpl::AlgoFallbackNCHWQS4::make_underlying_tensor_layout(
size_t co = dst[1], ho = dst[2], wo = dst[3];
size_t fh = filter[2], fw = filter[3];
SmallVector<TensorLayout> rst;
rst.emplace_back(TensorLayout{{n, ci / 64, hi, wi, 64}, src.dtype});
rst.emplace_back(TensorLayout{{co, ci / 64, fh, fw, 64}, filter.dtype});
rst.emplace_back(TensorLayout{{1, co / 64, 1, 1, 64}, bias.dtype});
rst.emplace_back(
TensorLayout{{n, div_ceil(ci, 64_z), hi, wi, 64}, src.dtype});
rst.emplace_back(
TensorLayout{{round_up(co, 64_z), div_ceil(ci, 64_z), fh, fw, 64},
filter.dtype});
rst.emplace_back(
TensorLayout{{1, div_ceil(co, 64_z), 1, 1, 64}, bias.dtype});
if (z.ndim > 0) {
rst.emplace_back(TensorLayout{{n, co / 64, ho, wo, 64}, z.dtype});
rst.emplace_back(
TensorLayout{{n, div_ceil(co, 64_z), ho, wo, 64}, z.dtype});
} else {
rst.emplace_back(TensorLayout{});
}
rst.emplace_back(TensorLayout{{n, co / 64, ho, wo, 64}, dst.dtype});
for (auto& i : rst) {
i.init_contiguous_stride();
}
rst.emplace_back(
TensorLayout{{n, div_ceil(co, 64_z), ho, wo, 64}, dst.dtype});
return rst;
}
WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_bundle(
void* raw_ptr, const SizeArgs& args) const {
size_t ws_size_src = args.src_layout->span().dist_byte();
size_t ws_size_filter = args.filter_layout->span().dist_byte();
size_t ws_size_dst = args.dst_layout->span().dist_byte();
auto layouts = make_underlying_tensor_layout(
*(args.src_layout), *(args.filter_layout), *(args.bias_layout),
*(args.z_layout), *(args.dst_layout));
size_t ws_size_src = layouts[0].span().dist_byte();
size_t ws_size_filter = layouts[1].span().dist_byte();
size_t ws_size_dst = layouts.back().span().dist_byte();
size_t co = args.filter_layout->operator[](0);
if (co % 64 != 0) {
size_t ci = args.filter_layout->operator[](1),
fh = args.filter_layout->operator[](2),
fw = args.filter_layout->operator[](3);
ws_size_filter += TensorLayout({round_up(co, 64_z), ci, fh, fw},
args.filter_layout->dtype)
.span()
.dist_byte();
ws_size_filter += sizeof(int) * round_up(co, 64_z);
}
auto conv_op = args.opr->handle()->create_operator<ConvBiasForward>();
conv_op->param() = args.opr->param();
using Format = param::ConvBias::Format;
......@@ -164,7 +199,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_bundle(
size_t ws_size_underlying_algo =
m_underlying_algo.get_workspace_in_bytes(args_);
if (args.z_layout->ndim > 0) {
size_t ws_size_z = args.z_layout->span().dist_byte();
size_t ws_size_z = layouts[3].span().dist_byte();
return WorkspaceBundle{raw_ptr,
{ws_size_src, ws_size_filter, ws_size_dst,
ws_size_underlying_algo, ws_size_z}};
......
......@@ -4535,7 +4535,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
ThinHashMap<Format, size_t> format_size;
bool same_format = true;
bool first_touch = false;
Format format;
Format format(Format::NCHW);
for (const auto& i : opr->input()) {
Format cur;
auto iter = format_map.find(i->owner_opr());
......@@ -4561,7 +4561,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
opr->config());
}
Format max_format;
Format max_format(Format::NCHW);
size_t max_size = std::numeric_limits<size_t>::min();
for (const auto& item : format_size) {
if (item.second > max_size) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册