From 7b855dc64ac8d2a61c2557bd77a8b98a11de17a0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 2 Sep 2021 11:57:20 +0800 Subject: [PATCH] fix(dnn/cuda): fix compilation for windows bazel GitOrigin-RevId: 2023dea19c04dbdd17f559f80cfa2e6b4be27a0e --- .../backward_data/deconv_int8_helper.cu | 12 ++-- dnn/src/cuda/transpose_utils.cuh | 64 +++++++++++-------- 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu b/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu index 2df449f54..5d26536e3 100644 --- a/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu +++ b/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu @@ -30,7 +30,8 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel( const int32_t fhfw = blockIdx.x * BLOCKSIZE_Y + threadIdx.x; if (fhfw < FHFW && icb < IC / 4) { - int src_value[4], dst_value[4]; + array_wrapper src_value; + int dst_value[4]; #pragma unroll for (int i = 0; i < 4; i++) { src_value[i] = *reinterpret_cast( @@ -38,7 +39,8 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel( } // transpose 4x4 - transpose_int8_interleavedx4<4, int>(src_value, dst_value); + auto trans = transpose_int8_interleavedx4<4, int>(); + trans(src_value, dst_value); #pragma unroll for (int i = 0; i < 4; i++) { @@ -60,7 +62,7 @@ __global__ void reorder_filter_nhwc_to_cnxhwx_kernel( const int32_t icb = fhfw_icb % (IC / 4); if (ocb < OC / interleaved && fhfw < FHFW) { - int src_value[interleaved]; + array_wrapper src_value; vec_type dst_value[4]; #pragma unroll @@ -70,8 +72,8 @@ __global__ void reorder_filter_nhwc_to_cnxhwx_kernel( icb * 4); } - transpose_int8_interleavedx4(src_value, - dst_value); + auto trans = transpose_int8_interleavedx4(); + trans(src_value, dst_value); #pragma unroll for (int i = 0; i < 4; i++) { diff --git a/dnn/src/cuda/transpose_utils.cuh b/dnn/src/cuda/transpose_utils.cuh index a0a286f41..686a49797 100644 --- a/dnn/src/cuda/transpose_utils.cuh +++ b/dnn/src/cuda/transpose_utils.cuh @@ -30,37 +30,51 @@ MEGDNN_DEVICE __forceinline__ void transpose_int8_4x4_impl( } template -MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4( - const int src[interleaved], vec_type (&dst)[4]); +struct transpose_int8_interleavedx4; template <> -MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<4, int>( - const int src[4], int (&dst)[4]) { - transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0], dst[1], - dst[2], dst[3]); -} +struct transpose_int8_interleavedx4<4, int> { + static constexpr uint32_t interleaved = 4; + using vec_type = int; + using Fragment = array_wrapper; + MEGDNN_DEVICE __forceinline__ void operator()(const Fragment src, + vec_type (&dst)[4]) { + transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0], dst[1], + dst[2], dst[3]); + } +}; template <> -MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<8, int2>( - const int src[8], int2 (&dst)[4]) { - transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x, - dst[2].x, dst[3].x); - transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y, - dst[2].y, dst[3].y); -} +struct transpose_int8_interleavedx4<8, int2> { + static constexpr uint32_t interleaved = 8; + using vec_type = int2; + using Fragment = array_wrapper; + MEGDNN_DEVICE __forceinline__ void operator()(const Fragment src, + vec_type (&dst)[4]) { + transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, + dst[1].x, dst[2].x, dst[3].x); + transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, + dst[1].y, dst[2].y, dst[3].y); + } +}; template <> -MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<16, int4>( - const int src[16], int4 (&dst)[4]) { - transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x, - dst[2].x, dst[3].x); - transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y, - dst[2].y, dst[3].y); - transpose_int8_4x4_impl(src[8], src[9], src[10], src[11], dst[0].z, - dst[1].z, dst[2].z, dst[3].z); - transpose_int8_4x4_impl(src[12], src[13], src[14], src[15], dst[0].w, - dst[1].w, dst[2].w, dst[3].w); -} +struct transpose_int8_interleavedx4<16, int4> { + static constexpr uint32_t interleaved = 16; + using vec_type = int4; + using Fragment = array_wrapper; + MEGDNN_DEVICE __forceinline__ void operator()(const Fragment src, + vec_type (&dst)[4]) { + transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, + dst[1].x, dst[2].x, dst[3].x); + transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, + dst[1].y, dst[2].y, dst[3].y); + transpose_int8_4x4_impl(src[8], src[9], src[10], src[11], dst[0].z, + dst[1].z, dst[2].z, dst[3].z); + transpose_int8_4x4_impl(src[12], src[13], src[14], src[15], dst[0].w, + dst[1].w, dst[2].w, dst[3].w); + } +}; } // namespace cuda } // namespace megdnn -- GitLab