From 0e8b81c20ebba04c8c324891afbbe07734671f41 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 18 Jan 2021 13:55:51 +0800 Subject: [PATCH] fix(dnn/opencl): fix elemwise negative stride support GitOrigin-RevId: 506d7e61043d84923a317722d295a7e3cf591341 --- dnn/test/common/elemwise.cpp | 125 +++++++++++++++++++++++++++++++++++ dnn/test/common/elemwise.h | 3 + 2 files changed, 128 insertions(+) diff --git a/dnn/test/common/elemwise.cpp b/dnn/test/common/elemwise.cpp index 930ddfb22..fcc06005c 100644 --- a/dnn/test/common/elemwise.cpp +++ b/dnn/test/common/elemwise.cpp @@ -924,6 +924,131 @@ DEF_TEST(all_modes) { #undef run } +#define UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int8()}, {}}); \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int16()}, {}}); \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int32()}, {}}); + +#define UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(_optr) \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Float32()}, {}}); + +#define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(RELU); \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(ABS); + +#define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ABS) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(LOG) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(COS) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIN) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FLOOR) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(CEIL) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIGMOID) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(EXP) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(RELU) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ROUND) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(TANH) \ + UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FAST_TANH) + +DEF_TEST(unary_negative_stride) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle); + BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT; + + UniformFloatRNG rng(1e-2, 6e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT; +} + +#undef UNARY_NEGATIVE_STRIDE_TEST_CASE_INT +#undef UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT +#undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT +#undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT + +#define BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int8()}, \ + {{1, 4, 1}, dtype::Int8()}, \ + {}}); \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int16()}, \ + {{1, 4, 1}, dtype::Int16()}, \ + {}}); \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int32()}, \ + {{1, 4, 1}, dtype::Int32()}, \ + {}}); + +#define BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(_optr) \ + checker.set_param(Mode::_optr) \ + .execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Float32()}, \ + {{1, 4, 1}, dtype::Float32()}, \ + {}}); + +#define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(ADD) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MUL) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MAX) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MIN) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(SUB) + +#define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32 \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(POW) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(TRUE_DIV) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_SIGMOID) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_TANH) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_RELU) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_H_SWISH) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FAST_TANH_GRAD) \ + BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(H_SWISH_GRAD) + +DEF_TEST(binary_negative_stride) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle); + BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT; + + UniformFloatRNG rng(1e-2, 2e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32; +} + +#undef BINARY_NEGATIVE_STRIDE_TEST_CASE_INT +#undef BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32 +#undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT +#undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32 + +DEF_TEST(ternary_negative_stride) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle); + checker.set_param(Mode::FUSE_MUL_ADD3); + checker.execl({{{1, 7}, {-7, -1}, dtype::Int8()}, + {{1, 7}, {-3, -1}, dtype::Int8()}, + {{1, 7}, {-7, -1}, dtype::Int8()}, + {}}); + checker.execl({{{1, 7}, {-7, -1}, dtype::Int16()}, + {{1, 7}, {-3, -1}, dtype::Int16()}, + {{1, 7}, {-7, -1}, dtype::Int16()}, + {}}); + checker.execl({{{1, 7}, {-7, -1}, dtype::Int32()}, + {{1, 7}, {-3, -1}, dtype::Int32()}, + {{1, 7}, {-7, -1}, dtype::Int32()}, + {}}); + + UniformFloatRNG rng(1e-2, 2e1); + checker.set_rng(0, &rng); + checker.set_epsilon(1e-5); + checker.execl({{{1, 7}, {-7, -1}, dtype::Float32()}, + {{1, 7}, {-3, -1}, dtype::Float32()}, + {{1, 7}, {-7, -1}, dtype::Float32()}, + {}}); +} + TEST(TEST_ELEMWISE, MODE_TRAIT) { using M = Elemwise::Mode; using T = Elemwise::ModeTrait; diff --git a/dnn/test/common/elemwise.h b/dnn/test/common/elemwise.h index 6bf66c455..e41c854df 100644 --- a/dnn/test/common/elemwise.h +++ b/dnn/test/common/elemwise.h @@ -40,6 +40,9 @@ namespace elemwise { cb(unary3) \ cb(binary3) \ cb(all_modes) \ + cb(unary_negative_stride) \ + cb(binary_negative_stride) \ + cb(ternary_negative_stride) \ #define FOREACH_ELEMWISE_CASE(cb) \ cb(FIRST_ELEMWISE_CASE) \ -- GitLab