提交 7b855dc6 编写于 作者: M Megvii Engine Team

fix(dnn/cuda): fix compilation for windows bazel

GitOrigin-RevId: 2023dea19c04dbdd17f559f80cfa2e6b4be27a0e
上级 3abe0b24
......@@ -30,7 +30,8 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel(
const int32_t fhfw = blockIdx.x * BLOCKSIZE_Y + threadIdx.x;
if (fhfw < FHFW && icb < IC / 4) {
int src_value[4], dst_value[4];
array_wrapper<int, 4> src_value;
int dst_value[4];
#pragma unroll
for (int i = 0; i < 4; i++) {
src_value[i] = *reinterpret_cast<const int*>(
......@@ -38,7 +39,8 @@ __global__ void reorder_filter_nc4hw4_to_n4hwc4_kernel(
}
// transpose 4x4
transpose_int8_interleavedx4<4, int>(src_value, dst_value);
auto trans = transpose_int8_interleavedx4<4, int>();
trans(src_value, dst_value);
#pragma unroll
for (int i = 0; i < 4; i++) {
......@@ -60,7 +62,7 @@ __global__ void reorder_filter_nhwc_to_cnxhwx_kernel(
const int32_t icb = fhfw_icb % (IC / 4);
if (ocb < OC / interleaved && fhfw < FHFW) {
int src_value[interleaved];
array_wrapper<int, interleaved> src_value;
vec_type dst_value[4];
#pragma unroll
......@@ -70,8 +72,8 @@ __global__ void reorder_filter_nhwc_to_cnxhwx_kernel(
icb * 4);
}
transpose_int8_interleavedx4<interleaved, vec_type>(src_value,
dst_value);
auto trans = transpose_int8_interleavedx4<interleaved, vec_type>();
trans(src_value, dst_value);
#pragma unroll
for (int i = 0; i < 4; i++) {
......
......@@ -30,37 +30,51 @@ MEGDNN_DEVICE __forceinline__ void transpose_int8_4x4_impl(
}
template <uint32_t interleaved, typename vec_type>
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4(
const int src[interleaved], vec_type (&dst)[4]);
struct transpose_int8_interleavedx4;
template <>
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<4, int>(
const int src[4], int (&dst)[4]) {
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0], dst[1],
dst[2], dst[3]);
}
struct transpose_int8_interleavedx4<4, int> {
static constexpr uint32_t interleaved = 4;
using vec_type = int;
using Fragment = array_wrapper<int, interleaved>;
MEGDNN_DEVICE __forceinline__ void operator()(const Fragment src,
vec_type (&dst)[4]) {
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0], dst[1],
dst[2], dst[3]);
}
};
template <>
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<8, int2>(
const int src[8], int2 (&dst)[4]) {
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x,
dst[2].x, dst[3].x);
transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y,
dst[2].y, dst[3].y);
}
struct transpose_int8_interleavedx4<8, int2> {
static constexpr uint32_t interleaved = 8;
using vec_type = int2;
using Fragment = array_wrapper<int, interleaved>;
MEGDNN_DEVICE __forceinline__ void operator()(const Fragment src,
vec_type (&dst)[4]) {
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x,
dst[1].x, dst[2].x, dst[3].x);
transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y,
dst[1].y, dst[2].y, dst[3].y);
}
};
template <>
MEGDNN_DEVICE __forceinline__ void transpose_int8_interleavedx4<16, int4>(
const int src[16], int4 (&dst)[4]) {
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x, dst[1].x,
dst[2].x, dst[3].x);
transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y, dst[1].y,
dst[2].y, dst[3].y);
transpose_int8_4x4_impl(src[8], src[9], src[10], src[11], dst[0].z,
dst[1].z, dst[2].z, dst[3].z);
transpose_int8_4x4_impl(src[12], src[13], src[14], src[15], dst[0].w,
dst[1].w, dst[2].w, dst[3].w);
}
struct transpose_int8_interleavedx4<16, int4> {
static constexpr uint32_t interleaved = 16;
using vec_type = int4;
using Fragment = array_wrapper<int, interleaved>;
MEGDNN_DEVICE __forceinline__ void operator()(const Fragment src,
vec_type (&dst)[4]) {
transpose_int8_4x4_impl(src[0], src[1], src[2], src[3], dst[0].x,
dst[1].x, dst[2].x, dst[3].x);
transpose_int8_4x4_impl(src[4], src[5], src[6], src[7], dst[0].y,
dst[1].y, dst[2].y, dst[3].y);
transpose_int8_4x4_impl(src[8], src[9], src[10], src[11], dst[0].z,
dst[1].z, dst[2].z, dst[3].z);
transpose_int8_4x4_impl(src[12], src[13], src[14], src[15], dst[0].w,
dst[1].w, dst[2].w, dst[3].w);
}
};
} // namespace cuda
} // namespace megdnn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册