Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
1f499285
Mace
项目概览
Xiaomi
/
Mace
通知
107
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
1f499285
编写于
8月 31, 2020
作者:
卢
卢旭辉
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'elu' into 'master'
Add onnx's Elu operator See merge request applied-machine-learning/sysml/mace!1294
上级
2eca3a0f
c6118a22
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
172 addition
and
24 deletion
+172
-24
docs/user_guide/op_lists.rst
docs/user_guide/op_lists.rst
+1
-0
mace/ops/activation.cc
mace/ops/activation.cc
+4
-4
mace/ops/activation.h
mace/ops/activation.h
+16
-9
mace/ops/common/activation_type.h
mace/ops/common/activation_type.h
+1
-0
mace/ops/opencl/cl/activation.cl
mace/ops/opencl/cl/activation.cl
+4
-4
mace/ops/opencl/cl/common.h
mace/ops/opencl/cl/common.h
+6
-3
mace/ops/opencl/image/activation.cc
mace/ops/opencl/image/activation.cc
+6
-1
mace/ops/opencl/image/winograd_conv2d.cc
mace/ops/opencl/image/winograd_conv2d.cc
+4
-0
test/ccbenchmark/mace/ops/activation_benchmark.cc
test/ccbenchmark/mace/ops/activation_benchmark.cc
+64
-0
test/ccunit/mace/ops/activation_test.cc
test/ccunit/mace/ops/activation_test.cc
+53
-0
tools/python/transform/base_converter.py
tools/python/transform/base_converter.py
+1
-0
tools/python/transform/onnx_converter.py
tools/python/transform/onnx_converter.py
+8
-2
tools/python/transform/transformer.py
tools/python/transform/transformer.py
+4
-1
未找到文件。
docs/user_guide/op_lists.rst
浏览文件 @
1f499285
...
...
@@ -21,6 +21,7 @@ Operator lists
"DEPTH_TO_SPACE","Y",""
"DEQUANTIZE","Y","Model quantization will be supported later."
"ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/SQRT/EQUAL/FLOOR_DIV"
"ELU","Y",""
"EMBEDDING_LOOKUP","Y",""
"EXPANDDIMS","Y","Only CPU and TensorFlow is supported."
"FILL","Y","Only CPU and TensorFlow is supported."
...
...
mace/ops/activation.cc
浏览文件 @
1f499285
...
...
@@ -54,7 +54,7 @@ class ActivationOp<DeviceType::CPU, T> : public Operation {
const
Tensor
*
input
=
this
->
Input
(
0
);
Tensor
*
output
=
this
->
Output
(
0
);
if
(
activation_type_
==
PRELU
)
{
if
(
activation_type_
==
PRELU
||
activation_type_
==
ELU
)
{
MACE_RETURN_IF_ERROR
(
output
->
ResizeLike
(
input
));
const
T
*
input_ptr
=
input
->
data
<
T
>
();
T
*
output_ptr
=
output
->
mutable_data
<
T
>
();
...
...
@@ -63,8 +63,8 @@ class ActivationOp<DeviceType::CPU, T> : public Operation {
const
T
*
alpha_ptr
=
alpha
->
data
<
T
>
();
const
index_t
outer_size
=
output
->
dim
(
0
);
const
index_t
inner_size
=
output
->
dim
(
2
)
*
output
->
dim
(
3
);
PReLUActivation
(
context
,
input_ptr
,
outer_size
,
input
->
dim
(
1
),
inner_size
,
alpha_ptr
,
output_ptr
);
ActivationWithAlpha
(
context
,
input_ptr
,
outer_size
,
input
->
dim
(
1
)
,
inner_size
,
alpha_ptr
,
activation_type_
,
output_ptr
);
}
else
{
activation_delegator_
->
Compute
(
context
,
input
,
output
);
}
...
...
@@ -96,7 +96,7 @@ class ActivationOp<DeviceType::GPU, float> : public Operation {
}
else
{
MACE_NOT_IMPLEMENTED
;
}
if
(
type
==
ActivationType
::
PRELU
)
{
if
(
type
==
ActivationType
::
PRELU
||
type
==
ActivationType
::
ELU
)
{
MACE_CHECK
(
TransformFilter
(
context
,
operator_def_
.
get
(),
1
,
OpenCLBufferType
::
ARGUMENT
,
mem_type
)
==
MaceStatus
::
MACE_SUCCESS
);
...
...
mace/ops/activation.h
浏览文件 @
1f499285
...
...
@@ -42,6 +42,8 @@ inline ActivationType StringToActivationType(const std::string type) {
return
ActivationType
::
NOOP
;
}
else
if
(
type
==
"LEAKYRELU"
)
{
return
ActivationType
::
LEAKYRELU
;
}
else
if
(
type
==
"ELU"
)
{
return
ActivationType
::
ELU
;
}
else
{
LOG
(
FATAL
)
<<
"Unknown activation type: "
<<
type
;
}
...
...
@@ -49,13 +51,14 @@ inline ActivationType StringToActivationType(const std::string type) {
}
template
<
typename
T
>
void
PReLUActivation
(
const
OpContext
*
context
,
const
T
*
input_ptr
,
const
index_t
outer_size
,
const
index_t
input_chan
,
const
index_t
inner_size
,
const
T
*
alpha_ptr
,
T
*
output_ptr
)
{
void
ActivationWithAlpha
(
const
OpContext
*
context
,
const
T
*
input_ptr
,
const
index_t
outer_size
,
const
index_t
input_chan
,
const
index_t
inner_size
,
const
T
*
alpha_ptr
,
const
index_t
activation_type
,
T
*
output_ptr
)
{
utils
::
ThreadPool
&
thread_pool
=
context
->
device
()
->
cpu_runtime
()
->
thread_pool
();
...
...
@@ -66,7 +69,12 @@ void PReLUActivation(const OpContext *context,
for
(
index_t
j
=
0
;
j
<
inner_size
;
++
j
)
{
index_t
idx
=
i
*
input_chan
*
inner_size
+
chan_idx
*
inner_size
+
j
;
if
(
input_ptr
[
idx
]
<
0
)
{
output_ptr
[
idx
]
=
input_ptr
[
idx
]
*
alpha_ptr
[
chan_idx
];
if
(
activation_type
==
ActivationType
::
PRELU
)
{
output_ptr
[
idx
]
=
input_ptr
[
idx
]
*
alpha_ptr
[
chan_idx
];
}
else
if
(
activation_type
==
ActivationType
::
ELU
)
{
output_ptr
[
idx
]
=
(
std
::
exp
(
input_ptr
[
idx
])
-
1
)
*
alpha_ptr
[
chan_idx
];
}
}
else
{
output_ptr
[
idx
]
=
input_ptr
[
idx
];
}
...
...
@@ -75,7 +83,6 @@ void PReLUActivation(const OpContext *context,
}
},
0
,
outer_size
,
1
,
0
,
input_chan
,
1
);
}
}
// namespace ops
}
// namespace mace
...
...
mace/ops/common/activation_type.h
浏览文件 @
1f499285
...
...
@@ -26,6 +26,7 @@ enum ActivationType {
TANH
=
4
,
SIGMOID
=
5
,
LEAKYRELU
=
6
,
ELU
=
7
,
};
}
// namespace ops
...
...
mace/ops/opencl/cl/activation.cl
浏览文件 @
1f499285
...
...
@@ -3,7 +3,7 @@
__kernel
void
activation
(
OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only
image2d_t
input,
#
if
def
USE_PRELU
#
if
defined
(
USE_PRELU
)
|
| defined (USE_ELU)
__read_only image2d_t alpha,
#endif
__private const float relux_max_limit,
...
...
@@ -23,9 +23,9 @@ __kernel void activation(OUT_OF_RANGE_PARAMS
const int pos = mad24(ch_blk, width, w);
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
#
if
def
USE_PRELU
DATA_TYPE4
prelu
_alpha
=
READ_IMAGET
(
alpha,
SAMPLER,
(
int2
)(
ch_blk,
0
))
;
DATA_TYPE4
out
=
do_activation
(
in,
prelu
_alpha,
relux_max_limit,
leakyrelu_coefficient
)
;
#if
defined (USE_PRELU) |
|
defined
(
USE_ELU
)
DATA_TYPE4
activation
_alpha
=
READ_IMAGET
(
alpha,
SAMPLER,
(
int2
)(
ch_blk,
0
))
;
DATA_TYPE4
out
=
do_activation
(
in,
activation
_alpha,
relux_max_limit,
leakyrelu_coefficient
)
;
#
else
DATA_TYPE4
out
=
do_activation
(
in,
relux_max_limit,
leakyrelu_coefficient
)
;
#
endif
...
...
mace/ops/opencl/cl/common.h
浏览文件 @
1f499285
...
...
@@ -83,8 +83,8 @@ inline float4 do_sigmoid(float4 in) {
#ifdef DATA_TYPE
inline
DATA_TYPE4
do_activation
(
DATA_TYPE4
in
,
#if
def USE_PRELU
DATA_TYPE4
prelu_
alpha
,
#if
defined (USE_PRELU) || defined (USE_ELU)
DATA_TYPE4
alpha
,
#endif
__private
const
float
relux_max_limit
,
__private
const
float
leakyrelu_coefficient
)
{
...
...
@@ -96,7 +96,10 @@ inline DATA_TYPE4 do_activation(DATA_TYPE4 in,
out
=
clamp
(
in
,
(
DATA_TYPE4
)
0
,
relux_max_limit
);
#endif
#ifdef USE_PRELU
out
=
select
(
prelu_alpha
*
in
,
in
,
in
>=
(
DATA_TYPE
)
0
);
out
=
select
(
alpha
*
in
,
in
,
in
>=
(
DATA_TYPE
)
0
);
#endif
#ifdef USE_ELU
out
=
select
(
alpha
*
(
native_exp
(
in
)
-
1
.
0
f
),
in
,
in
>=
(
DATA_TYPE
)
0
);
#endif
#ifdef USE_TANH
out
=
tanh
(
in
);
...
...
mace/ops/opencl/image/activation.cc
浏览文件 @
1f499285
...
...
@@ -58,6 +58,11 @@ MaceStatus ActivationKernel::Compute(
built_options
.
emplace
(
"-DUSE_PRELU"
);
break
;
}
case
ELU
:
{
tuning_key_prefix_
=
"elu_opencl_kernel"
;
built_options
.
emplace
(
"-DUSE_ELU"
);
break
;
}
case
TANH
:
{
tuning_key_prefix_
=
"tanh_opencl_kernel"
;
built_options
.
emplace
(
"-DUSE_TANH"
);
...
...
@@ -94,7 +99,7 @@ MaceStatus ActivationKernel::Compute(
MACE_OUT_OF_RANGE_SET_ARGS
(
kernel_
);
MACE_SET_3D_GWS_ARGS
(
kernel_
,
gws
);
kernel_
.
setArg
(
idx
++
,
*
(
input
->
opencl_image
()));
if
(
activation_
==
PRELU
)
{
if
(
activation_
==
PRELU
||
activation_
==
ELU
)
{
MACE_CHECK_NOTNULL
(
alpha
);
kernel_
.
setArg
(
idx
++
,
*
(
alpha
->
opencl_image
()));
}
...
...
mace/ops/opencl/image/winograd_conv2d.cc
浏览文件 @
1f499285
...
...
@@ -161,6 +161,10 @@ MaceStatus WinogradOutputTransform(OpContext *context,
built_options
.
emplace
(
"-DUSE_PRELU"
);
break
;
}
case
ELU
:
{
built_options
.
emplace
(
"-DUSE_ELU"
);
break
;
}
case
TANH
:
{
built_options
.
emplace
(
"-DUSE_TANH"
);
break
;
...
...
test/ccbenchmark/mace/ops/activation_benchmark.cc
浏览文件 @
1f499285
...
...
@@ -208,6 +208,70 @@ MACE_BM_PRELU(1, 3, 512, 512);
MACE_BM_PRELU
(
1
,
32
,
112
,
112
);
MACE_BM_PRELU
(
1
,
64
,
256
,
256
);
namespace
{
template
<
DeviceType
D
,
typename
T
>
void
EluBenchmark
(
int
iters
,
int
batch
,
int
channels
,
int
height
,
int
width
)
{
mace
::
testing
::
StopTiming
();
OpsTestNet
net
;
// Add input data
if
(
D
==
DeviceType
::
CPU
)
{
net
.
AddRandomInput
<
D
,
T
>
(
"Input"
,
{
batch
,
channels
,
height
,
width
});
}
else
if
(
D
==
DeviceType
::
GPU
)
{
net
.
AddRandomInput
<
D
,
T
>
(
"Input"
,
{
batch
,
height
,
width
,
channels
});
}
else
{
MACE_NOT_IMPLEMENTED
;
}
net
.
AddRandomInput
<
D
,
T
>
(
"Alpha"
,
{
channels
},
true
);
OpDefBuilder
(
"Activation"
,
"EluBM"
)
.
Input
(
"Input"
)
.
Input
(
"Alpha"
)
.
Output
(
"Output"
)
.
AddStringArg
(
"activation"
,
"ELU"
)
.
AddIntArg
(
"T"
,
static_cast
<
int
>
(
DataTypeToEnum
<
T
>::
value
))
.
Finalize
(
net
.
NewOperatorDef
());
// Warm-up
for
(
int
i
=
0
;
i
<
5
;
++
i
)
{
net
.
RunOp
(
D
);
}
net
.
Sync
();
mace
::
testing
::
StartTiming
();
while
(
iters
--
)
{
net
.
RunOp
(
D
);
}
net
.
Sync
();
}
}
// namespace
#define MACE_BM_ELU_MACRO(N, C, H, W, TYPE, DEVICE) \
static void MACE_BM_ELU_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
EluBenchmark<DEVICE, TYPE>(iters, N, C, H, W); \
} \
MACE_BENCHMARK(MACE_BM_ELU_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#ifdef MACE_ENABLE_OPENCL
#define MACE_BM_ELU(N, C, H, W) \
MACE_BM_ELU_MACRO(N, C, H, W, float, CPU); \
MACE_BM_ELU_MACRO(N, C, H, W, float, GPU); \
MACE_BM_ELU_MACRO(N, C, H, W, half, GPU)
#else
#define MACE_BM_ELU(N, C, H, W) \
MACE_BM_ELU_MACRO(N, C, H, W, float, CPU)
#endif
MACE_BM_ELU
(
1
,
1
,
512
,
512
);
MACE_BM_ELU
(
1
,
3
,
128
,
128
);
MACE_BM_ELU
(
1
,
3
,
512
,
512
);
MACE_BM_ELU
(
1
,
32
,
112
,
112
);
MACE_BM_ELU
(
1
,
64
,
256
,
256
);
namespace
{
template
<
DeviceType
D
,
typename
T
>
void
TanhBenchmark
(
int
iters
,
int
batch
,
int
channels
,
int
height
,
int
width
)
{
...
...
test/ccunit/mace/ops/activation_test.cc
浏览文件 @
1f499285
...
...
@@ -235,6 +235,59 @@ TEST_F(ActivationOpTest, OPENCLSimplePrelu) {
TestSimplePrelu
<
DeviceType
::
GPU
>
();
}
namespace
{
template
<
DeviceType
D
>
void
TestSimpleElu
()
{
OpsTestNet
net
;
// Add input data
net
.
AddInputFromArray
<
D
,
float
>
(
"Input"
,
{
2
,
2
,
2
,
2
},
{
-
7
,
7
,
-
6
,
6
,
-
5
,
-
5
,
-
4
,
-
4
,
-
3
,
3
,
-
2
,
2
,
-
1
,
-
1
,
0
,
0
});
net
.
AddInputFromArray
<
D
,
float
>
(
"Alpha"
,
{
2
},
{
2.0
,
3.0
},
true
);
if
(
D
==
DeviceType
::
GPU
)
{
OpDefBuilder
(
"Activation"
,
"EluTest"
)
.
Input
(
"Input"
)
.
Input
(
"Alpha"
)
.
Output
(
"Output"
)
.
AddStringArg
(
"activation"
,
"ELU"
)
.
Finalize
(
net
.
NewOperatorDef
());
// Run
net
.
RunOp
(
D
);
}
else
{
net
.
TransformDataFormat
<
D
,
float
>
(
"Input"
,
DataFormat
::
NHWC
,
"InputNCHW"
,
DataFormat
::
NCHW
);
OpDefBuilder
(
"Activation"
,
"EluTest"
)
.
Input
(
"InputNCHW"
)
.
Input
(
"Alpha"
)
.
Output
(
"OutputNCHW"
)
.
AddStringArg
(
"activation"
,
"ELU"
)
.
Finalize
(
net
.
NewOperatorDef
());
// Run
net
.
RunOp
(
D
);
net
.
TransformDataFormat
<
D
,
float
>
(
"OutputNCHW"
,
DataFormat
::
NCHW
,
"Output"
,
DataFormat
::
NHWC
);
}
auto
expected
=
net
.
CreateTensor
<
float
>
(
{
2
,
2
,
2
,
2
},
{
-
1.998176236068891
,
7
,
-
1.9950424956466672
,
6
,
-
1.986524106001829
,
-
2.9797861590027437
,
-
1.9633687222225316
,
-
2.9450530833337973
,
-
1.900425863264272
,
3
,
-
1.7293294335267746
,
2
,
-
1.2642411176571153
,
-
1.896361676485673
,
0
,
0
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
1e-5
);
}
}
// namespace
TEST_F
(
ActivationOpTest
,
CPUSimpleElu
)
{
TestSimpleElu
<
DeviceType
::
CPU
>
();
}
TEST_F
(
ActivationOpTest
,
OPENCLSimpleElu
)
{
TestSimpleElu
<
DeviceType
::
GPU
>
();
}
namespace
{
template
<
DeviceType
D
>
void
TestSimpleTanh
()
{
...
...
tools/python/transform/base_converter.py
浏览文件 @
1f499285
...
...
@@ -49,6 +49,7 @@ class ActivationType(Enum):
TANH
=
4
SIGMOID
=
5
LEAKYRELU
=
6
ELU
=
7
class
EltwiseType
(
Enum
):
...
...
tools/python/transform/onnx_converter.py
浏览文件 @
1f499285
...
...
@@ -85,7 +85,7 @@ OnnxSupportedOps = [
'Div'
,
'Dropout'
,
'DynamicLSTM'
,
#
'Elu',
'Elu'
,
'Equal'
,
# 'Exp',
# 'Expand',
...
...
@@ -323,6 +323,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType
.
Relu
.
name
:
ActivationType
.
RELU
,
OnnxOpType
.
LeakyRelu
.
name
:
ActivationType
.
LEAKYRELU
,
OnnxOpType
.
PRelu
.
name
:
ActivationType
.
PRELU
,
OnnxOpType
.
Elu
.
name
:
ActivationType
.
ELU
,
OnnxOpType
.
Tanh
.
name
:
ActivationType
.
TANH
,
OnnxOpType
.
Sigmoid
.
name
:
ActivationType
.
SIGMOID
,
}
...
...
@@ -348,6 +349,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType
.
Dropout
.
name
:
self
.
convert_dropout
,
OnnxOpType
.
DimRange
.
name
:
self
.
convert_dim_range
,
OnnxOpType
.
Div
.
name
:
self
.
convert_eltwise
,
OnnxOpType
.
Elu
.
name
:
self
.
convert_activation
,
OnnxOpType
.
Equal
.
name
:
self
.
convert_eltwise
,
OnnxOpType
.
ExtractPooling
.
name
:
self
.
convert_extract_pooling
,
OnnxOpType
.
Flatten
.
name
:
self
.
convert_flatten
,
...
...
@@ -627,7 +629,11 @@ class OnnxConverter(base_converter.ConverterInterface):
type_arg
.
s
=
six
.
b
(
self
.
activation_type
[
node
.
op_type
].
name
)
if
"alpha"
in
node
.
attrs
:
alpha_value
=
node
.
attrs
[
"alpha"
]
alpha_tensor_name
=
node
.
name
+
'_alpha'
alpha_value
=
np
.
array
([
node
.
attrs
[
"alpha"
]])
self
.
add_tensor
(
alpha_tensor_name
,
alpha_value
.
reshape
(
-
1
).
shape
,
mace_pb2
.
DT_FLOAT
,
alpha_value
)
op
.
input
.
extend
([
alpha_tensor_name
])
else
:
if
node
.
op_type
==
OnnxOpType
.
LeakyRelu
.
name
:
alpha_value
=
0.01
...
...
tools/python/transform/transformer.py
浏览文件 @
1f499285
...
...
@@ -977,7 +977,10 @@ class Transformer(base_converter.ConverterInterface):
[
ActivationType
.
RELU
.
name
,
ActivationType
.
RELUX
.
name
])
else
:
fold_consumer
=
(
act_type
!=
ActivationType
.
PRELU
.
name
)
fold_consumer
=
(
act_type
!=
ActivationType
.
PRELU
.
name
and
act_type
!=
ActivationType
.
ELU
.
name
)
# during quantization, only fold relu/relux
if
(
self
.
_option
.
quantize_stat
or
self
.
_option
.
quantize
)
\
and
act_type
not
in
[
ActivationType
.
RELU
.
name
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录