diff --git a/dnn/test/common/elemwise.cpp b/dnn/test/common/elemwise.cpp index 930ddfb22bbe6f92d5398a139232aa1481182823..fcc06005c3080d99e86f6e762dc1c093990a89f6 100644 --- a/dnn/test/common/elemwise.cpp +++ b/dnn/test/common/elemwise.cpp @@ -924,6 +924,131 @@ DEF_TEST(all_modes) { #undef run } +#define UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int8()}, {}}); \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int16()}, {}}); \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int32()}, {}}); + +#define UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(_optr) \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Float32()}, {}}); + +#define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(RELU); \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(ABS); + +#define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ABS) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(LOG) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(COS) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIN) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FLOOR) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(CEIL) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIGMOID) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(EXP) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(RELU) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ROUND) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(TANH) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FAST_TANH) + +DEF_TEST(unary_negative_stride) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle); + BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT; + + UniformFloatRNG rng(1e-2, 6e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT; +} + +#undef UNARY_NEGATIVE_STRIDE_TEST_CASE_INT +#undef UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT +#undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT +#undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT + +#define BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int8()}, \ + {{1, 4, 1}, dtype::Int8()}, \ + {}}); \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int16()}, \ + {{1, 4, 1}, dtype::Int16()}, \ + {}}); \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int32()}, \ + {{1, 4, 1}, dtype::Int32()}, \ + {}}); + +#define BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(_optr) \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Float32()}, \ + {{1, 4, 1}, dtype::Float32()}, \ + {}}); + +#define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(ADD) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MUL) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MAX) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MIN) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(SUB) + +#define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32 \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(POW) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(TRUE_DIV) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_SIGMOID) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_TANH) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_RELU) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_H_SWISH) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FAST_TANH_GRAD) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(H_SWISH_GRAD) + +DEF_TEST(binary_negative_stride) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle); + BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT; + + UniformFloatRNG rng(1e-2, 2e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32; +} + +#undef BINARY_NEGATIVE_STRIDE_TEST_CASE_INT +#undef BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32 +#undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT +#undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32 + +DEF_TEST(ternary_negative_stride) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle); + checker.set_param(Mode::FUSE_MUL_ADD3); + checker.execl({{{1, 7}, {-7, -1}, dtype::Int8()}, + {{1, 7}, {-3, -1}, dtype::Int8()}, + {{1, 7}, {-7, -1}, dtype::Int8()}, + {}}); + checker.execl({{{1, 7}, {-7, -1}, dtype::Int16()}, + {{1, 7}, {-3, -1}, dtype::Int16()}, + {{1, 7}, {-7, -1}, dtype::Int16()}, + {}}); + checker.execl({{{1, 7}, {-7, -1}, dtype::Int32()}, + {{1, 7}, {-3, -1}, dtype::Int32()}, + {{1, 7}, {-7, -1}, dtype::Int32()}, + {}}); + + UniformFloatRNG rng(1e-2, 2e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + checker.execl({{{1, 7}, {-7, -1}, dtype::Float32()}, + {{1, 7}, {-3, -1}, dtype::Float32()}, + {{1, 7}, {-7, -1}, dtype::Float32()}, + {}}); +} + TEST(TEST_ELEMWISE, MODE_TRAIT) { using M = Elemwise::Mode; using T = Elemwise::ModeTrait; diff --git a/dnn/test/common/elemwise.h b/dnn/test/common/elemwise.h index 6bf66c45595112453f3297c9bcae373f7a1dd63f..e41c854df428ed47687d3885fab1b9af86b37f65 100644 --- a/dnn/test/common/elemwise.h +++ b/dnn/test/common/elemwise.h @@ -40,6 +40,9 @@ namespace elemwise { cb(unary3) \ cb(binary3) \ cb(all_modes) \ + cb(unary_negative_stride) \ + cb(binary_negative_stride) \ + cb(ternary_negative_stride) \ #define FOREACH_ELEMWISE_CASE(cb) \ cb(FIRST_ELEMWISE_CASE) \