提交 1e576e32 编写于 作者: M Megvii Engine Team

feat(dnn/aarch64-arm_common): add mat_idx warppespective for aarch64/arm_common/naive

GitOrigin-RevId: 9eb0cdda5c3c4f4a766c67f51b9888960c26876e
上级 714cb232
...@@ -28,8 +28,8 @@ void WarpPerspectiveImpl::exec(_megdnn_tensor_in src, ...@@ -28,8 +28,8 @@ void WarpPerspectiveImpl::exec(_megdnn_tensor_in src,
check_exec(src.layout, mat.layout, mat_idx.layout, dst.layout, check_exec(src.layout, mat.layout, mat_idx.layout, dst.layout,
workspace.size); workspace.size);
if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode,
param().format) && !mat_idx.layout.ndim) { param().format)) {
warp_perspective_cv_exec(src, mat, dst, param().border_val, warp_perspective_cv_exec(src, mat, mat_idx, dst, param().border_val,
param().bmode, param().imode, handle()); param().bmode, param().imode, handle());
} else { } else {
//! Use arm_common implementation //! Use arm_common implementation
......
...@@ -190,9 +190,9 @@ void warp_perspective_cv(const Mat<T>& src, Mat<T>& dst, const float* trans, ...@@ -190,9 +190,9 @@ void warp_perspective_cv(const Mat<T>& src, Mat<T>& dst, const float* trans,
} }
} // anonymous namespace } // anonymous namespace
void megdnn::aarch64::warp_perspective_cv_exec( void megdnn::aarch64::warp_perspective_cv_exec(
_megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in dst, _megdnn_tensor_in src, _megdnn_tensor_in trans,
float border_value, BorderMode bmode, InterpolationMode imode, _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, float border_value,
Handle* handle) { BorderMode bmode, InterpolationMode imode, Handle* handle) {
size_t ch = dst.layout[3]; size_t ch = dst.layout[3];
size_t width = dst.layout[2]; size_t width = dst.layout[2];
size_t height = dst.layout[1]; size_t height = dst.layout[1];
...@@ -208,13 +208,26 @@ void megdnn::aarch64::warp_perspective_cv_exec( ...@@ -208,13 +208,26 @@ void megdnn::aarch64::warp_perspective_cv_exec(
"unsupported src channel: %zu, avaiable channel size: 1/2/3", "unsupported src channel: %zu, avaiable channel size: 1/2/3",
ch); ch);
const float* trans_ptr = trans.ptr<dt_float32>(); const float* trans_ptr = trans.ptr<dt_float32>();
if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { const int* midx_ptr = nullptr;
if (mat_idx.raw_ptr) {
megdnn_assert(mat_idx.layout.ndim == 1);
midx_ptr = mat_idx.ptr<int>();
}
if (dst.layout.dtype.enumv() == DTypeEnum::Float32) {
#define cb(_imode, _bmode, _ch) \ #define cb(_imode, _bmode, _ch) \
auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ auto task = [src, trans_ptr, midx_ptr, dst, border_value, \
size_t index, size_t) { \ parallelism_batch](size_t index, size_t) { \
size_t batch_id = index / parallelism_batch; \ size_t batch_id = index / parallelism_batch; \
size_t task_id = index % parallelism_batch; \ size_t task_id = index % parallelism_batch; \
Mat<float> src_mat = TensorND2Mat<float>(src, batch_id); \ size_t src_id = batch_id; \
if (midx_ptr) { \
src_id = midx_ptr[batch_id]; \
megdnn_assert( \
src_id < src.layout.shape[0], \
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \
batch_id, src_id, src.layout.shape[0]); \
} \
Mat<float> src_mat = TensorND2Mat<float>(src, src_id); \
Mat<float> dst_mat = TensorND2Mat<float>(dst, batch_id); \ Mat<float> dst_mat = TensorND2Mat<float>(dst, batch_id); \
const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \
warp_perspective_cv<float MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \ warp_perspective_cv<float MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \
...@@ -230,11 +243,19 @@ void megdnn::aarch64::warp_perspective_cv_exec( ...@@ -230,11 +243,19 @@ void megdnn::aarch64::warp_perspective_cv_exec(
#undef cb #undef cb
} else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) {
#define cb(_imode, _bmode, _ch) \ #define cb(_imode, _bmode, _ch) \
auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ auto task = [src, trans_ptr, midx_ptr, dst, border_value, \
size_t index, size_t) { \ parallelism_batch](size_t index, size_t) { \
size_t batch_id = index / parallelism_batch; \ size_t batch_id = index / parallelism_batch; \
size_t task_id = index % parallelism_batch; \ size_t task_id = index % parallelism_batch; \
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, batch_id); \ size_t src_id = batch_id; \
if (midx_ptr) { \
src_id = midx_ptr[batch_id]; \
megdnn_assert( \
src_id < src.layout.shape[0], \
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \
batch_id, src_id, src.layout.shape[0]); \
} \
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, src_id); \
Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, batch_id); \ Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, batch_id); \
const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \
warp_perspective_cv<uchar MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \ warp_perspective_cv<uchar MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \
...@@ -250,8 +271,7 @@ void megdnn::aarch64::warp_perspective_cv_exec( ...@@ -250,8 +271,7 @@ void megdnn::aarch64::warp_perspective_cv_exec(
#undef cb #undef cb
} else { } else {
megdnn_throw( megdnn_throw(
megdnn_mangle("Unsupported datatype of WarpAffine optr.")); megdnn_mangle("Unsupported datatype of WarpPerspective optr."));
} }
} }
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -21,7 +21,8 @@ namespace aarch64 { ...@@ -21,7 +21,8 @@ namespace aarch64 {
* \brief Used if the format is NHWC, transfer from megcv * \brief Used if the format is NHWC, transfer from megcv
*/ */
void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans,
_megdnn_tensor_in dst, float border_value, _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst,
float border_value,
param::WarpPerspective::BorderMode border_mode, param::WarpPerspective::BorderMode border_mode,
param::WarpPerspective::InterpolationMode imode, param::WarpPerspective::InterpolationMode imode,
Handle* handle); Handle* handle);
......
...@@ -28,10 +28,9 @@ void WarpPerspectiveImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, ...@@ -28,10 +28,9 @@ void WarpPerspectiveImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in mat,
check_exec_allow_nhwc_mat_idx(src.layout, mat.layout, mat_idx.layout, check_exec_allow_nhwc_mat_idx(src.layout, mat.layout, mat_idx.layout,
dst.layout, workspace.size); dst.layout, workspace.size);
if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode,
param().format) && param().format)) {
!mat_idx.layout.ndim) {
MIDOUT_BEGIN(megdnn_arm_warpperspective, void) { MIDOUT_BEGIN(megdnn_arm_warpperspective, void) {
warp_perspective_cv_exec(src, mat, dst, param().border_val, warp_perspective_cv_exec(src, mat, mat_idx, dst, param().border_val,
param().bmode, param().imode, handle()); param().bmode, param().imode, handle());
} }
MIDOUT_END(); MIDOUT_END();
......
...@@ -149,9 +149,9 @@ void warp_perspective_cv(const Mat<T>& src, Mat<T>& dst, const float* trans, ...@@ -149,9 +149,9 @@ void warp_perspective_cv(const Mat<T>& src, Mat<T>& dst, const float* trans,
} // anonymous namespace } // anonymous namespace
void megdnn::arm_common::warp_perspective_cv_exec( void megdnn::arm_common::warp_perspective_cv_exec(
_megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in dst, _megdnn_tensor_in src, _megdnn_tensor_in trans,
float border_value, BorderMode bmode, InterpolationMode imode, _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, float border_value,
Handle* handle) { BorderMode bmode, InterpolationMode imode, Handle* handle) {
size_t ch = dst.layout[3]; size_t ch = dst.layout[3];
size_t width = dst.layout[2]; size_t width = dst.layout[2];
size_t height = dst.layout[1]; size_t height = dst.layout[1];
...@@ -167,13 +167,26 @@ void megdnn::arm_common::warp_perspective_cv_exec( ...@@ -167,13 +167,26 @@ void megdnn::arm_common::warp_perspective_cv_exec(
"unsupported src channel: %zu, avaiable channel size: 1/2/3", "unsupported src channel: %zu, avaiable channel size: 1/2/3",
ch); ch);
const float* trans_ptr = trans.ptr<dt_float32>(); const float* trans_ptr = trans.ptr<dt_float32>();
const int* midx_ptr = nullptr;
if (mat_idx.raw_ptr) {
megdnn_assert(mat_idx.layout.ndim == 1);
midx_ptr = mat_idx.ptr<int>();
}
if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { if (dst.layout.dtype.enumv() == DTypeEnum::Float32) {
#define cb(_imode, _bmode, _ch) \ #define cb(_imode, _bmode, _ch) \
auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ auto task = [src, trans_ptr, midx_ptr, dst, border_value, \
size_t index, size_t) { \ parallelism_batch](size_t index, size_t) { \
size_t batch_id = index / parallelism_batch; \ size_t batch_id = index / parallelism_batch; \
size_t task_id = index % parallelism_batch; \ size_t task_id = index % parallelism_batch; \
Mat<float> src_mat = TensorND2Mat<float>(src, batch_id); \ size_t src_id = batch_id; \
if (midx_ptr) { \
src_id = midx_ptr[batch_id]; \
megdnn_assert( \
src_id < src.layout.shape[0], \
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \
batch_id, src_id, src.layout.shape[0]); \
} \
Mat<float> src_mat = TensorND2Mat<float>(src, src_id); \
Mat<float> dst_mat = TensorND2Mat<float>(dst, batch_id); \ Mat<float> dst_mat = TensorND2Mat<float>(dst, batch_id); \
const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \
warp_perspective_cv<float MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \ warp_perspective_cv<float MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \
...@@ -189,11 +202,19 @@ void megdnn::arm_common::warp_perspective_cv_exec( ...@@ -189,11 +202,19 @@ void megdnn::arm_common::warp_perspective_cv_exec(
#undef cb #undef cb
} else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) {
#define cb(_imode, _bmode, _ch) \ #define cb(_imode, _bmode, _ch) \
auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ auto task = [src, trans_ptr, midx_ptr, dst, border_value, \
size_t index, size_t) { \ parallelism_batch](size_t index, size_t) { \
size_t batch_id = index / parallelism_batch; \ size_t batch_id = index / parallelism_batch; \
size_t task_id = index % parallelism_batch; \ size_t task_id = index % parallelism_batch; \
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, batch_id); \ size_t src_id = batch_id; \
if (midx_ptr) { \
src_id = midx_ptr[batch_id]; \
megdnn_assert( \
src_id < src.layout.shape[0], \
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \
batch_id, src_id, src.layout.shape[0]); \
} \
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, src_id); \
Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, batch_id); \ Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, batch_id); \
const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \
warp_perspective_cv<uchar MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \ warp_perspective_cv<uchar MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \
......
...@@ -21,7 +21,8 @@ namespace arm_common { ...@@ -21,7 +21,8 @@ namespace arm_common {
* \brief Used if the format is NHWC, transfer from megcv * \brief Used if the format is NHWC, transfer from megcv
*/ */
void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans,
_megdnn_tensor_in dst, float border_value, _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst,
float border_value,
param::WarpPerspective::BorderMode border_mode, param::WarpPerspective::BorderMode border_mode,
param::WarpPerspective::InterpolationMode imode, param::WarpPerspective::InterpolationMode imode,
Handle* handle); Handle* handle);
......
...@@ -236,10 +236,6 @@ void WarpPerspectiveForward::check_exec(const TensorLayout &src, ...@@ -236,10 +236,6 @@ void WarpPerspectiveForward::check_exec(const TensorLayout &src,
size_t workspace_in_bytes) size_t workspace_in_bytes)
{ {
check_exec_allow_nhwc_mat_idx(src, mat, mat_idx, dst, workspace_in_bytes); check_exec_allow_nhwc_mat_idx(src, mat, mat_idx, dst, workspace_in_bytes);
if (param().format == Param::Format::NHWC) {
megdnn_assert(!mat_idx.ndim,
"mat_idx not supported for current format");
}
} }
void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx( void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx(
......
...@@ -320,10 +320,9 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src, ...@@ -320,10 +320,9 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src,
.c_str()); .c_str());
} }
if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode,
param().format) && param().format)) {
!mat_idx.layout.ndim) {
MIDOUT_BEGIN(megdnn_naive_warpperspective, void) { MIDOUT_BEGIN(megdnn_naive_warpperspective, void) {
warp_perspective_cv_exec(src, mat, dst, param().border_val, warp_perspective_cv_exec(src, mat, mat_idx, dst, param().border_val,
param().bmode, param().imode, handle()); param().bmode, param().imode, handle());
} }
MIDOUT_END(); MIDOUT_END();
......
...@@ -151,9 +151,9 @@ void warp_perspective_cv(const Mat<T>& src, Mat<T>& dst, const float* trans, ...@@ -151,9 +151,9 @@ void warp_perspective_cv(const Mat<T>& src, Mat<T>& dst, const float* trans,
} // anonymous namespace } // anonymous namespace
void megdnn::naive::warp_perspective_cv_exec( void megdnn::naive::warp_perspective_cv_exec(
_megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in dst, _megdnn_tensor_in src, _megdnn_tensor_in trans,
float border_value, BorderMode bmode, InterpolationMode imode, _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, float border_value,
Handle* handle) { BorderMode bmode, InterpolationMode imode, Handle* handle) {
size_t ch = dst.layout[3]; size_t ch = dst.layout[3];
size_t width = dst.layout[2]; size_t width = dst.layout[2];
size_t height = dst.layout[1]; size_t height = dst.layout[1];
...@@ -169,13 +169,26 @@ void megdnn::naive::warp_perspective_cv_exec( ...@@ -169,13 +169,26 @@ void megdnn::naive::warp_perspective_cv_exec(
"unsupported src channel: %zu, avaiable channel size: 1/2/3", "unsupported src channel: %zu, avaiable channel size: 1/2/3",
ch); ch);
const float* trans_ptr = trans.ptr<dt_float32>(); const float* trans_ptr = trans.ptr<dt_float32>();
const int* midx_ptr = nullptr;
if (mat_idx.raw_ptr) {
megdnn_assert(mat_idx.layout.ndim == 1);
midx_ptr = mat_idx.ptr<int>();
}
if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { if (dst.layout.dtype.enumv() == DTypeEnum::Float32) {
#define cb(_imode, _bmode, _ch) \ #define cb(_imode, _bmode, _ch) \
auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ auto task = [src, trans_ptr, midx_ptr, dst, border_value, \
size_t index, size_t) { \ parallelism_batch](size_t index, size_t) { \
size_t batch_id = index / parallelism_batch; \ size_t batch_id = index / parallelism_batch; \
size_t task_id = index % parallelism_batch; \ size_t task_id = index % parallelism_batch; \
Mat<float> src_mat = TensorND2Mat<float>(src, batch_id); \ size_t src_id = batch_id; \
if (midx_ptr) { \
src_id = midx_ptr[batch_id]; \
megdnn_assert( \
src_id < src.layout.shape[0], \
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \
batch_id, src_id, src.layout.shape[0]); \
} \
Mat<float> src_mat = TensorND2Mat<float>(src, src_id); \
Mat<float> dst_mat = TensorND2Mat<float>(dst, batch_id); \ Mat<float> dst_mat = TensorND2Mat<float>(dst, batch_id); \
const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \
warp_perspective_cv<float MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \ warp_perspective_cv<float MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \
...@@ -191,11 +204,19 @@ void megdnn::naive::warp_perspective_cv_exec( ...@@ -191,11 +204,19 @@ void megdnn::naive::warp_perspective_cv_exec(
#undef cb #undef cb
} else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) {
#define cb(_imode, _bmode, _ch) \ #define cb(_imode, _bmode, _ch) \
auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ auto task = [src, trans_ptr, midx_ptr, dst, border_value, \
size_t index, size_t) { \ parallelism_batch](size_t index, size_t) { \
size_t batch_id = index / parallelism_batch; \ size_t batch_id = index / parallelism_batch; \
size_t task_id = index % parallelism_batch; \ size_t task_id = index % parallelism_batch; \
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, batch_id); \ size_t src_id = batch_id; \
if (midx_ptr) { \
src_id = midx_ptr[batch_id]; \
megdnn_assert( \
src_id < src.layout.shape[0], \
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \
batch_id, src_id, src.layout.shape[0]); \
} \
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, src_id); \
Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, batch_id); \ Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, batch_id); \
const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \
warp_perspective_cv<uchar MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \ warp_perspective_cv<uchar MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \
...@@ -210,7 +231,8 @@ void megdnn::naive::warp_perspective_cv_exec( ...@@ -210,7 +231,8 @@ void megdnn::naive::warp_perspective_cv_exec(
DISPATCH_IMODE(imode, bmode, ch, cb) DISPATCH_IMODE(imode, bmode, ch, cb)
#undef cb #undef cb
} else { } else {
megdnn_throw(megdnn_mangle("Unsupported datatype of WarpAffine optr.")); megdnn_throw(
megdnn_mangle("Unsupported datatype of WarpPerspective optr."));
} }
} }
......
...@@ -21,7 +21,8 @@ namespace naive { ...@@ -21,7 +21,8 @@ namespace naive {
* \brief Used if the format is NHWC, transfer from megcv * \brief Used if the format is NHWC, transfer from megcv
*/ */
void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans,
_megdnn_tensor_in dst, float border_value, _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst,
float border_value,
param::WarpPerspective::BorderMode border_mode, param::WarpPerspective::BorderMode border_mode,
param::WarpPerspective::InterpolationMode imode, param::WarpPerspective::InterpolationMode imode,
Handle* handle); Handle* handle);
......
...@@ -27,8 +27,8 @@ void WarpPerspectiveImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, ...@@ -27,8 +27,8 @@ void WarpPerspectiveImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in mat,
dst.layout, workspace.size); dst.layout, workspace.size);
if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode,
param().format) && param().format) &&
!mat_idx.layout.ndim && is_supported(SIMDType::SSE4_2)) { is_supported(SIMDType::SSE4_2)) {
warp_perspective_cv_exec(src, mat, dst, param().border_val, warp_perspective_cv_exec(src, mat, mat_idx, dst, param().border_val,
param().bmode, param().imode, handle()); param().bmode, param().imode, handle());
} else { } else {
//! Use fallback implementation //! Use fallback implementation
......
...@@ -59,7 +59,6 @@ ...@@ -59,7 +59,6 @@
* --------------------------------------------------------------------------- * ---------------------------------------------------------------------------
*/ */
#include "src/x86/warp_perspective/warp_perspective_cv.h" #include "src/x86/warp_perspective/warp_perspective_cv.h"
#include "src/common/cv/common.h" #include "src/common/cv/common.h"
#include "src/common/cv/helper.h" #include "src/common/cv/helper.h"
...@@ -154,12 +153,10 @@ void warp_perspective_cv(const Mat<T>& src, Mat<T>& dst, const float* trans, ...@@ -154,12 +153,10 @@ void warp_perspective_cv(const Mat<T>& src, Mat<T>& dst, const float* trans,
} }
} // anonymous namespace } // anonymous namespace
void megdnn::x86::warp_perspective_cv_exec(_megdnn_tensor_in src, void megdnn::x86::warp_perspective_cv_exec(
_megdnn_tensor_in trans, _megdnn_tensor_in src, _megdnn_tensor_in trans,
_megdnn_tensor_in dst, _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, float border_value,
float border_value, BorderMode bmode, BorderMode bmode, InterpolationMode imode, Handle* handle) {
InterpolationMode imode,
Handle* handle) {
size_t ch = dst.layout[3]; size_t ch = dst.layout[3];
size_t width = dst.layout[2]; size_t width = dst.layout[2];
size_t height = dst.layout[1]; size_t height = dst.layout[1];
...@@ -175,13 +172,26 @@ void megdnn::x86::warp_perspective_cv_exec(_megdnn_tensor_in src, ...@@ -175,13 +172,26 @@ void megdnn::x86::warp_perspective_cv_exec(_megdnn_tensor_in src,
"unsupported src channel: %zu, avaiable channel size: 1/2/3", "unsupported src channel: %zu, avaiable channel size: 1/2/3",
ch); ch);
const float* trans_ptr = trans.ptr<dt_float32>(); const float* trans_ptr = trans.ptr<dt_float32>();
const int* midx_ptr = nullptr;
if (mat_idx.raw_ptr) {
megdnn_assert(mat_idx.layout.ndim == 1);
midx_ptr = mat_idx.ptr<int>();
}
if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { if (dst.layout.dtype.enumv() == DTypeEnum::Float32) {
#define cb(_imode, _bmode, _ch) \ #define cb(_imode, _bmode, _ch) \
auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ auto task = [src, trans_ptr, midx_ptr, dst, border_value, \
size_t index, size_t) { \ parallelism_batch](size_t index, size_t) { \
size_t batch_id = index / parallelism_batch; \ size_t batch_id = index / parallelism_batch; \
size_t task_id = index % parallelism_batch; \ size_t task_id = index % parallelism_batch; \
Mat<float> src_mat = TensorND2Mat<float>(src, batch_id); \ size_t src_id = batch_id; \
if (midx_ptr) { \
src_id = midx_ptr[batch_id]; \
megdnn_assert( \
src_id < src.layout.shape[0], \
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \
batch_id, src_id, src.layout.shape[0]); \
} \
Mat<float> src_mat = TensorND2Mat<float>(src, src_id); \
Mat<float> dst_mat = TensorND2Mat<float>(dst, batch_id); \ Mat<float> dst_mat = TensorND2Mat<float>(dst, batch_id); \
const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \
warp_perspective_cv<float MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \ warp_perspective_cv<float MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \
...@@ -197,11 +207,19 @@ void megdnn::x86::warp_perspective_cv_exec(_megdnn_tensor_in src, ...@@ -197,11 +207,19 @@ void megdnn::x86::warp_perspective_cv_exec(_megdnn_tensor_in src,
#undef cb #undef cb
} else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) {
#define cb(_imode, _bmode, _ch) \ #define cb(_imode, _bmode, _ch) \
auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ auto task = [src, trans_ptr, midx_ptr, dst, border_value, \
size_t index, size_t) { \ parallelism_batch](size_t index, size_t) { \
size_t batch_id = index / parallelism_batch; \ size_t batch_id = index / parallelism_batch; \
size_t task_id = index % parallelism_batch; \ size_t task_id = index % parallelism_batch; \
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, batch_id); \ size_t src_id = batch_id; \
if (midx_ptr) { \
src_id = midx_ptr[batch_id]; \
megdnn_assert( \
src_id < src.layout.shape[0], \
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \
batch_id, src_id, src.layout.shape[0]); \
} \
Mat<uchar> src_mat = TensorND2Mat<uchar>(src, src_id); \
Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, batch_id); \ Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, batch_id); \
const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \
warp_perspective_cv<uchar MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \ warp_perspective_cv<uchar MEGDNN_COMMA _imode MEGDNN_COMMA _bmode \
......
...@@ -21,12 +21,13 @@ namespace x86 { ...@@ -21,12 +21,13 @@ namespace x86 {
* \brief Used if the format is NHWC, transfer from megcv * \brief Used if the format is NHWC, transfer from megcv
*/ */
void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans,
_megdnn_tensor_in dst, float border_value, _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst,
float border_value,
param::WarpPerspective::BorderMode border_mode, param::WarpPerspective::BorderMode border_mode,
param::WarpPerspective::InterpolationMode imode, param::WarpPerspective::InterpolationMode imode,
Handle* handle); Handle* handle);
} // x86 } // x86
} // megdnn } // megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -25,7 +25,7 @@ namespace test { ...@@ -25,7 +25,7 @@ namespace test {
TEST_F(AARCH64, WARP_PERSPECTIVE_CV) { TEST_F(AARCH64, WARP_PERSPECTIVE_CV) {
//! Just for the format NHWC //! Just for the format NHWC
Checker<WarpPerspective> checker(handle()); Checker<WarpPerspective, WarpPerspectiveMatIdxProxy> checker(handle());
param::WarpPerspective param; param::WarpPerspective param;
class ResizeMatRNG : public RNG { class ResizeMatRNG : public RNG {
void gen(const TensorND& tensor_) override { void gen(const TensorND& tensor_) override {
...@@ -82,7 +82,10 @@ TEST_F(AARCH64, WARP_PERSPECTIVE_CV) { ...@@ -82,7 +82,10 @@ TEST_F(AARCH64, WARP_PERSPECTIVE_CV) {
param.bmode = mode; param.bmode = mode;
param.border_val = 1.737; param.border_val = 1.737;
checker.set_param(param); checker.set_param(param);
checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); UniformIntRNG rng(0, 1);
checker.set_rng(2, &rng);
checker.set_dtype(2, dtype::Int32());
checker.exec({{2, 5, 5, 1}, {4, 3, 3}, {4}, {4, 5, 5, 1}});
} }
// resize nan case // resize nan case
UniformFloatRNG rng_zero(0, 0); UniformFloatRNG rng_zero(0, 0);
...@@ -91,7 +94,11 @@ TEST_F(AARCH64, WARP_PERSPECTIVE_CV) { ...@@ -91,7 +94,11 @@ TEST_F(AARCH64, WARP_PERSPECTIVE_CV) {
param.bmode = BMode::CONSTANT; param.bmode = BMode::CONSTANT;
param.border_val = 1.737; param.border_val = 1.737;
checker.set_param(param); checker.set_param(param);
checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); UniformIntRNG rng(0, 999);
checker.set_rng(2, &rng);
checker.set_dtype(2, dtype::Int32());
checker.exec(
{{1000, 2, 10, 3}, {2000, 3, 3}, {2000}, {2000, 2, 12, 3}});
} }
// add linear test // add linear test
...@@ -101,7 +108,10 @@ TEST_F(AARCH64, WARP_PERSPECTIVE_CV) { ...@@ -101,7 +108,10 @@ TEST_F(AARCH64, WARP_PERSPECTIVE_CV) {
param.bmode = mode; param.bmode = mode;
param.border_val = 1.737; param.border_val = 1.737;
checker.set_param(param); checker.set_param(param);
checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); UniformIntRNG rng(0, 9);
checker.set_rng(2, &rng);
checker.set_dtype(2, dtype::Int32());
checker.exec({{10, 128, 108, 3}, {20, 3, 3}, {20}, {20, 56, 128, 3}});
} }
// resize nan case // resize nan case
checker.set_rng(1, &rng_zero); checker.set_rng(1, &rng_zero);
...@@ -109,24 +119,34 @@ TEST_F(AARCH64, WARP_PERSPECTIVE_CV) { ...@@ -109,24 +119,34 @@ TEST_F(AARCH64, WARP_PERSPECTIVE_CV) {
param.bmode = BMode::CONSTANT; param.bmode = BMode::CONSTANT;
param.border_val = 1.737; param.border_val = 1.737;
checker.set_param(param); checker.set_param(param);
checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); UniformIntRNG rng(0, 999);
checker.set_rng(2, &rng);
checker.set_dtype(2, dtype::Int32());
checker.exec(
{{1000, 2, 10, 3}, {2000, 3, 3}, {2000}, {2000, 2, 12, 3}});
} }
auto args = warp_perspective::get_cv_args(); auto args = warp_perspective::get_cv_args();
for (auto&& arg : args) { for (auto&& arg : args) {
ConstValue rng(0.f);
checker.set_param(arg.param) checker.set_param(arg.param)
.set_rng(2, &rng)
.set_dtype(0, dtype::Uint8()) .set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::Float32()) .set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Uint8()) .set_dtype(2, dtype::Int32())
.execs({arg.src, arg.trans, arg.dst}); .set_dtype(3, dtype::Uint8())
.execs({arg.src, arg.trans, arg.mat_idx, arg.dst});
} }
for (auto&& arg : args) { for (auto&& arg : args) {
ConstValue rng(0.f);
checker.set_param(arg.param) checker.set_param(arg.param)
.set_rng(2, &rng)
.set_dtype(0, dtype::Float32()) .set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32()) .set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32()) .set_dtype(2, dtype::Int32())
.execs({arg.src, arg.trans, arg.dst}); .set_dtype(3, dtype::Float32())
.execs({arg.src, arg.trans, arg.mat_idx, arg.dst});
} }
} }
......
...@@ -25,7 +25,7 @@ namespace test { ...@@ -25,7 +25,7 @@ namespace test {
TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) { TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) {
//! Just for the format NHWC //! Just for the format NHWC
Checker<WarpPerspective> checker(handle()); Checker<WarpPerspective, WarpPerspectiveMatIdxProxy> checker(handle());
param::WarpPerspective param; param::WarpPerspective param;
class ResizeMatRNG : public RNG { class ResizeMatRNG : public RNG {
void gen(const TensorND& tensor_) override { void gen(const TensorND& tensor_) override {
...@@ -82,7 +82,10 @@ TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) { ...@@ -82,7 +82,10 @@ TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) {
param.bmode = mode; param.bmode = mode;
param.border_val = 1.737; param.border_val = 1.737;
checker.set_param(param); checker.set_param(param);
checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); UniformIntRNG rng(0, 9);
checker.set_rng(2, &rng);
checker.set_dtype(2, dtype::Int32());
checker.exec({{10, 128, 108, 3}, {20, 3, 3}, {20}, {20, 56, 128, 3}});
} }
// resize nan case // resize nan case
UniformFloatRNG rng_zero(0, 0); UniformFloatRNG rng_zero(0, 0);
...@@ -91,7 +94,11 @@ TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) { ...@@ -91,7 +94,11 @@ TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) {
param.bmode = BMode::CONSTANT; param.bmode = BMode::CONSTANT;
param.border_val = 1.737; param.border_val = 1.737;
checker.set_param(param); checker.set_param(param);
checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); UniformIntRNG rng(0, 999);
checker.set_rng(2, &rng);
checker.set_dtype(2, dtype::Int32());
checker.exec(
{{1000, 2, 10, 3}, {1000, 3, 3}, {1000}, {1000, 2, 12, 3}});
} }
// add linear test // add linear test
...@@ -101,7 +108,10 @@ TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) { ...@@ -101,7 +108,10 @@ TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) {
param.bmode = mode; param.bmode = mode;
param.border_val = 1.737; param.border_val = 1.737;
checker.set_param(param); checker.set_param(param);
checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); UniformIntRNG rng(0, 9);
checker.set_rng(2, &rng);
checker.set_dtype(2, dtype::Int32());
checker.exec({{10, 128, 108, 3}, {20, 3, 3}, {20}, {20, 56, 128, 3}});
} }
// resize nan case // resize nan case
checker.set_rng(1, &rng_zero); checker.set_rng(1, &rng_zero);
...@@ -109,30 +119,40 @@ TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) { ...@@ -109,30 +119,40 @@ TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) {
param.bmode = BMode::CONSTANT; param.bmode = BMode::CONSTANT;
param.border_val = 1.737; param.border_val = 1.737;
checker.set_param(param); checker.set_param(param);
checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); UniformIntRNG rng(0, 999);
checker.set_rng(2, &rng);
checker.set_dtype(2, dtype::Int32());
checker.exec(
{{1000, 2, 10, 3}, {2000, 3, 3}, {2000}, {2000, 2, 12, 3}});
} }
auto args = warp_perspective::get_cv_args(); auto args = warp_perspective::get_cv_args();
for (auto&& arg : args) { for (auto&& arg : args) {
ConstValue rng(0.f);
checker.set_param(arg.param) checker.set_param(arg.param)
.set_rng(2, &rng)
.set_dtype(0, dtype::Uint8()) .set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::Float32()) .set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Uint8()) .set_dtype(2, dtype::Int32())
.execs({arg.src, arg.trans, arg.dst}); .set_dtype(3, dtype::Uint8())
.execs({arg.src, arg.trans, arg.mat_idx, arg.dst});
} }
for (auto&& arg : args) { for (auto&& arg : args) {
ConstValue rng(0.f);
checker.set_param(arg.param) checker.set_param(arg.param)
.set_rng(2, &rng)
.set_dtype(0, dtype::Float32()) .set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32()) .set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32()) .set_dtype(2, dtype::Int32())
.execs({arg.src, arg.trans, arg.dst}); .set_dtype(3, dtype::Float32())
.execs({arg.src, arg.trans, arg.mat_idx, arg.dst});
} }
} }
TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) { TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) {
//! Just for the format NHWC //! Just for the format NHWC
Checker<WarpPerspective> checker(handle()); Checker<WarpPerspective, WarpPerspectiveMatIdxProxy> checker(handle());
param::WarpPerspective param; param::WarpPerspective param;
class ResizeMatRNG : public RNG { class ResizeMatRNG : public RNG {
void gen(const TensorND& tensor_) override { void gen(const TensorND& tensor_) override {
...@@ -189,7 +209,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) { ...@@ -189,7 +209,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) {
param.bmode = mode; param.bmode = mode;
param.border_val = 1.737; param.border_val = 1.737;
checker.set_param(param); checker.set_param(param);
checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); UniformIntRNG rng(0, 9);
checker.set_rng(2, &rng);
checker.set_dtype(2, dtype::Int32());
checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10}, {10, 56, 128, 3}});
} }
// resize nan case // resize nan case
UniformFloatRNG rng_zero(0, 0); UniformFloatRNG rng_zero(0, 0);
...@@ -198,7 +221,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) { ...@@ -198,7 +221,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) {
param.bmode = BMode::CONSTANT; param.bmode = BMode::CONSTANT;
param.border_val = 1.737; param.border_val = 1.737;
checker.set_param(param); checker.set_param(param);
checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); UniformIntRNG rng(0, 999);
checker.set_rng(2, &rng);
checker.set_dtype(2, dtype::Int32());
checker.exec(
{{1000, 2, 10, 3}, {2000, 3, 3}, {2000}, {2000, 2, 12, 3}});
} }
// add linear test // add linear test
...@@ -208,7 +235,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) { ...@@ -208,7 +235,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) {
param.bmode = mode; param.bmode = mode;
param.border_val = 1.737; param.border_val = 1.737;
checker.set_param(param); checker.set_param(param);
checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); UniformIntRNG rng(0, 9);
checker.set_rng(2, &rng);
checker.set_dtype(2, dtype::Int32());
checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10}, {10, 56, 128, 3}});
} }
// resize nan case // resize nan case
checker.set_rng(1, &rng_zero); checker.set_rng(1, &rng_zero);
...@@ -216,24 +246,34 @@ TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) { ...@@ -216,24 +246,34 @@ TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) {
param.bmode = BMode::CONSTANT; param.bmode = BMode::CONSTANT;
param.border_val = 1.737; param.border_val = 1.737;
checker.set_param(param); checker.set_param(param);
checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); UniformIntRNG rng(0, 999);
checker.set_rng(2, &rng);
checker.set_dtype(2, dtype::Int32());
checker.exec(
{{1000, 2, 10, 3}, {1000, 3, 3}, {1000}, {1000, 2, 12, 3}});
} }
auto args = warp_perspective::get_cv_args(); auto args = warp_perspective::get_cv_args();
for (auto&& arg : args) { for (auto&& arg : args) {
ConstValue rng(0.f);
checker.set_param(arg.param) checker.set_param(arg.param)
.set_rng(2, &rng)
.set_dtype(0, dtype::Uint8()) .set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::Float32()) .set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Uint8()) .set_dtype(2, dtype::Int32())
.execs({arg.src, arg.trans, arg.dst}); .set_dtype(3, dtype::Uint8())
.execs({arg.src, arg.trans, arg.mat_idx, arg.dst});
} }
for (auto&& arg : args) { for (auto&& arg : args) {
ConstValue rng(0.f);
checker.set_param(arg.param) checker.set_param(arg.param)
.set_rng(2, &rng)
.set_dtype(0, dtype::Float32()) .set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32()) .set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32()) .set_dtype(2, dtype::Int32())
.execs({arg.src, arg.trans, arg.dst}); .set_dtype(3, dtype::Float32())
.execs({arg.src, arg.trans, arg.mat_idx, arg.dst});
} }
} }
......
...@@ -56,24 +56,24 @@ std::vector<TestArg> warp_perspective::get_cv_args() { ...@@ -56,24 +56,24 @@ std::vector<TestArg> warp_perspective::get_cv_args() {
cur_param.imode = imode; cur_param.imode = imode;
args.emplace_back(cur_param, TensorShape{1, i, i, ic}, args.emplace_back(cur_param, TensorShape{1, i, i, ic},
TensorShape{1, 3, 3}, TensorShape{1, 3, 3}, TensorShape{1},
TensorShape{1, i, i, ic}); TensorShape{1, i, i, ic});
args.emplace_back(cur_param, TensorShape{1, i, i * 2, ic}, args.emplace_back(cur_param, TensorShape{1, i, i * 2, ic},
TensorShape{1, 3, 3}, TensorShape{1, 3, 3}, TensorShape{1},
TensorShape{1, i, i * 2, ic}); TensorShape{1, i, i * 2, ic});
args.emplace_back(cur_param, TensorShape{1, i * 3, i, ic}, args.emplace_back(cur_param, TensorShape{1, i * 3, i, ic},
TensorShape{1, 3, 3}, TensorShape{1, 3, 3}, TensorShape{1},
TensorShape{1, i * 3, i, ic}); TensorShape{1, i * 3, i, ic});
cur_param.border_val = 0.78f; cur_param.border_val = 0.78f;
args.emplace_back(cur_param, TensorShape{1, i, i, ic}, args.emplace_back(cur_param, TensorShape{1, i, i, ic},
TensorShape{1, 3, 3}, TensorShape{1, 3, 3}, TensorShape{1},
TensorShape{1, 8, 8, ic}); TensorShape{1, 8, 8, ic});
args.emplace_back(cur_param, TensorShape{1, i, i * 2, ic}, args.emplace_back(cur_param, TensorShape{1, i, i * 2, ic},
TensorShape{1, 3, 3}, TensorShape{1, 3, 3}, TensorShape{1},
TensorShape{1, 8, 8, ic}); TensorShape{1, 8, 8, ic});
args.emplace_back(cur_param, TensorShape{1, i * 3, i, ic}, args.emplace_back(cur_param, TensorShape{1, i * 3, i, ic},
TensorShape{1, 3, 3}, TensorShape{1, 3, 3}, TensorShape{1},
TensorShape{1, 8, 8, ic}); TensorShape{1, 8, 8, ic});
} }
} }
...@@ -101,7 +101,10 @@ void warp_perspective::run_mat_idx_test(Handle* handle) { ...@@ -101,7 +101,10 @@ void warp_perspective::run_mat_idx_test(Handle* handle) {
// test NHWC // test NHWC
param.format = WarpPerspective::Param::Format::NHWC; param.format = WarpPerspective::Param::Format::NHWC;
checker.set_param(param); checker.set_param(param)
.set_rng(2, &mat_idx_rng)
.set_epsilon(1e-1)
.set_dtype(2, dtype::Int32());
checker.execs({{N_SRC, 10, 11, 3}, {2, 3, 3}, {2}, {2, 11, 12, 3}}); checker.execs({{N_SRC, 10, 11, 3}, {2, 3, 3}, {2}, {2, 11, 12, 3}});
} }
......
...@@ -57,10 +57,11 @@ struct TestArg { ...@@ -57,10 +57,11 @@ struct TestArg {
param::WarpPerspective param; param::WarpPerspective param;
TensorShape src; TensorShape src;
TensorShape trans; TensorShape trans;
TensorShape mat_idx;
TensorShape dst; TensorShape dst;
TestArg(param::WarpPerspective param_, TensorShape src_, TensorShape trans_, TestArg(param::WarpPerspective param_, TensorShape src_, TensorShape trans_, TensorShape mat_idx_,
TensorShape dst_) TensorShape dst_)
: param(param_), src(src_), trans(trans_), dst(dst_) {} : param(param_), src(src_), trans(trans_), mat_idx(mat_idx_), dst(dst_) {}
}; };
//! Test args for the WarpPerspective with format NHWC //! Test args for the WarpPerspective with format NHWC
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册