提交 c74660ea 编写于 作者: M Megvii Engine Team

fix(dnn/cuda): fix invalid local read for relayout format kernel

GitOrigin-RevId: 5a77b82212626072059bbd95b624f57195f7b36e
上级 8fef78d0
......@@ -321,11 +321,11 @@ struct Translayout<2, 64, SrcType, dtype::QuantizedS4, dtype::QuantizedS4,
int* dst_frag = reinterpret_cast<int*>(dst_width);
#pragma unroll
for (int i = 0; i < 64; i += 8) {
#define unpack_int4x2(_idx) \
intermediate[_idx][0] = unpack_integer_4bits<true>( \
reinterpret_cast<unsigned&>(read_channel[i + _idx]), 0); \
intermediate[_idx][1] = unpack_integer_4bits<true>( \
reinterpret_cast<unsigned&>(read_channel[i + _idx]), 4);
#define unpack_int4x2(_idx) \
intermediate[_idx][0] = unpack_integer_4bits<true>( \
reinterpret_cast<uint8_t&>(read_channel[i + _idx]), 0); \
intermediate[_idx][1] = unpack_integer_4bits<true>( \
reinterpret_cast<uint8_t&>(read_channel[i + _idx]), 4);
// clang-format off
unpack_int4x2(0)
unpack_int4x2(1)
......@@ -336,7 +336,7 @@ struct Translayout<2, 64, SrcType, dtype::QuantizedS4, dtype::QuantizedS4,
unpack_int4x2(6)
unpack_int4x2(7)
// clang-format on
int frag_idx = i / 8;
dst_frag[0 * 8 + frag_idx] = pack_channel(0);
dst_frag[1 * 8 + frag_idx] = pack_channel(1);
......@@ -428,11 +428,11 @@ struct Translayout<2, 64, SrcType, dtype::Quantized4Asymm,
int* dst_frag = reinterpret_cast<int*>(dst_width);
#pragma unroll
for (int i = 0; i < 64; i += 8) {
#define unpack_int4x2(_idx) \
intermediate[_idx][0] = unpack_integer_4bits<false>( \
reinterpret_cast<unsigned&>(read_channel[i + _idx]), 0); \
intermediate[_idx][1] = unpack_integer_4bits<false>( \
reinterpret_cast<unsigned&>(read_channel[i + _idx]), 4);
#define unpack_int4x2(_idx) \
intermediate[_idx][0] = unpack_integer_4bits<false>( \
reinterpret_cast<uint8_t&>(read_channel[i + _idx]), 0); \
intermediate[_idx][1] = unpack_integer_4bits<false>( \
reinterpret_cast<uint8_t&>(read_channel[i + _idx]), 4);
// clang-format off
unpack_int4x2(0)
unpack_int4x2(1)
......@@ -1257,7 +1257,7 @@ private:
uint32_t mul;
uint32_t shr;
uint32_t mask[mask_size];
size_t stride[accesses];
size_t stride[lane_size_in_type / pack_size_in_type];
};
template <bool padding_, typename Type_, int pack_size_, int chan_blk_,
......
......@@ -445,24 +445,16 @@ MEGDNN_DEVICE __forceinline__ static int transform_int8_to_uint4x8(
return reinterpret_cast<int const&>(out);
}
template <bool signedness>
MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(unsigned storage,
unsigned bits);
template <>
MEGDNN_DEVICE __forceinline__ int unpack_integer_4bits<true>(unsigned storage,
unsigned bits) {
uint8_t result = (uint8_t)((unsigned)(storage >> bits) & 0xf);
static constexpr uint8_t mask = (uint8_t)((1 << 4) - 1);
return (result & uint8_t(1 << 3)) ? ((int)(result) | ~(int)(mask))
: (int)(result);
}
template <>
MEGDNN_DEVICE __forceinline__ int unpack_integer_4bits<false>(unsigned storage,
unsigned bits) {
uint8_t result = (uint8_t)((unsigned)(storage >> bits) & 0xf);
return (int)(result);
template <bool signedness, typename T>
MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage,
int bits) {
uint8_t result = (uint8_t)((storage >> bits) & 0xf);
if (signedness) {
static constexpr uint8_t mask = (uint8_t)((1 << 4) - 1);
return (result & uint8_t(1 << 3)) ? ((int)(result) | ~(int)(mask))
: (int)(result);
}
return int(result);
}
MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册