diff --git a/dnn/src/common/relayout_format.cpp b/dnn/src/common/relayout_format.cpp index a175e36a3f591393c1b798a5f30ca88dea483a92..ac913ca27bb2760d783be9b317d510742a724624 100644 --- a/dnn/src/common/relayout_format.cpp +++ b/dnn/src/common/relayout_format.cpp @@ -252,10 +252,10 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, megdnn_assert(dst[1] % param().group == 0); break; case Param::Mode::NCHW_NCHW64: - megdnn_assert(src.ndim == 4 && (src[1] % 64) == 0); + megdnn_assert(src.ndim == 4); dst.ndim = 5; dst[0] = src[0]; - dst[1] = src[1] / 64; + dst[1] = div_ceil(src[1], 64_z); dst[2] = src[2]; dst[3] = src[3]; dst[4] = 64; @@ -264,7 +264,7 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, megdnn_assert(src.ndim == 5); dst.ndim = 4; dst[0] = src[0]; - dst[1] = src[1] * 64; + dst[1] = param().oc == 0 ? src[1] * 64 : param().oc; dst[2] = src[2]; dst[3] = src[3]; break; @@ -483,12 +483,11 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, case Param::Mode::NCHW4_NCHW: // nchw to nchw4 { + megdnn_assert(src.format == dst.format); exec_workspace = TensorLayout({src[0], src[1] * 4, src[2], src[3]}, - src.dtype, src.format) - .reshape({src[0], src[1], 4, src[2], src[3]}) - .dimshuffle({0, 1, 3, 4, 2}); - exec_src = src; + dst.dtype, dst.format); + exec_src = src.dimshuffle({0, 1, 4, 2, 3}); exec_dst = dst; } break; @@ -658,13 +657,20 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, case Param::Mode::NCHW_NCHW64: // src is {N, C, H, W} // dst is {N, C/64, H, W, 64} - exec_src = src.reshape({src[0], src[1] / 64, 64, src[2], src[3]}) + exec_workspace = TensorLayout( + {src[0], round_up(src[1], 64_z), src[2], src[3]}, + src.dtype); + exec_src = exec_workspace + .reshape({src[0], div_ceil(src[1], 64_z), 64, + src[2], src[3]}) .dimshuffle({0, 1, 3, 4, 2}); exec_dst = dst; break; case Param::Mode::NCHW64_NCHW: // src is {N, C/64, H, W, 64} // dst is {N, C, H, W} + exec_workspace = TensorLayout({src[0], src[1] * 64, src[2], src[3]}, + dst.dtype); exec_src = src.dimshuffle({0, 1, 4, 2, 3}); exec_dst = dst; break; diff --git a/dnn/src/cuda/relayout_format/helper.cuh b/dnn/src/cuda/relayout_format/helper.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f77c7cc9d0d2fb64eb298ac9d6c4545754efcd91 --- /dev/null +++ b/dnn/src/cuda/relayout_format/helper.cuh @@ -0,0 +1,149 @@ +/** + * \file dnn/src/cuda/relayout_format/helper.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +namespace megdnn { +namespace cuda { +namespace relayout_format { + +#define devfunc __forceinline__ __device__ +template +devfunc int make_zero(int zero_point); + +template <> +devfunc int make_zero<4>(int zero_point) { + return transform_int8_to_uint4x8(zero_point, zero_point, zero_point, + zero_point, zero_point, zero_point, + zero_point, zero_point); +} + +template +struct global_load_with_zero_point; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Specializations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// The redundant mov PTX instruction is used to enforce the compiler to +// initialize data to zero before ld.global +template +struct global_load_with_zero_point { + devfunc global_load_with_zero_point(AccessType& D, void const* ptr, + bool pred_guard, int zero_point) { + uint4* data = reinterpret_cast(&D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %9, 0;\n" + " mov.b32 %0, %10;\n" + " mov.b32 %1, %10;\n" + " mov.b32 %2, %10;\n" + " mov.b32 %3, %10;\n" + " mov.b32 %4, %10;\n" + " mov.b32 %5, %10;\n" + " mov.b32 %6, %10;\n" + " mov.b32 %7, %10;\n" + " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n" + " @p ld.global.v4.u32 {%4, %5, %6, %7}, [%11];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), + "=r"(data[0].w), "=r"(data[1].x), "=r"(data[1].y), + "=r"(data[1].z), "=r"(data[1].w) + : "l"(ptr), "r"((int)pred_guard), + "r"(reinterpret_cast(zero_point)), + "l"(((uint8_t*)ptr) + 16)); + } +}; + +template +struct global_load_with_zero_point { + devfunc global_load_with_zero_point(AccessType& D, void const* ptr, + bool pred_guard, int zero_point) { + uint4& data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %5, 0;\n" + " mov.b32 %0, %6;\n" + " mov.b32 %1, %6;\n" + " mov.b32 %2, %6;\n" + " mov.b32 %3, %6;\n" + " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) + : "l"(ptr), "r"((int)pred_guard), + "r"(reinterpret_cast(zero_point))); + } +}; + +template +struct global_load_with_zero_point { + devfunc global_load_with_zero_point(AccessType& D, void const* ptr, + bool pred_guard, int zero_point) { + uint2& data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" + " mov.b32 %0, %4;\n" + " mov.b32 %1, %4;\n" + " @p ld.global.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data.x), "=r"(data.y) + : "l"(ptr), "r"((int)pred_guard), + "r"(reinterpret_cast(zero_point))); + } +}; + +template +struct global_load_with_zero_point { + devfunc global_load_with_zero_point(AccessType& D, void const* ptr, + bool pred_guard, int zero_point) { + unsigned& data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " mov.b32 %0, %3;\n" + " @p ld.global.u32 %0, [%1];\n" + "}\n" + : "=r"(data) + : "l"(ptr), "r"((int)pred_guard), + "r"(reinterpret_cast(zero_point))); + } +}; + +template +struct global_load_with_zero_point { + devfunc global_load_with_zero_point(AccessType& D, void const* ptr, + bool pred_guard, int zero_point) { + if (pred_guard) + D = *(reinterpret_cast(ptr)); + else { + unsigned uv = reinterpret_cast(zero_point); + uint8_t& data = reinterpret_cast(D); + data = uv & 0xff; + } + } +}; + +#undef devfunc +} // namespace relayout_format +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/relayout_format/relayout_format.cu b/dnn/src/cuda/relayout_format/relayout_format.cu index 0e74e7610ff64519930b6f05335d5d844f61ad38..c28d5ffcfabf809618fed35f5af85cdf35dc1b9a 100644 --- a/dnn/src/cuda/relayout_format/relayout_format.cu +++ b/dnn/src/cuda/relayout_format/relayout_format.cu @@ -18,6 +18,7 @@ #pragma GCC diagnostic pop #include "src/cuda/query_blocksize.cuh" #include "src/cuda/relayout_format/relayout_format.cuh" +#include "src/cuda/relayout_format/helper.cuh" using namespace megdnn; using namespace cuda; @@ -728,17 +729,18 @@ struct Translayout<64, 2, SrcType, dtype::Quantized4Asymm, #undef pack template -inline __device__ DstType make_zero_pad(const char zero_point) { +inline __device__ DstType make_zero_pad(const uint8_t zero_point) { return zero_point; } template <> -inline __device__ char4 make_zero_pad(const char zero_point) { - return {zero_point, zero_point, zero_point, zero_point}; +inline __device__ char4 make_zero_pad(const uint8_t zero_point) { + char izp = reinterpret_cast(zero_point); + return {izp, izp, izp, izp}; } template <> -inline __device__ int4 make_zero_pad(const char zero_point) { +inline __device__ int4 make_zero_pad(const uint8_t zero_point) { return {zero_point, zero_point, zero_point, zero_point}; } @@ -767,7 +769,7 @@ inline __device__ void write_helper>( : "l"(ptr_), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w), "l"(((uint8_t*)ptr_) + 16), "r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); -}; +} template & post_process, - const char zero_point) { + const uint8_t zero_point) { InnerDtype read_channel[pack_c]; if (all_pad) { const InnerDtype zero_pad = make_zero_pad(zero_point); @@ -855,7 +857,7 @@ __global__ void kern_nchw_nchwx( const SrcType* src, DstType* dst, int in_n, int ic, int ihw, int n_stride_src, int ic_stride, int n_stride_dst, int oc_stride, CudaPostProcess post_process, - const char zero_point, const int group, const int ocpg) { + const uint8_t zero_point, const int group, const int ocpg) { static constexpr int size_src_type = sizeof(SrcType); static constexpr int size_dst_type = sizeof(DstType); #ifndef MEGDNN_COMMA @@ -1072,6 +1074,7 @@ public: MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { pointer += (c_idx / pack_size) * chan_stride_in_elements + hw_idx * pack_size * size_nbits / (8 * sizeof(Type)); + channel -= c_idx; } MEGDNN_DEVICE __forceinline__ void add_pointer_offset( @@ -1079,7 +1082,7 @@ public: pointer += offset_in_type; } - MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) { + MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { AccessType* frag_ptr = reinterpret_cast(&frag); Type* pointer_ = pointer; #pragma unroll @@ -1090,11 +1093,12 @@ public: (lane_size_in_type / pack_size_in_type) + j; bool guard = i < channel; - cutlass::arch::global_load( + relayout_format::global_load_with_zero_point( frag_ptr[frag_idx], reinterpret_cast(pointer_ + j * pack_size_in_type), - guard); + guard, zero_point); } pointer_ += chan_stride_in_elements; } @@ -1173,6 +1177,7 @@ public: MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { pointer += (c_idx / pack_size) * chan_stride_in_elements; + channel -= c_idx; #pragma unroll for (int i = 0; i < mask_size; ++i) { mask[i] = 0; @@ -1201,7 +1206,7 @@ public: pointer += offset_in_type; } - MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) { + MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { AccessType* frag_ptr = reinterpret_cast(&frag); Type* pointer_ = pointer; #pragma unroll @@ -1214,9 +1219,11 @@ public: int mask_index = (frag_idx >> 5); int mask_shift = (frag_idx & 0x1f); bool guard = (mask[mask_index] & (1 << mask_shift)); - cutlass::arch::global_load( + relayout_format::global_load_with_zero_point( frag_ptr[frag_idx], - reinterpret_cast(pointer_ + stride[j]), guard); + reinterpret_cast(pointer_ + stride[j]), guard, + zero_point); } pointer_ += chan_stride_in_elements; } @@ -1306,11 +1313,13 @@ struct RelayoutProblem { int batch_size; int channels; int hw; + int zero_point; MEGDNN_HOST MEGDNN_DEVICE Param(SrcIterator src_iterator_, DstIterator dst_iterator_, CudaPostProcess post_process_, int n_stride_src_, int n_stride_dst_, - int batch_size_, int channels_, int hw_) + int batch_size_, int channels_, int hw_, + int zero_point_) : src_iterator{src_iterator_}, dst_iterator{dst_iterator_}, post_process{post_process_}, @@ -1318,7 +1327,8 @@ struct RelayoutProblem { n_stride_dst{n_stride_dst_}, batch_size{batch_size_}, channels{channels_}, - hw{hw_} {} + hw{hw_}, + zero_point{zero_point_} {} }; }; @@ -1345,7 +1355,9 @@ __global__ void relayout_kern(typename RelayoutProblem_::Param param) { param.dst_iterator.initialize(c_idx, hw_idx); typename SrcIterator::Fragment src_frag; typename DstIterator::Fragment dst_frag; - param.src_iterator.load(src_frag); + int zp = relayout_format::make_zero( + param.zero_point); + param.src_iterator.load(src_frag, zp); RelayoutProblem_::Transpose::trans( reinterpret_cast(dst_frag), src_frag, param.post_process); @@ -1382,7 +1394,8 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( stype.name(), dtype.name()); #undef DEF // no padding - if (src.layout.stride[2] == static_cast(src.layout[3])) { + if (stype.enumv().ev != DTypeEnum::Ev::QuantizedS4 && + stype.enumv().ev != DTypeEnum::Ev::Quantized4Asymm) { const int in_n = src.layout[0]; const int out_n = dst.layout[0]; const int ic = src.layout[1]; @@ -1428,18 +1441,10 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( DISPATCH_RAW(false, 4, 4, _src_type, _dst_type, char, char, 8); \ DISPATCH_RAW(true, 1, 4, _src_type, _dst_type, char, char, 8); \ DISPATCH_RAW(false, 1, 4, _src_type, _dst_type, char, char, 8); -#define DISPATCH_4BITS(_src_type, _dst_type) \ - DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \ - DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \ - DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \ - DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4); DISPATCH_INT(QuantizedS32, QuantizedS32); DISPATCH_BYTE(Uint8, QuantizedS8); DISPATCH_BYTE(Quantized8Asymm, QuantizedS8); DISPATCH_BYTE(QuantizedS8, QuantizedS8); - DISPATCH_4BITS(QuantizedS4, QuantizedS4); - DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); -#undef DISPATCH_4BITS #undef DISPATCH_BYTE #undef DISPATCH_INT #undef DISPATCH_RAW @@ -1450,7 +1455,8 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( } else { megdnn_assert(src_layout.dtype.is_low_bit()); int n = src.layout[0]; - int c = src.layout[1]; + int ic = src.layout[1]; + int oc = dst.layout[1] * 64; int h = src.layout[2]; // align to byte int w = src.layout[3]; @@ -1460,12 +1466,13 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( int ic_stride = src_layout.stride[1]; int n_stride_dst = dst_layout.stride[0]; int oc_stride = dst_layout.stride[1]; - int problem_size = n * (c / pack_oc) * hw; + int problem_size = n * (oc / pack_oc) * hw; bool same_scale = src_scale == dst_scale; -#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \ - _src_c_type, _dst_c_type, _size_nbits) \ - if (same_scale == _same_scale && hw % _pack_w == 0 && \ - stype.enumv().ev == DTypeEnum::Ev::_src_type && \ + bool padding = w % 2 != 0; +#define DISPATCH_RAW(_padding, _same_scale, _pack_w, _pack_oc, _src_type, \ + _dst_type, _src_c_type, _dst_c_type, _size_nbits) \ + if (padding == _padding && same_scale == _same_scale && \ + hw % _pack_w == 0 && stype.enumv().ev == DTypeEnum::Ev::_src_type && \ dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ using InnerDtype_ = typename DTypeRWHelper< \ typename DTypeTrait::ctype, \ @@ -1473,8 +1480,10 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( using SrcIterator_ = \ TensorIteratorOverChannel; \ - using DstIterator_ = MaskedTensorIteratorOverChannel< \ - _dst_c_type, _pack_oc, _pack_oc, _pack_w, _size_nbits>; \ + using DstIterator_ = \ + typename TensorIteratorPolicy<_padding, _dst_c_type, _pack_oc, \ + _pack_oc, _pack_w, \ + _size_nbits>::TensorIterator; \ using CudaPostProcess_ = \ CudaPostProcess; \ @@ -1489,17 +1498,18 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(_dst_c_type)); \ oc_stride = oc_stride * _size_nbits / (8 * sizeof(_dst_c_type)); \ typename RelayoutProblem_::Param param{ \ - SrcIterator_{(InnerDtype_*)src.raw_ptr, ic_stride, c, w, \ + SrcIterator_{(InnerDtype_*)src.raw_ptr, ic_stride, ic, w, \ w_pad}, \ - DstIterator_{(_dst_c_type*)dst.raw_ptr, oc_stride, c, w, \ + DstIterator_{(_dst_c_type*)dst.raw_ptr, oc_stride, oc, w, \ w_pad}, \ CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ dst_zero_point}, \ n_stride_src, \ n_stride_dst, \ n, \ - c, \ - hw}; \ + oc, \ + hw, \ + src_zero_point}; \ auto kernel = relayout_kern; \ int nr_threads = query_blocksize_for_kernel(kernel); \ nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ @@ -1507,11 +1517,15 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( const dim3 thread_dim(nr_threads); \ return kernel<<>>(param); \ } -#define DISPATCH_4BITS(_src_type, _dst_type) \ - DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \ - DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \ - DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \ - DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4); +#define DISPATCH_4BITS(_src_type, _dst_type) \ + DISPATCH_RAW(true, true, 8, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(true, false, 8, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(true, true, 2, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(true, false, 2, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, true, 8, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, false, 8, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, true, 2, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, false, 2, 64, _src_type, _dst_type, char, char, 4); DISPATCH_4BITS(QuantizedS4, QuantizedS4); DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); #undef DISPATCH_4BITS @@ -1521,7 +1535,6 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", stype.name(), dtype.name(), h, w); } - after_kernel_launch(); } bool relayout_format::relayout_format_cuda_usable( @@ -1568,7 +1581,7 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( megdnn_assert(pack_ic == 64, "Unsupport pack size(pack_ic:%d)", pack_ic); #undef DEF int n = src.layout[0]; - int c = src.layout[1] * pack_ic; + int ic = src.layout[1] * pack_ic; int h = src.layout[2]; // align to byte int w = src.layout[3]; @@ -1578,7 +1591,8 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( int ic_stride = src_layout.stride[1]; int n_stride_dst = dst_layout.stride[0]; int oc_stride = dst_layout.stride[1]; - int problem_size = n * (c / pack_ic) * hw; + int problem_size = n * (ic / pack_ic) * hw; + int oc = dst.layout[1]; bool same_scale = src_scale == dst_scale; bool padding = w % 2 != 0; @@ -1611,17 +1625,18 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(InnerDtype_)); \ oc_stride = oc_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \ typename RelayoutProblem_::Param param{ \ - SrcIterator_{(_src_c_type*)src.raw_ptr, ic_stride, c, w, \ + SrcIterator_{(_src_c_type*)src.raw_ptr, ic_stride, ic, w, \ w_pad}, \ - DstIterator_{(InnerDtype_*)dst.raw_ptr, oc_stride, c, w, \ + DstIterator_{(InnerDtype_*)dst.raw_ptr, oc_stride, oc, w, \ w_pad}, \ CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ dst_zero_point}, \ n_stride_src, \ n_stride_dst, \ n, \ - c, \ - hw}; \ + ic, \ + hw, \ + src_zero_point}; \ auto kernel = relayout_kern; \ int nr_threads = query_blocksize_for_kernel(kernel); \ nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ @@ -1645,7 +1660,6 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( megdnn_assert(false, "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", stype.name(), dtype.name(), h, w); - after_kernel_launch(); } void relayout_format::relayout_format_cuda_nchw4_nchw( diff --git a/dnn/src/cuda/utils.cuh b/dnn/src/cuda/utils.cuh index d8ff16ad19b8795203c1548d0bbce6f70cc3a816..87983f808af51bbe5a1335c0732e2c045c8328a6 100644 --- a/dnn/src/cuda/utils.cuh +++ b/dnn/src/cuda/utils.cuh @@ -21,6 +21,7 @@ #include "cuda.h" #include "src/cuda/cudnn_with_check.h" #include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" #define cuda_check(_x) \ do { \ @@ -448,13 +449,12 @@ MEGDNN_DEVICE __forceinline__ static int transform_int8_to_uint4x8( template MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage, int bits) { - uint8_t result = (uint8_t)((storage >> bits) & 0xf); - if (signedness) { - static constexpr uint8_t mask = (uint8_t)((1 << 4) - 1); - return (result & uint8_t(1 << 3)) ? ((int)(result) | ~(int)(mask)) - : (int)(result); - } - return int(result); + static constexpr int shift = 28; + using type = typename cutlass::platform::conditional::type; + unsigned intermediate = static_cast(storage); + type result = reinterpret_cast(intermediate); + return (result << (shift - bits)) >> shift; } MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8( diff --git a/dnn/src/naive/relayout_format/opr_impl.cpp b/dnn/src/naive/relayout_format/opr_impl.cpp index dd78acb6be3ab3549bc9fbc35842bf86b3a9ba45..74e79bf7987a3f8d7b98bbbf060720c18e9228d8 100644 --- a/dnn/src/naive/relayout_format/opr_impl.cpp +++ b/dnn/src/naive/relayout_format/opr_impl.cpp @@ -42,6 +42,36 @@ void recursive_cp(const TensorND& dst, const TensorND& src, size_t idx = 0, } } +template +void lowbit_recursive_cp(const TensorND& dst, const TensorND& src, + size_t idx = 0, size_t src_offset = 0, + size_t dst_offset = 0) { + MEGDNN_STATIC_ASSERT(!(8_z % size_nbits), + "size in bits of lowbit data type can only be 1, 2, 4 " + "or 8"); + if (idx < (src.layout.ndim - 1)) { + for (size_t i = 0; i < src.layout[idx]; ++i) { + lowbit_recursive_cp( + dst, src, idx + 1, src_offset + i * src.layout.stride[idx], + dst_offset + i * dst.layout.stride[idx]); + } + } else { + megdnn_assert(src.layout.stride[idx] == 1); + megdnn_assert(dst.layout.stride[idx] == 1); + size_t dim_bytes = div_ceil(src.layout[idx], 8_z / size_nbits); + // offset in elements + uint8_t* dptr = reinterpret_cast(dst.raw_ptr) + + (dst_offset * size_nbits / 8); + uint8_t* sptr = reinterpret_cast(src.raw_ptr) + + (src_offset * size_nbits / 8); + for (size_t i = 0; i < dim_bytes; ++i) { + *dptr = *sptr; + dptr++; + sptr++; + } + } +} + void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src) { switch (src.layout.dtype.enumv()) { #define cb(name, ctype) \ @@ -54,10 +84,17 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src) { cb(Int32, dt_int32); cb(QuantizedS32, dt_int32); cb(QuantizedS8, dt_qint8); - +#undef cb +#define cb(name, size_nbits) \ + case (DTypeEnum::name): { \ + lowbit_recursive_cp(dst, src); \ + break; \ + } + cb(QuantizedS4, 4); + cb(Quantized4Asymm, 4); +#undef cb default: megdnn_assert(0, "not support dtype %s", src.layout.dtype.name()); -#undef cb } } @@ -66,24 +103,27 @@ void extract_from_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src, megdnn_assert(dst.layout.is_contiguous() && src.layout.is_contiguous(), "dst %s, src %s", dst.layout.to_string().c_str(), src.layout.to_string().c_str()); - const size_t type_size = dst.layout.dtype.size(); const size_t n = dst.layout[0]; - const size_t n_stride_dst = dst.layout.stride[0]; - const size_t n_stride_src = src.layout.stride[0]; + const size_t n_stride_dst_in_bytes = + dst.layout.dtype.size(dst.layout.stride[0]); + const size_t n_stride_src_in_bytes = + src.layout.dtype.size(src.layout.stride[0]); const size_t ocpg = dst.layout[1] / group; const size_t icpg = src.layout[1] / group; - const size_t dst_hw = dst.layout[2] * dst.layout[3]; - const size_t src_hw = src.layout[2] * src.layout[3]; - megdnn_assert(dst_hw == src_hw); + const size_t dst_c_stride_in_bytes = + dst.layout.dtype.size(dst.layout.stride[1]); + const size_t src_c_stride_in_bytes = + src.layout.dtype.size(src.layout.stride[1]); + megdnn_assert(dst_c_stride_in_bytes == src_c_stride_in_bytes); for (size_t nid = 0; nid < n; ++nid) { - const size_t n_offset_dst = nid * n_stride_dst * type_size; - const size_t n_offset_src = nid * n_stride_src * type_size; + const size_t n_offset_dst = nid * n_stride_dst_in_bytes; + const size_t n_offset_src = nid * n_stride_src_in_bytes; for (size_t gid = 0; gid < group; ++gid) { memcpy((char*)dst.raw_ptr + n_offset_dst + - gid * ocpg * dst_hw * type_size, + gid * ocpg * dst_c_stride_in_bytes, (char*)src.raw_ptr + n_offset_src + - gid * icpg * src_hw * type_size, - ocpg * dst_hw * type_size); + gid * icpg * src_c_stride_in_bytes, + ocpg * dst_c_stride_in_bytes); } } }; @@ -415,6 +455,30 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, return oc * ic * h * w * src.dtype.size(); } + case Param::Mode::NCHW_NCHW64: { + if (src[1] % 64 != 0) { + size_t n = src[0]; + size_t c = round_up(src[1], 64_z); + size_t h = src[2]; + size_t w = src[3]; + TensorLayout wsly({n, c, h, w}, src.dtype); + return wsly.span().dist_byte(); + } + return 0_z; + } + + case Param::Mode::NCHW64_NCHW: { + if (param().oc != 0) { + size_t n = src[0]; + size_t c = src[1] * 64; + size_t h = src[2]; + size_t w = src[3]; + TensorLayout wsly({n, c, h, w}, dst.dtype); + return wsly.span().dist_byte(); + } + return 0_z; + } + default: return 0; } @@ -437,6 +501,7 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, // clean dst MEGDNN_DISPATCH_CPU_KERN( m_handle, memset(dst.raw_ptr, 0, dst.layout.span().dist_byte())); + // pre if (param().mode == Param::Mode::NCHW_NHWCD4I) { size_t N = src.layout[0]; size_t IC = src.layout[1]; @@ -551,6 +616,27 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, cb2(2, 4, NCHW_NCHW4, group_src_layout, workspace_layout); } + } else if (param().mode == Param::Mode::NCHW_NCHW64) { + MIDOUT_BEGIN(megdnn_naive_relayout_format, + midout_iv(Param::Mode::NCHW_NCHW64)) { + size_t c = src.layout[1]; + if (c % 64 != 0) { + uint8_t zp = 0; + if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { + zp = src.layout.dtype.param() + .zero_point; + zp = (zp & 0xf) | (zp << 4); + } + MEGDNN_DISPATCH_CPU_KERN( + m_handle, memset(workspace.raw_ptr, zp, + exec_workspace.span().dist_byte())); + TensorND ws_nd(workspace.raw_ptr, exec_workspace); + MEGDNN_DISPATCH_CPU_KERN(m_handle, + padding_to_workspace(ws_nd, src);); + exec_src_nd.raw_ptr = workspace.raw_ptr; + } + } + MIDOUT_END(); } else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) { cb(1, 4, NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT); @@ -574,24 +660,16 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, cb(1, 2, 4, NCHW_NCHW4_WEIGHT); } } else if (param().mode == Param::Mode::NCHW4_NCHW) { - if (exec_workspace.total_nr_elems() == dst.layout.total_nr_elems()) { - m_handle->relayout_opr()->exec( - exec_src_nd, {dst.raw_ptr, exec_workspace}, handle()); - return; - } else { - m_handle->relayout_opr()->exec( - exec_src_nd, {workspace.raw_ptr, exec_workspace}, handle()); - TensorLayout workspace_layout{{src.layout[0], src.layout[1] * 4, - src.layout[2], src.layout[3]}, - src.layout.dtype, - src.layout.format}; - extract_from_workspace(exec_dst_nd, - {workspace.raw_ptr, workspace_layout}, - param().group); - return; + if (exec_workspace.total_nr_elems() != dst.layout.total_nr_elems()) { + exec_dst_nd = {workspace.raw_ptr, exec_workspace}; + } + } else if (param().mode == Param::Mode::NCHW64_NCHW) { + if (exec_workspace.total_nr_elems() != dst.layout.total_nr_elems()) { + exec_dst_nd = {workspace.raw_ptr, exec_workspace}; } } - + + // do relayout if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; @@ -600,7 +678,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, do_copy_diff_qu8_q8(dst, src); }; MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); - return; } else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 && dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; @@ -609,7 +686,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, do_copy_diff_u8_q8(dst, src); }; MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); - return; } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 && dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; @@ -618,7 +694,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, do_copy_diff_q8_q8(dst, src); }; MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); - return; } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS32 && dst.layout.dtype.enumv() == DTypeEnum::QuantizedS32) { TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; @@ -627,7 +702,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, do_copy_diff_q32_q32(dst, src); }; MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); - return; } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && dst.layout.dtype.enumv() == DTypeEnum::QuantizedS4) { TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; @@ -636,7 +710,6 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, do_copy_diff_q4_q4(dst, src); }; MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); - return; } else if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm && dst.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { TensorND src0 = exec_src_nd, dst0 = exec_dst_nd; @@ -645,9 +718,20 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, do_copy_diff_qu4_qu4(dst, src); }; MEGDNN_DISPATCH_CPU_KERN_OPR(func(dst0, src0)); - return; } else { m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); + } + + // post + if (param().mode == Param::Mode::NCHW4_NCHW || + param().mode == Param::Mode::NCHW64_NCHW) { + if (exec_workspace.total_nr_elems() != dst.layout.total_nr_elems()) { + megdnn_assert(exec_workspace.dtype == dst.layout.dtype); + TensorND ws_nd{workspace.raw_ptr, exec_workspace}; + MEGDNN_DISPATCH_CPU_KERN( + m_handle, + extract_from_workspace(dst, ws_nd, param().group);); + } } #undef cb } diff --git a/dnn/test/cuda/relayout_format.cpp b/dnn/test/cuda/relayout_format.cpp index 1faf409a6bec2cd03edd90c8593372426074a6e4..0da9533b8b713d3099bdc342572b7ae3ef8da267 100644 --- a/dnn/test/cuda/relayout_format.cpp +++ b/dnn/test/cuda/relayout_format.cpp @@ -18,7 +18,6 @@ using namespace megdnn; using namespace test; -#define MEGDNN_WITH_BENCHMARK 1 TEST_F(CUDA, RELAYOUT_FORMAT) { Checker checker(handle_cuda()); @@ -245,7 +244,7 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW64) { param::RelayoutFormat param; param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64; for (size_t n : {1, 3}) { - for (size_t c : {64, 128}) { + for (size_t c : {15, 64, 128}) { for (size_t h : {7, 14, 16, 28}) { for (size_t w : {2, 3, 7, 8, 16, 31}) { checker.set_dtype(0, dtype::QuantizedS4{2.f}) @@ -285,36 +284,41 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW64_NCHW) { param::RelayoutFormat param; param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW; for (size_t n : {1, 3}) { - for (size_t c : {64, 128}) { + for (size_t c : {15, 64, 128}) { for (size_t h : {7, 14, 16, 28}) { for (size_t w : {2, 3, 4, 7, 14, 16, 17}) { + if (c % 64 != 0) { + param.oc = c; + } else { + param.oc = 0; + } checker.set_dtype(0, dtype::QuantizedS4{2.f}) .set_dtype(1, dtype::QuantizedS4{2.f}) .set_rng(0, &s4) .set_param(param) .set_epsilon(1e-3) - .execs({{n, c / 64, h, w, 64}, {}}); + .execs({{n, (c + 63) / 64, h, w, 64}, {}}); checker.set_dtype(0, dtype::Quantized4Asymm{1.2f, 4}) .set_dtype(1, dtype::Quantized4Asymm{1.2f, 8}) .set_rng(0, &u4) .set_param(param) .set_epsilon(1e-3) - .execs({{n, c / 64, h, w, 64}, {}}); + .execs({{n, (c + 63) / 64, h, w, 64}, {}}); checker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) .set_dtype(1, dtype::QuantizedS4{1.f}) .set_rng(0, &s4) .set_param(param) .set_epsilon(1e-3) - .execs({{n, c / 64, h, w, 64}, {}}); + .execs({{n, (c + 63) / 64, h, w, 64}, {}}); checker.set_dtype(0, dtype::Quantized4Asymm{1.20211209f, 8}) .set_dtype(1, dtype::Quantized4Asymm{1.f, 4}) .set_rng(0, &u4) .set_param(param) .set_epsilon(1e-3) - .execs({{n, c / 64, h, w, 64}, {}}); + .execs({{n, (c + 63) / 64, h, w, 64}, {}}); } } } @@ -375,10 +379,14 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { CUBenchmarker benchmarker(handle_cuda()); benchmarker.set_param(param); benchmarker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) - .set_dtype(1, dtype::QuantizedS4{1.20210322f}); + .set_dtype(1, dtype::QuantizedS4{1.19990307f}); for (auto&& shape : shapes) { - double memaccess = double(shape.total_nr_elems()) * 1e-6; + double memaccess = + double(TensorLayout(shape, dtype::QuantizedS4{1.f}) + .span() + .dist_byte()) * + 2e-6; auto time_ms = benchmarker.execs({shape, {}}); printf("execute %s, time %.4f ms, %.4f GB/s\n", shape.to_string().c_str(), time_ms, memaccess / time_ms); @@ -387,8 +395,9 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { { TensorShapeArray shapes = { - {1, 64, 56, 56}, {16, 64, 56, 56}, {64, 64, 56, 56}, - {1, 64, 56, 55}, {16, 64, 56, 55}, {64, 64, 56, 55}, + {1, 64, 56, 56}, {16, 64, 56, 56}, {64, 64, 56, 56}, + {1, 64, 56, 55}, {16, 64, 56, 55}, {64, 64, 56, 55}, + {1, 256, 384, 640}, }; Param param; param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64; @@ -399,7 +408,8 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { {64, 1, 56, 56, 64}, {1, 32, 7, 7, 64}, {16, 32, 7, 7, 64}, - {64, 32, 7, 7, 64}, + {64, 32, 7, 7, 64}, + {1, 4, 384, 640, 64}, }; Param param; param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW;