提交 2398df07 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add cuda int4 pooling

GitOrigin-RevId: 14ed4e6f0095231ca87cace41842b32df18d0818
上级 2a2a7f45
......@@ -47,7 +47,8 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
} else if (param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW32) {
param().format == Param::Format::NCHW32 ||
param().format == Param::Format::NCHW64) {
megdnn_assert(src.ndim == 5_z, "%s", errmsg_c);
spatial_pos = 2;
......@@ -82,6 +83,9 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
if (param().format == Param::Format::NCHW32) {
c *= 32;
}
if (param().format == Param::Format::NCHW64) {
c *= 64;
}
size_t oh, ow;
size_t fh = this->param().window_h;
size_t fw = this->param().window_w;
......@@ -109,6 +113,8 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
dst = TensorLayout{{n, c / 8, oh, ow, 8}, src.dtype, src.format};
} else if (param().format == Param::Format::NCHW32) {
dst = TensorLayout{{n, c / 32, oh, ow, 32}, src.dtype, src.format};
} else if (param().format == Param::Format::NCHW64) {
dst = TensorLayout{{n, c / 64, oh, ow, 64}, src.dtype, src.format};
} else if (param().format == Param::Format::CHWN4) {
dst = TensorLayout{{c / 4, oh, ow, n, 4}, src.dtype, src.format};
} else {
......
......@@ -9,13 +9,50 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/pooling/opr_impl.h"
#include "src/cuda/relayout_format/opr_impl.h"
#include "./pooling2d_int8.cuh"
#include "./pooling2d_qint.cuh"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
namespace {
inline void deduce_reformat_layout(std::unique_ptr<RelayoutFormat>& relayout,
const TensorLayout& src_layout,
TensorLayout& dst_layout,
RelayoutFormat::Param::Mode mode,
const int oc = 0, const int group = 1) {
if (src_layout.ndim > 0) {
RelayoutFormat::Param trans_param;
trans_param.mode = mode;
trans_param.oc = oc;
trans_param.group = group;
relayout->param() = trans_param;
relayout->deduce_layout(src_layout, dst_layout);
} else {
dst_layout = src_layout;
}
}
void get_inner_layout(const TensorLayout& src, const TensorLayout& dst,
TensorLayout& inner_src, TensorLayout& inner_dst,
Handle* handle,
PoolingForwardImpl::Param::Format format) {
bool is_nchw = format == PoolingForwardImpl::Param::Format::NCHW;
if (src.dtype.enumv() == DTypeEnum::QuantizedS4 &&
dst.dtype.enumv() == DTypeEnum::QuantizedS4 && is_nchw) {
auto relayout_opr = handle->create_operator<RelayoutFormat>();
deduce_reformat_layout(relayout_opr, src, inner_src,
RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1);
deduce_reformat_layout(relayout_opr, dst, inner_dst,
RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1);
} else {
megdnn_assert(0, "not support");
}
}
} // namespace
void PoolingForwardImpl::setup_descs(const TensorLayout& src,
const TensorLayout& dst) {
src_desc.set(src, param().format);
......@@ -28,14 +65,22 @@ WorkspaceBundle PoolingForwardImpl::get_workspace_bundle(
SmallVector<size_t> sizes;
TensorLayout fsrc = src;
TensorLayout fdst = dst;
auto get_workspace = [&sizes](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
sizes.push_back(layout.span().dist_byte());
}
};
get_workspace(fsrc);
get_workspace(fdst);
bool is_nchw = param().format == Param::Format::NCHW;
if (src.dtype.enumv() == DTypeEnum::QuantizedS4 &&
dst.dtype.enumv() == DTypeEnum::QuantizedS4 && is_nchw) {
get_inner_layout(src, dst, fsrc, fdst, handle(), param().format);
sizes.push_back(fsrc.span().dist_byte());
sizes.push_back(fdst.span().dist_byte());
} else {
auto get_workspace = [&sizes](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
sizes.push_back(layout.span().dist_byte());
}
};
get_workspace(fsrc);
get_workspace(fdst);
}
return {ptr, std::move(sizes)};
}
......@@ -44,12 +89,27 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst,
check_exec(ssrc.layout, sdst.layout, sworkspace.size);
TensorND src = ssrc;
TensorND dst = sdst;
Param::Format inner_format = param().format;
auto wsb =
get_workspace_bundle(sworkspace.raw_ptr, ssrc.layout, sdst.layout);
auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>(
concrete_handle(this->handle()), &wsb);
bool is_nchw = param().format == Param::Format::NCHW;
if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
ctypecvt.src_to_comp_type(ssrc, src).src_to_comp_type(sdst, dst);
} else if (ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 &&
sdst.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && is_nchw) {
auto handle_ptr = handle();
get_inner_layout(ssrc.layout, sdst.layout, src.layout, dst.layout,
handle_ptr, param().format);
src.raw_ptr = wsb.get(0);
dst.raw_ptr = wsb.get(1);
auto relayout_opr = handle_ptr->create_operator<RelayoutFormat>();
RelayoutFormat::Param trans_param;
trans_param.mode = RelayoutFormat::Param::Mode::NCHW_NCHW64;
relayout_opr->param() = trans_param;
relayout_opr->exec(ssrc, src, {});
inner_format = Param::Format::NCHW64;
}
{
using Format = param::Pooling::Format;
......@@ -104,6 +164,34 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst,
return pooling2d::do_pooling2d_int8_ncdiv32hw32(
src.compatible_ptr<int8_t>(), dst.compatible_ptr<int8_t>(),
kern_param, stream, static_cast<uint32_t>(param().mode));
} else if (param().format == Format::NCHW64 ||
inner_format == Format::NCHW64) {
megdnn_assert(src.layout.dtype.enumv() == DTypeEnum::QuantizedS4,
"but %s", src.layout.dtype.name());
pooling2d::Param kern_param;
size_t n = src.layout[0], hi = src.layout[2], wi = src.layout[3],
c = src.layout[1], ho = dst.layout[2], wo = dst.layout[3];
c = c * 64;
size_t ph = param().pad_h, pw = param().pad_w;
size_t window_h = param().window_h, window_w = param().window_w;
size_t sh = param().stride_h, sw = param().stride_w;
kern_param.n = n, kern_param.c = c, kern_param.hi = hi,
kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo,
kern_param.ph = ph, kern_param.pw = pw,
kern_param.window_h = window_h, kern_param.window_w = window_w,
kern_param.sh = sh, kern_param.sw = sw;
auto&& stream = cuda_stream(handle());
pooling2d::do_pooling2d_int4_ncdiv64hw64(
(int8_t*)src.raw_ptr, (int8_t*)dst.raw_ptr, kern_param,
stream, static_cast<uint32_t>(param().mode));
if (sdst.layout.ndim == 4) {
auto relayout_opr = handle()->create_operator<RelayoutFormat>();
RelayoutFormat::Param trans_param;
trans_param.mode = RelayoutFormat::Param::Mode::NCHW64_NCHW;
relayout_opr->param() = trans_param;
relayout_opr->exec(dst, sdst,{});
}
return;
}
auto handle = cudnn_handle(this->handle());
setup_descs(src.layout, dst.layout);
......@@ -114,7 +202,7 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst,
}
if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
ctypecvt.comp_to_dst_type(dst, sdst);
}
}
}
void PoolingBackwardImpl::setup_descs(const TensorLayout& src,
......
/**
* \file dnn/src/cuda/pooling/pooling2d_int8.cuh
* \file dnn/src/cuda/pooling/pooling2d_qint.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -32,6 +32,11 @@ void do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, int8_t* d_dst,
void do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src, int8_t* d_dst,
const Param& param, cudaStream_t stream,
uint32_t mode);
void do_pooling2d_int4_ncdiv64hw64(const int8_t* d_src, int8_t* d_dst,
const Param& param, cudaStream_t stream,
uint32_t mode);
} // namespace pooling2d
} // namespace cuda
} // namespace megdnn
......
......@@ -15,6 +15,7 @@
#include "megdnn/dtype.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include "src/naive/lowbit_utils.h"
#include "midout.h"
MIDOUT_DECL(megdnn_naive_pooling)
......@@ -190,6 +191,12 @@ struct NCHW32IdxGetter {
return (((n * (C >> 5) + (c >> 5)) * H + h) * W + w) * 32 + (c & 0x1f);
}
};
struct NCHW64IdxGetter {
static size_t get_idx(size_t n, size_t c, size_t h, size_t w, size_t,
size_t C, size_t H, size_t W) {
return (((n * (C >> 6) + (c >> 6)) * H + h) * W + w) * 64 + (c & 0x3f);
}
};
/*!
* Pooler for AVERAGE_COUNT_EXCLUDE_PADDING mode
*/
......@@ -375,15 +382,81 @@ void pooling_backward_max_impl(const ctype* __restrict src,
namespace megdnn {
namespace naive {
WorkspaceBundle PoolingForwardImpl::get_workspace_bundle(
void* ptr, const TensorLayout& src, const TensorLayout& dst) const {
SmallVector<size_t> sizes;
TensorLayout fsrc = src;
TensorLayout fdst = dst;
auto get_workspace = [&sizes](TensorLayout& layout) {
if (layout.dtype.enumv() == DTypeEnum::Quantized4Asymm ||
layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
layout.dtype = dtype::Int8();
layout.format = TensorLayout::Format(layout.dtype);
sizes.push_back(layout.span().dist_byte());
}
};
get_workspace(fsrc);
get_workspace(fdst);
return {ptr, std::move(sizes)};
};
size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) {
return get_workspace_bundle(nullptr, src, dst).total_size_in_bytes();
}
namespace {
void post_process(const TensorND& dst, TensorND& comp_dst, Handle* handle,
WorkspaceBundle& workspace_bundle) {
if (dst.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
int8_to_int4(comp_dst, dst);
} else if (dst.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
uint8_to_uint4(comp_dst, dst);
}
}
} // namespace
void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
TensorND comp_src = src;
TensorND comp_dst = dst;
auto wsb = get_workspace_bundle(workspace.raw_ptr, src.layout, dst.layout);
if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
float scale = src.layout.dtype.param<dtype::QuantizedS4>().scale;
comp_src.layout.dtype = dtype::QuantizedS8(scale);
comp_src.layout.init_contiguous_stride();
comp_src.layout.format = TensorLayout::Format(comp_src.layout.dtype);
comp_src.raw_ptr = wsb.get(0);
comp_dst.layout.dtype = dtype::QuantizedS8(scale);
comp_dst.layout.format = TensorLayout::Format(comp_dst.layout.dtype);
comp_dst.layout.init_contiguous_stride();
comp_dst.raw_ptr = wsb.get(1);
int4_to_int8(src, comp_src);
} else if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
float scale = src.layout.dtype.param<dtype::Quantized4Asymm>().scale;
uint8_t zero_point =
src.layout.dtype.param<dtype::Quantized4Asymm>().zero_point;
comp_src.layout.dtype = dtype::Quantized8Asymm(scale, zero_point);
comp_src.layout.format = TensorLayout::Format(comp_src.layout.dtype);
comp_src.layout.init_contiguous_stride();
comp_src.raw_ptr = wsb.get(0);
comp_dst.layout.dtype = dtype::Quantized8Asymm(scale, zero_point);
comp_dst.layout.format = TensorLayout::Format(comp_dst.layout.dtype);
comp_dst.layout.init_contiguous_stride();
comp_dst.raw_ptr = wsb.get(1);
uint4_to_uint8(src, comp_src);
}
size_t c_pos, spatial_pos, batch_pos = 0;
if (param().format == Param::Format::NCHW ||
param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW32) {
param().format == Param::Format::NCHW32 ||
param().format == Param::Format::NCHW64) {
c_pos = 1;
spatial_pos = 2;
} else if (param().format == Param::Format::NHWC) {
......@@ -398,27 +471,35 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
c_pos = 2;
spatial_pos = 1;
}
size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos],
IH = src.layout.shape[spatial_pos + 0],
IW = src.layout.shape[spatial_pos + 1];
size_t OH = dst.layout.shape[spatial_pos + 0],
OW = dst.layout.shape[spatial_pos + 1];
if (param().format == Param::Format::NHWCD4) {
C *= 4;
IW = src.layout.shape[spatial_pos + 2];
OW = dst.layout.shape[spatial_pos + 2];
}
if (param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::CHWN4) {
C *= 4;
}
if (param().format == Param::Format::NCHW88) {
C *= 8;
}
if (param().format == Param::Format::NCHW32) {
C *= 32;
size_t N = comp_src.layout.shape[batch_pos],
C = comp_src.layout.shape[c_pos],
IH = comp_src.layout.shape[spatial_pos + 0],
IW = comp_src.layout.shape[spatial_pos + 1];
size_t OH = comp_dst.layout.shape[spatial_pos + 0],
OW = comp_dst.layout.shape[spatial_pos + 1];
switch (param().format) {
case Param::Format::NHWCD4:
C *= 4;
IW = comp_src.layout.shape[spatial_pos + 2];
OW = comp_dst.layout.shape[spatial_pos + 2];
break;
case Param::Format::NCHW4:
case Param::Format::NCHW44:
case Param::Format::CHWN4:
C *= 4;
break;
case Param::Format::NCHW88:
C *= 8;
break;
case Param::Format::NCHW32:
C *= 32;
break;
case Param::Format::NCHW64:
C *= 64;
break;
default:;
}
size_t PH = param().pad_h, PW = param().pad_w;
size_t FH = param().window_h, FW = param().window_w;
size_t SH = param().stride_h, SW = param().stride_w;
......@@ -427,8 +508,8 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(handle()), \
pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \
sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, \
PH, PW, SH, SW, FH, FW)); \
sptr, dptr, comp_src.layout.dtype, N, C, IH, IW, OH, \
OW, PH, PW, SH, SW, FH, FW)); \
} \
MIDOUT_END();
......@@ -455,6 +536,9 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
case Param::Format::NCHW32: \
DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW32IdxGetter); \
break; \
case Param::Format::NCHW64: \
DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW64IdxGetter); \
break; \
case Param::Format::CHWN4: \
DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, CHWN4IdxGetter); \
break; \
......@@ -462,30 +546,35 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
megdnn_throw("invalid pooling format"); \
}
#define cb(DType) \
if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using ctype = typename DTypeTrait<DType>::ctype; \
switch (param().mode) { \
case Mode::MAX: { \
auto sptr = src.ptr<ctype>(); \
auto dptr = dst.ptr<ctype>(); \
DISPATCH_WITH_POOLER(MaxPooler<ctype>); \
return; \
} \
case Mode::AVERAGE: { \
auto sptr = src.ptr<ctype>(); \
auto dptr = dst.ptr<ctype>(); \
DISPATCH_WITH_POOLER(MeanIncludePooler<ctype>); \
return; \
} \
case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: { \
auto sptr = src.ptr<ctype>(); \
auto dptr = dst.ptr<ctype>(); \
DISPATCH_WITH_POOLER(MeanExcludePooler<ctype>); \
return; \
} \
} \
#define cb(DType) \
if (comp_src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using ctype = typename DTypeTrait<DType>::ctype; \
switch (param().mode) { \
case Mode::MAX: { \
auto sptr = comp_src.ptr<ctype>(); \
auto dptr = comp_dst.ptr<ctype>(); \
DISPATCH_WITH_POOLER(MaxPooler<ctype>); \
break; \
} \
case Mode::AVERAGE: { \
auto sptr = comp_src.ptr<ctype>(); \
auto dptr = comp_dst.ptr<ctype>(); \
DISPATCH_WITH_POOLER(MeanIncludePooler<ctype>); \
break; \
} \
case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: { \
auto sptr = comp_src.ptr<ctype>(); \
auto dptr = comp_dst.ptr<ctype>(); \
DISPATCH_WITH_POOLER(MeanExcludePooler<ctype>); \
break; \
} \
default: \
megdnn_assert(0, "not support mode"); \
} \
post_process(dst, comp_dst, handle(), wsb); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb
......
......@@ -20,10 +20,12 @@ class PoolingForwardImpl: public PoolingForward {
using PoolingForward::PoolingForward;
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout &,
const TensorLayout &) override {
return 0;
}
size_t get_workspace_in_bytes(const TensorLayout&,
const TensorLayout&) override;
private:
WorkspaceBundle get_workspace_bundle(void* ptr, const TensorLayout&,
const TensorLayout&) const;
};
class PoolingBackwardImpl : public PoolingBackward {
......
......@@ -414,34 +414,34 @@ TensorND TensorValue(const TensorShape& shape, T dtype,
template <typename T, typename U>
TensorND TensorValueLowbit4(const TensorShape& shape, T dtype,
std::vector<U> values) {
std::vector<U> values) {
TensorND tensor;
tensor.layout = {shape, dtype};
tensor.raw_ptr =
static_cast<dt_byte*>(malloc(tensor.layout.span().dist_byte()));
megdnn_assert(values.size() == tensor.layout.total_nr_elems());
auto ptr = tensor.ptr<typename DTypeTrait<T>::ctype>();
size_t i;
for (i = 0; i + 1 < values.size(); i += 2) {
U val0 = values[i], val1 = values[i + 1];
megdnn_assert(val0 >= DTypeTrait<T>::min());
megdnn_assert(val1 <= DTypeTrait<T>::max());
ptr[i / 2] = typename DTypeTrait<T>::ctype((val0 & 0xF) | (val1 << 4));
}
if (i < values.size()) {
U val0 = values[i];
megdnn_assert(val0 >= DTypeTrait<T>::min() &&
val0 <= DTypeTrait<T>::max());
if (i + 1 < values.size()) {
U val1 = values[i + 1];
megdnn_assert(val1 >= DTypeTrait<T>::min() &&
val1 <= DTypeTrait<T>::max());
ptr[i / 2] = typename DTypeTrait<T>::ctype((val0 & 0xF) | (val1 << 4));
} else {
ptr[i / 2] = typename DTypeTrait<T>::ctype(val0 & 0xF);
auto layout = tensor.layout;
auto dim_in = shape[layout.ndim - 1];
auto elems = tensor.layout.total_nr_elems();
auto dim_out = elems / dim_in;
auto stride_out = div_ceil(dim_in, 2_z);
size_t in_offset = 0;
for (size_t i = 0; i < dim_out; ++i) {
for (size_t j = 0; j < dim_in; j += 2) {
U a = values[in_offset + j];
U b = 0;
if (j + 1 < dim_in)
b = values[in_offset + j + 1];
megdnn_assert(a >= DTypeTrait<T>::min());
megdnn_assert(a <= DTypeTrait<T>::max());
megdnn_assert(b >= DTypeTrait<T>::min());
megdnn_assert(b <= DTypeTrait<T>::max());
ptr[j / 2] = (a & 0xF) | (b << 4);
}
in_offset += dim_in;
ptr += stride_out;
}
return tensor;
}
......
......@@ -242,6 +242,20 @@ TEST_F(CUDA, POOLING_BACKWARD)
.exec(TensorShapeArray{ilayout, olayout, olayout, ilayout});
}
}
TEST_F(CUDA, POOLING_FORWARD_NCHW_Q4) {
require_compute_capability(7, 5);
using Param = param::Pooling;
Checker<Pooling> checker(handle_cuda());
Param param{Param::Mode::MAX, 0, 0, 2, 2, 2, 2};
checker.set_dtype(0, dtype::QuantizedS4(0.1f));
param.format = Param::Format::NCHW;
checker.set_epsilon(1 + 1e-3);
checker.set_param(param).exec({{20, 64, 22, 33}, {}});
param.mode = Param::Mode::AVERAGE;
checker.set_param(param).exec({{20, 64, 22, 33}, {}});
param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
checker.set_param(param).exec({{20, 64, 22, 33}, {}});
}
TEST_F(CUDA, POOLING_FORWARD_NCHW4) {
require_compute_capability(7, 5);
......@@ -252,6 +266,10 @@ TEST_F(CUDA, POOLING_FORWARD_NCHW4) {
param.format = Param::Format::NCHW4;
checker.set_epsilon(1 + 1e-3);
checker.set_param(param).exec({{20, 3, 50, 50, 4}, {}});
param.mode = Param::Mode::AVERAGE;
checker.set_param(param).exec({{20, 3, 50, 50, 4}, {}});
param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
checker.set_param(param).exec({{20, 3, 50, 50, 4}, {}});
}
#if CUDNN_VERSION >= 7500
......@@ -267,9 +285,29 @@ TEST_F(CUDA, POOLING_FORWARD_NCHW32) {
param.format = Param::Format::NCHW32;
checker.set_epsilon(1e-3).set_rng(0, &int_rng);
checker.set_param(param).exec({{64, 8, 28, 28, 32}, {}});
param.mode = Param::Mode::AVERAGE;
checker.set_param(param).exec({{64, 8, 28, 28, 64}, {}});
param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
checker.set_param(param).exec({{64, 8, 28, 28, 64}, {}});
}
#endif
TEST_F(CUDA, POOLING_FORWARD_NCHW64) {
require_compute_capability(7, 5);
using Param = param::Pooling;
Checker<Pooling> checker(handle_cuda());
Param param{Param::Mode::MAX, 0, 0, 2, 2, 2, 2};
UniformIntRNG int_rng{-8, 7};
checker.set_dtype(0, dtype::QuantizedS4(1.f));
param.format = Param::Format::NCHW64;
checker.set_epsilon(1e-3).set_rng(0, &int_rng);
checker.set_param(param).exec({{64, 8, 28, 28, 64}, {}});
param.mode = Param::Mode::AVERAGE;
checker.set_param(param).exec({{64, 8, 28, 28, 64}, {}});
param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
checker.set_param(param).exec({{64, 8, 28, 28, 64}, {}});
}
TEST_F(CUDA, POOLING_FORWARD_CHWN4) {
require_compute_capability(6, 1);
using Param = param::Pooling;
......
......@@ -50,4 +50,63 @@ TEST_F(NAIVE, POOLING_QUANTIZED) {
12306, 23333})});
}
TEST_F(NAIVE, POOLING_QUANTIZED_Q4) {
using Mode = Pooling::Param::Mode;
Checker<Pooling> checker(handle(), /* check_dispatch */ false);
{
auto q4_dt = dtype::QuantizedS4(1.f);
std::vector<int> i8_src_vec{1, 2, 3,
4, 5, 6,
7, -1, -2};
std::vector<int> i8_max_dst_vec{1, 3, 7, 6};
std::vector<int> i8_avg_dst_vec{0, 1, 3, 2};
std::vector<int> i8_avg_exclu_dst_vec{1, 3, 6, 2};
Pooling::Param param{Mode::MAX, 1, 1, 2, 2, 2, 2};
Testcase input{TensorValueLowbit4({1, 1, 3, 3}, q4_dt, i8_src_vec), {}};
checker.set_param(param).exect(
input, Testcase{{},
TensorValueLowbit4({1, 1, 2, 2}, q4_dt,
i8_max_dst_vec)});
param = {Mode::AVERAGE, 1, 1, 2, 2, 2, 2};
checker.set_param(param).exect(
input, Testcase{{},
TensorValueLowbit4({1, 1, 2, 2}, q4_dt,
i8_avg_dst_vec)});
param = {Mode::AVERAGE_COUNT_EXCLUDE_PADDING, 1, 1, 2, 2, 2, 2};
checker.set_param(param).exect(
input, Testcase{{},
TensorValueLowbit4({1, 1, 2, 2}, q4_dt,
i8_avg_exclu_dst_vec)});
}
{
auto u4_dt = dtype::Quantized4Asymm(1.f, 0);
std::vector<int> u8_src_vec{1, 2, 3,
4, 5, 6,
7, 8, 9};
std::vector<int> u8_max_dst_vec{1, 3, 7, 9};
std::vector<int> u8_avg_dst_vec{0, 1, 3, 7};
std::vector<int> u8_avg_exclu_dst_vec{1, 3, 6, 7};
Pooling::Param param{Mode::MAX, 1, 1, 2, 2, 2, 2};
Testcase input{TensorValueLowbit4({1, 1, 3, 3}, u4_dt, u8_src_vec), {}};
checker.set_param(param).exect(
input, Testcase{{},
TensorValueLowbit4({1, 1, 2, 2}, u4_dt,
u8_max_dst_vec)});
param = {Mode::AVERAGE, 1, 1, 2, 2, 2, 2};
checker.set_param(param).exect(
input, Testcase{{},
TensorValueLowbit4({1, 1, 2, 2}, u4_dt,
u8_avg_dst_vec)});
param = {Mode::AVERAGE_COUNT_EXCLUDE_PADDING, 1, 1, 2, 2, 2, 2};
checker.set_param(param).exect(
input, Testcase{{},
TensorValueLowbit4({1, 1, 2, 2}, u4_dt,
u8_avg_exclu_dst_vec)});
}
}
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册