From 057343557c70e1a5ae88500c4b58d585dcba9356 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 2 Feb 2023 19:10:40 +0800 Subject: [PATCH] fix(opencl): support OpenCL NHWC_NHWCD4I channel not a multiple of 4 GitOrigin-RevId: 1e85cfc8589219a5339cd4a3b60fe50a7f20230a --- dnn/src/common/relayout_format.cpp | 23 +++++++- dnn/src/naive/relayout_format/opr_impl.cpp | 66 +++++++++++++++++++++- 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/dnn/src/common/relayout_format.cpp b/dnn/src/common/relayout_format.cpp index 6272aa917..be77138c9 100644 --- a/dnn/src/common/relayout_format.cpp +++ b/dnn/src/common/relayout_format.cpp @@ -91,7 +91,6 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, TensorLayout& ds dst[6] = 8; break; case Param::Mode::NHWC_NHWCD4: - case Param::Mode::NHWC_NHWCD4I: megdnn_assert(src.ndim == 4); //! channel mod 4 should == 4 megdnn_assert(src[3] % 4 == 0); @@ -102,6 +101,15 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, TensorLayout& ds dst[3] = src[2]; dst[4] = 4; break; + case Param::Mode::NHWC_NHWCD4I: + megdnn_assert(src.ndim == 4); + dst.ndim = 5; + dst[0] = src[0]; + dst[1] = src[1]; + dst[2] = (src[3] + 3) / 4; + dst[3] = src[2]; + dst[4] = 4; + break; case Param::Mode::NHWCD4I_NHWC: case Param::Mode::NHWCD4_NHWC: megdnn_assert(src.ndim == 5); @@ -587,13 +595,24 @@ void RelayoutFormat::deduce_exec_layout( exec_dst = dst; break; case Param::Mode::NHWC_NHWCD4: - case Param::Mode::NHWC_NHWCD4I: // src is {N, H, W, C}, // dst is {N, H, CB, W, 4} exec_src = src.reshape({src[0], src[1], src[2], src[3] / 4, 4}) .dimshuffle({0, 1, 3, 2, 4}); exec_dst = dst; break; + case Param::Mode::NHWC_NHWCD4I: + // src is {N, H, W, C}, + // dst is {N, H, CB, W, 4} + exec_src = src; + exec_src[3] = (exec_src[3] + 3) / 4 * 4; + exec_src.stride[2] = exec_src[3] * exec_src.stride[3]; + exec_src.stride[1] = exec_src[2] * exec_src.stride[2]; + exec_src.stride[0] = exec_src[1] * exec_src.stride[1]; + exec_src = exec_src.reshape({src[0], src[1], src[2], (src[3] + 3) / 4, 4}) + .dimshuffle({0, 1, 3, 2, 4}); + exec_dst = dst; + break; case Param::Mode::NHWCD4I_NHWC: case Param::Mode::NHWCD4_NHWC: // src is {N, H, CB, W, 4} diff --git a/dnn/src/naive/relayout_format/opr_impl.cpp b/dnn/src/naive/relayout_format/opr_impl.cpp index 9cc61412a..c7d95667b 100644 --- a/dnn/src/naive/relayout_format/opr_impl.cpp +++ b/dnn/src/naive/relayout_format/opr_impl.cpp @@ -134,6 +134,25 @@ void padding_src_to_workspace( } } +template +void padding_nhwc_src_to_workspace( + dtype* dptr, const dtype* sptr, size_t N, size_t IH, size_t IW, size_t IC) { + size_t IC4 = (IC + 3) / 4 * 4; + size_t HW = IH * IW; + for (size_t n = 0; n < N; n++) { + for (size_t idx = 0; idx < HW; idx++) { + for (size_t c = 0; c < IC4; c++) { + if (c < IC) { + *dptr = sptr[n * IC * HW + idx * IC + c]; + } else { + *dptr = 0; + } + dptr++; + } + } + } +} + template void padding_to_workspace( dtype* dptr, const dtype* sptr, const TensorLayout& src_layout, @@ -318,6 +337,15 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes( size_t IW = src[3]; return N * IC4 * IH * IW * src.dtype.size(); } + case Param::Mode::NHWC_NHWCD4I: { + if (src[3] % 4 == 0) + return 0; + size_t IC4 = dst[2] * 4; + size_t N = src[0]; + size_t IH = src[1]; + size_t IW = src[2]; + return N * IC4 * IH * IW * src.dtype.size(); + } case Param::Mode::INTER_WEIGHT_DENSEI_DOT: { if (src[1] % 4 == 0) return 0; @@ -509,7 +537,43 @@ void RelayoutFormatImpl::exec( cb(Uint8, dt_uint8); #undef cb default: - megdnn_assert(0); + megdnn_assert( + 0, "NCHW_NHWCD4I not support dtype %s", + src.layout.dtype.name()); + } + exec_src_nd = TensorND{workspace.raw_ptr, exec_src_nd.layout}; + } + } else if (param().mode == Param::Mode::NHWC_NHWCD4I) { + size_t N = src.layout[0]; + size_t IC = src.layout[3]; + size_t IH = src.layout[1]; + size_t IW = src.layout[2]; + //! ic % 4 != 0 + if ((IC & 0x3)) { + switch (src.layout.dtype.enumv()) { +#define cb(name, ctype) \ + case (DTypeEnum::name): { \ + MIDOUT_BEGIN( \ + megdnn_naive_relayout_format, ctype, \ + midout_iv(Param::Mode::NHWC_NHWCD4I)) { \ + MEGDNN_DISPATCH_CPU_KERN( \ + m_handle, padding_nhwc_src_to_workspace( \ + workspace.ptr(), \ + src.compatible_ptr(), N, IH, IW, IC);); \ + } \ + MIDOUT_END(); \ + break; \ + } + cb(Float32, dt_float32); + DNN_INC_FLOAT16(cb(Float16, dt_float16)); + cb(Quantized8Asymm, dt_uint8); + cb(QuantizedS8, dt_int8); + cb(Uint8, dt_uint8); +#undef cb + default: + megdnn_assert( + 0, "NHWC_NHWCD4I not support dtype %s", + src.layout.dtype.name()); } exec_src_nd = TensorND{workspace.raw_ptr, exec_src_nd.layout}; } -- GitLab