diff --git a/dnn/src/aarch64/warp_perspective/opr_impl.cpp b/dnn/src/aarch64/warp_perspective/opr_impl.cpp index 5616462a8a2f59ecc12bf2769e434c33c16d7676..9b9e6a7917a4fc1804835e6600192019fc153b23 100644 --- a/dnn/src/aarch64/warp_perspective/opr_impl.cpp +++ b/dnn/src/aarch64/warp_perspective/opr_impl.cpp @@ -28,8 +28,8 @@ void WarpPerspectiveImpl::exec(_megdnn_tensor_in src, check_exec(src.layout, mat.layout, mat_idx.layout, dst.layout, workspace.size); if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, - param().format) && !mat_idx.layout.ndim) { - warp_perspective_cv_exec(src, mat, dst, param().border_val, + param().format)) { + warp_perspective_cv_exec(src, mat, mat_idx, dst, param().border_val, param().bmode, param().imode, handle()); } else { //! Use arm_common implementation diff --git a/dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp b/dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp index c20f3054876ddb62346311f21a38e9c8003b61da..60c57873a057dfb7675a5d74984c9d5b777391fa 100644 --- a/dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp +++ b/dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp @@ -190,9 +190,9 @@ void warp_perspective_cv(const Mat& src, Mat& dst, const float* trans, } } // anonymous namespace void megdnn::aarch64::warp_perspective_cv_exec( - _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in dst, - float border_value, BorderMode bmode, InterpolationMode imode, - Handle* handle) { + _megdnn_tensor_in src, _megdnn_tensor_in trans, + _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, float border_value, + BorderMode bmode, InterpolationMode imode, Handle* handle) { size_t ch = dst.layout[3]; size_t width = dst.layout[2]; size_t height = dst.layout[1]; @@ -208,13 +208,26 @@ void megdnn::aarch64::warp_perspective_cv_exec( "unsupported src channel: %zu, avaiable channel size: 1/2/3", ch); const float* trans_ptr = trans.ptr(); - if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { + const int* midx_ptr = nullptr; + if (mat_idx.raw_ptr) { + megdnn_assert(mat_idx.layout.ndim == 1); + midx_ptr = mat_idx.ptr(); + } + if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { #define cb(_imode, _bmode, _ch) \ - auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ - size_t index, size_t) { \ + auto task = [src, trans_ptr, midx_ptr, dst, border_value, \ + parallelism_batch](size_t index, size_t) { \ size_t batch_id = index / parallelism_batch; \ size_t task_id = index % parallelism_batch; \ - Mat src_mat = TensorND2Mat(src, batch_id); \ + size_t src_id = batch_id; \ + if (midx_ptr) { \ + src_id = midx_ptr[batch_id]; \ + megdnn_assert( \ + src_id < src.layout.shape[0], \ + "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \ + batch_id, src_id, src.layout.shape[0]); \ + } \ + Mat src_mat = TensorND2Mat(src, src_id); \ Mat dst_mat = TensorND2Mat(dst, batch_id); \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ warp_perspective_cv src_mat = TensorND2Mat(src, batch_id); \ + size_t src_id = batch_id; \ + if (midx_ptr) { \ + src_id = midx_ptr[batch_id]; \ + megdnn_assert( \ + src_id < src.layout.shape[0], \ + "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \ + batch_id, src_id, src.layout.shape[0]); \ + } \ + Mat src_mat = TensorND2Mat(src, src_id); \ Mat dst_mat = TensorND2Mat(dst, batch_id); \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ warp_perspective_cv& src, Mat& dst, const float* trans, } // anonymous namespace void megdnn::arm_common::warp_perspective_cv_exec( - _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in dst, - float border_value, BorderMode bmode, InterpolationMode imode, - Handle* handle) { + _megdnn_tensor_in src, _megdnn_tensor_in trans, + _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, float border_value, + BorderMode bmode, InterpolationMode imode, Handle* handle) { size_t ch = dst.layout[3]; size_t width = dst.layout[2]; size_t height = dst.layout[1]; @@ -167,13 +167,26 @@ void megdnn::arm_common::warp_perspective_cv_exec( "unsupported src channel: %zu, avaiable channel size: 1/2/3", ch); const float* trans_ptr = trans.ptr(); + const int* midx_ptr = nullptr; + if (mat_idx.raw_ptr) { + megdnn_assert(mat_idx.layout.ndim == 1); + midx_ptr = mat_idx.ptr(); + } if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { #define cb(_imode, _bmode, _ch) \ - auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ - size_t index, size_t) { \ + auto task = [src, trans_ptr, midx_ptr, dst, border_value, \ + parallelism_batch](size_t index, size_t) { \ size_t batch_id = index / parallelism_batch; \ size_t task_id = index % parallelism_batch; \ - Mat src_mat = TensorND2Mat(src, batch_id); \ + size_t src_id = batch_id; \ + if (midx_ptr) { \ + src_id = midx_ptr[batch_id]; \ + megdnn_assert( \ + src_id < src.layout.shape[0], \ + "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \ + batch_id, src_id, src.layout.shape[0]); \ + } \ + Mat src_mat = TensorND2Mat(src, src_id); \ Mat dst_mat = TensorND2Mat(dst, batch_id); \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ warp_perspective_cv src_mat = TensorND2Mat(src, batch_id); \ + size_t src_id = batch_id; \ + if (midx_ptr) { \ + src_id = midx_ptr[batch_id]; \ + megdnn_assert( \ + src_id < src.layout.shape[0], \ + "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \ + batch_id, src_id, src.layout.shape[0]); \ + } \ + Mat src_mat = TensorND2Mat(src, src_id); \ Mat dst_mat = TensorND2Mat(dst, batch_id); \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ warp_perspective_cv& src, Mat& dst, const float* trans, } // anonymous namespace void megdnn::naive::warp_perspective_cv_exec( - _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in dst, - float border_value, BorderMode bmode, InterpolationMode imode, - Handle* handle) { + _megdnn_tensor_in src, _megdnn_tensor_in trans, + _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, float border_value, + BorderMode bmode, InterpolationMode imode, Handle* handle) { size_t ch = dst.layout[3]; size_t width = dst.layout[2]; size_t height = dst.layout[1]; @@ -169,13 +169,26 @@ void megdnn::naive::warp_perspective_cv_exec( "unsupported src channel: %zu, avaiable channel size: 1/2/3", ch); const float* trans_ptr = trans.ptr(); + const int* midx_ptr = nullptr; + if (mat_idx.raw_ptr) { + megdnn_assert(mat_idx.layout.ndim == 1); + midx_ptr = mat_idx.ptr(); + } if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { #define cb(_imode, _bmode, _ch) \ - auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ - size_t index, size_t) { \ + auto task = [src, trans_ptr, midx_ptr, dst, border_value, \ + parallelism_batch](size_t index, size_t) { \ size_t batch_id = index / parallelism_batch; \ size_t task_id = index % parallelism_batch; \ - Mat src_mat = TensorND2Mat(src, batch_id); \ + size_t src_id = batch_id; \ + if (midx_ptr) { \ + src_id = midx_ptr[batch_id]; \ + megdnn_assert( \ + src_id < src.layout.shape[0], \ + "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \ + batch_id, src_id, src.layout.shape[0]); \ + } \ + Mat src_mat = TensorND2Mat(src, src_id); \ Mat dst_mat = TensorND2Mat(dst, batch_id); \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ warp_perspective_cv src_mat = TensorND2Mat(src, batch_id); \ + size_t src_id = batch_id; \ + if (midx_ptr) { \ + src_id = midx_ptr[batch_id]; \ + megdnn_assert( \ + src_id < src.layout.shape[0], \ + "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \ + batch_id, src_id, src.layout.shape[0]); \ + } \ + Mat src_mat = TensorND2Mat(src, src_id); \ Mat dst_mat = TensorND2Mat(dst, batch_id); \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ warp_perspective_cv& src, Mat& dst, const float* trans, } } // anonymous namespace -void megdnn::x86::warp_perspective_cv_exec(_megdnn_tensor_in src, - _megdnn_tensor_in trans, - _megdnn_tensor_in dst, - float border_value, BorderMode bmode, - InterpolationMode imode, - Handle* handle) { +void megdnn::x86::warp_perspective_cv_exec( + _megdnn_tensor_in src, _megdnn_tensor_in trans, + _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, float border_value, + BorderMode bmode, InterpolationMode imode, Handle* handle) { size_t ch = dst.layout[3]; size_t width = dst.layout[2]; size_t height = dst.layout[1]; @@ -175,13 +172,26 @@ void megdnn::x86::warp_perspective_cv_exec(_megdnn_tensor_in src, "unsupported src channel: %zu, avaiable channel size: 1/2/3", ch); const float* trans_ptr = trans.ptr(); + const int* midx_ptr = nullptr; + if (mat_idx.raw_ptr) { + megdnn_assert(mat_idx.layout.ndim == 1); + midx_ptr = mat_idx.ptr(); + } if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { #define cb(_imode, _bmode, _ch) \ - auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ - size_t index, size_t) { \ + auto task = [src, trans_ptr, midx_ptr, dst, border_value, \ + parallelism_batch](size_t index, size_t) { \ size_t batch_id = index / parallelism_batch; \ size_t task_id = index % parallelism_batch; \ - Mat src_mat = TensorND2Mat(src, batch_id); \ + size_t src_id = batch_id; \ + if (midx_ptr) { \ + src_id = midx_ptr[batch_id]; \ + megdnn_assert( \ + src_id < src.layout.shape[0], \ + "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \ + batch_id, src_id, src.layout.shape[0]); \ + } \ + Mat src_mat = TensorND2Mat(src, src_id); \ Mat dst_mat = TensorND2Mat(dst, batch_id); \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ warp_perspective_cv src_mat = TensorND2Mat(src, batch_id); \ + size_t src_id = batch_id; \ + if (midx_ptr) { \ + src_id = midx_ptr[batch_id]; \ + megdnn_assert( \ + src_id < src.layout.shape[0], \ + "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \ + batch_id, src_id, src.layout.shape[0]); \ + } \ + Mat src_mat = TensorND2Mat(src, src_id); \ Mat dst_mat = TensorND2Mat(dst, batch_id); \ const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ warp_perspective_cv checker(handle()); + Checker checker(handle()); param::WarpPerspective param; class ResizeMatRNG : public RNG { void gen(const TensorND& tensor_) override { @@ -82,7 +82,10 @@ TEST_F(AARCH64, WARP_PERSPECTIVE_CV) { param.bmode = mode; param.border_val = 1.737; checker.set_param(param); - checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); + UniformIntRNG rng(0, 1); + checker.set_rng(2, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.exec({{2, 5, 5, 1}, {4, 3, 3}, {4}, {4, 5, 5, 1}}); } // resize nan case UniformFloatRNG rng_zero(0, 0); @@ -91,7 +94,11 @@ TEST_F(AARCH64, WARP_PERSPECTIVE_CV) { param.bmode = BMode::CONSTANT; param.border_val = 1.737; checker.set_param(param); - checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); + UniformIntRNG rng(0, 999); + checker.set_rng(2, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.exec( + {{1000, 2, 10, 3}, {2000, 3, 3}, {2000}, {2000, 2, 12, 3}}); } // add linear test @@ -101,7 +108,10 @@ TEST_F(AARCH64, WARP_PERSPECTIVE_CV) { param.bmode = mode; param.border_val = 1.737; checker.set_param(param); - checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); + UniformIntRNG rng(0, 9); + checker.set_rng(2, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.exec({{10, 128, 108, 3}, {20, 3, 3}, {20}, {20, 56, 128, 3}}); } // resize nan case checker.set_rng(1, &rng_zero); @@ -109,24 +119,34 @@ TEST_F(AARCH64, WARP_PERSPECTIVE_CV) { param.bmode = BMode::CONSTANT; param.border_val = 1.737; checker.set_param(param); - checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); + UniformIntRNG rng(0, 999); + checker.set_rng(2, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.exec( + {{1000, 2, 10, 3}, {2000, 3, 3}, {2000}, {2000, 2, 12, 3}}); } auto args = warp_perspective::get_cv_args(); for (auto&& arg : args) { + ConstValue rng(0.f); checker.set_param(arg.param) + .set_rng(2, &rng) .set_dtype(0, dtype::Uint8()) .set_dtype(1, dtype::Float32()) - .set_dtype(2, dtype::Uint8()) - .execs({arg.src, arg.trans, arg.dst}); + .set_dtype(2, dtype::Int32()) + .set_dtype(3, dtype::Uint8()) + .execs({arg.src, arg.trans, arg.mat_idx, arg.dst}); } for (auto&& arg : args) { + ConstValue rng(0.f); checker.set_param(arg.param) + .set_rng(2, &rng) .set_dtype(0, dtype::Float32()) .set_dtype(1, dtype::Float32()) - .set_dtype(2, dtype::Float32()) - .execs({arg.src, arg.trans, arg.dst}); + .set_dtype(2, dtype::Int32()) + .set_dtype(3, dtype::Float32()) + .execs({arg.src, arg.trans, arg.mat_idx, arg.dst}); } } diff --git a/dnn/test/arm_common/warp_perspective.cpp b/dnn/test/arm_common/warp_perspective.cpp index 262ca6b59f9db7cdd674932155bfcef4d221a6d5..83c2c38af06dcca1cb8e377974d8b2a2a045fa7a 100644 --- a/dnn/test/arm_common/warp_perspective.cpp +++ b/dnn/test/arm_common/warp_perspective.cpp @@ -25,7 +25,7 @@ namespace test { TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) { //! Just for the format NHWC - Checker checker(handle()); + Checker checker(handle()); param::WarpPerspective param; class ResizeMatRNG : public RNG { void gen(const TensorND& tensor_) override { @@ -82,7 +82,10 @@ TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) { param.bmode = mode; param.border_val = 1.737; checker.set_param(param); - checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); + UniformIntRNG rng(0, 9); + checker.set_rng(2, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.exec({{10, 128, 108, 3}, {20, 3, 3}, {20}, {20, 56, 128, 3}}); } // resize nan case UniformFloatRNG rng_zero(0, 0); @@ -91,7 +94,11 @@ TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) { param.bmode = BMode::CONSTANT; param.border_val = 1.737; checker.set_param(param); - checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); + UniformIntRNG rng(0, 999); + checker.set_rng(2, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.exec( + {{1000, 2, 10, 3}, {1000, 3, 3}, {1000}, {1000, 2, 12, 3}}); } // add linear test @@ -101,7 +108,10 @@ TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) { param.bmode = mode; param.border_val = 1.737; checker.set_param(param); - checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); + UniformIntRNG rng(0, 9); + checker.set_rng(2, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.exec({{10, 128, 108, 3}, {20, 3, 3}, {20}, {20, 56, 128, 3}}); } // resize nan case checker.set_rng(1, &rng_zero); @@ -109,30 +119,40 @@ TEST_F(ARM_COMMON, WARP_PERSPECTIVE_CV) { param.bmode = BMode::CONSTANT; param.border_val = 1.737; checker.set_param(param); - checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); + UniformIntRNG rng(0, 999); + checker.set_rng(2, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.exec( + {{1000, 2, 10, 3}, {2000, 3, 3}, {2000}, {2000, 2, 12, 3}}); } auto args = warp_perspective::get_cv_args(); for (auto&& arg : args) { + ConstValue rng(0.f); checker.set_param(arg.param) + .set_rng(2, &rng) .set_dtype(0, dtype::Uint8()) .set_dtype(1, dtype::Float32()) - .set_dtype(2, dtype::Uint8()) - .execs({arg.src, arg.trans, arg.dst}); + .set_dtype(2, dtype::Int32()) + .set_dtype(3, dtype::Uint8()) + .execs({arg.src, arg.trans, arg.mat_idx, arg.dst}); } for (auto&& arg : args) { + ConstValue rng(0.f); checker.set_param(arg.param) + .set_rng(2, &rng) .set_dtype(0, dtype::Float32()) .set_dtype(1, dtype::Float32()) - .set_dtype(2, dtype::Float32()) - .execs({arg.src, arg.trans, arg.dst}); + .set_dtype(2, dtype::Int32()) + .set_dtype(3, dtype::Float32()) + .execs({arg.src, arg.trans, arg.mat_idx, arg.dst}); } } TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) { //! Just for the format NHWC - Checker checker(handle()); + Checker checker(handle()); param::WarpPerspective param; class ResizeMatRNG : public RNG { void gen(const TensorND& tensor_) override { @@ -189,7 +209,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) { param.bmode = mode; param.border_val = 1.737; checker.set_param(param); - checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); + UniformIntRNG rng(0, 9); + checker.set_rng(2, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10}, {10, 56, 128, 3}}); } // resize nan case UniformFloatRNG rng_zero(0, 0); @@ -198,7 +221,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) { param.bmode = BMode::CONSTANT; param.border_val = 1.737; checker.set_param(param); - checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); + UniformIntRNG rng(0, 999); + checker.set_rng(2, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.exec( + {{1000, 2, 10, 3}, {2000, 3, 3}, {2000}, {2000, 2, 12, 3}}); } // add linear test @@ -208,7 +235,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) { param.bmode = mode; param.border_val = 1.737; checker.set_param(param); - checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10, 56, 128, 3}}); + UniformIntRNG rng(0, 9); + checker.set_rng(2, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.exec({{10, 128, 108, 3}, {10, 3, 3}, {10}, {10, 56, 128, 3}}); } // resize nan case checker.set_rng(1, &rng_zero); @@ -216,24 +246,34 @@ TEST_F(ARM_COMMON_MULTI_THREADS, WARP_PERSPECTIVE_CV) { param.bmode = BMode::CONSTANT; param.border_val = 1.737; checker.set_param(param); - checker.exec({{1000, 2, 10, 3}, {1000, 3, 3}, {1000, 2, 12, 3}}); + UniformIntRNG rng(0, 999); + checker.set_rng(2, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.exec( + {{1000, 2, 10, 3}, {1000, 3, 3}, {1000}, {1000, 2, 12, 3}}); } auto args = warp_perspective::get_cv_args(); for (auto&& arg : args) { + ConstValue rng(0.f); checker.set_param(arg.param) + .set_rng(2, &rng) .set_dtype(0, dtype::Uint8()) .set_dtype(1, dtype::Float32()) - .set_dtype(2, dtype::Uint8()) - .execs({arg.src, arg.trans, arg.dst}); + .set_dtype(2, dtype::Int32()) + .set_dtype(3, dtype::Uint8()) + .execs({arg.src, arg.trans, arg.mat_idx, arg.dst}); } for (auto&& arg : args) { + ConstValue rng(0.f); checker.set_param(arg.param) + .set_rng(2, &rng) .set_dtype(0, dtype::Float32()) .set_dtype(1, dtype::Float32()) - .set_dtype(2, dtype::Float32()) - .execs({arg.src, arg.trans, arg.dst}); + .set_dtype(2, dtype::Int32()) + .set_dtype(3, dtype::Float32()) + .execs({arg.src, arg.trans, arg.mat_idx, arg.dst}); } } diff --git a/dnn/test/common/warp_perspective.cpp b/dnn/test/common/warp_perspective.cpp index b9f42a044c287518e4d6b0ec40301efea034d6ef..23e45b80173f5bdf18af8465bb847a72cb4f4511 100644 --- a/dnn/test/common/warp_perspective.cpp +++ b/dnn/test/common/warp_perspective.cpp @@ -56,24 +56,24 @@ std::vector warp_perspective::get_cv_args() { cur_param.imode = imode; args.emplace_back(cur_param, TensorShape{1, i, i, ic}, - TensorShape{1, 3, 3}, + TensorShape{1, 3, 3}, TensorShape{1}, TensorShape{1, i, i, ic}); args.emplace_back(cur_param, TensorShape{1, i, i * 2, ic}, - TensorShape{1, 3, 3}, + TensorShape{1, 3, 3}, TensorShape{1}, TensorShape{1, i, i * 2, ic}); args.emplace_back(cur_param, TensorShape{1, i * 3, i, ic}, - TensorShape{1, 3, 3}, + TensorShape{1, 3, 3}, TensorShape{1}, TensorShape{1, i * 3, i, ic}); cur_param.border_val = 0.78f; args.emplace_back(cur_param, TensorShape{1, i, i, ic}, - TensorShape{1, 3, 3}, + TensorShape{1, 3, 3}, TensorShape{1}, TensorShape{1, 8, 8, ic}); args.emplace_back(cur_param, TensorShape{1, i, i * 2, ic}, - TensorShape{1, 3, 3}, + TensorShape{1, 3, 3}, TensorShape{1}, TensorShape{1, 8, 8, ic}); args.emplace_back(cur_param, TensorShape{1, i * 3, i, ic}, - TensorShape{1, 3, 3}, + TensorShape{1, 3, 3}, TensorShape{1}, TensorShape{1, 8, 8, ic}); } } @@ -101,7 +101,10 @@ void warp_perspective::run_mat_idx_test(Handle* handle) { // test NHWC param.format = WarpPerspective::Param::Format::NHWC; - checker.set_param(param); + checker.set_param(param) + .set_rng(2, &mat_idx_rng) + .set_epsilon(1e-1) + .set_dtype(2, dtype::Int32()); checker.execs({{N_SRC, 10, 11, 3}, {2, 3, 3}, {2}, {2, 11, 12, 3}}); } diff --git a/dnn/test/common/warp_perspective.h b/dnn/test/common/warp_perspective.h index 31bfbaff87ea42ba53e518ac1e2a84444a1f6f73..68931763386e63b5184a40f8439b172f248f87be 100644 --- a/dnn/test/common/warp_perspective.h +++ b/dnn/test/common/warp_perspective.h @@ -57,10 +57,11 @@ struct TestArg { param::WarpPerspective param; TensorShape src; TensorShape trans; + TensorShape mat_idx; TensorShape dst; - TestArg(param::WarpPerspective param_, TensorShape src_, TensorShape trans_, + TestArg(param::WarpPerspective param_, TensorShape src_, TensorShape trans_, TensorShape mat_idx_, TensorShape dst_) - : param(param_), src(src_), trans(trans_), dst(dst_) {} + : param(param_), src(src_), trans(trans_), mat_idx(mat_idx_), dst(dst_) {} }; //! Test args for the WarpPerspective with format NHWC