diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index d60153d400227e768c73d45f17a99da58c3cf9d3..b3ef8155cc70d46cff7fc4c0bb908cac69e404c8 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -434,7 +434,7 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), Doc('MK8', 'Split 8 from M and K, better for neon compute:' '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' - 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), + 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:' 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) @@ -858,7 +858,10 @@ when the ``I`` suffix is present. 'NCHW_NCHW88_CONV_CHAN_WEIGHT', 'NCHW_NCHW88_CONV_GROUP_WEIGHT', 'NCHW_NCHW88', - 'NCHW88_NCHW') + 'NCHW88_NCHW', + 'NCHW_NCHW4_IC_SMALL', + 'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT', + ) ) diff --git a/dnn/src/common/relayout_format.cpp b/dnn/src/common/relayout_format.cpp index 878f16f22349d045306fb8bda5a5bb0079021b4b..ee6d50c527927f5f70cf825c7222d7aed60d42bf 100644 --- a/dnn/src/common/relayout_format.cpp +++ b/dnn/src/common/relayout_format.cpp @@ -28,6 +28,26 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, dst[3] = src[3]; dst[4] = 4; break; + case Param::Mode::NCHW_NCHW4_IC_SMALL: + dst.ndim = 5; + megdnn_assert(src[1] <= 4_z, "ic should be less equal 4"); + dst[0] = src[0]; + dst[1] = div_ceil(src[1], 4_z); + dst[2] = src[2]; + dst[3] = src[3]; + dst[4] = 4; + break; + case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: + megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 4"); + megdnn_assert(src[1] <= 4_z, "ic should be less equal 4"); + dst.ndim = 5; + dst[0] = src[0]; + dst[1] = div_ceil(src[1], 4_z); + dst[2] = src[2]; + dst[3] = src[3]; + dst[4] = 4; + break; + case Param::Mode::NCHW_NCHW88: dst.ndim = 5; dst[0] = src[0]; @@ -276,6 +296,8 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: + case Param::Mode::NCHW_NCHW4_IC_SMALL: + case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: CHECK_SRC(DefaultTensorFormat::make()); dst = src; break; @@ -374,6 +396,23 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, exec_dst = dst; } break; + + case Param::Mode::NCHW_NCHW4_IC_SMALL: + case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: + // nchw to nchw4c or oihw to oihw4i + { + TensorLayout work_space_layout( + {src[0], round_up(src[1], 4_z), src[2], src[3]}, + src.dtype, src.format); + exec_src = work_space_layout + .reshape({src[0], div_ceil(src[1], 4_z), 4, + src[2], src[3]}) + .dimshuffle({0, 1, 3, 4, 2}); + exec_dst = dst; + } + break; + + case Param::Mode::NCHW_NHWCD4: case Param::Mode::NCHW_NHWCD4I: // src is {N, C, H, W} diff --git a/dnn/src/cuda/relayout_format/opr_impl.cpp b/dnn/src/cuda/relayout_format/opr_impl.cpp index 6004630fd094066004404e8f343f97692ebf48ff..abfc6a21e047c1ca5a9e9347a0c06a8c25b4e21e 100644 --- a/dnn/src/cuda/relayout_format/opr_impl.cpp +++ b/dnn/src/cuda/relayout_format/opr_impl.cpp @@ -11,6 +11,7 @@ #include "src/cuda/relayout_format/opr_impl.h" #include "src/cuda/handle.h" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; @@ -20,15 +21,22 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, auto src_dtype = src.layout.dtype; megdnn_assert( param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || - param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4, + param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4 || + param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL || + param().mode == + Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT, "relayout format of cuda only support NCHW4->CHWN4 or " - "CHWN4->NCHW4"); - if (src_dtype.enumv() == DTypeEnum::QuantizedS8) { + "CHWN4->NCHW4 or NCHW->NCHW4"); + if ((param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || + param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4) && + src_dtype.enumv() == DTypeEnum::QuantizedS8) { size_t row = 0, col = 0; if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) { row = src.layout[0], col = src.layout[1] * src.layout[2] * src.layout[3]; } else { + megdnn_assert(param().mode == + param::RelayoutFormat::Mode::CHWN4_NCHW4); row = src.layout[0] * src.layout[1] * src.layout[2], col = src.layout[3]; } @@ -43,6 +51,27 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, return handle()->create_operator()->exec(trans_in, trans_out); } + if ((param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL || + param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) && + src.layout[1] % 4 != 0) { + megdnn_assert(src.raw_ptr != dst.raw_ptr && src.layout.ndim == 4, + "The mode of NCHW_NCHW4 and NCHW_NCHW4_CONV_DENSE_WEIGHT " + "of RelayoutFormat opr(cuda backend) does not support " + "src.ptr == dst.ptr"); + megdnn_assert(src.layout[1] <= 4); + cuda_check(cudaMemsetAsync(dst.raw_ptr, 0, + dst.layout.span().dist_byte(), + cuda_stream(this->handle()))); + TensorLayout exec_dst_layout = dst.layout; + exec_dst_layout[4] = src.layout[1]; + TensorLayout exec_src_layout = + src.layout + .reshape({src.layout[0], src.layout[1], 1, + src.layout[2], src.layout[3]}) + .dimshuffle({0, 2, 3, 4, 1}); + return handle()->create_operator()->exec( + {src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout}); + } TensorLayout exec_src, exec_dst; deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst); TensorND exec_src_nd{src.raw_ptr, exec_src}; diff --git a/dnn/src/naive/relayout_format/opr_impl.cpp b/dnn/src/naive/relayout_format/opr_impl.cpp index 731942e0394c9f3c002765dbf2fb83311f6bca64..21d4e811226e0454b38afd935a04ef03486551cb 100644 --- a/dnn/src/naive/relayout_format/opr_impl.cpp +++ b/dnn/src/naive/relayout_format/opr_impl.cpp @@ -79,6 +79,7 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src, } cb(Float32, dt_float32); + cb(QuantizedS8, dt_qint8); default: megdnn_assert(0); #undef cb @@ -138,7 +139,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, return n * c * h * w * src.dtype.size(); } case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: { - megdnn_assert(src.ndim == 4, "src must be oihw ,nmdim == 5"); + megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5"); megdnn_assert(src[0] % 8 == 0, "NCHW_NCHW88_CONV_DENSE_WEIGHT oc must align to 8"); if (src[1] % 8 == 0) @@ -150,7 +151,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, return oc * ic * h * w * src.dtype.size(); } case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: { - megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5"); + megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5"); megdnn_assert(src[1] % 8 == 0, "NCHW_NCHW88_CONV_CHAN_WEIGHT oc per group must " "align to 8"); @@ -164,7 +165,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, return group * ocpg * icpg * h * w * src.dtype.size(); } case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: { - megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5"); + megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5"); if (src[0] % 8 == 0) return 0; size_t group = round_up(src[0], 8_z); @@ -174,6 +175,27 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, size_t w = src[4]; return group * ocpg * icpg * h * w * src.dtype.size(); } + + case Param::Mode::NCHW_NCHW4_IC_SMALL: { + if (src[1] % 4 == 0) + return 0; + size_t n = src[0]; + size_t c = round_up(src[1], 4_z); + size_t h = src[2]; + size_t w = src[3]; + return n * c * h * w * src.dtype.size(); + } + case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: { + megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5"); + if (src[1] % 4 == 0) + return 0; + size_t oc = src[0]; + size_t ic = round_up(src[1], 4_z); + size_t h = src[2]; + size_t w = src[3]; + return oc * ic * h * w * src.dtype.size(); + } + default: return 0; } @@ -244,31 +266,28 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, exec_src_nd.raw_ptr = workspace.raw_ptr; } } else if (param().mode == Param::Mode::NCHW_NCHW88) { - size_t ic = src.layout[1]; - if (ic % 8 != 0) { - padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8); - exec_src_nd.raw_ptr = workspace.raw_ptr; - } +#define cb(_idx, _pack_size) \ + size_t val = src.layout[_idx]; \ + if (val % _pack_size != 0) { \ + padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \ + _pack_size); \ + exec_src_nd.raw_ptr = workspace.raw_ptr; \ + } + cb(1, 8); + } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { megdnn_assert(src.layout[0] % 8 == 0); - size_t ic = src.layout[1]; - if (ic % 8 != 0) { - padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8_z); - exec_src_nd.raw_ptr = workspace.raw_ptr; - } + cb(1, 8); } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) { - size_t group = src.layout[0]; - if (group % 8 != 0) { - padding_to_workspace({workspace.raw_ptr, exec_src}, src, 0, 8_z); - exec_src_nd.raw_ptr = workspace.raw_ptr; - } + cb(0, 8); } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) { megdnn_assert(src.layout[1] % 8 == 0); - size_t ic = src.layout[2]; - if (ic % 8 != 0) { - padding_to_workspace({workspace.raw_ptr, exec_src}, src, 2, 8_z); - exec_src_nd.raw_ptr = workspace.raw_ptr; - } + cb(2, 8); + } else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) { + cb(1, 4); + } else if (param().mode == + Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { + cb(1, 4); } m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); } diff --git a/dnn/test/cuda/relayout_format.cpp b/dnn/test/cuda/relayout_format.cpp index 5f0c0aba44ebf39e8321f54144aab9c42b83c3ab..c4a99fc051f084d6cc9e112dbe087c526de97fc5 100644 --- a/dnn/test/cuda/relayout_format.cpp +++ b/dnn/test/cuda/relayout_format.cpp @@ -8,6 +8,7 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "megdnn/dtype.h" #include "megdnn/oprs.h" #include "test/common/checker.h" #include "test/common/rng.h" @@ -30,4 +31,25 @@ TEST_F(CUDA, RELAYOUT_FORMAT) { checker.execs({{22, 23, 24, 25, 4}, {}}); } +TEST_F(CUDA, RELAYOUT_FORMAT_NCHW4) { + Checker checker(handle_cuda()); + UniformIntRNG rng{-50, 50}; + param::RelayoutFormat param; + param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL; + + for (DType dtype : + std::vector({dtype::QuantizedS8{0.1f}, dtype::Float32{}})) { + checker.set_dtype(0, dtype).set_rng(0, &rng); + + checker.set_param(param).execs({{2, 4, 35, 36}, {}}); + checker.set_param(param).execs({{2, 3, 35, 36}, {}}); + checker.set_param(param).execs({{2, 1, 35, 36}, {}}); + param.mode = param::RelayoutFormat::Mode:: + NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT; + checker.set_param(param).execs({{4, 3, 3, 3}, {}}); + checker.set_param(param).execs({{4, 4, 3, 3}, {}}); + checker.set_param(param).execs({{1, 4, 3, 3}, {}}); + } +} + // vim: syntax=cpp.doxygen