提交 0e8b81c2 编写于 作者: M Megvii Engine Team

fix(dnn/opencl): fix elemwise negative stride support

GitOrigin-RevId: 506d7e61043d84923a317722d295a7e3cf591341
上级 dbb3dd68
......@@ -924,6 +924,131 @@ DEF_TEST(all_modes) {
#undef run
}
#define UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int8()}, {}}); \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int16()}, {}}); \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Int32()}, {}});
#define UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(_optr) \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 1}, {-12, -4, -1}, dtype::Float32()}, {}});
#define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT \
UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(RELU); \
UNARY_NEGATIVE_STRIDE_TEST_CASE_INT(ABS);
#define BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ABS) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(LOG) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(COS) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIN) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FLOOR) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(CEIL) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(SIGMOID) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(EXP) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(RELU) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(ROUND) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(TANH) \
UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT(FAST_TANH)
DEF_TEST(unary_negative_stride) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT;
UniformFloatRNG rng(1e-2, 6e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT;
}
#undef UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
#undef UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT
#undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_INT
#undef BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT
#define BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(_optr) \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int8()}, \
{{1, 4, 1}, dtype::Int8()}, \
{}}); \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int16()}, \
{{1, 4, 1}, dtype::Int16()}, \
{}}); \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Int32()}, \
{{1, 4, 1}, dtype::Int32()}, \
{}});
#define BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(_optr) \
checker.set_param(Mode::_optr) \
.execl({{{3, 4, 7}, {-28, -7, -1}, dtype::Float32()}, \
{{1, 4, 1}, dtype::Float32()}, \
{}});
#define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT \
BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(ADD) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MUL) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MAX) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(MIN) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_INT(SUB)
#define BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32 \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(POW) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(TRUE_DIV) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_SIGMOID) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_TANH) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_RELU) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FUSE_ADD_H_SWISH) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(FAST_TANH_GRAD) \
BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32(H_SWISH_GRAD)
DEF_TEST(binary_negative_stride) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT;
UniformFloatRNG rng(1e-2, 2e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32;
}
#undef BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
#undef BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32
#undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_INT
#undef BUILD_BINARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT32
DEF_TEST(ternary_negative_stride) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_param(Mode::FUSE_MUL_ADD3);
checker.execl({{{1, 7}, {-7, -1}, dtype::Int8()},
{{1, 7}, {-3, -1}, dtype::Int8()},
{{1, 7}, {-7, -1}, dtype::Int8()},
{}});
checker.execl({{{1, 7}, {-7, -1}, dtype::Int16()},
{{1, 7}, {-3, -1}, dtype::Int16()},
{{1, 7}, {-7, -1}, dtype::Int16()},
{}});
checker.execl({{{1, 7}, {-7, -1}, dtype::Int32()},
{{1, 7}, {-3, -1}, dtype::Int32()},
{{1, 7}, {-7, -1}, dtype::Int32()},
{}});
UniformFloatRNG rng(1e-2, 2e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.execl({{{1, 7}, {-7, -1}, dtype::Float32()},
{{1, 7}, {-3, -1}, dtype::Float32()},
{{1, 7}, {-7, -1}, dtype::Float32()},
{}});
}
TEST(TEST_ELEMWISE, MODE_TRAIT) {
using M = Elemwise::Mode;
using T = Elemwise::ModeTrait;
......
......@@ -40,6 +40,9 @@ namespace elemwise {
cb(unary3) \
cb(binary3) \
cb(all_modes) \
cb(unary_negative_stride) \
cb(binary_negative_stride) \
cb(ternary_negative_stride) \
#define FOREACH_ELEMWISE_CASE(cb) \
cb(FIRST_ELEMWISE_CASE) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册