提交 bbafe699 编写于 作者: M Megvii Engine Team

feat(dnn): add elemwise COND_LT_MOV

GitOrigin-RevId: 444cd6825a775bed21562ebf5443b153b130745e
上级 ed92b9c1
......@@ -17,7 +17,7 @@ MODES = {
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD',
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD',
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'],
3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'],
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'],
}
QINT4_MODES = {
......@@ -26,7 +26,7 @@ QINT4_MODES = {
2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0',
'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH',
'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'],
3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'],
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'],
}
QINT32_MODES = {
......
......@@ -16,7 +16,7 @@ MODES = {
(2, 'INT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ',
'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH'],
(3, 'INT'): ['COND_LEQ_MOV'],
(3, 'INT'): ['COND_LEQ_MOV', 'COND_LT_MOV'],
(1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS',
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN',
......@@ -28,7 +28,7 @@ MODES = {
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD',
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD',
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'],
(3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'],
(3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'],
(1, 'BOOL'): ['NOT'],
(2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'],
(3, 'BOOL'): []
......
......@@ -420,6 +420,7 @@ pdef('Elemwise').add_enum(
Doc('SILU_GRAD = 57', 'binary: grad(x / (1 + exp(-x))'),
Doc('GELU = 58', 'unary: x Phi(x)'),
Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'),
Doc('COND_LT_MOV = 60', 'ternary: x < y ? z : 0'),
)
pdef('ElemwiseMultiType').add_enum(
......@@ -510,7 +511,8 @@ pdef('ElemwiseMultiType').add_enum(
'and the result is float32.'),
Doc('FUSE_MUL_ADD3_UINT8xF32xF32xF32 = 56',
'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and '
'``c`` float32, and the result is float32.')
'``c`` float32, and the result is float32.'),
Doc('QCOND_LT_MOV = 57', 'quantized cond_lt_mov'),
)
pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
......
......@@ -92,7 +92,9 @@
#define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb)
#define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb)
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
......@@ -265,6 +265,7 @@ DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
// int and float
DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0));
DEF_KERN_ALL(COND_LT_MOV, x < y ? z : ctype(0));
DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z);
#undef KERN_SIG
......
......@@ -219,6 +219,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
CB_MODE(Mode::SILU_GRAD);
CB_MODE(Mode::GELU);
CB_MODE(Mode::GELU_GRAD);
CB_MODE(Mode::COND_LT_MOV);
default:
megdnn_assert(
0,
......
......@@ -239,6 +239,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
SET(init_quantized_ternary_op, QFUSE_MUL_ADD3);
SET(init_quantized_ternary_op, QCOND_LEQ_MOV);
SET(init_quantized_ternary_op, QCOND_LT_MOV);
#undef SET
}
......
......@@ -95,6 +95,7 @@ void ElemwiseMultiTypeImplHelper::exec(
ON_QUANTIZED_MODE(FUSE_MUL_ADD3, 3);
ON_QUANTIZED_MODE(COND_LEQ_MOV, 3);
ON_QUANTIZED_MODE(COND_LT_MOV, 3);
default:
megdnn_throw("invalid mode");
}
......
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
......@@ -25,6 +25,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(
DISPATCH(FUSE_MUL_ADD3);
DISPATCH(COND_LEQ_MOV);
DISPATCH(COND_LT_MOV);
#undef DISPATCH
default:
megdnn_assert_internal(0);
......
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_bfloat16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#if !MEGDNN_DISABLE_FLOAT16
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float16
#include "../kern_impl.inl"
#endif
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_float32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int16
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int32
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_int8
#include "../kern_impl.inl"
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb)
#define KERN_IMPL_ARITY 3
#define KERN_IMPL_CTYPE dt_uint8
#include "../kern_impl.inl"
......@@ -179,6 +179,35 @@ DEF_TEST(ternary_non_contig) {
checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
}
DEF_TEST(ternary_lt) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_param(Mode::COND_LT_MOV);
checker.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
checker.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_dtype(2, dtype::Float16())
.set_dtype(3, dtype::Float16())
.execs({{1, 3, 4}, {2, 1, 4}, {2, 3, 1}, {2, 3, 4}});
checker.execs({{2, 1, 1, 5}, {4, 5}, {3, 1, 1}, {2, 3, 4, 5}});
checker.execs({{3, 1, 1}, {5}, {4, 1}, {3, 4, 5}});
ASSERT_THROW(checker.execs({{2, 3, 4}, {4, 1}, {1}, {2, 3, 4}}), MegDNNError);
ASSERT_THROW(checker.execs({{2, 4, 4}, {4, 1}, {3, 1, 1}, {2, 3, 4}}), MegDNNError);
}
DEF_TEST(ternary_lt_non_contig) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
checker.set_param(Mode::COND_LT_MOV);
TensorLayout ly{{2, 3}, dtype::Float32()};
ly.stride[0] = 4;
checker.execl({ly, ly, ly, {{2, 3}, dtype::Float32()}});
}
DEF_TEST(fuse_mul_add3) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);
......
......@@ -16,6 +16,8 @@ namespace elemwise {
cb(binary_non_contig) \
cb(ternary) \
cb(ternary_non_contig) \
cb(ternary_lt) \
cb(ternary_lt_non_contig) \
cb(fuse_mul_add3) \
cb(fuse_mul_add3_non_contig) \
cb(fuse_mul_add4) \
......
......@@ -207,7 +207,7 @@ TEST_F(CUDA, ELEMWISE_QUANTIZED_MODE_TENARY) {
using Mode = ElemwiseMultiType::Param::Mode;
Checker<ElemwiseMultiType> checker(handle_cuda());
for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV}) {
for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV, Mode::QCOND_LT_MOV}) {
UniformIntRNG rng_int8{-127, 127};
UniformIntRNG rng_uint8{0, 225};
checker.set_param({mode})
......@@ -368,7 +368,7 @@ TEST_F(CUDA, BENCHMARK_ELEMWISE_QUANTIZED_MODE_TENARY) {
CUBenchmarker<ElemwiseMultiType> bencher(handle_cuda());
UniformIntRNG rng{-128, 127};
for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV}) {
for (auto mode : {Mode::QFUSE_MUL_ADD3, Mode::QCOND_LEQ_MOV, Mode::QCOND_LT_MOV}) {
printf("Benchmark mode: %d\n", (int)mode);
bencher.set_param({mode})
.set_rng(0, &rng)
......
......@@ -59,6 +59,7 @@ Elemwise::Mode get_elem_mode(ElemwiseMultiType::Mode mode) {
MODE(FAST_TANH_GRAD);
MODE(ATAN2);
MODE(COND_LEQ_MOV);
MODE(COND_LT_MOV);
MODE(H_SWISH_GRAD);
MODE(FUSE_ADD_H_SWISH);
......@@ -231,7 +232,9 @@ TEST_F(NAIVE, ELEMWISE_QUANTIZED_MODE_TERNARY) {
.set_dtype(1, dtype::QuantizedS8(0.2f))
.set_dtype(2, dtype::QuantizedS8(0.3f));
for (auto mode : {Param::Mode::QFUSE_MUL_ADD3, Param::Mode::QCOND_LEQ_MOV}) {
for (auto mode :
{Param::Mode::QFUSE_MUL_ADD3, Param::Mode::QCOND_LEQ_MOV,
Param::Mode::QCOND_LT_MOV}) {
Param param{mode};
checker.set_param(param);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册