Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e258812f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e258812f
编写于
7月 22, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): add bool dtype
GitOrigin-RevId: 98c8a092b4872f4f0ef6068c58e47871d9f042be
上级
734c498d
变更
42
隐藏空白更改
内联
并排
Showing
42 changed file
with
424 addition
and
32 deletion
+424
-32
dnn/include/megdnn/dtype.h
dnn/include/megdnn/dtype.h
+8
-4
dnn/include/megdnn/oprs/general.h
dnn/include/megdnn/oprs/general.h
+2
-1
dnn/scripts/gen_elemwise_utils.py
dnn/scripts/gen_elemwise_utils.py
+4
-0
dnn/scripts/opr_param_defs.py
dnn/scripts/opr_param_defs.py
+6
-1
dnn/src/common/cond_take/predicate.cuh
dnn/src/common/cond_take/predicate.cuh
+1
-0
dnn/src/common/elemwise/each_mode.inl
dnn/src/common/elemwise/each_mode.inl
+8
-0
dnn/src/common/elemwise/kern_defs.cuh
dnn/src/common/elemwise/kern_defs.cuh
+4
-0
dnn/src/common/elemwise/opr_impl.cpp
dnn/src/common/elemwise/opr_impl.cpp
+17
-1
dnn/src/common/elemwise/opr_impl_body.inl
dnn/src/common/elemwise/opr_impl_body.inl
+18
-1
dnn/src/common/elemwise/opr_impl_class_def.inl
dnn/src/common/elemwise/opr_impl_class_def.inl
+3
-0
dnn/src/common/type_cvt.cpp
dnn/src/common/type_cvt.cpp
+4
-2
dnn/src/cuda/cond_take/kimpl/dt_bool.cu
dnn/src/cuda/cond_take/kimpl/dt_bool.cu
+27
-0
dnn/src/cuda/elemwise/kern_wrapper.cuh
dnn/src/cuda/elemwise/kern_wrapper.cuh
+18
-12
dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu
dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu
+15
-0
dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu
dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu
+15
-0
dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu
dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu
+15
-0
dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu
dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu
+15
-0
dnn/src/cuda/elemwise_helper.cpp
dnn/src/cuda/elemwise_helper.cpp
+7
-0
dnn/src/cuda/elemwise_helper.cuh
dnn/src/cuda/elemwise_helper.cuh
+4
-0
dnn/src/cuda/type_cvt/kern.cu
dnn/src/cuda/type_cvt/kern.cu
+10
-4
dnn/src/fallback/type_cvt/opr_impl.cpp
dnn/src/fallback/type_cvt/opr_impl.cpp
+8
-3
dnn/src/naive/elemwise/kimpl/AND_dt_bool.cpp
dnn/src/naive/elemwise/kimpl/AND_dt_bool.cpp
+15
-0
dnn/src/naive/elemwise/kimpl/NOT_dt_bool.cpp
dnn/src/naive/elemwise/kimpl/NOT_dt_bool.cpp
+15
-0
dnn/src/naive/elemwise/kimpl/OR_dt_bool.cpp
dnn/src/naive/elemwise/kimpl/OR_dt_bool.cpp
+15
-0
dnn/src/naive/elemwise/kimpl/XOR_dt_bool.cpp
dnn/src/naive/elemwise/kimpl/XOR_dt_bool.cpp
+15
-0
dnn/src/naive/type_cvt/opr_impl.cpp
dnn/src/naive/type_cvt/opr_impl.cpp
+2
-0
dnn/test/common/elemwise.cpp
dnn/test/common/elemwise.cpp
+2
-0
python_module/src/cpp/megbrain_wrap.cpp
python_module/src/cpp/megbrain_wrap.cpp
+1
-0
src/core/include/megbrain/dtype.h
src/core/include/megbrain/dtype.h
+1
-0
src/jit/impl/ast_c.cpp
src/jit/impl/ast_c.cpp
+2
-2
src/jit/impl/halide/ast_hl.cpp
src/jit/impl/halide/ast_hl.cpp
+8
-0
src/opr/impl/basic_arith.cpp
src/opr/impl/basic_arith.cpp
+6
-0
src/opr/impl/loop/forward.cpp
src/opr/impl/loop/forward.cpp
+2
-0
src/opr/impl/loop/impl.cpp
src/opr/impl/loop/impl.cpp
+2
-0
src/opr/include/megbrain/opr/basic_arith_wrapper.h
src/opr/include/megbrain/opr/basic_arith_wrapper.h
+4
-0
src/opr/test/basic_arith/elemwise.cpp
src/opr/test/basic_arith/elemwise.cpp
+62
-1
src/opr/test/basic_arith/elemwise_binary_trait_def.inl
src/opr/test/basic_arith/elemwise_binary_trait_def.inl
+12
-0
src/opr/test/basic_arith/elemwise_ternary_trait_def.inl
src/opr/test/basic_arith/elemwise_ternary_trait_def.inl
+2
-0
src/opr/test/basic_arith/elemwise_unary_trait_def.inl
src/opr/test/basic_arith/elemwise_unary_trait_def.inl
+11
-0
src/serialization/impl/dtype.fbs
src/serialization/impl/dtype.fbs
+1
-0
test/src/helper.cpp
test/src/helper.cpp
+15
-0
test/src/include/megbrain/test/helper.h
test/src/include/megbrain/test/helper.h
+22
-0
未找到文件。
dnn/include/megdnn/dtype.h
浏览文件 @
e258812f
...
...
@@ -52,6 +52,7 @@ namespace megdnn {
MEGDNN_INC_FLOAT16(cb(Float16)) \
MEGDNN_INC_FLOAT16(cb(BFloat16)) \
cb(UintB4) \
cb(Bool) \
/*!
* \brief iterate through each full byte dtype
...
...
@@ -65,6 +66,7 @@ namespace megdnn {
cb(Byte) \
MEGDNN_INC_FLOAT16(cb(Float16)) \
MEGDNN_INC_FLOAT16(cb(BFloat16)) \
cb(Bool) \
/*!
* \brief iterate through each fractional byte dtype
...
...
@@ -122,7 +124,7 @@ namespace megdnn {
*/
#define MEGDNN_FOREACH_COMPUTING_DTYPE(cb) \
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \
MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb)
MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb)
\
//! In order to avoid an unnecessary increase in binary size, we just
//! use QuantizedS16 dtype in winograd_filter_preprocess now. So I didn't add
...
...
@@ -348,6 +350,7 @@ typedef int32_t dt_int32;
typedef
int16_t
dt_int16
;
typedef
int8_t
dt_int8
;
typedef
uint8_t
dt_uint8
;
typedef
bool
dt_bool
;
MEGDNN_INC_FLOAT16
(
typedef
half_float
::
half
dt_float16
;)
MEGDNN_INC_FLOAT16
(
typedef
half_bfloat16
::
bfloat16
dt_bfloat16
;)
...
...
@@ -375,7 +378,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)
#if !MEGDNN_DISABLE_FLOAT16
BFloat16
=
11
,
#endif
Bool
=
12
,
#define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE,
#define D(_name) _name,
MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2
(
FST
,
D
)
...
...
@@ -392,7 +395,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)
#if MEGDNN_CC_HOST
//! dtype numeric category fo
enum
class
DTypeCategory
:
int
{
OTHER
,
FLOAT
,
INT
,
LOWBIT
,
QUANTIZED
OTHER
,
FLOAT
,
INT
,
LOWBIT
,
QUANTIZED
,
BOOL
};
//! dtype signedness
enum
class
DTypeSignedness
:
int
{
...
...
@@ -401,7 +404,7 @@ MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;)
#else
struct
DTypeCategory
{
enum
Ev
{
OTHER
,
FLOAT
,
INT
,
LOWBIT
,
QUANTIZED
OTHER
,
FLOAT
,
INT
,
LOWBIT
,
QUANTIZED
,
BOOL
};
int
ev
;
};
...
...
@@ -707,6 +710,7 @@ MEGDNN_DEF_DT(Int32, dt_int32, INT, SIGNED, INT32_MIN, INT32_MAX);
MEGDNN_DEF_DT
(
Int16
,
dt_int16
,
INT
,
SIGNED
,
INT16_MIN
,
INT16_MAX
);
MEGDNN_DEF_DT
(
Int8
,
dt_int8
,
INT
,
SIGNED
,
INT8_MIN
,
INT8_MAX
);
MEGDNN_DEF_DT
(
Uint8
,
dt_uint8
,
INT
,
UNSIGNED
,
0
,
UINT8_MAX
);
MEGDNN_DEF_DT
(
Bool
,
dt_bool
,
BOOL
,
UNSIGNED
,
false
,
true
);
MEGDNN_INC_FLOAT16
(
MEGDNN_DEF_DT
(
Float16
,
dt_float16
,
FLOAT
,
SIGNED
,
std
::
numeric_limits
<
dt_float16
>::
lowest
(),
std
::
numeric_limits
<
dt_float16
>::
max
()));
...
...
dnn/include/megdnn/oprs/general.h
浏览文件 @
e258812f
...
...
@@ -39,11 +39,12 @@ class ElemwiseForward: public OperatorBase {
bool
commutable
;
//!< whether arity == 2 and inputs commutable
bool
allow_int
;
//!< whether int inputs allowed
bool
allow_float
;
//!< whether float inputs allowed
bool
allow_bool
;
//!< whether bool inputs allowed
const
char
*
name
;
//!< name of the mode
ModeTrait
()
:
arity
(
0
),
commutable
(
0
),
allow_int
(
0
),
allow_float
(
0
),
arity
(
0
),
commutable
(
0
),
allow_int
(
0
),
allow_float
(
0
),
allow_bool
(
0
),
name
(
NULL
)
{}
...
...
dnn/scripts/gen_elemwise_utils.py
浏览文件 @
e258812f
...
...
@@ -5,6 +5,7 @@ DTYPES = {'dt_int32': ('Int32', 'INT'),
'dt_uint8'
:
(
'Uint8'
,
'INT'
),
'dt_int8'
:
(
'Int8'
,
'INT'
),
'dt_int16'
:
(
'Int16'
,
'INT'
),
'dt_bool'
:
(
'Bool'
,
'BOOL'
),
'dt_float32'
:
(
'Float32'
,
'FLOAT'
),
'dt_float16'
:
(
'Float16'
,
'FLOAT'
),
'dt_bfloat16'
:
(
'BFloat16'
,
'FLOAT'
)
...
...
@@ -28,4 +29,7 @@ MODES = {
'FUSE_ADD_SIGMOID'
,
'ATAN2'
,
'H_SWISH_GRAD'
,
'FUSE_ADD_H_SWISH'
],
(
3
,
'FLOAT'
):
[
'COND_LEQ_MOV'
,
'FUSE_MUL_ADD3'
],
(
1
,
'BOOL'
):
[
'NOT'
],
(
2
,
'BOOL'
):
[
'AND'
,
'OR'
,
'XOR'
],
(
3
,
'BOOL'
):
[]
}
dnn/scripts/opr_param_defs.py
浏览文件 @
e258812f
...
...
@@ -314,7 +314,12 @@ pdef('Elemwise').add_enum(
Doc
(
'ERFCINV'
,
'unary: inverse function of erfc(x)'
),
Doc
(
'H_SWISH'
,
'unary: x * clip(x + 3, 0, 6) / 6'
),
Doc
(
'H_SWISH_GRAD'
,
'binary: x < -3 ? 0 : (x > 3 ? y : (2 * x + 3) / 6 * y)'
),
Doc
(
'FUSE_ADD_H_SWISH'
,
'binary: hswish(x+y)'
)
Doc
(
'FUSE_ADD_H_SWISH'
,
'binary: hswish(x+y)'
),
Doc
(
'NOT'
,
'unary: !x'
),
Doc
(
'AND'
,
'binary: x && y'
),
Doc
(
'OR'
,
'binary: x || y'
),
Doc
(
'XOR'
,
'binary: x ^ y'
)
)
pdef
(
'ElemwiseMultiType'
).
add_enum
(
...
...
dnn/src/common/cond_take/predicate.cuh
浏览文件 @
e258812f
...
...
@@ -68,6 +68,7 @@ namespace cond_take {
#define inst_eq_i(_dt) do_inst_eq_i(DTypeTrait<_dt>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT
(
inst_eq_f
)
MEGDNN_FOREACH_COMPUTING_DTYPE_INT
(
inst_eq_i
)
inst_eq_i
(
::
megdnn
::
dtype
::
Bool
)
#undef inst_eq_f
#undef inst_eq_i
...
...
dnn/src/common/elemwise/each_mode.inl
浏览文件 @
e258812f
...
...
@@ -9,6 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_each_mode.py
#define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(NOT, cb) \
#define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \
...
...
@@ -38,6 +41,11 @@
MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) \
#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) \
#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \
MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) \
...
...
dnn/src/common/elemwise/kern_defs.cuh
浏览文件 @
e258812f
...
...
@@ -139,6 +139,7 @@ namespace megdnn {
DEF_KERN_FLOAT
(
H_SWISH
,
x
*
min
(
max
(
x
+
3
,
0.
f
),
6.
f
)
*
(
1.
f
/
6.
f
));
// int only
DEF_KERN
(
dt_bool
,
NOT
,
x
^
1
);
#undef KERN_SIG
...
...
@@ -156,6 +157,9 @@ namespace megdnn {
DEF_KERN_ALL
(
MAX
,
x
>
y
?
x
:
y
);
DEF_KERN_ALL
(
MIN
,
x
<
y
?
x
:
y
);
DEF_KERN_ALL
(
MUL
,
x
*
y
);
DEF_KERN
(
dt_bool
,
AND
,
x
&&
y
);
DEF_KERN
(
dt_bool
,
OR
,
x
||
y
);
DEF_KERN
(
dt_bool
,
XOR
,
x
^
y
);
DEF_KERN_INT
(
RMULH
,
round_mulh_saturate
(
x
,
y
));
DEF_KERN_ALL
(
SIGMOID_GRAD
,
x
*
(
ctype
(
1
)
-
x
)
*
y
);
DEF_KERN_ALL
(
SUB
,
x
-
y
);
...
...
dnn/src/common/elemwise/opr_impl.cpp
浏览文件 @
e258812f
...
...
@@ -72,6 +72,15 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT
(
cb
);
#undef cb
#define cb(_m) \
MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) { \
get(Mode::_m).allow_bool = true; \
} \
MIDOUT_END();
MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL
(
cb
);
MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL
(
cb
);
#undef cb
#define cb(_m) \
MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) { \
auto&& t = get(Mode::_m); \
...
...
@@ -82,10 +91,12 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
#define _a 1
MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT
(
cb
);
MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT
(
cb
);
MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL
(
cb
);
#undef _a
#define _a 2
MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT
(
cb
);
MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT
(
cb
);
MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL
(
cb
);
#undef _a
#define _a 3
MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT
(
cb
);
...
...
@@ -98,6 +109,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
auto&& t = get(Mode::_m); \
t.allow_int = true; \
t.allow_float = true; \
t.allow_bool = true; \
t.arity = _arity; \
t.name = megdnn_mangle(#_m); \
} \
...
...
@@ -129,7 +141,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) {
#if MEGDNN_ELEMWISE_MODE_ENABLE_ALL
for
(
auto
&&
i
:
traits
)
{
megdnn_assert
(
i
.
arity
&&
(
i
.
allow_int
||
i
.
allow_float
)
&&
megdnn_assert
(
i
.
arity
&&
(
i
.
allow_int
||
i
.
allow_float
||
i
.
allow_bool
)
&&
(
!
i
.
commutable
||
i
.
arity
==
2
));
}
#else
...
...
@@ -282,6 +294,10 @@ void ElemwiseForward::check_dtype(DType dtype) {
megdnn_assert
(
trait
.
allow_int
,
"unsupport mode %s for int
\n
"
,
trait
.
name
);
break
;
case
DTypeCategory
::
BOOL
:
megdnn_assert
(
trait
.
allow_bool
,
"unsupport mode %s for bool
\n
"
,
trait
.
name
);
break
;
default:
megdnn_throw
(
"bad dtype"
);
}
...
...
dnn/src/common/elemwise/opr_impl_body.inl
浏览文件 @
e258812f
...
...
@@ -15,6 +15,15 @@
template
<
int
arity
>
void
ElemwiseForwardImpl
::
on_arity_dispatched
()
{
auto
src
=
make_elemwise_op_param
<
arity
>
();
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT
(
on_arity_dispatched_cb_dtype
)
MEGDNN_FOREACH_COMPUTING_DTYPE_INT
(
on_arity_dispatched_cb_dtype
)
on_arity_dispatched_cb_dtype
(
::
megdnn
::
dtype
::
Bool
)
megdnn_throw
(
"bad dtype"
);
}
template
<
int
arity
>
void
ElemwiseForwardImpl
::
on_arity_dispatched_no_bool
()
{
auto
src
=
make_elemwise_op_param
<
arity
>
();
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT
(
on_arity_dispatched_cb_dtype
)
MEGDNN_FOREACH_COMPUTING_DTYPE_INT
(
on_arity_dispatched_cb_dtype
)
...
...
@@ -45,6 +54,14 @@ IMPL_MODE_DISPATCHER(2, DTypeCategory::FLOAT);
IMPL_MODE_DISPATCHER
(
3
,
DTypeCategory
::
FLOAT
);
#undef FOREACH
#define FOREACH MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL
IMPL_MODE_DISPATCHER
(
1
,
DTypeCategory
::
BOOL
);
#undef FOREACH
#define FOREACH MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL
IMPL_MODE_DISPATCHER
(
2
,
DTypeCategory
::
BOOL
);
#undef FOREACH
void
ElemwiseForwardImpl
::
exec
(
const
TensorNDArray
&
src
,
_megdnn_tensor_out
dst
)
{
...
...
@@ -97,8 +114,8 @@ void ElemwiseForwardImpl::exec(
#define D(_n) case _n: return on_arity_dispatched<_n>()
D
(
1
);
D
(
2
);
D
(
3
);
#undef D
case
3
:
return
on_arity_dispatched_no_bool
<
3
>
();
default:
megdnn_throw
(
"bad size of input tensors"
);
}
...
...
dnn/src/common/elemwise/opr_impl_class_def.inl
浏览文件 @
e258812f
...
...
@@ -13,6 +13,9 @@
template
<
int
arity
>
void
on_arity_dispatched
();
template
<
int
arity
>
void
on_arity_dispatched_no_bool
();
template
<
int
arity
,
DTypeCategory
dtype_cat
,
typename
ctype
>
struct
ModeDispatcher
;
...
...
dnn/src/common/type_cvt.cpp
浏览文件 @
e258812f
...
...
@@ -19,10 +19,12 @@ void TypeCvt::check_exec(const TensorLayout &src, const TensorLayout &dst) {
megdnn_assert_eq_shape
(
src
,
dst
);
auto
cat
=
src
.
dtype
.
category
();
megdnn_assert
(
cat
==
DTypeCategory
::
FLOAT
||
cat
==
DTypeCategory
::
INT
||
cat
==
DTypeCategory
::
QUANTIZED
);
cat
==
DTypeCategory
::
QUANTIZED
||
cat
==
DTypeCategory
::
BOOL
);
cat
=
dst
.
dtype
.
category
();
megdnn_assert
(
cat
==
DTypeCategory
::
FLOAT
||
cat
==
DTypeCategory
::
INT
||
cat
==
DTypeCategory
::
QUANTIZED
);
cat
==
DTypeCategory
::
QUANTIZED
||
cat
==
DTypeCategory
::
BOOL
);
}
}
// namespace megdnn
...
...
dnn/src/cuda/cond_take/kimpl/dt_bool.cu
0 → 100644
浏览文件 @
e258812f
/**
* \file dnn/src/cuda/cond_take/kimpl/dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_cond_take_kern_impls.py
#include "../kern.inl"
namespace
megdnn
{
namespace
cuda
{
namespace
cond_take
{
inst_genidx
(
::
megdnn
::
dtype
::
Bool
)
#undef inst_genidx
inst_copy
(
::
megdnn
::
dtype
::
Bool
)
#undef inst_copy
#undef inst_copy_
}
// cond_take
}
// cuda
}
// megdnn
dnn/src/cuda/elemwise/kern_wrapper.cuh
浏览文件 @
e258812f
...
...
@@ -25,8 +25,9 @@ namespace cuda {
1
,
KernImpl
,
typename
std
::
enable_if
<
!
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_int8
>::
value
&&
!
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_uint8
>::
value
>::
type
>
{
!
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_uint8
>::
value
&&
!
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_bool
>::
value
>::
type
>
{
typedef
typename
KernImpl
::
ctype
ctype
;
ctype
*
dst
;
...
...
@@ -41,8 +42,9 @@ namespace cuda {
2
,
KernImpl
,
typename
std
::
enable_if
<
!
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_int8
>::
value
&&
!
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_uint8
>::
value
>::
type
>
{
!
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_uint8
>::
value
&&
!
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_bool
>::
value
>::
type
>
{
typedef
typename
KernImpl
::
ctype
ctype
;
ctype
*
dst
;
...
...
@@ -57,8 +59,9 @@ namespace cuda {
3
,
KernImpl
,
typename
std
::
enable_if
<
!
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_int8
>::
value
&&
!
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_uint8
>::
value
>::
type
>
{
!
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_uint8
>::
value
&&
!
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_bool
>::
value
>::
type
>
{
typedef
typename
KernImpl
::
ctype
ctype
;
ctype
*
dst
;
...
...
@@ -74,8 +77,9 @@ namespace cuda {
1
,
KernImpl
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_int8
>::
value
||
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_uint8
>::
value
>::
type
>
{
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_uint8
>::
value
||
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_bool
>::
value
>::
type
>
{
typedef
typename
KernImpl
::
ctype
ctype
;
using
VectTypeTrait
=
elemwise_intl
::
VectTypeTrait
<
ctype
>
;
typedef
typename
VectTypeTrait
::
vect_type
vect_type
;
...
...
@@ -99,8 +103,9 @@ namespace cuda {
2
,
KernImpl
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_int8
>::
value
||
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_uint8
>::
value
>::
type
>
{
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_uint8
>::
value
||
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_bool
>::
value
>::
type
>
{
typedef
typename
KernImpl
::
ctype
ctype
;
using
VectTypeTrait
=
elemwise_intl
::
VectTypeTrait
<
ctype
>
;
typedef
typename
VectTypeTrait
::
vect_type
vect_type
;
...
...
@@ -126,8 +131,9 @@ namespace cuda {
3
,
KernImpl
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_int8
>::
value
||
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_uint8
>::
value
>::
type
>
{
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_uint8
>::
value
||
std
::
is_same
<
typename
KernImpl
::
ctype
,
dt_bool
>::
value
>::
type
>
{
typedef
typename
KernImpl
::
ctype
ctype
;
using
VectTypeTrait
=
elemwise_intl
::
VectTypeTrait
<
ctype
>
;
typedef
typename
VectTypeTrait
::
vect_type
vect_type
;
...
...
dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu
0 → 100644
浏览文件 @
e258812f
/**
* \file dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"
dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu
0 → 100644
浏览文件 @
e258812f
/**
* \file dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NOT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"
dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu
0 → 100644
浏览文件 @
e258812f
/**
* \file dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"
dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu
0 → 100644
浏览文件 @
e258812f
/**
* \file dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"
dnn/src/cuda/elemwise_helper.cpp
浏览文件 @
e258812f
...
...
@@ -169,6 +169,9 @@ INST_FOR_CTYPE
#define ct dt_qint32
INST_FOR_CTYPE
#undef ct
#define ct dt_bool
INST_FOR_CTYPE
#undef ct
#undef INST_FOR_CTYPE
#undef INST
...
...
@@ -216,6 +219,9 @@ INST_FOR_CTYPE
#define ct dt_qint32
INST_FOR_CTYPE
#undef ct
#define ct dt_bool
INST_FOR_CTYPE
#undef ct
#undef ndim_cb
...
...
@@ -225,6 +231,7 @@ INST_FOR_CTYPE
#define INST(dt_ibyte) template class ParamVectVisitor<4, dt_ibyte, BCAST_1010>
INST
(
dt_int8
);
INST
(
dt_uint8
);
INST
(
dt_bool
);
INST
(
dt_qint8
);
INST
(
dt_quint8
);
#undef dt_ibyte
...
...
dnn/src/cuda/elemwise_helper.cuh
浏览文件 @
e258812f
...
...
@@ -102,6 +102,7 @@ INST(dt_float16, half4);
INST
(
dt_bfloat16
,
bhalf4
);
INST
(
dt_int32
,
int4
);
INST
(
dt_int16
,
short4
);
INST
(
dt_bool
,
uchar4
);
#undef as_raw
#define as_raw(x) x.as_int8()
INST
(
dt_qint8
,
char4
);
...
...
@@ -454,6 +455,7 @@ INST_DT_IBYTE(dt_int8);
INST_DT_IBYTE
(
dt_uint8
);
INST_DT_IBYTE
(
dt_qint8
);
INST_DT_IBYTE
(
dt_quint8
);
INST_DT_IBYTE
(
dt_bool
);
#undef INST_DT_IBYTE
#undef DEVICE_WRAPPER
#undef INST_PARAM_VECT_VISITOR
...
...
@@ -913,6 +915,7 @@ INST_DT_IBYTE(dt_int8);
INST_DT_IBYTE
(
dt_uint8
);
INST_DT_IBYTE
(
dt_qint8
);
INST_DT_IBYTE
(
dt_quint8
);
INST_DT_IBYTE
(
dt_bool
);
#undef INST_DT_IBYTE
//! implement general case by UserOpInvokerToSameNdim
...
...
@@ -1259,6 +1262,7 @@ INST_DT_IBYTE(dt_int8);
INST_DT_IBYTE
(
dt_uint8
);
INST_DT_IBYTE
(
dt_qint8
);
INST_DT_IBYTE
(
dt_quint8
);
INST_DT_IBYTE
(
dt_bool
);
#undef INST_DT_IBYTE
#endif
...
...
dnn/src/cuda/type_cvt/kern.cu
浏览文件 @
e258812f
...
...
@@ -62,7 +62,8 @@ template <typename ctype_dest, typename ctype_src>
struct
TypeCvtOp
<
ctype_dest
,
ctype_src
,
typename
std
::
enable_if
<
std
::
is_same
<
ctype_src
,
dt_int8
>::
value
||
std
::
is_same
<
ctype_src
,
dt_uint8
>::
value
>::
type
>
{
std
::
is_same
<
ctype_src
,
dt_uint8
>::
value
||
std
::
is_same
<
ctype_src
,
dt_bool
>::
value
>::
type
>
{
ctype_dest
*
dest
;
using
src_vect_type
=
typename
VectTypeTrait
<
ctype_src
>::
vect_type
;
using
dst_vect_type
=
typename
VectTypeTrait
<
ctype_dest
>::
vect_type
;
...
...
@@ -85,7 +86,8 @@ struct TypeCvtOpToQuantized<
ctype_dest
,
ctype_src
,
typename
std
::
enable_if
<
std
::
is_same
<
ctype_src
,
dt_int8
>::
value
||
std
::
is_same
<
ctype_src
,
dt_uint8
>::
value
>::
type
>
{
std
::
is_same
<
ctype_src
,
dt_uint8
>::
value
||
std
::
is_same
<
ctype_src
,
dt_bool
>::
value
>::
type
>
{
ctype_dest
*
dest
;
CudaDTypeParam
<
ctype_dest
>
param
;
using
src_vect_type
=
typename
VectTypeTrait
<
ctype_src
>::
vect_type
;
...
...
@@ -109,7 +111,8 @@ struct TypeCvtOpFromQuantized<
ctype_dest
,
ctype_src
,
typename
std
::
enable_if
<
std
::
is_same
<
ctype_src
,
dt_qint8
>::
value
||
std
::
is_same
<
ctype_src
,
dt_quint8
>::
value
>::
type
>
{
std
::
is_same
<
ctype_src
,
dt_quint8
>::
value
||
std
::
is_same
<
ctype_src
,
dt_bool
>::
value
>::
type
>
{
ctype_dest
*
dest
;
CudaDTypeParam
<
ctype_src
>
param
;
using
src_vect_type
=
typename
VectTypeTrait
<
ctype_src
>::
vect_type
;
...
...
@@ -137,7 +140,8 @@ struct TypeCvtOpBetweenQuantized<
ctype_dest
,
ctype_src
,
typename
std
::
enable_if
<
std
::
is_same
<
ctype_src
,
dt_qint8
>::
value
||
std
::
is_same
<
ctype_src
,
dt_quint8
>::
value
>::
type
>
{
std
::
is_same
<
ctype_src
,
dt_quint8
>::
value
||
std
::
is_same
<
ctype_src
,
dt_bool
>::
value
>::
type
>
{
ctype_dest
*
dest
;
CudaDTypeParam
<
ctype_src
>
src_param
;
CudaDTypeParam
<
ctype_dest
>
dst_param
;
...
...
@@ -243,6 +247,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src,
cb(dtype_src, dt_float32) \
cb(dtype_src, dt_float16) \
cb(dtype_src, dt_bfloat16) \
cb(dtype_src, dt_bool) \
#define MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \
cb(dtype_src, dt_quint8) \
...
...
@@ -265,6 +270,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src,
cb(dt_float32) \
cb(dt_float16) \
cb(dt_bfloat16) \
cb(dt_bool) \
#define MEGDNN_FOREACH_QUANTIZED_CTYPE(cb) \
cb(dt_quint8) \
...
...
dnn/src/fallback/type_cvt/opr_impl.cpp
浏览文件 @
e258812f
...
...
@@ -138,7 +138,8 @@ void do_cvt_s8_normal(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
dctype
*
__restrict
dptr
=
dst
.
ptr
<
dctype
>
();
float
scale
=
src
.
layout
.
dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
dptr
[
i
]
=
static_cast
<
dctype
>
(
sptr
[
i
]
*
scale
);
auto
val
=
sptr
[
i
]
*
scale
;
dptr
[
i
]
=
static_cast
<
dctype
>
(
val
);
}
}
...
...
@@ -150,7 +151,8 @@ void do_cvt_s32_normal(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
dctype
*
__restrict
dptr
=
dst
.
ptr
<
dctype
>
();
float
scale
=
src
.
layout
.
dtype
.
param
<
dtype
::
QuantizedS32
>
().
scale
;
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
dptr
[
i
]
=
static_cast
<
dctype
>
(
sptr
[
i
]
*
scale
);
auto
val
=
sptr
[
i
]
*
scale
;
dptr
[
i
]
=
static_cast
<
dctype
>
(
val
);
}
}
...
...
@@ -163,7 +165,8 @@ void do_cvt_asymm8_normal(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
float
scale
=
src
.
layout
.
dtype
.
param
<
dtype
::
Quantized8Asymm
>
().
scale
;
uint8_t
zp
=
src
.
layout
.
dtype
.
param
<
dtype
::
Quantized8Asymm
>
().
zero_point
;
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
dptr
[
i
]
=
static_cast
<
dctype
>
((
sptr
[
i
]
-
zp
)
*
scale
);
auto
val
=
(
sptr
[
i
]
-
zp
)
*
scale
;
dptr
[
i
]
=
static_cast
<
dctype
>
(
val
);
}
}
...
...
@@ -310,6 +313,7 @@ void on_dest_ctype(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
break; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
Bool
)
case
DTypeEnum
::
QuantizedS8
:
MIDOUT_BEGIN
(
megdnn_fb_typecvt_src_dtype
,
midout_iv
(
DTypeEnum
::
QuantizedS8
))
{
...
...
@@ -467,6 +471,7 @@ void run_contiguous(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
}
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
Bool
)
case
DTypeEnum
::
QuantizedS8
:
MIDOUT_BEGIN
(
megdnn_fb_typecvt_dst_dtype
,
midout_iv
(
DTypeEnum
::
QuantizedS8
))
{
...
...
dnn/src/naive/elemwise/kimpl/AND_dt_bool.cpp
0 → 100644
浏览文件 @
e258812f
/**
* \file dnn/src/naive/elemwise/kimpl/AND_dt_bool.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"
dnn/src/naive/elemwise/kimpl/NOT_dt_bool.cpp
0 → 100644
浏览文件 @
e258812f
/**
* \file dnn/src/naive/elemwise/kimpl/NOT_dt_bool.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NOT, cb)
#define KERN_IMPL_ARITY 1
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"
dnn/src/naive/elemwise/kimpl/OR_dt_bool.cpp
0 → 100644
浏览文件 @
e258812f
/**
* \file dnn/src/naive/elemwise/kimpl/OR_dt_bool.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"
dnn/src/naive/elemwise/kimpl/XOR_dt_bool.cpp
0 → 100644
浏览文件 @
e258812f
/**
* \file dnn/src/naive/elemwise/kimpl/XOR_dt_bool.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// generated by gen_elemwise_kern_impls.py
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb)
#define KERN_IMPL_ARITY 2
#define KERN_IMPL_CTYPE dt_bool
#include "../kern_impl.inl"
dnn/src/naive/type_cvt/opr_impl.cpp
浏览文件 @
e258812f
...
...
@@ -82,6 +82,7 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest,
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
Bool
)
#undef cb
default:
megdnn_throw
(
"bad dtype"
);
...
...
@@ -103,6 +104,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
Bool
)
#undef cb
default:
megdnn_throw
(
"bad dtype"
);
...
...
dnn/test/common/elemwise.cpp
浏览文件 @
e258812f
...
...
@@ -942,6 +942,8 @@ TEST(TEST_ELEMWISE, MODE_TRAIT) {
ASSERT_TRUE
(
T
::
from_mode
(
M
::
RMULH
).
commutable
);
ASSERT_FALSE
(
T
::
from_mode
(
M
::
RMULH
).
allow_float
);
ASSERT_TRUE
(
T
::
from_mode
(
M
::
XOR
).
allow_bool
);
}
}
// namespace elemwise
...
...
python_module/src/cpp/megbrain_wrap.cpp
浏览文件 @
e258812f
...
...
@@ -916,6 +916,7 @@ SymbolVar fill_retain_dtype(SymbolVar var, PyObject *value) {
case
DTypeEnum
::
QuantizedS4
:
case
DTypeEnum
::
Byte
:
case
DTypeEnum
::
QuantizedS16
:
case
DTypeEnum
::
Bool
:
break
;
#define cb(low_bit, size) \
case DTypeEnum::low_bit##size: \
...
...
src/core/include/megbrain/dtype.h
浏览文件 @
e258812f
...
...
@@ -27,6 +27,7 @@ using ::megdnn::dt_int32;
using
::
megdnn
::
dt_quint8
;
using
::
megdnn
::
dt_qint8
;
using
::
megdnn
::
dt_qint32
;
using
::
megdnn
::
dt_bool
;
using
::
megdnn
::
DType
;
using
::
megdnn
::
DTypeEnum
;
using
::
megdnn
::
DTypeTrait
;
...
...
src/jit/impl/ast_c.cpp
浏览文件 @
e258812f
...
...
@@ -145,9 +145,9 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() {
0.
f
})
/
6.
f
),
};
mgb_assert
(
map
.
size
()
+
8
==
opr
::
Elemwise
::
Param
::
MODE_NR_MEMBER
);
mgb_assert
(
map
.
size
()
+
12
==
opr
::
Elemwise
::
Param
::
MODE_NR_MEMBER
);
// unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH,
// ERFINV, ERFCINV
// ERFINV, ERFCINV
, NOT, AND, OR, XOR
return
map
;
#undef ADD_OPR
}
...
...
src/jit/impl/halide/ast_hl.cpp
浏览文件 @
e258812f
...
...
@@ -193,6 +193,14 @@ Halide::Expr dispatch_elemwise_mode(
return
Halide
::
round
(
inp
(
0
));
case
Mode
::
RMULH
:
return
(
inp
(
0
)
*
inp
(
1
))
>>
Halide
::
popcount
(
inp
(
0
));
case
Mode
::
NOT
:
return
cv
(
1
)
-
cv
(
inp
(
0
)
!=
cv
(
0
));
case
Mode
::
AND
:
return
cv
(
inp
(
0
)
!=
cv
(
0
))
*
cv
(
inp
(
1
)
!=
cv
(
0
));
case
Mode
::
OR
:
return
cv
(
cv
(
inp
(
0
)
!=
cv
(
0
))
+
cv
(
inp
(
1
)
!=
cv
(
0
))
>
cv
(
0
));
case
Mode
::
XOR
:
return
cv
(
cv
(
inp
(
0
)
!=
cv
(
0
))
+
cv
(
inp
(
1
)
!=
cv
(
0
))
==
cv
(
1
));
default:
mgb_throw
(
InternalError
,
"unsupported Elemwise mode(%d)"
,
static_cast
<
int
>
(
mode
));
...
...
src/opr/impl/basic_arith.cpp
浏览文件 @
e258812f
...
...
@@ -631,6 +631,8 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
RET
(
EL2
(
H_SWISH_GRAD
,
i0
,
og
));
case
Mode
::
FUSE_ADD_H_SWISH
:
RET
(
EL2
(
H_SWISH_GRAD
,
(
i0
+
i1
),
og
));
case
Mode
::
NOT
:
return
nullptr
;
// binary
case
Mode
::
ABS_GRAD
:
...
...
@@ -693,6 +695,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
return
nullptr
;
case
Mode
::
EQ
:
RET_INVALID
();
case
Mode
::
OR
:
case
Mode
::
XOR
:
case
Mode
::
AND
:
return
nullptr
;
// ternary
case
Mode
::
COND_LEQ_MOV
:
...
...
src/opr/impl/loop/forward.cpp
浏览文件 @
e258812f
...
...
@@ -408,6 +408,8 @@ cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const {
break
;
case
DTypeEnum
::
UintB4
:
break
;
case
DTypeEnum
::
Bool
:
break
;
#define cb(x) case DTypeEnum::x: break;
MEGDNN_FOREACH_PARAMETERIZED_DTYPE
(
cb
)
...
...
src/opr/impl/loop/impl.cpp
浏览文件 @
e258812f
...
...
@@ -247,6 +247,8 @@ MGB_DEFINE_OPR_CLASS(LoopImpl::DescImplBase::LoopCondManager::GetCondOpr,
break
;
case
DTypeEnum
::
UintB4
:
break
;
case
DTypeEnum
::
Bool
:
break
;
#define cb(_dt) \
case DTypeEnum::_dt: \
break;
...
...
src/opr/include/megbrain/opr/basic_arith_wrapper.h
浏览文件 @
e258812f
...
...
@@ -32,6 +32,7 @@ namespace opr {
EL1
(
exp
,
EXP
)
EL1
(
log
,
LOG
)
EL1
(
abs
,
ABS
)
EL1
(
not_
,
NOT
)
#undef EL1
...
...
@@ -53,6 +54,9 @@ namespace opr {
EL2
(
min
,
MIN
)
EL2
(
switch_gt0
,
SWITCH_GT0
)
EL2
(
eq
,
EQ
)
EL2
(
and_
,
AND
)
EL2
(
or_
,
OR
)
EL2
(
xor_
,
XOR
)
#undef EL2
...
...
src/opr/test/basic_arith/elemwise.cpp
浏览文件 @
e258812f
...
...
@@ -206,6 +206,7 @@ namespace {
static constexpr Mode MODE = Mode::_mode; \
static constexpr bool ALLOW_INT = _ALLOW_INT; \
static constexpr bool ALLOW_FLOAT = _ALLOW_FLOAT; \
static constexpr bool ALLOW_BOOL = _ALLOW_BOOL; \
static constexpr const char* NAME = #_mode; \
template<typename ctype> \
static inline ctype apply( \
...
...
@@ -588,6 +589,14 @@ namespace {
struct
enable_for_dtype_impl
<
dtype
::
Int32
,
void
>
{
static
constexpr
bool
value
=
false
;
};
template
<
class
Trait
>
struct
enable_for_dtype_impl
<
dtype
::
Bool
,
Trait
>
{
static
constexpr
bool
value
=
Trait
::
ALLOW_BOOL
;
};
template
<
>
struct
enable_for_dtype_impl
<
dtype
::
Bool
,
void
>
{
static
constexpr
bool
value
=
false
;
};
}
//! whether to enable test for specific dtype and Trait
...
...
@@ -749,8 +758,60 @@ TYPED_TEST(TestOprBasicArithTernaryElemwise, Float32) {
TEST
(
TestOprBasicArithElemwise
,
CheckAllModeTested
)
{
size_t
nr_member
=
opr
::
Elemwise
::
Param
::
MODE_NR_MEMBER
;
ASSERT_EQ
(
nr_member
,
tested_mode
.
size
());
ASSERT_EQ
(
nr_member
,
tested_mode
.
size
()
+
4
);
// Not using TestRunner: NOT, AND, OR, XOR
}
#define TEST_OPR_BASIC_ARITH_UNARY_BOOL(_mode, _op) \
TEST(TestOprBasicArithElemwise, _mode) { \
HostTensorGenerator<dtype::Bool> gen; \
auto host_x = gen({2, 1}); \
auto ptr = host_x->ptr<dt_bool>(); \
for (size_t i = 0; i < 2; ++i) { \
ptr[i] = (i & 1); \
} \
auto graph = ComputingGraph::make(); \
using Mode = opr::Elemwise::Mode; \
auto x = opr::Host2DeviceCopy::make(*graph, host_x), \
y = opr::Elemwise::make({x}, Mode::_mode); \
HostTensorND host_y; \
auto func = graph->compile({make_callback_copy(y, host_y)}); \
func->execute(); \
ASSERT_EQ(TensorShape({2, 1}), host_y.shape()); \
auto ptry = host_y.ptr<dt_bool>(); \
for (int i = 0;i < 2;i ++) { \
ASSERT_EQ(_op ptr[i], ptry[i]); \
} \
} \
TEST_OPR_BASIC_ARITH_UNARY_BOOL
(
NOT
,
!
)
#define TEST_OPR_BASIC_ARITH_BINARY_BOOL(_mode, _op) \
TEST(TestOprBasicArithElemwise, _mode) { \
HostTensorGenerator<dtype::Bool> gen; \
auto host_x1 = gen({2, 2}), host_x2 = gen({2, 2}); \
auto ptr1 = host_x1->ptr<dt_bool>(), ptr2 = host_x2->ptr<dt_bool>(); \
for (size_t i = 0; i < 4; ++i) { \
ptr1[i] = (i < 2); \
ptr2[i] = (i & 1); \
} \
auto graph = ComputingGraph::make(); \
using Mode = opr::Elemwise::Mode; \
auto x1 = opr::Host2DeviceCopy::make(*graph, host_x1), \
x2 = opr::Host2DeviceCopy::make(*graph, host_x2), \
y = opr::Elemwise::make({x1, x2}, Mode::_mode); \
HostTensorND host_y; \
auto func = graph->compile({make_callback_copy(y, host_y)}); \
func->execute(); \
ASSERT_EQ(TensorShape({2, 2}), host_y.shape()); \
auto ptry = host_y.ptr<dt_bool>(); \
for (int i = 0;i < 4;i ++) { \
ASSERT_EQ(ptr1[i] _op ptr2[i], ptry[i]); \
} \
} \
TEST_OPR_BASIC_ARITH_BINARY_BOOL
(
AND
,
&&
)
TEST_OPR_BASIC_ARITH_BINARY_BOOL
(
OR
,
||
)
TEST_OPR_BASIC_ARITH_BINARY_BOOL
(
XOR
,
^
)
TEST
(
TestOprBasicArithElemwise
,
FuseMulAdd3Shapes
)
{
using
Checker
=
AutoOprChecker
<
3
,
1
>
;
...
...
src/opr/test/basic_arith/elemwise_binary_trait_def.inl
浏览文件 @
e258812f
...
...
@@ -19,6 +19,17 @@
ctype x = inp[0][idx]; \
ctype y = inp[1][idx]
#define _ALLOW_BOOL true
#define _ALLOW_FLOAT false
#define _ALLOW_INT false
DEF_TRAIT
(
AND
,
x
&&
y
)
DEF_TRAIT
(
OR
,
x
||
y
)
DEF_TRAIT
(
XOR
,
x
^
y
)
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
#undef _ALLOW_BOOL
#define _ALLOW_BOOL false
#define _ALLOW_FLOAT true
#define _ALLOW_INT true
DEF_TRAIT
(
ABS_GRAD
,
x
>
0
?
y
:
-
y
)
...
...
@@ -60,6 +71,7 @@ DEF_TRAIT(SHR, do_shr(x, y))
DEF_TRAIT
(
RMULH
,
do_round_mulh_saturate
(
x
,
y
))
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
#undef _ALLOW_BOOL
#undef _CUR_ARITY
#undef _EXPAND_PARAMS
...
...
src/opr/test/basic_arith/elemwise_ternary_trait_def.inl
浏览文件 @
e258812f
...
...
@@ -20,6 +20,7 @@
ctype y = inp[1][idx]; \
ctype z = inp[2][idx]
#define _ALLOW_BOOL false
#define _ALLOW_FLOAT true
#define _ALLOW_INT true
DEF_TRAIT
(
COND_LEQ_MOV
,
x
<=
y
?
z
:
0
)
...
...
@@ -46,5 +47,6 @@ DEF_TRAIT(FUSE_MUL_ADD4, i0 * i1 + i2 * i3)
#undef _CUR_ARITY
#undef _EXPAND_PARAMS
#undef _ALLOW_BOOL
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr/test/basic_arith/elemwise_unary_trait_def.inl
浏览文件 @
e258812f
...
...
@@ -18,6 +18,15 @@
#define _EXPAND_PARAMS \
ctype x = inp[0][idx]
#define _ALLOW_BOOL true
#define _ALLOW_FLOAT false
#define _ALLOW_INT false
DEF_TRAIT
(
NOT
,
!
x
)
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
#undef _ALLOW_BOOL
#define _ALLOW_BOOL false
#define _ALLOW_FLOAT true
...
...
@@ -51,6 +60,8 @@ DEF_TRAIT(H_SWISH, do_h_swish(x))
#undef _ALLOW_FLOAT
#undef _ALLOW_BOOL
#undef _CUR_ARITY
#undef _EXPAND_PARAMS
...
...
src/serialization/impl/dtype.fbs
浏览文件 @
e258812f
...
...
@@ -21,6 +21,7 @@ enum DTypeEnum : byte {
QuantizedS4,
QuantizedS16,
BFloat16,
Bool,
}
table LinearQuantizationParam {
...
...
test/src/helper.cpp
浏览文件 @
e258812f
...
...
@@ -140,6 +140,21 @@ namespace mgb {
dtype
::
Int32
,
RandomDistribution
::
UNIFORM
>;
template
class
HostTensorGenerator
<
dtype
::
Int32
,
RandomDistribution
::
CONSTANT
>;
std
::
shared_ptr
<
HostTensorND
>
HostTensorGenerator
<
dtype
::
Bool
,
RandomDistribution
::
UNIFORM
>::
operator
()(
const
TensorShape
&
shape
,
CompNode
cn
)
{
if
(
!
cn
.
valid
())
cn
=
CompNode
::
load
(
"xpu0"
);
auto
dtype
=
dtype
::
Bool
();
std
::
shared_ptr
<
HostTensorND
>
ret
=
std
::
make_shared
<
HostTensorND
>
(
cn
,
shape
,
dtype
);
auto
ptr
=
ret
->
ptr
<
dt_bool
>
();
for
(
size_t
i
=
0
,
it
=
shape
.
total_nr_elems
();
i
<
it
;
++
i
)
{
ptr
[
i
]
=
(
i
%
2
==
1
);
}
return
ret
;
}
std
::
shared_ptr
<
HostTensorND
>
HostTensorGenerator
<
dtype
::
QuantizedS8
,
RandomDistribution
::
UNIFORM
>::
operator
()(
const
TensorShape
&
shape
,
CompNode
cn
)
{
...
...
test/src/include/megbrain/test/helper.h
浏览文件 @
e258812f
...
...
@@ -202,6 +202,10 @@ struct RandomDistributionDTypeDefault<dtype::Int32> {
static
constexpr
auto
dist
=
RandomDistribution
::
UNIFORM
;
};
template
<
>
struct
RandomDistributionDTypeDefault
<
dtype
::
Bool
>
{
static
constexpr
auto
dist
=
RandomDistribution
::
UNIFORM
;
};
template
<
>
struct
RandomDistributionDTypeDefault
<
dtype
::
QuantizedS8
>
{
static
constexpr
auto
dist
=
RandomDistribution
::
UNIFORM
;
};
...
...
@@ -251,6 +255,10 @@ struct UniformRNGDefaultRange<dtype::Uint8> {
static
constexpr
dt_uint8
LO
=
0
,
HI
=
255
;
};
template
<
>
struct
UniformRNGDefaultRange
<
dtype
::
Bool
>
{
static
constexpr
dt_bool
LO
=
false
,
HI
=
true
;
};
template
<
>
struct
UniformRNGDefaultRange
<
dtype
::
Int16
>
{
static
constexpr
dt_int16
LO
=
-
32767
,
HI
=
32767
;
};
...
...
@@ -341,6 +349,20 @@ class HostTensorGenerator<dtype, RandomDistribution::CONSTANT> final:
private:
ctype
m_default_val
;
};
template
<
>
class
HostTensorGenerator
<
dtype
::
Bool
,
RandomDistribution
::
UNIFORM
>
final
:
public
HostTensorGeneratorBase
{
public:
using
ctype
=
typename
DTypeTrait
<
dtype
::
Bool
>::
ctype
;
HostTensorGenerator
(
uint64_t
seed
=
next_rand_seed
())
:
HostTensorGeneratorBase
{
seed
}
{}
std
::
shared_ptr
<
HostTensorND
>
operator
()(
const
TensorShape
&
shape
,
CompNode
cn
=
{})
override
;
using
HostTensorGeneratorBase
::
operator
();
};
template
<
>
class
HostTensorGenerator
<
dtype
::
QuantizedS8
,
RandomDistribution
::
UNIFORM
>
final
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录