提交 19919384 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add cuda uint warp perspective

GitOrigin-RevId: 2aec72010f81ad92b726924fcdc4b65069ec0cab
上级 01354337
......@@ -86,6 +86,16 @@ struct RoundingConverter<dt_qint4> {
}
};
template <>
struct RoundingConverter<dt_quint4> {
__host__ __device__ __forceinline__ dt_quint4 operator()(float x) const {
#if MEGDNN_CC_HOST
using std::round;
#endif
return static_cast<dt_quint4>(round(x));
}
};
} // namespace rounding
} // namespace megdnn
......
......@@ -73,9 +73,10 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src,
src.dtype.enumv() == DTypeEnum::Uint8 ||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm) ||
src.dtype.enumv() == DTypeEnum::QuantizedS4,
src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
src.dtype.enumv() == DTypeEnum::Quantized4Asymm,
"WarpPerspective NCHW input dtype should be "
"Float32/Int8/Uint8/QInt8/QUint8" DNN_FLOAT16_SELECT(
"Float32/Int8/Uint8/QInt8/QUint8/QInt4/QUInt4" DNN_FLOAT16_SELECT(
"/Float16/BFloat16", "") ".");
megdnn_assert(
(src.dtype.category() == DTypeCategory::FLOAT &&
......@@ -118,8 +119,9 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src,
megdnn_assert(param().bmode !=
param::WarpPerspective::BorderMode::ISOLATED);
} else if (param().format == param::WarpPerspective::Format::NCHW64) {
megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS4,
"src expected QuantizedS4, but got %s",
megdnn_assert((src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
src.dtype.enumv() == DTypeEnum::Quantized4Asymm),
"src expected QuantizedS4/Quantized4Asymm, but got %s",
src.dtype.name());
megdnn_assert(mat.dtype == dtype::Float32(),
"matrix dtype expected float, got %s",
......
......@@ -44,8 +44,9 @@ void get_inner_layout(const TensorLayout& src, const TensorLayout& dst,
TensorLayout& inner_src, TensorLayout& inner_dst,
Handle* handle,
WarpPerspectiveForwardImpl::Param::Format format) {
if (src.dtype.enumv() == DTypeEnum::QuantizedS4 &&
dst.dtype.enumv() == DTypeEnum::QuantizedS4 &&
if ((src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
src.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
dst.dtype.enumv() == src.dtype.enumv() &&
format == param::WarpPerspective::Format::NCHW) {
auto relayout_opr = handle->create_operator<RelayoutFormat>();
deduce_reformat_layout(relayout_opr, src, inner_src,
......@@ -130,7 +131,8 @@ WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle(
TensorLayout fsrc = src;
TensorLayout fmat = mat;
TensorLayout fdst = dst;
if (src.dtype.enumv() == DTypeEnum::QuantizedS4 &&
if ((src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
src.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
param().format == param::WarpPerspective::Format::NCHW) {
get_inner_layout(src, dst, fsrc, fdst, handle(), param().format);
sizes.push_back(fsrc.span().dist_byte());
......@@ -177,7 +179,8 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
ctypecvt.src_to_comp_type(ssrc, src)
.src_to_comp_type(smat, mat)
.src_to_comp_type(sdst, dst);
} else if (ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 &&
} else if ((ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 ||
ssrc.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
param().format == Param::Format::NCHW) {
auto handle_ptr = handle();
get_inner_layout(ssrc.layout, sdst.layout, src.layout, dst.layout,
......@@ -330,7 +333,7 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
param().format == Param::Format::NCHW64 ||
param().format == Param::Format::NCHW,
"WarpPerspective on CUDA supports NCHW64 or NCHW+ "
"QuantizedS4 only");
"QuantizedS4");
bval = roundf(bval);
bval = fmin(fmax(-8.f, bval), 7.f);
warp_perspective::forward_proxy_nchw64<dt_qint4>(
......@@ -352,6 +355,34 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
relayout_opr->param() = trans_param;
relayout_opr->exec(dst, sdst, {});
}
} else if (src.layout.dtype.enumv() ==
DTypeEnum::Quantized4Asymm) {
megdnn_assert(
param().format == Param::Format::NCHW64 ||
param().format == Param::Format::NCHW,
"WarpPerspective on CUDA supports NCHW64 or NCHW+ "
"Quantized4Asymm");
bval = roundf(bval);
bval = fmin(fmax(0, bval), 15);
warp_perspective::forward_proxy_nchw64<dt_quint4>(
src.compatible_ptr<dt_quint4>(),
mat.ptr<dt_float32>(),
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
dst.compatible_ptr<dt_quint4>(), src.layout[0],
mat.layout[0], C, IH, IW, OH, OW,
static_cast<dt_quint4>(bval), bmode,
async_error_info(handle()), m_error_tracker,
stream);
if (param().format == Param::Format::NCHW) {
auto relayout_opr =
handle()->create_operator<RelayoutFormat>();
RelayoutFormat::Param trans_param;
trans_param.mode =
RelayoutFormat::Param::Mode::NCHW64_NCHW;
trans_param.oc = sdst.layout[1];
relayout_opr->param() = trans_param;
relayout_opr->exec(dst, sdst, {});
}
}
} else if ((src.layout.dtype.enumv() ==
DTypeEnum::Quantized8Asymm ||
......
......@@ -144,25 +144,68 @@ __global__ void kern_general_nchw4(SrcVisitor src, const float* __restrict mat,
}
}
#define warp_perspective_transform(idx) \
template <bool signedness>
MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8(int s0, int s1,
int s2, int s3,
int s4, int s5,
int s6, int s7);
template <>
MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8<true>(
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) {
return transform_int8_to_int4x8(s0, s1, s2, s3, s4, s5, s6, s7);
}
template <>
MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8<false>(
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) {
return transform_int8_to_uint4x8(s0, s1, s2, s3, s4, s5, s6, s7);
}
template <bool signedness>
MEGDNN_DEVICE __forceinline__ void
transform_bit4x8_to_int8(int (&result)[8], const int& source);
template <>
MEGDNN_DEVICE __forceinline__ void
transform_bit4x8_to_int8<true>(int (&result)[8], const int& source){
transform_int4x8_to_int8(result, source);
}
template <>
MEGDNN_DEVICE __forceinline__ void
transform_bit4x8_to_int8<false>(int (&result)[8], const int& source){
transform_uint4x8_to_int8(result, source);
}
template <bool signedness, typename OutputConverter>
MEGDNN_DEVICE __forceinline__ int pack_output_func(
OutputConverter& output_converter, int (&s00)[8], int (&s01)[8],
int (&s10)[8], int (&s11)[8], float palpha, float pbeta, float nalpha,
float nbeta) {
#define warp_perspective_transform(idx) \
static_cast<int>(output_converter(s00[idx] * nalpha * nbeta + \
s01[idx] * nalpha * pbeta + \
s10[idx] * palpha * nbeta + \
s11[idx] * palpha * pbeta) \
.as_int8())
#define pack_output \
transform_int8_to_int4x8( \
warp_perspective_transform(0), warp_perspective_transform(1), \
warp_perspective_transform(2), warp_perspective_transform(3), \
warp_perspective_transform(4), warp_perspective_transform(5), \
warp_perspective_transform(6), warp_perspective_transform(7))
.as_storage())
return transform_int8_to_bit4x8<signedness>(
warp_perspective_transform(0), warp_perspective_transform(1),
warp_perspective_transform(2), warp_perspective_transform(3),
warp_perspective_transform(4), warp_perspective_transform(5),
warp_perspective_transform(6), warp_perspective_transform(7));
#undef warp_perspective_transform
}
template <typename ctype, typename Getter, typename SrcVisitor,
typename OutputConverter>
__global__ void kern_general_nchw64(SrcVisitor src, const float* __restrict mat,
ctype* __restrict dst, int C, int IH,
int IW, int OH, int OW) {
constexpr bool signedness = std::is_same<ctype, dt_qint4>::value;
Getter getter;
OutputConverter output_converter;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -199,29 +242,37 @@ __global__ void kern_general_nchw64(SrcVisitor src, const float* __restrict mat,
s[2] = __ldg(sptr_int4 + i_coor_10 + c1);
s[3] = __ldg(sptr_int4 + i_coor_11 + c1);
transform_int4x8_to_int8(s00, s[0].x);
transform_int4x8_to_int8(s01, s[1].x);
transform_int4x8_to_int8(s10, s[2].x);
transform_int4x8_to_int8(s11, s[3].x);
d.x = pack_output;
transform_int4x8_to_int8(s00, s[0].y);
transform_int4x8_to_int8(s01, s[1].y);
transform_int4x8_to_int8(s10, s[2].y);
transform_int4x8_to_int8(s11, s[3].y);
d.y = pack_output;
transform_int4x8_to_int8(s00, s[0].z);
transform_int4x8_to_int8(s01, s[1].z);
transform_int4x8_to_int8(s10, s[2].z);
transform_int4x8_to_int8(s11, s[3].z);
d.z = pack_output;
transform_int4x8_to_int8(s00, s[0].w);
transform_int4x8_to_int8(s01, s[1].w);
transform_int4x8_to_int8(s10, s[2].w);
transform_int4x8_to_int8(s11, s[3].w);
d.w = pack_output;
transform_bit4x8_to_int8<signedness>(s00, s[0].x);
transform_bit4x8_to_int8<signedness>(s01, s[1].x);
transform_bit4x8_to_int8<signedness>(s10, s[2].x);
transform_bit4x8_to_int8<signedness>(s11, s[3].x);
d.x = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, palpha, pbeta, nalpha,
nbeta);
transform_bit4x8_to_int8<signedness>(s00, s[0].y);
transform_bit4x8_to_int8<signedness>(s01, s[1].y);
transform_bit4x8_to_int8<signedness>(s10, s[2].y);
transform_bit4x8_to_int8<signedness>(s11, s[3].y);
d.y = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, palpha, pbeta, nalpha,
nbeta);
transform_bit4x8_to_int8<signedness>(s00, s[0].z);
transform_bit4x8_to_int8<signedness>(s01, s[1].z);
transform_bit4x8_to_int8<signedness>(s10, s[2].z);
transform_bit4x8_to_int8<signedness>(s11, s[3].z);
d.z = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, palpha, pbeta, nalpha,
nbeta);
transform_bit4x8_to_int8<signedness>(s00, s[0].w);
transform_bit4x8_to_int8<signedness>(s01, s[1].w);
transform_bit4x8_to_int8<signedness>(s10, s[2].w);
transform_bit4x8_to_int8<signedness>(s11, s[3].w);
d.w = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, palpha, pbeta, nalpha,
nbeta);
dst_int4[o_coor + c1] = d;
sptr_int4 += IH * IW * 2;
......@@ -320,15 +371,25 @@ __global__ void kern_const_border_nchw4(SrcVisitor src,
}
}
}
template <bool signedness>
MEGDNN_DEVICE __forceinline__ static void transform_bit4x8_to_int8(
int (&result)[8], const int& source) {
#pragma unroll
for (int i = 0; i < 8; i++) {
result[i] = unpack_integer_4bits<signedness>(
reinterpret_cast<unsigned const&>(source), (i << 2));
}
}
template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_const_border_nchw64(SrcVisitor src,
const float* __restrict mat,
ctype* __restrict dst, int C, int IH,
int IW, int OH, int OW, ctype bval) {
constexpr bool signedness = std::is_same<ctype, dt_qint4>::value;
OutputConverter output_converter;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int c1 = ow %2;
int c1 = ow % 2;
ow = ow / 2;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW / 2);
......@@ -359,9 +420,9 @@ __global__ void kern_const_border_nchw64(SrcVisitor src,
int i_coor_11 = (ih1 * IW + iw1) << 1;
bool flag00 = okh0 && okw0, flag01 = okh0 && okw1,
flag10 = okh1 && okw0, flag11 = okh1 && okw1;
int8_t bval_4 = bval.as_int8() & 0xF;
int bval_8 = transform_int8_to_int4x8(bval_4, bval_4, bval_4, bval_4,
bval_4, bval_4, bval_4, bval_4);
int8_t bval_4 = bval.as_storage() & 0xF;
int bval_8 = transform_int8_to_bit4x8<signedness>(
bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4);
int4 bval_int4;
bval_int4.x = bval_8;
bval_int4.y = bval_8;
......@@ -391,29 +452,37 @@ __global__ void kern_const_border_nchw64(SrcVisitor src,
s[3] = bval_int4;
}
transform_int4x8_to_int8(s00, s[0].x);
transform_int4x8_to_int8(s01, s[1].x);
transform_int4x8_to_int8(s10, s[2].x);
transform_int4x8_to_int8(s11, s[3].x);
d.x = pack_output;
transform_int4x8_to_int8(s00, s[0].y);
transform_int4x8_to_int8(s01, s[1].y);
transform_int4x8_to_int8(s10, s[2].y);
transform_int4x8_to_int8(s11, s[3].y);
d.y = pack_output;
transform_int4x8_to_int8(s00, s[0].z);
transform_int4x8_to_int8(s01, s[1].z);
transform_int4x8_to_int8(s10, s[2].z);
transform_int4x8_to_int8(s11, s[3].z);
d.z = pack_output;
transform_int4x8_to_int8(s00, s[0].w);
transform_int4x8_to_int8(s01, s[1].w);
transform_int4x8_to_int8(s10, s[2].w);
transform_int4x8_to_int8(s11, s[3].w);
d.w = pack_output;
transform_bit4x8_to_int8<signedness>(s00, s[0].x);
transform_bit4x8_to_int8<signedness>(s01, s[1].x);
transform_bit4x8_to_int8<signedness>(s10, s[2].x);
transform_bit4x8_to_int8<signedness>(s11, s[3].x);
d.x = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, palpha, pbeta, nalpha,
nbeta);
transform_bit4x8_to_int8<signedness>(s00, s[0].y);
transform_bit4x8_to_int8<signedness>(s01, s[1].y);
transform_bit4x8_to_int8<signedness>(s10, s[2].y);
transform_bit4x8_to_int8<signedness>(s11, s[3].y);
d.y = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, palpha, pbeta, nalpha,
nbeta);
transform_bit4x8_to_int8<signedness>(s00, s[0].z);
transform_bit4x8_to_int8<signedness>(s01, s[1].z);
transform_bit4x8_to_int8<signedness>(s10, s[2].z);
transform_bit4x8_to_int8<signedness>(s11, s[3].z);
d.z = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, palpha, pbeta, nalpha,
nbeta);
transform_bit4x8_to_int8<signedness>(s00, s[0].w);
transform_bit4x8_to_int8<signedness>(s01, s[1].w);
transform_bit4x8_to_int8<signedness>(s10, s[2].w);
transform_bit4x8_to_int8<signedness>(s11, s[3].w);
d.w = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, palpha, pbeta, nalpha,
nbeta);
dst_int4[o_coor + c1] = d;
sptr_int4 += IH * IW * 2;
......@@ -1448,6 +1517,7 @@ INST(int8_t)
void*, cudaStream_t);
INST(dt_qint4)
INST(dt_quint4)
#undef INST
template <typename src_dtype, typename src_ctype, typename dst_ctype>
......
......@@ -249,6 +249,7 @@ void WarpPerspectiveForwardImpl::kern_naive_nhwcd4(
MIDOUT_END();
}
template <typename ctype, typename mtype>
void WarpPerspectiveForwardImpl::kern_naive_int4(
const KernParam<ctype, mtype>& kern_param, size_t task_id) {
......@@ -257,6 +258,7 @@ void WarpPerspectiveForwardImpl::kern_naive_int4(
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param);
MEGDNN_MARK_USED_VAR(N_MAT);
uint8_t c_shift, c_mask, iw_shift = 0, ow_shift = 0;
constexpr bool signedness = std::is_same<ctype, dt_qint4>::value;
switch (param().format) {
case Format::NCHW:
c_shift = 0;
......@@ -282,8 +284,13 @@ void WarpPerspectiveForwardImpl::kern_naive_int4(
<< c_shift) +
(c & c_mask);
uint8_t result =
(sptr[index / 2].as_int8() >> (4 * (index % 2))) & 0xF;
return result & uint8_t(1 << 3) ? result | ~mask : result;
(sptr[index / 2].as_storage() >> (4 * (index % 2))) & 0xF;
if (signedness) {
return result & uint8_t(1 << 3) ? result | ~mask : result;
} else {
megdnn_assert((std::is_same<ctype, dt_quint4>::value));
return result;
}
};
auto visit_src_bd = [&sptr, sstrd, border_val, c_shift, c_mask](
size_t c, int h, int w) -> float {
......@@ -292,8 +299,14 @@ void WarpPerspectiveForwardImpl::kern_naive_int4(
<< c_shift) +
(c & c_mask);
uint8_t result =
(sptr[index / 2].as_int8() >> (4 * (index % 2))) & 0xF;
return result & uint8_t(1 << 3) ? result | ~mask : result;
(sptr[index / 2].as_storage() >> (4 * (index % 2))) &
0xF;
if (signedness) {
return result & uint8_t(1 << 3) ? result | ~mask : result;
} else {
megdnn_assert((std::is_same<ctype, dt_quint4>::value));
return result;;
}
} else
return border_val;
};
......@@ -302,9 +315,9 @@ void WarpPerspectiveForwardImpl::kern_naive_int4(
size_t index = ((dstrd[0] * (c >> c_shift) + dstrd[1] * h + w)
<< c_shift) +
(c & c_mask);
dptr[index / 2] =
(dptr[index / 2].as_int8() & (0xF0 >> (4 * (index % 2)))) |
(v.as_int8() << (4 * (index % 2)));
dptr[index / 2] = (dptr[index / 2].as_storage() &
(0xF0 >> (4 * (index % 2)))) |
(v.as_storage() << (4 * (index % 2)));
};
rounding::RoundingConverter<ctype> output_converter;
......@@ -334,21 +347,20 @@ void WarpPerspectiveForwardImpl::kern_naive_int4(
int iw1 = get_real_coord(std::floor(alphaw) + 1, IW);
int ih0 = get_real_coord(std::floor(alphah) + 0, IH);
int ih1 = get_real_coord(std::floor(alphah) + 1, IH);
alphaw -= floor(alphaw);
alphah -= floor(alphah);
if (bmode != BorderMode::CONSTANT) {
rep(c, C) {
set_visit_dst(
c, oh, ow,
output_converter(
visit_src(c, ih0, iw0) * (1.0f - alphaw) *
auto val = visit_src(c, ih0, iw0) * (1.0f - alphaw) *
(1.0f - alphah) +
visit_src(c, ih0, iw1) * alphaw *
(1.0f - alphah) +
visit_src(c, ih1, iw0) * (1.0f - alphaw) *
alphah +
visit_src(c, ih1, iw1) * alphaw * alphah));
visit_src(c, ih1, iw1) * alphaw * alphah;
set_visit_dst(
c, oh, ow,
output_converter(val));
}
} else {
rep(c, C) {
......@@ -613,6 +625,13 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src,
"WarpPerspective: %s",
src.layout.dtype.name())
.c_str());
} else if (src.layout.dtype.enumv() ==
DTypeTrait<dtype::Quantized4Asymm>::enumv) {
DISPATCH_ST(dtype::Quantized4Asymm, dt_quint4, float, KERN_INT4);
megdnn_throw(ssprintf("Unsupported input DType in "
"WarpPerspective: %s",
src.layout.dtype.name())
.c_str());
}
bool is_fusion_dtype = src.layout.dtype.enumv() != dst.layout.dtype.enumv();
......
......@@ -107,7 +107,8 @@ protected:
ret.mptr = mat.ptr<mtype>();
ret.dptr = dst.compatible_ptr<ctype>();
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 ||
src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
ret.sptr = src.compatible_ptr<ctype>();
ret.mptr = mat.ptr<mtype>();
ret.dptr = dst.compatible_ptr<ctype>();
......
......@@ -647,6 +647,31 @@ TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD_QINT4) {
}
}
TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD_QUINT4) {
using Param = WarpPerspective::Param;
Checker<WarpPerspectiveForward> checker(handle_cuda());
WarpPerspectiveMatRNG rng;
checker.set_rng(1, &rng);
checker.set_dtype(0, dtype::Quantized4Asymm(1.25f, 0))
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Quantized4Asymm(1.25f, 0));
for (auto bmode : {WarpPerspective::BorderMode::WRAP,
WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
WarpPerspective::Param param;
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NCHW;
checker.set_param(param);
checker.set_epsilon(1 + 1e-3);
checker.execs({{1, 64, 11, 11}, {1, 3, 3}, {1, 64, 11, 11}});
checker.execs({{20, 640, 11, 12}, {20, 3, 3}, {20, 640, 11, 12}});
}
}
TEST_F(CUDA, WARP_PERSPECTIVE_BACKWARD_DATA_BFLOAT16) {
Checker<WarpPerspectiveBackwardData> checker(handle_cuda());
WarpPerspectiveMatRNG rng;
......@@ -701,7 +726,7 @@ TEST_F(CUDA, WARP_PERSPECTIVE_MAT_IDX) {
warp_perspective::run_mat_idx_test(handle_cuda());
}
TEST_F(CUDA, WARP_PERSPECTIVE_NCHW64) {
TEST_F(CUDA, WARP_PERSPECTIVE_NCHW64_QINT4) {
using Param = WarpPerspective::Param;
WarpPerspective::Param param;
Checker<WarpPerspectiveForward> checker(handle_cuda());
......@@ -767,6 +792,72 @@ TEST_F(CUDA, WARP_PERSPECTIVE_NCHW64) {
}
}
TEST_F(CUDA, WARP_PERSPECTIVE_NCHW64_QUINT4) {
using Param = WarpPerspective::Param;
WarpPerspective::Param param;
Checker<WarpPerspectiveForward> checker(handle_cuda());
WarpPerspectiveMatRNG_V2 rng;
checker.set_dtype(0, dtype::Quantized4Asymm(0.1f, 3));
checker.set_dtype(2, dtype::Quantized4Asymm(0.1f, 3));
for (auto bmode : {WarpPerspective::BorderMode::WRAP,
WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;
param.format = Param::Format::NCHW64;
checker.set_param(param);
checker.set_epsilon(1 + 1e-3);
rng.set_hw(10, 11);
checker.set_rng(1, &rng);
checker.execs({{2, 1, 10, 11, 64}, {2, 3, 3}, {2, 1, 11, 12, 64}});
checker.execs(
{{20, 300, 10, 11, 64}, {20, 3, 3}, {20, 300, 11, 12, 64}});
checker.execs(
{{2200, 3, 10, 11, 64}, {2200, 3, 3}, {2200, 3, 11, 12, 64}});
rng.set_hw(25, 25);
checker.set_rng(1, &rng);
checker.execs({{1, 25, 25, 25, 64}, {1, 3, 3}, {1, 25, 25, 51, 64}});
rng.set_hw(25, 510);
checker.set_rng(1, &rng);
checker.execs({{1, 1, 25, 510, 64}, {1, 3, 3}, {1, 1, 25, 25, 64}});
rng.set_hw(25, 25);
checker.set_rng(1, &rng);
checker.execs({{1, 1, 25, 25, 64}, {1, 3, 3}, {1, 1, 51, 51, 64}});
rng.set_hw(51, 51);
checker.set_rng(1, &rng);
checker.execs({{1, 1, 51, 51, 64}, {1, 3, 3}, {1, 1, 25, 25, 64}});
}
{
Checker<WarpPerspective, WarpPerspectiveMatIdxProxy> checker(
handle_cuda());
constexpr int N_SRC = 5;
UniformIntRNG mat_idx_rng{0, N_SRC - 1};
checker.set_dtype(0, dtype::Quantized4Asymm(0.1f, 3));
checker.set_rng(1, &rng);
checker.set_dtype(2, dtype::Int32());
checker.set_rng(2, &mat_idx_rng);
checker.set_dtype(3, dtype::Quantized4Asymm(0.1f, 3));
param.bmode = WarpPerspective::Param::BorderMode::REFLECT;
param.imode = param::WarpPerspective::InterpolationMode::LINEAR;
checker.set_param(param);
checker.set_epsilon(1 + 1e-3);
rng.set_hw(10, 11);
checker.set_rng(1, &rng);
checker.execs(
{{N_SRC, 3, 10, 11, 64}, {2, 3, 3}, {2}, {2, 3, 11, 12, 64}});
rng.set_hw(17, 13);
checker.set_rng(1, &rng);
checker.execs({{N_SRC, 14, 17, 13, 64},
{123, 3, 3},
{123},
{123, 14, 16, 15, 64}});
}
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, BENCHMARK_WARP_PERSPECTIVE_NCHW4) {
......
......@@ -196,8 +196,8 @@ TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW_QINT4) {
param.imode = WarpPerspective::Param::InterpolationMode::LINEAR;
param.format = WarpPerspective::Param::Format::NCHW;
std::vector<int> input_values = {1, 3, 2, 2, 0, 0, 0, 0, 2},
output_values = {1, 2, 2, 2};
std::vector<int> input_values = {-1, -3, -2, -2, 0, 0, 0, 0, -2},
output_values = {-1, -2, -2, -2};
checker.set_param(param).exect(
Testcase{TensorValueLowbit4({1, 1, 3, 3}, dtype::QuantizedS4(0.1),
......@@ -212,6 +212,31 @@ TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW_QINT4) {
output_values)});
}
TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW_QUINT4) {
Checker<WarpPerspective> checker(handle(), false);
WarpPerspective::Param param;
param.bmode = WarpPerspective::Param::BorderMode::BORDER_REFLECT;
param.imode = WarpPerspective::Param::InterpolationMode::LINEAR;
param.format = WarpPerspective::Param::Format::NCHW;
std::vector<int> input_values = {4, 13, 0, 0, 0, 0, 0, 0, 0},
output_values = {6, 8, 8, 9};
checker.set_param(param).exect(
Testcase{TensorValueLowbit4({1, 1, 3, 3},
dtype::Quantized4Asymm(0.1, 3),
input_values),
TensorValue({1, 3, 3}, dtype::Float32{},
{1.2f, 1.2f, 0.6f, -1.05f, -2.0f, -0.7f, 1.3f,
1.5f, 3.0f}),
{}},
Testcase{{},
{},
TensorValueLowbit4({1, 1, 2, 2},
dtype::Quantized4Asymm(0.1, 3),
output_values)});
}
TEST_F(NAIVE_MULTI_THREADS, WARP_PERSPECTIVE_NCHW4) {
using Param = WarpPerspective::Param;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册