提交 2d4e62ef 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add cuda uint4 pooling

GitOrigin-RevId: a7289772068d08deef1021f71984cb4ecfdc6702
上级 19919384
...@@ -40,8 +40,7 @@ void get_inner_layout(const TensorLayout& src, const TensorLayout& dst, ...@@ -40,8 +40,7 @@ void get_inner_layout(const TensorLayout& src, const TensorLayout& dst,
Handle* handle, Handle* handle,
PoolingForwardImpl::Param::Format format) { PoolingForwardImpl::Param::Format format) {
bool is_nchw = format == PoolingForwardImpl::Param::Format::NCHW; bool is_nchw = format == PoolingForwardImpl::Param::Format::NCHW;
if (src.dtype.enumv() == DTypeEnum::QuantizedS4 && if (is_nchw) {
dst.dtype.enumv() == DTypeEnum::QuantizedS4 && is_nchw) {
auto relayout_opr = handle->create_operator<RelayoutFormat>(); auto relayout_opr = handle->create_operator<RelayoutFormat>();
deduce_reformat_layout(relayout_opr, src, inner_src, deduce_reformat_layout(relayout_opr, src, inner_src,
RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1); RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1);
...@@ -66,8 +65,11 @@ WorkspaceBundle PoolingForwardImpl::get_workspace_bundle( ...@@ -66,8 +65,11 @@ WorkspaceBundle PoolingForwardImpl::get_workspace_bundle(
TensorLayout fsrc = src; TensorLayout fsrc = src;
TensorLayout fdst = dst; TensorLayout fdst = dst;
bool is_nchw = param().format == Param::Format::NCHW; bool is_nchw = param().format == Param::Format::NCHW;
if (src.dtype.enumv() == DTypeEnum::QuantizedS4 && if ((src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
dst.dtype.enumv() == DTypeEnum::QuantizedS4 && is_nchw) { src.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
(dst.dtype.enumv() == DTypeEnum::QuantizedS4 ||
dst.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
is_nchw) {
get_inner_layout(src, dst, fsrc, fdst, handle(), param().format); get_inner_layout(src, dst, fsrc, fdst, handle(), param().format);
sizes.push_back(fsrc.span().dist_byte()); sizes.push_back(fsrc.span().dist_byte());
sizes.push_back(fdst.span().dist_byte()); sizes.push_back(fdst.span().dist_byte());
...@@ -97,8 +99,11 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst, ...@@ -97,8 +99,11 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst,
bool is_nchw = param().format == Param::Format::NCHW; bool is_nchw = param().format == Param::Format::NCHW;
if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) { if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
ctypecvt.src_to_comp_type(ssrc, src).src_to_comp_type(sdst, dst); ctypecvt.src_to_comp_type(ssrc, src).src_to_comp_type(sdst, dst);
} else if (ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && } else if ((ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 ||
sdst.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && is_nchw) { ssrc.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
(sdst.layout.dtype.enumv() == DTypeEnum::QuantizedS4 ||
sdst.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
is_nchw) {
auto handle_ptr = handle(); auto handle_ptr = handle();
get_inner_layout(ssrc.layout, sdst.layout, src.layout, dst.layout, get_inner_layout(ssrc.layout, sdst.layout, src.layout, dst.layout,
handle_ptr, param().format); handle_ptr, param().format);
...@@ -166,8 +171,6 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst, ...@@ -166,8 +171,6 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst,
kern_param, stream, static_cast<uint32_t>(param().mode)); kern_param, stream, static_cast<uint32_t>(param().mode));
} else if (param().format == Format::NCHW64 || } else if (param().format == Format::NCHW64 ||
inner_format == Format::NCHW64) { inner_format == Format::NCHW64) {
megdnn_assert(src.layout.dtype.enumv() == DTypeEnum::QuantizedS4,
"but %s", src.layout.dtype.name());
pooling2d::Param kern_param; pooling2d::Param kern_param;
size_t n = src.layout[0], hi = src.layout[2], wi = src.layout[3], size_t n = src.layout[0], hi = src.layout[2], wi = src.layout[3],
c = src.layout[1], ho = dst.layout[2], wo = dst.layout[3]; c = src.layout[1], ho = dst.layout[2], wo = dst.layout[3];
...@@ -180,16 +183,24 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst, ...@@ -180,16 +183,24 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst,
kern_param.ph = ph, kern_param.pw = pw, kern_param.ph = ph, kern_param.pw = pw,
kern_param.window_h = window_h, kern_param.window_w = window_w, kern_param.window_h = window_h, kern_param.window_w = window_w,
kern_param.sh = sh, kern_param.sw = sw; kern_param.sh = sh, kern_param.sw = sw;
bool uint_case = false;
int zero_point = 0;
if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
uint_case = true;
zero_point = src.layout.dtype.param<dtype::Quantized4Asymm>()
.zero_point;
}
auto&& stream = cuda_stream(handle()); auto&& stream = cuda_stream(handle());
pooling2d::do_pooling2d_int4_ncdiv64hw64( pooling2d::do_pooling2d_int4_ncdiv64hw64(
(int8_t*)src.raw_ptr, (int8_t*)dst.raw_ptr, kern_param, (int8_t*)src.raw_ptr, (int8_t*)dst.raw_ptr, kern_param,
stream, static_cast<uint32_t>(param().mode)); stream, static_cast<uint32_t>(param().mode), uint_case,
if (sdst.layout.ndim == 4) { zero_point);
auto relayout_opr = handle()->create_operator<RelayoutFormat>(); if (sdst.layout.ndim == 4) {
RelayoutFormat::Param trans_param; auto relayout_opr = handle()->create_operator<RelayoutFormat>();
trans_param.mode = RelayoutFormat::Param::Mode::NCHW64_NCHW; RelayoutFormat::Param trans_param;
relayout_opr->param() = trans_param; trans_param.mode = RelayoutFormat::Param::Mode::NCHW64_NCHW;
relayout_opr->exec(dst, sdst,{}); relayout_opr->param() = trans_param;
relayout_opr->exec(dst, sdst, {});
} }
return; return;
} }
......
...@@ -29,53 +29,51 @@ __device__ __forceinline__ int pack_int8_to_int8x4(int8_t x, int8_t y, int8_t z, ...@@ -29,53 +29,51 @@ __device__ __forceinline__ int pack_int8_to_int8x4(int8_t x, int8_t y, int8_t z,
return ix; return ix;
} }
template <int regs, typename Dtype, typename OutDtype> template <int regs, int dtype_bits, typename OutDtype>
__device__ __forceinline__ OutDtype pack_int8(int8_t (&x)[regs]); __device__ __forceinline__ OutDtype pack_int8(int8_t (&x)[regs]);
template <> template <>
__device__ __forceinline__ int pack_int8<4, int8_t, int>(int8_t (&x)[4]) { __device__ __forceinline__ int pack_int8<4, 8, int>(int8_t (&x)[4]) {
return pack_int8_to_int8x4(x[0], x[1], x[2], x[3]); return pack_int8_to_int8x4(x[0], x[1], x[2], x[3]);
} }
template <> template <>
__device__ __forceinline__ int2 pack_int8<8, int8_t, int2>(int8_t (&x)[8]) { __device__ __forceinline__ int2 pack_int8<8, 8, int2>(int8_t (&x)[8]) {
int8_t x0[4]{x[0], x[1], x[2], x[3]}; int8_t x0[4]{x[0], x[1], x[2], x[3]};
int8_t x1[4]{x[4], x[5], x[6], x[7]}; int8_t x1[4]{x[4], x[5], x[6], x[7]};
return ::make_int2(pack_int8<4, int8_t, int>(x0), return ::make_int2(pack_int8<4, 8, int>(x0), pack_int8<4, 8, int>(x1));
pack_int8<4, int8_t, int>(x1));
} }
template <> template <>
__device__ __forceinline__ int4 pack_int8<16, int8_t, int4>(int8_t (&x)[16]) { __device__ __forceinline__ int4 pack_int8<16, 8, int4>(int8_t (&x)[16]) {
int8_t x0[4]{x[0], x[1], x[2], x[3]}; int8_t x0[4]{x[0], x[1], x[2], x[3]};
int8_t x1[4]{x[4], x[5], x[6], x[7]}; int8_t x1[4]{x[4], x[5], x[6], x[7]};
int8_t x2[4]{x[8], x[9], x[10], x[11]}; int8_t x2[4]{x[8], x[9], x[10], x[11]};
int8_t x3[4]{x[12], x[13], x[14], x[15]}; int8_t x3[4]{x[12], x[13], x[14], x[15]};
return ::make_int4( return ::make_int4(pack_int8<4, 8, int>(x0), pack_int8<4, 8, int>(x1),
pack_int8<4, int8_t, int>(x0), pack_int8<4, int8_t, int>(x1), pack_int8<4, 8, int>(x2), pack_int8<4, 8, int>(x3));
pack_int8<4, int8_t, int>(x2), pack_int8<4, int8_t, int>(x3));
} }
__device__ __forceinline__ int8_t pack_int8_to_int4x2(int8_t x0, int8_t x1) { __device__ __forceinline__ int8_t pack_int8_to_int4x2(int8_t x0, int8_t x1) {
return (x0 & 0xf) | (x1 << 4); return (x0 & 0xf) | (x1 << 4);
} }
template <> template <>
__device__ __forceinline__ int pack_int8<8, dt_qint4, int>(int8_t (&x)[8]) { __device__ __forceinline__ int pack_int8<8, 4, int>(int8_t (&x)[8]) {
int8_t x0 = pack_int8_to_int4x2(x[0], x[1]); int8_t x0 = pack_int8_to_int4x2(x[0], x[1]);
int8_t x1 = pack_int8_to_int4x2(x[2], x[3]); int8_t x1 = pack_int8_to_int4x2(x[2], x[3]);
int8_t x2 = pack_int8_to_int4x2(x[4], x[5]); int8_t x2 = pack_int8_to_int4x2(x[4], x[5]);
int8_t x3 = pack_int8_to_int4x2(x[6], x[7]); int8_t x3 = pack_int8_to_int4x2(x[6], x[7]);
return pack_int8_to_int8x4(x0, x1, x2, x3); return pack_int8_to_int8x4(x0, x1, x2, x3);
} }
template <> template <>
__device__ __forceinline__ int4 pack_int8<32, dt_qint4, int4>(int8_t (&x)[32]) { __device__ __forceinline__ int4 pack_int8<32, 4, int4>(int8_t (&x)[32]) {
int8_t x0[8]{x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]}; int8_t x0[8]{x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]};
int8_t x1[8]{x[8], x[9], x[10], x[11], x[12], x[13], x[14], x[15]}; int8_t x1[8]{x[8], x[9], x[10], x[11], x[12], x[13], x[14], x[15]};
int8_t x2[8]{x[16], x[17], x[18], x[19], x[20], x[21], x[22], x[23]}; int8_t x2[8]{x[16], x[17], x[18], x[19], x[20], x[21], x[22], x[23]};
int8_t x3[8]{x[24], x[25], x[26], x[27], x[28], x[29], x[30], x[31]}; int8_t x3[8]{x[24], x[25], x[26], x[27], x[28], x[29], x[30], x[31]};
return ::make_int4( return ::make_int4(pack_int8<8, 4, int>(x0), pack_int8<8, 4, int>(x1),
pack_int8<8, dt_qint4, int>(x0), pack_int8<8, dt_qint4, int>(x1), pack_int8<8, 4, int>(x2), pack_int8<8, 4, int>(x3));
pack_int8<8, dt_qint4, int>(x2), pack_int8<8, dt_qint4, int>(x3));
} }
template <typename Dtype> template <typename Dtype>
...@@ -88,6 +86,7 @@ struct TypeTrait<int8_t> { ...@@ -88,6 +86,7 @@ struct TypeTrait<int8_t> {
static constexpr int8_t min = -128; static constexpr int8_t min = -128;
static constexpr int elem_per_32bit = 32 / bit_width; static constexpr int elem_per_32bit = 32 / bit_width;
static constexpr int shift_fix_sign = 0; static constexpr int shift_fix_sign = 0;
static constexpr bool need_zero_pad = false;
}; };
template <> template <>
...@@ -97,6 +96,16 @@ struct TypeTrait<dt_qint4> { ...@@ -97,6 +96,16 @@ struct TypeTrait<dt_qint4> {
static constexpr int8_t min = -8; static constexpr int8_t min = -8;
static constexpr int elem_per_32bit = 32 / bit_width; static constexpr int elem_per_32bit = 32 / bit_width;
static constexpr int shift_fix_sign = 4; static constexpr int shift_fix_sign = 4;
static constexpr bool need_zero_pad = false;
};
template <>
struct TypeTrait<dt_quint4> {
static constexpr int bit_width = 4;
static constexpr int mask = 0xf;
static constexpr int8_t min = 0;
static constexpr int elem_per_32bit = 32 / bit_width;
static constexpr int shift_fix_sign = 0;
static constexpr bool need_zero_pad = true;
}; };
template <typename src_type, typename _feed_type> template <typename src_type, typename _feed_type>
...@@ -108,7 +117,7 @@ struct MaxPooler { ...@@ -108,7 +117,7 @@ struct MaxPooler {
static constexpr int shift_fix_sign = TypeTrait<src_type>::shift_fix_sign; static constexpr int shift_fix_sign = TypeTrait<src_type>::shift_fix_sign;
int8_t res[nr_results]; int8_t res[nr_results];
__device__ MaxPooler(int) {} __device__ MaxPooler(int, int) {}
__device__ __forceinline__ void init() { __device__ __forceinline__ void init() {
#pragma unroll #pragma unroll
for (int i = 0; i < nr_results; ++i) { for (int i = 0; i < nr_results; ++i) {
...@@ -137,7 +146,7 @@ struct MaxPooler { ...@@ -137,7 +146,7 @@ struct MaxPooler {
} }
__device__ __forceinline__ feed_type get_ans() { __device__ __forceinline__ feed_type get_ans() {
feed_type ans; feed_type ans;
ans = pack_int8<nr_results, src_type, feed_type>(res); ans = pack_int8<nr_results, bit_width, feed_type>(res);
return ans; return ans;
} }
}; };
...@@ -149,21 +158,27 @@ struct MeanIncludeRoundedPooler { ...@@ -149,21 +158,27 @@ struct MeanIncludeRoundedPooler {
static constexpr int nr_results = sizeof(feed_type) * 8 / bit_width; static constexpr int nr_results = sizeof(feed_type) * 8 / bit_width;
static constexpr int elem_per_32bit = TypeTrait<src_type>::elem_per_32bit; static constexpr int elem_per_32bit = TypeTrait<src_type>::elem_per_32bit;
static constexpr int shift_fix_sign = TypeTrait<src_type>::shift_fix_sign; static constexpr int shift_fix_sign = TypeTrait<src_type>::shift_fix_sign;
static constexpr bool need_zero_pad = TypeTrait<src_type>::need_zero_pad;
int32_t res[nr_results]; int32_t res[nr_results];
const int count; const int count;
const float fi_count; const float fi_count;
int real_fi_count;
const int zero_pad;
__device__ MeanIncludeRoundedPooler(int count) __device__ MeanIncludeRoundedPooler(int count, int zero_point)
: count{count}, fi_count{1.f / count} {} : count{count}, fi_count{1.f / count}, zero_pad{zero_point} {}
__device__ __forceinline__ void init() { __device__ __forceinline__ void init() {
#pragma unroll #pragma unroll
for (int i = 0; i < nr_results; ++i) { for (int i = 0; i < nr_results; ++i) {
res[i] = 0; res[i] = 0;
} }
if (need_zero_pad) {
real_fi_count = 0;
}
} }
__device__ __forceinline__ void feed(int x, int idx = 0) { __device__ __forceinline__ void feed(int x, int idx) {
constexpr int unroll_n = sizeof(int) * 8 / bit_width; constexpr int unroll_n = sizeof(int) * 8 / bit_width;
#pragma unroll #pragma unroll
for (int i = 0; i < unroll_n; i++) { for (int i = 0; i < unroll_n; i++) {
...@@ -173,15 +188,27 @@ struct MeanIncludeRoundedPooler { ...@@ -173,15 +188,27 @@ struct MeanIncludeRoundedPooler {
res[idx + i] += static_cast<int32_t>(temp); res[idx + i] += static_cast<int32_t>(temp);
} }
} }
__device__ __forceinline__ void feed(int x) {
feed(x, 0);
if (need_zero_pad) {
real_fi_count++;
}
}
__device__ __forceinline__ void feed(int2 x) { __device__ __forceinline__ void feed(int2 x) {
feed(x.x, 0 * elem_per_32bit); feed(x.x, 0 * elem_per_32bit);
feed(x.y, 1 * elem_per_32bit); feed(x.y, 1 * elem_per_32bit);
if (need_zero_pad) {
real_fi_count++;
}
} }
__device__ __forceinline__ void feed(int4 x) { __device__ __forceinline__ void feed(int4 x) {
feed(x.x, 0 * elem_per_32bit); feed(x.x, 0 * elem_per_32bit);
feed(x.y, 1 * elem_per_32bit); feed(x.y, 1 * elem_per_32bit);
feed(x.z, 2 * elem_per_32bit); feed(x.z, 2 * elem_per_32bit);
feed(x.w, 3 * elem_per_32bit); feed(x.w, 3 * elem_per_32bit);
if (need_zero_pad) {
real_fi_count++;
}
} }
__device__ __forceinline__ feed_type get_ans() { __device__ __forceinline__ feed_type get_ans() {
feed_type ans; feed_type ans;
...@@ -189,13 +216,18 @@ struct MeanIncludeRoundedPooler { ...@@ -189,13 +216,18 @@ struct MeanIncludeRoundedPooler {
#pragma unroll #pragma unroll
for (int i = 0; i < nr_results; i++) { for (int i = 0; i < nr_results; i++) {
float f32_res = roundf(static_cast<float>(res[i]) * fi_count); float f32_res = roundf(static_cast<float>(res[i]) * fi_count);
if (need_zero_pad) {
f32_res = roundf((static_cast<float>(res[i]) +
(count - real_fi_count) * zero_pad) *
fi_count);
}
int i8_res; int i8_res;
asm volatile("cvt.rni.s8.f32 %0, %1;" asm volatile("cvt.rni.s8.f32 %0, %1;"
: "=r"(i8_res) : "=r"(i8_res)
: "f"(f32_res)); : "f"(f32_res));
out_res[i] = i8_res; out_res[i] = i8_res;
} }
ans = pack_int8<nr_results, src_type, feed_type>(out_res); ans = pack_int8<nr_results, bit_width, feed_type>(out_res);
return ans; return ans;
} }
}; };
...@@ -209,7 +241,7 @@ struct MeanExcludeRoundedPooler { ...@@ -209,7 +241,7 @@ struct MeanExcludeRoundedPooler {
static constexpr int shift_fix_sign = TypeTrait<src_type>::shift_fix_sign; static constexpr int shift_fix_sign = TypeTrait<src_type>::shift_fix_sign;
int32_t res[nr_results]; int32_t res[nr_results];
int count; int count;
__device__ MeanExcludeRoundedPooler(int) {} __device__ MeanExcludeRoundedPooler(int, int) {}
__device__ __forceinline__ void init() { __device__ __forceinline__ void init() {
#pragma unroll #pragma unroll
...@@ -257,7 +289,7 @@ struct MeanExcludeRoundedPooler { ...@@ -257,7 +289,7 @@ struct MeanExcludeRoundedPooler {
: "f"(f32_res)); : "f"(f32_res));
out_res[i] = i8_res; out_res[i] = i8_res;
} }
ans = pack_int8<nr_results, src_type, feed_type>(out_res); ans = pack_int8<nr_results, bit_width, feed_type>(out_res);
return ans; return ans;
} }
}; };
...@@ -290,7 +322,7 @@ __global__ void pooling2d_device_template_int8_cdiv4hwn4( ...@@ -290,7 +322,7 @@ __global__ void pooling2d_device_template_int8_cdiv4hwn4(
packed_ch * output_pixels * npack + packed_ch * output_pixels * npack +
(ho * param.wo + wo) * npack; (ho * param.wo + wo) * npack;
Pooler pooler(param.window_h * param.window_w); Pooler pooler(param.window_h * param.window_w, 0);
pooler.init(); pooler.init();
for (int fh = 0; fh < param.window_h; fh++) { for (int fh = 0; fh < param.window_h; fh++) {
uint32_t ih = ho * param.sh + fh - param.ph; uint32_t ih = ho * param.sh + fh - param.ph;
...@@ -313,7 +345,7 @@ template <typename Pooler, int pack_size, int pack_byte, ...@@ -313,7 +345,7 @@ template <typename Pooler, int pack_size, int pack_byte,
int ldg_width_assert = 4> int ldg_width_assert = 4>
__global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src, __global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src,
int8_t* __restrict__ dst, int8_t* __restrict__ dst,
Param param) { Param param, int zero_point) {
const int tid = blockIdx.x * blockDim.x + threadIdx.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x;
using ldg_type = typename Pooler::feed_type; using ldg_type = typename Pooler::feed_type;
static int constexpr ldg_width = sizeof(ldg_type) / sizeof(int32_t); static int constexpr ldg_width = sizeof(ldg_type) / sizeof(int32_t);
...@@ -348,7 +380,7 @@ __global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src, ...@@ -348,7 +380,7 @@ __global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src,
dst + (batch * out_batch_stride + oc * out_channel_stride + dst + (batch * out_batch_stride + oc * out_channel_stride +
(oh * param.wo + ow) * pack_byte + sec * ldg_width_bytes); (oh * param.wo + ow) * pack_byte + sec * ldg_width_bytes);
Pooler pooler(param.window_h * param.window_w); Pooler pooler(param.window_h * param.window_w, zero_point);
pooler.init(); pooler.init();
for (int fh = 0; fh < param.window_h; fh++) { for (int fh = 0; fh < param.window_h; fh++) {
uint32_t ih = oh * param.sh + fh - param.ph; uint32_t ih = oh * param.sh + fh - param.ph;
...@@ -418,13 +450,12 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, ...@@ -418,13 +450,12 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src,
after_kernel_launch(); after_kernel_launch();
} }
void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(
int8_t* d_dst, const int8_t* d_src, int8_t* d_dst, const Param& param,
const Param& param, cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) {
cudaStream_t stream,
uint32_t mode) {
using Mode = megdnn::param_enumv::Pooling::Mode; using Mode = megdnn::param_enumv::Pooling::Mode;
void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param,
int zero_point);
constexpr int ldg_byte = 4; constexpr int ldg_byte = 4;
constexpr int elem_per_byte = 1; constexpr int elem_per_byte = 1;
constexpr int pack_size = 4; constexpr int pack_size = 4;
...@@ -455,17 +486,16 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, ...@@ -455,17 +486,16 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src,
uint32_t nr_threads = query_blocksize_for_kernel(kern); uint32_t nr_threads = query_blocksize_for_kernel(kern);
nr_threads = std::min(nr_threads, vthreads); nr_threads = std::min(nr_threads, vthreads);
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); uint32_t nr_blocks = DIVUP(vthreads, nr_threads);
kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param); kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param, zero_point);
after_kernel_launch(); after_kernel_launch();
} }
void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src, void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32(
int8_t* d_dst, const int8_t* d_src, int8_t* d_dst, const Param& param,
const Param& param, cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) {
cudaStream_t stream,
uint32_t mode) {
using Mode = megdnn::param_enumv::Pooling::Mode; using Mode = megdnn::param_enumv::Pooling::Mode;
void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param,
int zero_point);
constexpr int ldg_byte = 16; constexpr int ldg_byte = 16;
constexpr int elem_per_byte = 1; constexpr int elem_per_byte = 1;
constexpr int pack_size = 32; constexpr int pack_size = 32;
...@@ -494,17 +524,16 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src, ...@@ -494,17 +524,16 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src,
uint32_t nr_threads = query_blocksize_for_kernel(kern); uint32_t nr_threads = query_blocksize_for_kernel(kern);
nr_threads = std::min(nr_threads, vthreads); nr_threads = std::min(nr_threads, vthreads);
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); uint32_t nr_blocks = DIVUP(vthreads, nr_threads);
kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param); kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param, zero_point);
after_kernel_launch(); after_kernel_launch();
} }
void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64(const int8_t* d_src, void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64(
int8_t* d_dst, const int8_t* d_src, int8_t* d_dst, const Param& param,
const Param& param, cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) {
cudaStream_t stream,
uint32_t mode) {
using Mode = megdnn::param_enumv::Pooling::Mode; using Mode = megdnn::param_enumv::Pooling::Mode;
void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param,
int zero_point);
constexpr int ldg_byte = 16; constexpr int ldg_byte = 16;
constexpr int elem_per_byte = 2; constexpr int elem_per_byte = 2;
constexpr int pack_size = 64; constexpr int pack_size = 64;
...@@ -512,28 +541,50 @@ void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64(const int8_t* d_src, ...@@ -512,28 +541,50 @@ void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64(const int8_t* d_src,
constexpr int elem_per_thread = ldg_byte * elem_per_byte; constexpr int elem_per_thread = ldg_byte * elem_per_byte;
uint32_t vthreads = uint32_t vthreads =
param.n * param.c * param.ho * param.wo / elem_per_thread; param.n * param.c * param.ho * param.wo / elem_per_thread;
switch (mode) { if (uint_case) {
case Mode::MAX: switch (mode) {
kern = pooling2d_device_template_nchwc<MaxPooler<dt_qint4, int4>, case Mode::MAX:
pack_size, pack_byte>; kern = pooling2d_device_template_nchwc<
break; MaxPooler<dt_quint4, int4>, pack_size, pack_byte>;
case Mode::AVERAGE: break;
kern = pooling2d_device_template_nchwc< case Mode::AVERAGE:
MeanIncludeRoundedPooler<dt_qint4, int4, int32_t>, kern = pooling2d_device_template_nchwc<
pack_size, pack_byte>; MeanIncludeRoundedPooler<dt_quint4, int4, int32_t>,
break; pack_size, pack_byte>;
case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: break;
kern = pooling2d_device_template_nchwc< case Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
MeanExcludeRoundedPooler<dt_qint4, int4, int32_t>, kern = pooling2d_device_template_nchwc<
pack_size, pack_byte>; MeanExcludeRoundedPooler<dt_quint4, int4, int32_t>,
break; pack_size, pack_byte>;
default: break;
megdnn_assert(false, "invalid pooling mode"); default:
megdnn_assert(false, "invalid pooling mode");
}
} else {
switch (mode) {
case Mode::MAX:
kern = pooling2d_device_template_nchwc<MaxPooler<dt_qint4, int4>,
pack_size, pack_byte>;
break;
case Mode::AVERAGE:
kern = pooling2d_device_template_nchwc<
MeanIncludeRoundedPooler<dt_qint4, int4, int32_t>,
pack_size, pack_byte>;
break;
case Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
kern = pooling2d_device_template_nchwc<
MeanExcludeRoundedPooler<dt_qint4, int4, int32_t>,
pack_size, pack_byte>;
break;
default:
megdnn_assert(false, "invalid pooling mode");
}
} }
uint32_t nr_threads = query_blocksize_for_kernel(kern); uint32_t nr_threads = query_blocksize_for_kernel(kern);
nr_threads = std::min(nr_threads, vthreads); nr_threads = std::min(nr_threads, vthreads);
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); uint32_t nr_blocks = DIVUP(vthreads, nr_threads);
kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param); kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param, zero_point);
after_kernel_launch(); after_kernel_launch();
} }
......
...@@ -27,15 +27,18 @@ void do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, int8_t* d_dst, ...@@ -27,15 +27,18 @@ void do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, int8_t* d_dst,
void do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, int8_t* d_dst, void do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, int8_t* d_dst,
const Param& param, cudaStream_t stream, const Param& param, cudaStream_t stream,
uint32_t mode); uint32_t mode, bool uint_case = false,
int zero_point = 0);
void do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src, int8_t* d_dst, void do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src, int8_t* d_dst,
const Param& param, cudaStream_t stream, const Param& param, cudaStream_t stream,
uint32_t mode); uint32_t mode, bool uint_case = false,
int zero_point = 0);
void do_pooling2d_int4_ncdiv64hw64(const int8_t* d_src, int8_t* d_dst, void do_pooling2d_int4_ncdiv64hw64(const int8_t* d_src, int8_t* d_dst,
const Param& param, cudaStream_t stream, const Param& param, cudaStream_t stream,
uint32_t mode); uint32_t mode, bool uint_case = false,
int zero_point = 0);
} // namespace pooling2d } // namespace pooling2d
} // namespace cuda } // namespace cuda
......
...@@ -254,6 +254,13 @@ TEST_F(CUDA, POOLING_FORWARD_NCHW_Q4) { ...@@ -254,6 +254,13 @@ TEST_F(CUDA, POOLING_FORWARD_NCHW_Q4) {
checker.set_param(param).exec({{20, 96, 22, 33}, {}}); checker.set_param(param).exec({{20, 96, 22, 33}, {}});
param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING; param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
checker.set_param(param).exec({{20, 24, 22, 33}, {}}); checker.set_param(param).exec({{20, 24, 22, 33}, {}});
checker.set_dtype(0, dtype::Quantized4Asymm(3.1415926f, 3));
param.format = Param::Format::NCHW;
checker.set_param(param).exec({{20, 64, 22, 33}, {}});
param.mode = Param::Mode::AVERAGE;
checker.set_param(param).exec({{20, 96, 22, 33}, {}});
param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
checker.set_param(param).exec({{20, 24, 22, 33}, {}});
} }
TEST_F(CUDA, POOLING_FORWARD_NCHW4) { TEST_F(CUDA, POOLING_FORWARD_NCHW4) {
...@@ -291,20 +298,36 @@ TEST_F(CUDA, POOLING_FORWARD_NCHW32) { ...@@ -291,20 +298,36 @@ TEST_F(CUDA, POOLING_FORWARD_NCHW32) {
} }
#endif #endif
TEST_F(CUDA, POOLING_FORWARD_NCHW64) { TEST_F(CUDA, POOLING_FORWARD_NCHW64_Q4) {
require_compute_capability(7, 5); require_compute_capability(7, 5);
using Param = param::Pooling; using Param = param::Pooling;
Checker<Pooling> checker(handle_cuda()); Checker<Pooling> checker(handle_cuda());
Param param{Param::Mode::MAX, 0, 0, 2, 2, 2, 2}; Param param{Param::Mode::MAX, 1, 1, 2, 2, 2, 2};
UniformIntRNG int_rng{-8, 7}; UniformIntRNG int_rng{-8, 7};
checker.set_dtype(0, dtype::QuantizedS4(1.f)); checker.set_dtype(0, dtype::QuantizedS4(1.f));
param.format = Param::Format::NCHW64; param.format = Param::Format::NCHW64;
checker.set_epsilon(1e-3).set_rng(0, &int_rng); checker.set_epsilon(1e-3).set_rng(0, &int_rng);
checker.set_param(param).exec({{64, 8, 28, 28, 64}, {}}); checker.set_param(param).exec({{4, 8, 28, 28, 64}, {}});
param.mode = Param::Mode::AVERAGE; param.mode = Param::Mode::AVERAGE;
checker.set_param(param).exec({{64, 8, 28, 28, 64}, {}}); checker.set_param(param).exec({{4, 8, 28, 28, 64}, {}});
param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING; param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
checker.set_param(param).exec({{64, 8, 28, 28, 64}, {}}); checker.set_param(param).exec({{4, 8, 28, 28, 64}, {}});
}
TEST_F(CUDA, POOLING_FORWARD_NCHW64_U4) {
require_compute_capability(7, 5);
using Param = param::Pooling;
Checker<Pooling> checker(handle_cuda());
Param param{Param::Mode::MAX, 1, 1, 2, 2, 2, 2};
UniformIntRNG int_rng{0, 15};
checker.set_dtype(0, dtype::Quantized4Asymm(1.f, 3));
param.format = Param::Format::NCHW64;
checker.set_epsilon(1e-3).set_rng(0, &int_rng);
checker.set_param(param).exec({{4, 8, 28, 28, 64}, {}});
param.mode = Param::Mode::AVERAGE;
checker.set_param(param).exec({{4, 8, 28, 28, 64}, {}});
param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
checker.set_param(param).exec({{4, 8, 28, 28, 64}, {}});
} }
TEST_F(CUDA, POOLING_FORWARD_CHWN4) { TEST_F(CUDA, POOLING_FORWARD_CHWN4) {
......
...@@ -84,12 +84,12 @@ TEST_F(NAIVE, POOLING_QUANTIZED_Q4) { ...@@ -84,12 +84,12 @@ TEST_F(NAIVE, POOLING_QUANTIZED_Q4) {
} }
{ {
auto u4_dt = dtype::Quantized4Asymm(1.f, 0); auto u4_dt = dtype::Quantized4Asymm(0.1f, 3);
std::vector<int> u8_src_vec{1, 2, 3, std::vector<int> u8_src_vec{1, 2, 3,
4, 5, 6, 4, 5, 6,
7, 8, 9}; 7, 8, 9};
std::vector<int> u8_max_dst_vec{1, 3, 7, 9}; std::vector<int> u8_max_dst_vec{1, 3, 7, 9};
std::vector<int> u8_avg_dst_vec{0, 1, 3, 7}; std::vector<int> u8_avg_dst_vec{3, 3, 4, 7};
std::vector<int> u8_avg_exclu_dst_vec{1, 3, 6, 7}; std::vector<int> u8_avg_exclu_dst_vec{1, 3, 6, 7};
Pooling::Param param{Mode::MAX, 1, 1, 2, 2, 2, 2}; Pooling::Param param{Mode::MAX, 1, 1, 2, 2, 2, 2};
Testcase input{TensorValueLowbit4({1, 1, 3, 3}, u4_dt, u8_src_vec), {}}; Testcase input{TensorValueLowbit4({1, 1, 3, 3}, u4_dt, u8_src_vec), {}};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册