From 8fef78d06d297d2efbdbd1cce515bcde567fa61c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 22 Mar 2021 17:47:55 +0800 Subject: [PATCH] feat(dnn/cuda): add relayout format when width is an odd number GitOrigin-RevId: f059f1f56dd66c33633118c893027ddd50ac8f1d --- dnn/src/common/relayout_format.cpp | 2 +- .../cuda/relayout_format/relayout_format.cu | 549 ++++++++++++++---- dnn/test/common/benchmarker.h | 9 +- dnn/test/common/checker.cpp | 74 +-- dnn/test/common/utils.h | 17 + dnn/test/cuda/relayout_format.cpp | 47 +- 6 files changed, 512 insertions(+), 186 deletions(-) diff --git a/dnn/src/common/relayout_format.cpp b/dnn/src/common/relayout_format.cpp index d14f17ea6..a175e36a3 100644 --- a/dnn/src/common/relayout_format.cpp +++ b/dnn/src/common/relayout_format.cpp @@ -380,7 +380,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { break; } - if (!dst.is_default() && + if (dst.type() == TensorFormat::Type::IMAGE2D_PACK4 && ( handle()->type() != Handle::HandleType::NAIVE)) { #if MEGDNN_ENABLE_MANGLING diff --git a/dnn/src/cuda/relayout_format/relayout_format.cu b/dnn/src/cuda/relayout_format/relayout_format.cu index 05de91cbb..a06787a51 100644 --- a/dnn/src/cuda/relayout_format/relayout_format.cu +++ b/dnn/src/cuda/relayout_format/relayout_format.cu @@ -10,10 +10,10 @@ * implied. */ -#include #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "cutlass/fast_math.h" #include "cutlass/arch/memory.h" #pragma GCC diagnostic pop #include "src/cuda/query_blocksize.cuh" @@ -112,6 +112,8 @@ struct CudaPostProcess { template <> struct CudaPostProcess { + using SrcType = dtype::QuantizedS4; + using DstType = dtype::QuantizedS4; CudaDTypeParamImpl m_dst_type_cvt; CudaDTypeParamImpl m_src_type_cvt; CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) { @@ -126,12 +128,16 @@ struct CudaPostProcess { template <> struct CudaPostProcess { + using SrcType = dtype::QuantizedS4; + using DstType = dtype::QuantizedS4; CudaPostProcess(float, uint8_t, float, uint8_t){}; inline __device__ int8_t operator()(int8_t val) { return val; } }; template <> struct CudaPostProcess { + using SrcType = dtype::Quantized4Asymm; + using DstType = dtype::Quantized4Asymm; CudaDTypeParamImpl m_dst_type_cvt; CudaDTypeParamImpl m_src_type_cvt; CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale, @@ -149,6 +155,8 @@ struct CudaPostProcess { template <> struct CudaPostProcess { + using SrcType = dtype::Quantized4Asymm; + using DstType = dtype::Quantized4Asymm; uint8_t m_src_zero_point = 0; uint8_t m_dst_zero_point = 0; CudaPostProcess(float, uint8_t src_zero_point, float, @@ -328,13 +336,20 @@ struct Translayout<2, 64, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, unpack_int4x2(6) unpack_int4x2(7) // clang-format on - + int frag_idx = i / 8; dst_frag[0 * 8 + frag_idx] = pack_channel(0); dst_frag[1 * 8 + frag_idx] = pack_channel(1); #undef unpack_int4x2 } } + using Fragment = array_wrapper; + static inline __device__ void trans( + Fragment& dst, Fragment& src, + CudaPostProcess& post_process) { + trans(reinterpret_cast(dst), + reinterpret_cast(src), post_process, 0); + } }; template @@ -375,6 +390,13 @@ struct Translayout<8, 64, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, dst_frag[7 * 8 + frag_idx] = pack_channel(7); } } + using Fragment = array_wrapper; + static inline __device__ void trans( + Fragment& dst, Fragment& src, + CudaPostProcess& post_process) { + trans(reinterpret_cast(dst), + reinterpret_cast(src), post_process, 0); + } }; #undef pack_channel @@ -428,6 +450,13 @@ struct Translayout<2, 64, SrcType, dtype::Quantized4Asymm, #undef unpack_int4x2 } } + using Fragment = array_wrapper; + static inline __device__ void trans( + Fragment& dst, Fragment& src, + CudaPostProcess& post_process) { + trans(reinterpret_cast(dst), + reinterpret_cast(src), post_process, 0); + } }; template @@ -468,6 +497,13 @@ struct Translayout<8, 64, SrcType, dtype::Quantized4Asymm, dst_frag[7 * 8 + frag_idx] = pack_channel(7); } } + using Fragment = array_wrapper; + static inline __device__ void trans( + Fragment& dst, Fragment& src, + CudaPostProcess& post_process) { + trans(reinterpret_cast(dst), + reinterpret_cast(src), post_process, 0); + } }; #undef pack_channel @@ -1028,11 +1064,21 @@ public: : pointer{nullptr}, chan_stride_in_elements{0}, channel{0} {} MEGDNN_DEVICE TensorIteratorOverChannel(Type* pointer_, int chan_stride_in_elements_, - int channel_) + int channel_, int, int) : pointer{pointer_}, chan_stride_in_elements{chan_stride_in_elements_}, channel{channel_} {} + MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { + pointer += (c_idx / pack_size) * chan_stride_in_elements + + hw_idx * pack_size * size_nbits / (8 * sizeof(Type)); + } + + MEGDNN_DEVICE __forceinline__ void add_pointer_offset( + size_t offset_in_type) { + pointer += offset_in_type; + } + MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) { AccessType* frag_ptr = reinterpret_cast(&frag); Type* pointer_ = pointer; @@ -1087,64 +1133,224 @@ private: int channel; }; -template -__global__ void kern_nchwx_nchw( - const SrcType* src, DstType* dst, int ic, int ihw, int n_stride_src, - int ic_stride, int n_stride_dst, int oc_stride, - CudaPostProcess post_process, - const char zero_point) { - using InnerDtype = - typename DTypeRWHelper::ctype, - pack_w>::InnerDtype; - using SrcIterator = TensorIteratorOverChannel; - using DstIteraotr = TensorIteratorOverChannel; - using Transpose = Translayout; - static constexpr int size_src_type = sizeof(SrcType); - static constexpr int size_dst_type = sizeof(DstType); - MEGDNN_STATIC_ASSERT(std::is_same::value, - "Currently this kernel only support accessing tensor " - "src and dst in same data type."); - n_stride_src /= size_src_type; - ic_stride /= size_src_type; - n_stride_dst /= size_dst_type; - oc_stride /= size_dst_type; -#undef MEGDNN_COMMA +template +class MaskedTensorIteratorOverChannel { +public: + using Type = Type_; + static constexpr int pack_size = pack_size_; + static constexpr int chan_blk = chan_blk_; + static constexpr int width = width_; + static constexpr int size_nbits = size_nbits_; + static constexpr int elements_in_type = + chan_blk * width * size_nbits / (8 * sizeof(Type)); + static constexpr int lane_size_in_type = + (width * pack_size * size_nbits) / (8 * sizeof(Type)); + static constexpr int pack_size_in_type = + (pack_size * size_nbits) >= (8 * sizeof(Type)) + ? (pack_size * size_nbits / (8 * sizeof(Type))) + : (width * pack_size * size_nbits / (8 * sizeof(Type))); + static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); + static constexpr int accesses = elements_in_type / pack_size_in_type; + static constexpr int mask_size = (accesses + 32 - 1) / 32; + using AccessType = array_wrapper; + using Fragment = array_wrapper; - const int n_idx = blockIdx.y; - const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x; - const int ihw_offset = ihw_block_idx * pack_w; - const int ihw_offset_in_type = - ihw_offset * size_nbits / (8 * size_src_type); - const int oc_stride_inner_dtype = - oc_stride * size_dst_type / sizeof(InnerDtype); - if (ihw_offset < ihw) { - const int ic_block = (ic + pack_c - 1) / pack_c; - const int src_offset_base = - n_idx * n_stride_src + ihw_offset_in_type * pack_c; - const int dst_offset_base = n_idx * n_stride_dst + ihw_offset_in_type; - SrcIterator src_iterator{const_cast(src + src_offset_base), - ic_stride, ic}; - DstIteraotr dst_iterator{ - reinterpret_cast(dst + dst_offset_base), - oc_stride_inner_dtype, ic}; - - for (int ic_blk_idx = 0; ic_blk_idx < ic_block; ++ic_blk_idx) { - typename SrcIterator::Fragment src_frag; - typename DstIteraotr::Fragment dst_frag; - src_iterator.load(src_frag); - Transpose::trans( - reinterpret_cast(dst_frag), - src_frag, post_process); - dst_iterator.store(dst_frag); - src_iterator.advance(); - dst_iterator.advance(); + MEGDNN_HOST MEGDNN_DEVICE MaskedTensorIteratorOverChannel() + : pointer{nullptr}, + chan_stride_in_elements{0}, + channel{0} {} + MEGDNN_HOST MEGDNN_DEVICE MaskedTensorIteratorOverChannel( + Type* pointer_, int chan_stride_in_elements_, int channel_, + int bound_, int div_) + : pointer{pointer_}, + chan_stride_in_elements{chan_stride_in_elements_}, + channel{channel_}, + bound{bound_}, + div{div_} { + cutlass::find_divisor(mul, shr, div); + } + + MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { + pointer += (c_idx / pack_size) * chan_stride_in_elements; +#pragma unroll + for (int i = 0; i < mask_size; ++i) { + mask[i] = 0; + } +#pragma unroll + for (int i = 0; i < chan_blk; i += pack_size) { +#pragma unroll + for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { + int offset = hw_idx + j; + int h, w; + cutlass::fast_divmod(h, w, offset, div, mul, shr); + bool guard = (i < channel) && (w < bound); + int index = (i / pack_size) * + (lane_size_in_type / pack_size_in_type) + + j; + int mask_index = (index >> 5); + int mask_shift = (index & 0x1f); + mask[mask_index] |= (guard << mask_shift); + stride[j] = (h * bound + w) * pack_size * size_nbits / + (8 * sizeof(Type)); + } + } + } + + MEGDNN_DEVICE __forceinline__ void add_pointer_offset(size_t offset_in_type) { + pointer += offset_in_type; + } + + MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) { + AccessType* frag_ptr = reinterpret_cast(&frag); + Type* pointer_ = pointer; +#pragma unroll + for (int i = 0; i < chan_blk; i += pack_size) { +#pragma unroll + for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { + int frag_idx = i / pack_size * + (lane_size_in_type / pack_size_in_type) + + j; + int mask_index = (frag_idx >> 5); + int mask_shift = (frag_idx & 0x1f); + bool guard = (mask[mask_index] & (1 << mask_shift)); + cutlass::arch::global_load( + frag_ptr[frag_idx], + reinterpret_cast(pointer_ + stride[j]), guard); + } + pointer_ += chan_stride_in_elements; + } + } + + MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { + const AccessType* frag_ptr = reinterpret_cast(&frag); + Type* pointer_ = pointer; +#pragma unroll + for (int i = 0; i < chan_blk; i += pack_size) { +#pragma unroll + for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { + int frag_idx = i / pack_size * + (lane_size_in_type / pack_size_in_type) + + j; + int mask_index = (frag_idx >> 5); + int mask_shift = (frag_idx & 0x1f); + bool guard = (mask[mask_index] & (1 << mask_shift)); + cutlass::arch::global_store( + frag_ptr[frag_idx], + reinterpret_cast(pointer_ + stride[j]), guard); + } + pointer_ += chan_stride_in_elements; } } + + MEGDNN_DEVICE __forceinline__ void advance() { + pointer += (chan_blk / pack_size) * chan_stride_in_elements; + channel -= chan_blk; + } + +private: + Type* pointer; + int chan_stride_in_elements; + int channel; + int bound; + int div; + uint32_t mul; + uint32_t shr; + uint32_t mask[mask_size]; + size_t stride[accesses]; +}; + +template +struct TensorIteratorPolicy; +template +struct TensorIteratorPolicy { + using TensorIterator = + MaskedTensorIteratorOverChannel; +}; +template +struct TensorIteratorPolicy { + using TensorIterator = + TensorIteratorOverChannel; +}; + +template +struct RelayoutProblem { + using SrcIterator = SrcIterator_; + using DstIterator = DstIterator_; + using Transpose = Transpose_; + using CudaPostProcess = CudaPostProcess_; + MEGDNN_STATIC_ASSERT(SrcIterator::chan_blk == DstIterator::chan_blk, + "channel block mismatch"); + MEGDNN_STATIC_ASSERT(SrcIterator::width == DstIterator::width, + "width block mismatch"); + MEGDNN_STATIC_ASSERT(SrcIterator::size_nbits == DstIterator::size_nbits, + "size in bits of elements mismatch"); + static constexpr int pack_chan = SrcIterator::chan_blk; + static constexpr int pack_width = SrcIterator::width; + using DnnSrcType = typename CudaPostProcess::SrcType; + using DnnDstType = typename CudaPostProcess::DstType; + struct Param { + SrcIterator src_iterator; + DstIterator dst_iterator; + CudaPostProcess post_process; + int n_stride_src; + int n_stride_dst; + int batch_size; + int channels; + int hw; + MEGDNN_HOST MEGDNN_DEVICE Param(SrcIterator src_iterator_, + DstIterator dst_iterator_, + CudaPostProcess post_process_, + int n_stride_src_, int n_stride_dst_, + int batch_size_, int channels_, int hw_) + : src_iterator{src_iterator_}, + dst_iterator{dst_iterator_}, + post_process{post_process_}, + n_stride_src{n_stride_src_}, + n_stride_dst{n_stride_dst_}, + batch_size{batch_size_}, + channels{channels_}, + hw{hw_} {} + }; +}; + +template +__global__ void relayout_kern(typename RelayoutProblem_::Param param) { + using SrcIterator = typename RelayoutProblem_::SrcIterator; + using DstIterator = typename RelayoutProblem_::DstIterator; + static constexpr int pack_chan = RelayoutProblem_::pack_chan; + static constexpr int pack_width = RelayoutProblem_::pack_width; + const int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + const int thread_offset = thread_idx * pack_width; + const int hw_idx = (thread_offset % param.hw); + const int nc_blks = thread_offset / param.hw; + const int c_blks = (param.channels + pack_chan - 1) / pack_chan; + const int n_idx = nc_blks / c_blks; + const int c_blk_idx = nc_blks % c_blks; + const int c_idx = c_blk_idx * pack_chan; + if (n_idx < param.batch_size) { + const int src_offset = n_idx * param.n_stride_src; + const int dst_offset = n_idx * param.n_stride_dst; + param.src_iterator.add_pointer_offset(src_offset); + param.dst_iterator.add_pointer_offset(dst_offset); + param.src_iterator.initialize(c_idx, hw_idx); + param.dst_iterator.initialize(c_idx, hw_idx); + typename SrcIterator::Fragment src_frag; + typename DstIterator::Fragment dst_frag; + param.src_iterator.load(src_frag); + RelayoutProblem_::Transpose::trans( + reinterpret_cast(dst_frag), + src_frag, param.post_process); + param.dst_iterator.store(dst_frag); + } } } // namespace @@ -1175,21 +1381,23 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( "Unsupport pack size(pack_oc:%d, src:%s, dst:%s)", pack_oc, stype.name(), dtype.name()); #undef DEF - const int in_n = src.layout[0]; - const int out_n = dst.layout[0]; - const int ic = src.layout[1]; - const int h = src.layout[2]; - const int w = src.layout[3]; - const int oc = dst.layout[1] * pack_oc; - const int hw = h * w; - const int ocpg = oc / group; - // stride in byte - const int n_stride_src = src_layout.dtype.size(src_layout.stride[0]); - const int ic_stride = src_layout.dtype.size(src_layout.stride[1]); - const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]); - const int oc_stride = dst_layout.dtype.size(dst_layout.stride[1]); + // no padding + if (src.layout.stride[2] == static_cast(src.layout[3])) { + const int in_n = src.layout[0]; + const int out_n = dst.layout[0]; + const int ic = src.layout[1]; + const int h = src.layout[2]; + const int w = src.layout[3]; + const int oc = dst.layout[1] * pack_oc; + const int hw = h * w; + const int ocpg = oc / group; + // stride in byte + const int n_stride_src = src_layout.dtype.size(src_layout.stride[0]); + const int ic_stride = src_layout.dtype.size(src_layout.stride[1]); + const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]); + const int oc_stride = dst_layout.dtype.size(dst_layout.stride[1]); - bool same_scale = src_scale == dst_scale; + bool same_scale = src_scale == dst_scale; #define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \ _src_c_type, _dst_c_type, _size_nbits) \ if (same_scale == _same_scale && hw % _pack_w == 0 && \ @@ -1225,19 +1433,95 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \ DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \ DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4); - DISPATCH_INT(QuantizedS32, QuantizedS32); - DISPATCH_BYTE(Uint8, QuantizedS8); - DISPATCH_BYTE(Quantized8Asymm, QuantizedS8); - DISPATCH_BYTE(QuantizedS8, QuantizedS8); - DISPATCH_4BITS(QuantizedS4, QuantizedS4); - DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); + DISPATCH_INT(QuantizedS32, QuantizedS32); + DISPATCH_BYTE(Uint8, QuantizedS8); + DISPATCH_BYTE(Quantized8Asymm, QuantizedS8); + DISPATCH_BYTE(QuantizedS8, QuantizedS8); + DISPATCH_4BITS(QuantizedS4, QuantizedS4); + DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); #undef DISPATCH_4BITS #undef DISPATCH_BYTE #undef DISPATCH_INT #undef DISPATCH_RAW - megdnn_assert(false, - "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", - stype.name(), dtype.name(), h, w); + megdnn_assert( + false, + "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", + stype.name(), dtype.name(), h, w); + } else { + megdnn_assert(src_layout.dtype.is_low_bit()); + int n = src.layout[0]; + int c = src.layout[1]; + int h = src.layout[2]; + // align to byte + int w = src.layout[3]; + int w_pad = DIVUP(w, 2) * 2; + int hw = h * w_pad; + int n_stride_src = src_layout.stride[0]; + int ic_stride = src_layout.stride[1]; + int n_stride_dst = dst_layout.stride[0]; + int oc_stride = dst_layout.stride[1]; + int problem_size = n * (c / pack_oc) * hw; + bool same_scale = src_scale == dst_scale; +#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \ + _src_c_type, _dst_c_type, _size_nbits) \ + if (same_scale == _same_scale && hw % _pack_w == 0 && \ + stype.enumv().ev == DTypeEnum::Ev::_src_type && \ + dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ + using InnerDtype_ = typename DTypeRWHelper< \ + typename DTypeTrait::ctype, \ + _pack_w>::InnerDtype; \ + using SrcIterator_ = \ + TensorIteratorOverChannel; \ + using DstIterator_ = MaskedTensorIteratorOverChannel< \ + _dst_c_type, _pack_oc, _pack_oc, _pack_w, _size_nbits>; \ + using CudaPostProcess_ = \ + CudaPostProcess; \ + using Transpose_ = \ + Translayout<_pack_w, _pack_oc, _src_c_type, dtype::_src_type, \ + dtype::_dst_type, _same_scale>; \ + using RelayoutProblem_ = \ + RelayoutProblem; \ + n_stride_src = n_stride_src * _size_nbits / (8 * sizeof(InnerDtype_)); \ + ic_stride = ic_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \ + n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(_dst_c_type)); \ + oc_stride = oc_stride * _size_nbits / (8 * sizeof(_dst_c_type)); \ + typename RelayoutProblem_::Param param{ \ + SrcIterator_{(InnerDtype_*)src.raw_ptr, ic_stride, c, w, \ + w_pad}, \ + DstIterator_{(_dst_c_type*)dst.raw_ptr, oc_stride, c, w, \ + w_pad}, \ + CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ + dst_zero_point}, \ + n_stride_src, \ + n_stride_dst, \ + n, \ + c, \ + hw}; \ + auto kernel = relayout_kern; \ + int nr_threads = query_blocksize_for_kernel(kernel); \ + nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ + const dim3 block_dim(DIVUP(problem_size, nr_threads* _pack_w)); \ + const dim3 thread_dim(nr_threads); \ + return kernel<<>>(param); \ + } +#define DISPATCH_4BITS(_src_type, _dst_type) \ + DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4); + DISPATCH_4BITS(QuantizedS4, QuantizedS4); + DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); +#undef DISPATCH_4BITS +#undef DISPATCH_RAW + megdnn_assert( + false, + "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", + stype.name(), dtype.name(), h, w); + } + after_kernel_launch(); } bool relayout_format::relayout_format_cuda_usable( @@ -1283,43 +1567,77 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( // clang-format on megdnn_assert(pack_ic == 64, "Unsupport pack size(pack_ic:%d)", pack_ic); #undef DEF - const int n = src.layout[0]; - const int c = src.layout[1] * pack_ic; - const int h = src.layout[2]; + int n = src.layout[0]; + int c = src.layout[1] * pack_ic; + int h = src.layout[2]; // align to byte - const int w = src.layout[3]; - const int hw = h * w; - const int n_stride_src = src_layout.dtype.size(src_layout.stride[0]); - const int ic_stride = src_layout.dtype.size(src_layout.stride[1]); - const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]); - const int oc_stride = dst_layout.dtype.size(dst_layout.stride[1]); + int w = src.layout[3]; + int w_pad = DIVUP(w, 2) * 2; + int hw = h * w_pad; + int n_stride_src = src_layout.stride[0]; + int ic_stride = src_layout.stride[1]; + int n_stride_dst = dst_layout.stride[0]; + int oc_stride = dst_layout.stride[1]; + int problem_size = n * (c / pack_ic) * hw; bool same_scale = src_scale == dst_scale; -#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \ - _src_c_type, _dst_c_type, _size_nbits) \ - if (same_scale == _same_scale && hw % _pack_w == 0 && \ - stype.enumv().ev == DTypeEnum::Ev::_src_type && \ - dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ - auto kernel = \ - kern_nchwx_nchw<_pack_w, _pack_oc, _same_scale, _src_c_type, \ - _dst_c_type, dtype::_src_type, \ - dtype::_dst_type, _size_nbits>; \ - int nr_threads = query_blocksize_for_kernel(kernel); \ - const dim3 block_dim(DIVUP(hw, nr_threads* _pack_w), n); \ - const dim3 thread_dim(nr_threads); \ - return kernel<<>>( \ - (_src_c_type*)src.raw_ptr, (_dst_c_type*)dst.raw_ptr, c, hw, \ - n_stride_src, ic_stride, n_stride_dst, oc_stride, \ - CudaPostProcess(src_scale, src_zero_point, \ - dst_scale, dst_zero_point), \ - src_zero_point); \ + bool padding = w % 2 != 0; +#define DISPATCH_RAW(_padding, _same_scale, _pack_w, _pack_oc, _src_type, \ + _dst_type, _src_c_type, _dst_c_type, _size_nbits) \ + if (padding == _padding && same_scale == _same_scale && \ + hw % _pack_w == 0 && stype.enumv().ev == DTypeEnum::Ev::_src_type && \ + dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ + using SrcIterator_ = \ + typename TensorIteratorPolicy<_padding, _src_c_type, _pack_oc, \ + _pack_oc, _pack_w, \ + _size_nbits>::TensorIterator; \ + using InnerDtype_ = typename DTypeRWHelper< \ + typename DTypeTrait::ctype, \ + _pack_w>::InnerDtype; \ + using DstIterator_ = \ + TensorIteratorOverChannel; \ + using CudaPostProcess_ = \ + CudaPostProcess; \ + using Transpose_ = \ + Translayout<_pack_oc, _pack_w, _src_c_type, dtype::_src_type, \ + dtype::_dst_type, _same_scale>; \ + using RelayoutProblem_ = \ + RelayoutProblem; \ + n_stride_src = n_stride_src * _size_nbits / (8 * sizeof(_src_c_type)); \ + ic_stride = ic_stride * _size_nbits / (8 * sizeof(_src_c_type)); \ + n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(InnerDtype_)); \ + oc_stride = oc_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \ + typename RelayoutProblem_::Param param{ \ + SrcIterator_{(_src_c_type*)src.raw_ptr, ic_stride, c, w, \ + w_pad}, \ + DstIterator_{(InnerDtype_*)dst.raw_ptr, oc_stride, c, w, \ + w_pad}, \ + CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ + dst_zero_point}, \ + n_stride_src, \ + n_stride_dst, \ + n, \ + c, \ + hw}; \ + auto kernel = relayout_kern; \ + int nr_threads = query_blocksize_for_kernel(kernel); \ + nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ + const dim3 block_dim(DIVUP(problem_size, nr_threads* _pack_w)); \ + const dim3 thread_dim(nr_threads); \ + return kernel<<>>(param); \ } -#define DISPATCH_4BITS(_src_type, _dst_type) \ - DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \ - DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \ - DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \ - DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4); +#define DISPATCH_4BITS(_src_type, _dst_type) \ + DISPATCH_RAW(true, true, 8, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(true, false, 8, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(true, true, 2, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(true, false, 2, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, true, 8, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, false, 8, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, true, 2, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, false, 2, 64, _src_type, _dst_type, char, char, 4); DISPATCH_4BITS(QuantizedS4, QuantizedS4); DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); #undef DISPATCH_4BITS @@ -1327,6 +1645,7 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( megdnn_assert(false, "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", stype.name(), dtype.name(), h, w); + after_kernel_launch(); } void relayout_format::relayout_format_cuda_nchw4_nchw( @@ -1344,6 +1663,7 @@ void relayout_format::relayout_format_cuda_nchw4_nchw( const dim3 thread_dim(nr_threads); kern_nchw4_nchw<<>>( (int8_t*)src.raw_ptr, (int8_t*)dst.raw_ptr, n, ic, oc, h, w, group); + after_kernel_launch(); } void relayout_format::relayout_format_cuda_nchw_nchw4_weight( @@ -1372,4 +1692,5 @@ void relayout_format::relayout_format_cuda_nchw_nchw4_weight( (char*)src.raw_ptr, (char*)dst.raw_ptr, oc, ic, hw, oc_stride_src, ic_stride, oc_stride_dst, group_stride_src, group_stride_dst, 0, {}); + after_kernel_launch(); } diff --git a/dnn/test/common/benchmarker.h b/dnn/test/common/benchmarker.h index ca0c00dd1..3a706526c 100644 --- a/dnn/test/common/benchmarker.h +++ b/dnn/test/common/benchmarker.h @@ -87,10 +87,11 @@ public: for (size_t i = 0; i < shapes.size(); ++i) { DType dt = (m_dtype.find(i) != m_dtype.end() ? m_dtype[i] : dtype::Float32()); - TensorFormat fmt = (m_fmt.find(i) != m_fmt.end() - ? m_fmt[i] - : DefaultTensorFormat::make()); - layouts[i] = TensorLayout(shapes[i], dt, fmt); + if (m_fmt.find(i) == m_fmt.end()) { + layouts[i] = TensorLayout(shapes[i], dt); + layouts[i].init_contiguous_stride(); + } else + layouts[i] = TensorLayout(shapes[i], dt, m_fmt[i]); } return layouts; } diff --git a/dnn/test/common/checker.cpp b/dnn/test/common/checker.cpp index 3e2fab9ff..614d16da4 100644 --- a/dnn/test/common/checker.cpp +++ b/dnn/test/common/checker.cpp @@ -19,7 +19,6 @@ using namespace megdnn; using namespace test; namespace { - template ::testing::AssertionResult assert_tensor_eq_with_iter( const char *expr0, const char *expr1, @@ -30,7 +29,7 @@ namespace { double error_sum = 0; double error_sum_biased = 0; for (size_t i = 0; i < nr_elem; ++ i) { - ctype iv0 = *it0, iv1 = *it1; + ctype iv0 = ctype(*it0), iv1 = ctype(*it1); float err = diff(iv0, iv1); error_sum += std::abs(err); error_sum_biased += err; @@ -84,12 +83,14 @@ namespace { const char *expr0, const char *expr1, const TensorND &v0, const TensorND &v1, float maxerr, float maxerr_avg, float maxerr_avg_biased) { - - if (v0.layout.is_physical_contiguous() && - v1.layout.is_physical_contiguous()) { - return assert_tensor_eq_with_iter( - expr0, expr1, v0.ptr(), v1.ptr(), v0.layout, - maxerr, maxerr_avg, maxerr_avg_biased); + if (!std::is_same::value && + !std::is_same::value) { + if (v0.layout.is_physical_contiguous() && + v1.layout.is_physical_contiguous()) { + return assert_tensor_eq_with_iter( + expr0, expr1, v0.ptr(), v1.ptr(), + v0.layout, maxerr, maxerr_avg, maxerr_avg_biased); + } } auto it0 = megdnn::tensor_iter_valonly(v0).begin(), @@ -100,56 +101,6 @@ namespace { maxerr_avg_biased); } - template - ::testing::AssertionResult assert_tensor_eq_with_lowbit4( - const char* expr0, const char* expr1, - const TensorND& v0, const TensorND& v1, - float maxerr, float maxerr_avg) { - if (!v0.layout.eq_layout(v1.layout)) { - return ::testing::AssertionFailure() - << "Layout mismatch for testing equality of lowbit4\n" - << "Value of: " << expr1 << "\n" - << " Actual: " << v1.layout.TensorShape::to_string() << "\n" - << "Expected: " << expr0 << "\n" - << "Which is: " << v0.layout.TensorShape::to_string() << "\n"; - } - auto v0_ptr = static_cast(v0.raw_ptr) - v0.layout.span().low_byte; - auto v1_ptr = static_cast(v1.raw_ptr) - v1.layout.span().low_byte; - double error_sum = 0; - for (size_t i = 0; i < v0.layout.span().dist_elem(); ++i) { - ITYPE iv0 = (v0_ptr[i / 2] << (i ^ 1) * 4); - iv0 = iv0 >> 4; - ITYPE iv1 = (v1_ptr[i / 2] << (i ^ 1) * 4); - iv1 = iv1 >> 4; - - float err = std::abs(diff(iv0, iv1)); - error_sum += err; - if (!good_float(iv0) || !good_float(iv1) || err >= maxerr) { - Index index(v0.layout, i); - return ::testing::AssertionFailure() - << "Unequal value\n" - << "Value of: " << expr1 << "\n" - << " Actual: " << (iv1+0) << "\n" - << "Expected: " << expr0 << "\n" - << "Which is: " << (iv0+0) << "\n" - << "At index: " << - index.to_string() << "/" << v0.layout.TensorShape::to_string() << "\n" - << " Dtype: " << v0.layout.dtype.name() << "\n" - << " error: " << err << "/" << maxerr; - } - } - float error_avg = error_sum / v0.layout.total_nr_elems(); - if (error_avg > maxerr_avg) { - return ::testing::AssertionFailure() - << "Average error too high\n" - << "Value of: " << expr1 << "\n" - << "Expected: " << expr0 << "\n" - << "Average error: " << error_avg << "/" << maxerr_avg; - } - - return ::testing::AssertionSuccess(); - } - template void memcpy_noncontig( void *dst, const void *src, const TensorLayout &layout, @@ -215,12 +166,7 @@ namespace { //! In order to avoid an unnecessary increase in binary size, we just //! use QuantizedS16 dtype in winograd_filter_preprocess now. cb(::megdnn::dtype::QuantizedS16) - case DTypeTrait::enumv: - return assert_tensor_eq_with_lowbit4(expr0, expr1, v0, v1, - maxerr, maxerr_avg); - case DTypeTrait::enumv: - return assert_tensor_eq_with_lowbit4(expr0, expr1, v0, v1, - maxerr, maxerr_avg); + MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) #undef cb default: megdnn_trap(); diff --git a/dnn/test/common/utils.h b/dnn/test/common/utils.h index 5cebb0017..9974022fa 100644 --- a/dnn/test/common/utils.h +++ b/dnn/test/common/utils.h @@ -228,6 +228,14 @@ static inline int diff(dt_qint8 x, dt_qint8 y) { return x.as_int8() - y.as_int8(); } +static inline int diff(dt_qint4 x, dt_qint4 y) { + return x.as_int8() - y.as_int8(); +} + +static inline int diff(dt_quint4 x, dt_quint4 y) { + return x.as_uint8() - y.as_uint8(); +} + inline TensorShape cvt_src_or_dst_nchw2nhwc(const TensorShape& shape) { megdnn_assert(shape.ndim == 4); auto N = shape[0], C = shape[1], H = shape[2], W = shape[3]; @@ -356,6 +364,15 @@ static inline int operator+(dt_qint16 lhs, int rhs) { return lhs.as_int16(); } +static inline int operator+(dt_quint4 lhs, int rhs) { + megdnn_assert(rhs == 0, "unexpected rhs"); + return lhs.as_uint8(); +} + +static inline int operator+(dt_qint4 lhs, int rhs) { + megdnn_assert(rhs == 0, "unexpected rhs"); + return lhs.as_int8(); +} } // namespace test static inline bool operator==(const TensorLayout& a, const TensorLayout& b) { diff --git a/dnn/test/cuda/relayout_format.cpp b/dnn/test/cuda/relayout_format.cpp index a7d1dd71e..1faf409a6 100644 --- a/dnn/test/cuda/relayout_format.cpp +++ b/dnn/test/cuda/relayout_format.cpp @@ -11,13 +11,14 @@ */ #include "megdnn/dtype.h" #include "megdnn/oprs.h" -#include "test/common/benchmarker.h" +#include "test/cuda/benchmark.h" #include "test/common/checker.h" #include "test/common/rng.h" #include "test/cuda/fixture.h" using namespace megdnn; using namespace test; +#define MEGDNN_WITH_BENCHMARK 1 TEST_F(CUDA, RELAYOUT_FORMAT) { Checker checker(handle_cuda()); @@ -246,7 +247,7 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NCHW64) { for (size_t n : {1, 3}) { for (size_t c : {64, 128}) { for (size_t h : {7, 14, 16, 28}) { - for (size_t w : {2, 4, 14, 16}) { + for (size_t w : {2, 3, 7, 8, 16, 31}) { checker.set_dtype(0, dtype::QuantizedS4{2.f}) .set_dtype(1, dtype::QuantizedS4{2.f}) .set_rng(0, &s4) @@ -286,7 +287,7 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW64_NCHW) { for (size_t n : {1, 3}) { for (size_t c : {64, 128}) { for (size_t h : {7, 14, 16, 28}) { - for (size_t w : {2, 4, 14, 16}) { + for (size_t w : {2, 3, 4, 7, 14, 16, 17}) { checker.set_dtype(0, dtype::QuantizedS4{2.f}) .set_dtype(1, dtype::QuantizedS4{2.f}) .set_rng(0, &s4) @@ -366,6 +367,46 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT) { run(shapes, param, default_param); } } + +TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { + using Param = RelayoutFormat::Param; + + auto run = [&](const TensorShapeArray& shapes, Param param) { + CUBenchmarker benchmarker(handle_cuda()); + benchmarker.set_param(param); + benchmarker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) + .set_dtype(1, dtype::QuantizedS4{1.20210322f}); + + for (auto&& shape : shapes) { + double memaccess = double(shape.total_nr_elems()) * 1e-6; + auto time_ms = benchmarker.execs({shape, {}}); + printf("execute %s, time %.4f ms, %.4f GB/s\n", + shape.to_string().c_str(), time_ms, memaccess / time_ms); + } + }; + + { + TensorShapeArray shapes = { + {1, 64, 56, 56}, {16, 64, 56, 56}, {64, 64, 56, 56}, + {1, 64, 56, 55}, {16, 64, 56, 55}, {64, 64, 56, 55}, + }; + Param param; + param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64; + run(shapes, param); + } + { + TensorShapeArray shapes = { + {64, 1, 56, 56, 64}, + {1, 32, 7, 7, 64}, + {16, 32, 7, 7, 64}, + {64, 32, 7, 7, 64}, + }; + Param param; + param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW; + run(shapes, param); + } +} + #endif TEST_F(CUDA, RELAYOUT_FORMAT_NCHW4) { -- GitLab