提交 7b855dc6 编写于 作者: M Megvii Engine Team

fix(dnn/cuda): fix compilation for windows bazel

GitOrigin-RevId: 2023dea19c04dbdd17f559f80cfa2e6b4be27a0e
上级 3abe0b24
...@@ -30,7 +30,8 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel( ...@@ -30,7 +30,8 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel(
const int32_t fhfw = blockIdx.x * BLOCKSIZE_Y + threadIdx.x; const int32_t fhfw = blockIdx.x * BLOCKSIZE_Y + threadIdx.x;
if (fhfw < FHFW && icb < IC / 4) { if (fhfw < FHFW && icb < IC / 4) {
int src_value[4], dst_value[4]; array_wrapper<int, 4> src_value;
int dst_value[4];
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
src_value[i] = *reinterpret_cast<const int*>( src_value[i] = *reinterpret_cast<const int*>(
...@@ -38,7 +39,8 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel( ...@@ -38,7 +39,8 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel(
} }
// transpose 4x4 // transpose 4x4
transpose_int8_interleavedx4<4, int>(src_value, dst_value); auto trans = transpose_int8_interleavedx4<4, int>();
trans(src_value, dst_value);
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
...@@ -60,7 +62,7 @@ __global__ void reorder_filter_nhwc_to_cnxhwx_kernel( ...@@ -60,7 +62,7 @@ __global__ void reorder_filter_nhwc_to_cnxhwx_kernel(
const int32_t icb = fhfw_icb % (IC / 4); const int32_t icb = fhfw_icb % (IC / 4);
if (ocb < OC / interleaved && fhfw < FHFW) { if (ocb < OC / interleaved && fhfw < FHFW) {
int src_value[interleaved]; array_wrapper<int, interleaved> src_value;
vec_type dst_value[4]; vec_type dst_value[4];
#pragma unroll #pragma unroll
...@@ -70,8 +72,8 @@ __global__ void reorder_filter_nhwc_to_cnxhwx_kernel( ...@@ -70,8 +72,8 @@ __global__ void reorder_filter_nhwc_to_cnxhwx_kernel(
icb * 4); icb * 4);
} }
transpose_int8_interleavedx4<interleaved, vec_type>(src_value, auto trans = transpose_int8_interleavedx4<interleaved, vec_type>();
dst_value); trans(src_value, dst_value);
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
......
...@@ -30,37 +30,51 @@ MEGDNN_DEVICE __forceinline__ void transpose_int8_4x4_impl( ...@@ -30,37 +30,51 @@ MEGDNN_DEVICE __forceinline__ void transpose_int8_4x4_impl(
} }
template <uint32_t interleaved, typename vec_type> template <uint32_t interleaved, typename vec_type>
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4( struct transpose_int8_interleavedx4;
const int src[interleaved], vec_type (&dst)[4]);
template <> template <>
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<4, int>( struct transpose_int8_interleavedx4<4, int> {
const int src[4], int (&dst)[4]) { static constexpr uint32_t interleaved = 4;
using vec_type = int;
using Fragment = array_wrapper<int, interleaved>;
MEGDNN_DEVICE __forceinline__ void operator()(const Fragment src,
vec_type (&dst)[4]) {
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0], dst[1], transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0], dst[1],
dst[2], dst[3]); dst[2], dst[3]);
} }
};
template <> template <>
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<8, int2>( struct transpose_int8_interleavedx4<8, int2> {
const int src[8], int2 (&dst)[4]) { static constexpr uint32_t interleaved = 8;
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x, using vec_type = int2;
dst[2].x, dst[3].x); using Fragment = array_wrapper<int, interleaved>;
transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y, MEGDNN_DEVICE __forceinline__ void operator()(const Fragment src,
dst[2].y, dst[3].y); vec_type (&dst)[4]) {
} transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x,
dst[1].x, dst[2].x, dst[3].x);
transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y,
dst[1].y, dst[2].y, dst[3].y);
}
};
template <> template <>
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<16, int4>( struct transpose_int8_interleavedx4<16, int4> {
const int src[16], int4 (&dst)[4]) { static constexpr uint32_t interleaved = 16;
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x, using vec_type = int4;
dst[2].x, dst[3].x); using Fragment = array_wrapper<int, interleaved>;
transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y, MEGDNN_DEVICE __forceinline__ void operator()(const Fragment src,
dst[2].y, dst[3].y); vec_type (&dst)[4]) {
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x,
dst[1].x, dst[2].x, dst[3].x);
transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y,
dst[1].y, dst[2].y, dst[3].y);
transpose_int8_4x4_impl(src[8], src[9], src[10], src[11], dst[0].z, transpose_int8_4x4_impl(src[8], src[9], src[10], src[11], dst[0].z,
dst[1].z, dst[2].z, dst[3].z); dst[1].z, dst[2].z, dst[3].z);
transpose_int8_4x4_impl(src[12], src[13], src[14], src[15], dst[0].w, transpose_int8_4x4_impl(src[12], src[13], src[14], src[15], dst[0].w,
dst[1].w, dst[2].w, dst[3].w); dst[1].w, dst[2].w, dst[3].w);
} }
};
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册