diff --git a/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu b/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu index 2df449f54ea5652e724262b7fb52b5906979d767..5d26536e38615bf46b97b047ad18003585decf81 100644 --- a/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu +++ b/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu @@ -30,7 +30,8 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel( const int32_t fhfw = blockIdx.x * BLOCKSIZE_Y + threadIdx.x; if (fhfw < FHFW && icb < IC / 4) { - int src_value[4], dst_value[4]; + array_wrapper src_value; + int dst_value[4]; #pragma unroll for (int i = 0; i < 4; i++) { src_value[i] = *reinterpret_cast( @@ -38,7 +39,8 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel( } // transpose 4x4 - transpose_int8_interleavedx4<4, int>(src_value, dst_value); + auto trans = transpose_int8_interleavedx4<4, int>(); + trans(src_value, dst_value); #pragma unroll for (int i = 0; i < 4; i++) { @@ -60,7 +62,7 @@ __global__ void reorder_filter_nhwc_to_cnxhwx_kernel( const int32_t icb = fhfw_icb % (IC / 4); if (ocb < OC / interleaved && fhfw < FHFW) { - int src_value[interleaved]; + array_wrapper src_value; vec_type dst_value[4]; #pragma unroll @@ -70,8 +72,8 @@ __global__ void reorder_filter_nhwc_to_cnxhwx_kernel( icb * 4); } - transpose_int8_interleavedx4(src_value, - dst_value); + auto trans = transpose_int8_interleavedx4(); + trans(src_value, dst_value); #pragma unroll for (int i = 0; i < 4; i++) { diff --git a/dnn/src/cuda/transpose_utils.cuh b/dnn/src/cuda/transpose_utils.cuh index a0a286f413a717459b8a33fe5cdf4e28d1a4ebfc..686a49797ef6a137b761fe623dc6e962a600c531 100644 --- a/dnn/src/cuda/transpose_utils.cuh +++ b/dnn/src/cuda/transpose_utils.cuh @@ -30,37 +30,51 @@ MEGDNN_DEVICE __forceinline__ void transpose_int8_4x4_impl( } template -MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4( - const int src[interleaved], vec_type (&dst)[4]); +struct transpose_int8_interleavedx4; template <> -MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<4, int>( - const int src[4], int (&dst)[4]) { - transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0], dst[1], - dst[2], dst[3]); -} +struct transpose_int8_interleavedx4<4, int> { + static constexpr uint32_t interleaved = 4; + using vec_type = int; + using Fragment = array_wrapper; + MEGDNN_DEVICE __forceinline__ void operator()(const Fragment src, + vec_type (&dst)[4]) { + transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0], dst[1], + dst[2], dst[3]); + } +}; template <> -MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<8, int2>( - const int src[8], int2 (&dst)[4]) { - transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x, - dst[2].x, dst[3].x); - transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y, - dst[2].y, dst[3].y); -} +struct transpose_int8_interleavedx4<8, int2> { + static constexpr uint32_t interleaved = 8; + using vec_type = int2; + using Fragment = array_wrapper; + MEGDNN_DEVICE __forceinline__ void operator()(const Fragment src, + vec_type (&dst)[4]) { + transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, + dst[1].x, dst[2].x, dst[3].x); + transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, + dst[1].y, dst[2].y, dst[3].y); + } +}; template <> -MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<16, int4>( - const int src[16], int4 (&dst)[4]) { - transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x, - dst[2].x, dst[3].x); - transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y, - dst[2].y, dst[3].y); - transpose_int8_4x4_impl(src[8], src[9], src[10], src[11], dst[0].z, - dst[1].z, dst[2].z, dst[3].z); - transpose_int8_4x4_impl(src[12], src[13], src[14], src[15], dst[0].w, - dst[1].w, dst[2].w, dst[3].w); -} +struct transpose_int8_interleavedx4<16, int4> { + static constexpr uint32_t interleaved = 16; + using vec_type = int4; + using Fragment = array_wrapper; + MEGDNN_DEVICE __forceinline__ void operator()(const Fragment src, + vec_type (&dst)[4]) { + transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, + dst[1].x, dst[2].x, dst[3].x); + transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, + dst[1].y, dst[2].y, dst[3].y); + transpose_int8_4x4_impl(src[8], src[9], src[10], src[11], dst[0].z, + dst[1].z, dst[2].z, dst[3].z); + transpose_int8_4x4_impl(src[12], src[13], src[14], src[15], dst[0].w, + dst[1].w, dst[2].w, dst[3].w); + } +}; } // namespace cuda } // namespace megdnn