提交 a1e67207 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(dnn): enable bool comparison

GitOrigin-RevId: 735693b81e46189db15be0b7f98fa64973c3035e
上级 8aa34e4a
......@@ -30,6 +30,6 @@ MODES = {
'FUSE_ADD_H_SWISH'],
(3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'],
(1, 'BOOL'): ['NOT'],
(2, 'BOOL'): ['AND', 'OR', 'XOR'],
(2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'],
(3, 'BOOL'): []
}
......@@ -45,6 +45,9 @@
MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) \
#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \
......
......@@ -173,6 +173,9 @@ namespace megdnn {
DEF_KERN_ALL(LT, x < y);
DEF_KERN_ALL(LEQ, x <= y);
DEF_KERN_ALL(EQ, x == y);
DEF_KERN(dt_bool, LT, x < y);
DEF_KERN(dt_bool, LEQ, x <= y);
DEF_KERN(dt_bool, EQ, x == y);
DEF_KERN_INT(FLOOR_DIV, x / y);
DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y));
......
/**
* \file dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"
/**
* \file dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"
/**
* \file dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"
/**
* \file dnn/src/naive/elemwise/kimpl/EQ_dt_bool.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"
/**
* \file dnn/src/naive/elemwise/kimpl/LEQ_dt_bool.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"
/**
* \file dnn/src/naive/elemwise/kimpl/LT_dt_bool.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"
......@@ -812,6 +812,9 @@ TEST_OPR_BASIC_ARITH_UNARY_BOOL(NOT, !)
TEST_OPR_BASIC_ARITH_BINARY_BOOL(AND, &&)
TEST_OPR_BASIC_ARITH_BINARY_BOOL(OR, ||)
TEST_OPR_BASIC_ARITH_BINARY_BOOL(XOR, ^)
TEST_OPR_BASIC_ARITH_BINARY_BOOL(LT, <)
TEST_OPR_BASIC_ARITH_BINARY_BOOL(LEQ, <=)
TEST_OPR_BASIC_ARITH_BINARY_BOOL(EQ, ==)
TEST(TestOprBasicArithElemwise, FuseMulAdd3Shapes) {
using Checker = AutoOprChecker<3, 1>;
......
......@@ -27,6 +27,13 @@ DEF_TRAIT(OR, x || y)
DEF_TRAIT(XOR, x ^ y)
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
#define _ALLOW_INT true
#define _ALLOW_FLOAT true
DEF_TRAIT(EQ, x == y)
DEF_TRAIT(LEQ, x <= y)
DEF_TRAIT(LT, x < y)
#undef _ALLOW_BOOL
#define _ALLOW_BOOL false
......@@ -44,10 +51,6 @@ DEF_TRAIT(SUB, x - y)
DEF_TRAIT(SWITCH_GT0, x > 0 ? y : 0)
DEF_TRAIT(TANH_GRAD, (1 - x * x) * y)
DEF_TRAIT(EQ, x == y)
DEF_TRAIT(LEQ, x <= y)
DEF_TRAIT(LT, x < y)
DEF_TRAIT(FUSE_ADD_RELU, std::max<ctype>(x + y, 0))
#undef _ALLOW_INT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册