Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
慢慢CG
Mace
提交
6e132756
Mace
项目概览
慢慢CG
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
6e132756
编写于
9月 12, 2018
作者:
Y
yejianwu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support differ input_size and hidden_units, add lstmcell cpu benchmark
上级
9ffd51d9
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
282 addition
and
247 deletion
+282
-247
mace/kernels/opencl/cl/lstmcell.cl
mace/kernels/opencl/cl/lstmcell.cl
+65
-80
mace/kernels/opencl/lstmcell.cc
mace/kernels/opencl/lstmcell.cc
+11
-7
mace/ops/BUILD
mace/ops/BUILD
+4
-4
mace/ops/lstmcell_benchmark.cc
mace/ops/lstmcell_benchmark.cc
+44
-32
mace/ops/lstmcell_test.cc
mace/ops/lstmcell_test.cc
+17
-124
mace/ops/lstmcell_test_util.h
mace/ops/lstmcell_test_util.h
+141
-0
未找到文件。
mace/kernels/opencl/cl/lstmcell.cl
浏览文件 @
6e132756
...
...
@@ -9,6 +9,8 @@ __kernel void lstmcell(KERNEL_ERROR_PARAMS
__read_only
image2d_t
pre_cell,
__private
const
float
forget_bias,
__private
const
int
width,
__private
const
int
hidden_units,
__private
const
int
in_w_blk,
__write_only
image2d_t
cell,
__write_only
image2d_t
output
)
{
const
int
w_blk_idx
=
get_global_id
(
0
)
;
...
...
@@ -25,115 +27,98 @@ __kernel void lstmcell(KERNEL_ERROR_PARAMS
DATA_TYPE4
fc_res0
=
0.0
,
fc_res1
=
0.0
,
fc_res2
=
0.0
,
fc_res3
=
0.0
;
DATA_TYPE4
in,
pre_h
;
DATA_TYPE4
w0,
w1,
w2,
w3
;
int
k_offset
;
//
concat
matmul
for
(
short
i
=
0
; i <
global_size_dim0
; ++i) {
for
(
short
i
=
0
; i <
in_w_blk
; ++i) {
in
=
READ_IMAGET
(
input,
SAMPLER,
(
int2
)(
i,
h_idx
))
;
short
k
=
4
*
i
;
int
k
=
i
<<
2
;
w0
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx,
k
))
;
w1
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0,
k
))
;
w2
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
2
,
k
))
;
w3
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
3
,
k
))
;
fc_res0
+=
in.x
*
w0
;
fc_res1
+=
in.x
*
w1
;
fc_res2
+=
in.x
*
w2
;
fc_res3
+=
in.x
*
w3
;
k
=
4
*
i
+
1
;
if
(
k
<
width
)
{
w0
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx,
k
))
;
w1
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0,
k
))
;
w2
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
2
,
k
))
;
w3
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
3
,
k
))
;
k
+=
1
;
k_offset
=
select
(
-1
,
k,
k
<
width
)
;
w0
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx,
k_offset
))
;
w1
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0,
k_offset
))
;
w2
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
2
,
k_offset
))
;
w3
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
3
,
k_offset
))
;
fc_res0
+=
in.y
*
w0
;
fc_res1
+=
in.y
*
w1
;
fc_res2
+=
in.y
*
w2
;
fc_res3
+=
in.y
*
w3
;
}
k
=
4
*
i
+
2
;
if
(
k
<
width
)
{
w0
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx,
k
))
;
w1
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0,
k
))
;
w2
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
2
,
k
))
;
w3
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
3
,
k
))
;
k
+=
1
;
k_offset
=
select
(
-1
,
k,
k
<
width
)
;
w0
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx,
k_offset
))
;
w1
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0,
k_offset
))
;
w2
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
2
,
k_offset
))
;
w3
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
3
,
k_offset
))
;
fc_res0
+=
in.z
*
w0
;
fc_res1
+=
in.z
*
w1
;
fc_res2
+=
in.z
*
w2
;
fc_res3
+=
in.z
*
w3
;
}
k
=
4
*
i
+
3
;
if
(
k
<
width
)
{
w0
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx,
k
))
;
w1
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0,
k
))
;
w2
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
2
,
k
))
;
w3
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
3
,
k
))
;
k
+=
1
;
k_offset
=
select
(
-1
,
k,
k
<
width
)
;
w0
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx,
k_offset
))
;
w1
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0,
k_offset
))
;
w2
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
2
,
k_offset
))
;
w3
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
3
,
k_offset
))
;
fc_res0
+=
in.w
*
w0
;
fc_res1
+=
in.w
*
w1
;
fc_res2
+=
in.w
*
w2
;
fc_res3
+=
in.w
*
w3
;
}
}
for
(
short
i
=
0
; i < global_size_dim0; ++i) {
pre_h
=
READ_IMAGET
(
pre_output,
SAMPLER,
(
int2
)(
i,
h_idx
))
;
short
k
=
4
*
(
i
+
global_size_dim0
)
;
short
k_limit
=
4
*
global_size_dim0
+
width
;
int
k
=
(
i
<<
2
)
+
width
;
w0
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx,
k
))
;
w1
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0,
k
))
;
w2
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
2
,
k
))
;
w3
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
3
,
k
))
;
fc_res0
+=
pre_h.x
*
w0
;
fc_res1
+=
pre_h.x
*
w1
;
fc_res2
+=
pre_h.x
*
w2
;
fc_res3
+=
pre_h.x
*
w3
;
k
=
4
*
(
i
+
global_size_dim0
)
+
1
;
if
(
k
<
k_limit
)
{
k
+=
1
;
w0
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx,
k
))
;
w1
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0,
k
))
;
w2
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
2
,
k
))
;
w3
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
3
,
k
))
;
fc_res0
+=
pre_h.y
*
w0
;
fc_res1
+=
pre_h.y
*
w1
;
fc_res2
+=
pre_h.y
*
w2
;
fc_res3
+=
pre_h.y
*
w3
;
}
k
=
4
*
(
i
+
global_size_dim0
)
+
2
;
if
(
k
<
k_limit
)
{
k
+=
1
;
w0
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx,
k
))
;
w1
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0,
k
))
;
w2
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
2
,
k
))
;
w3
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
3
,
k
))
;
fc_res0
+=
pre_h.z
*
w0
;
fc_res1
+=
pre_h.z
*
w1
;
fc_res2
+=
pre_h.z
*
w2
;
fc_res3
+=
pre_h.z
*
w3
;
}
k
=
4
*
(
i
+
global_size_dim0
)
+
3
;
if
(
k
<
k_limit
)
{
k
+=
1
;
w0
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx,
k
))
;
w1
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0,
k
))
;
w2
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
2
,
k
))
;
w3
=
READ_IMAGET
(
weight,
SAMPLER,
(
int2
)(
w_blk_idx
+
global_size_dim0
*
3
,
k
))
;
fc_res0
+=
pre_h.w
*
w0
;
fc_res1
+=
pre_h.w
*
w1
;
fc_res2
+=
pre_h.w
*
w2
;
fc_res3
+=
pre_h.w
*
w3
;
}
}
//
bias
DATA_TYPE4
b0,
b1,
b2,
b3
;
...
...
mace/kernels/opencl/lstmcell.cc
浏览文件 @
6e132756
...
...
@@ -31,12 +31,13 @@ MaceStatus LSTMCellFunctor<DeviceType::GPU, T>::operator()(
Tensor
*
cell
,
Tensor
*
output
,
StatsFuture
*
future
)
{
MACE_CHECK
(
input
->
dim_size
()
==
2
&&
in
put
->
dim
(
1
)
%
4
==
0
,
"LSTM
step
should be a multiple of 4"
);
MACE_CHECK
(
pre_output
->
dim_size
()
==
2
&&
pre_out
put
->
dim
(
1
)
%
4
==
0
,
"LSTM
hidden units
should be a multiple of 4"
);
const
index_t
height
=
input
->
dim
(
0
);
const
index_t
width
=
input
->
dim
(
1
);
const
index_t
width_blocks
=
width
/
4
;
const
index_t
hidden_units
=
pre_output
->
dim
(
1
);
const
index_t
w_blocks
=
hidden_units
>>
2
;
auto
runtime
=
context_
->
device
()
->
opencl_runtime
();
...
...
@@ -57,17 +58,18 @@ MaceStatus LSTMCellFunctor<DeviceType::GPU, T>::operator()(
static_cast
<
uint32_t
>
(
runtime
->
GetKernelMaxWorkGroupSize
(
kernel_
));
}
const
uint32_t
gws
[
2
]
=
{
static_cast
<
uint32_t
>
(
w
idth
_blocks
),
const
uint32_t
gws
[
2
]
=
{
static_cast
<
uint32_t
>
(
w_blocks
),
static_cast
<
uint32_t
>
(
height
)};
if
(
!
IsVecEqual
(
input_shape_
,
input
->
shape
()))
{
std
::
vector
<
index_t
>
output_shape_padded
=
{
height
,
1
,
1
,
width
};
std
::
vector
<
index_t
>
output_shape_padded
=
{
height
,
1
,
1
,
hidden_units
};
std
::
vector
<
size_t
>
output_image_shape
;
CalImage2DShape
(
output_shape_padded
,
BufferType
::
IN_OUT_CHANNEL
,
&
output_image_shape
);
MACE_RETURN_IF_ERROR
(
output
->
ResizeImage
(
input
->
shape
(),
MACE_RETURN_IF_ERROR
(
output
->
ResizeImage
(
pre_output
->
shape
(),
output_image_shape
));
MACE_RETURN_IF_ERROR
(
cell
->
ResizeImage
(
pre_cell
->
shape
(),
output_image_shape
));
MACE_RETURN_IF_ERROR
(
cell
->
ResizeImage
(
input
->
shape
(),
output_image_shape
));
uint32_t
idx
=
0
;
OUT_OF_RANGE_SET_ARG
;
...
...
@@ -79,6 +81,8 @@ MaceStatus LSTMCellFunctor<DeviceType::GPU, T>::operator()(
kernel_
.
setArg
(
idx
++
,
*
(
pre_cell
->
opencl_image
()));
kernel_
.
setArg
(
idx
++
,
static_cast
<
float
>
(
forget_bias_
));
kernel_
.
setArg
(
idx
++
,
static_cast
<
int32_t
>
(
width
));
kernel_
.
setArg
(
idx
++
,
static_cast
<
int32_t
>
(
hidden_units
));
kernel_
.
setArg
(
idx
++
,
static_cast
<
int32_t
>
(
RoundUpDiv4
(
width
)));
kernel_
.
setArg
(
idx
++
,
*
(
cell
->
opencl_image
()));
kernel_
.
setArg
(
idx
++
,
*
(
output
->
opencl_image
()));
...
...
mace/ops/BUILD
浏览文件 @
6e132756
...
...
@@ -20,9 +20,9 @@ load(
cc_library
(
name
=
"test"
,
testonly
=
1
,
hdrs
=
[
"
ops
_test_util.h"
,
],
hdrs
=
glob
(
[
"
*
_test_util.h"
,
]
)
,
srcs
=
[
"ops_test_util.cc"
,
],
...
...
@@ -67,7 +67,7 @@ cc_library(
),
hdrs
=
glob
(
[
"*.h"
],
exclude
=
[
"ops_test_util.h"
]
,
exclude
=
glob
([
"*_test_util.h"
])
,
),
copts
=
[
"-Werror"
,
...
...
mace/ops/lstmcell_benchmark.cc
浏览文件 @
6e132756
...
...
@@ -15,6 +15,7 @@
#include "mace/core/operator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/lstmcell_test_util.h"
#include "mace/ops/ops_test_util.h"
namespace
mace
{
...
...
@@ -23,23 +24,31 @@ namespace test {
namespace
{
template
<
DeviceType
D
,
typename
T
>
void
LSTMCell
(
int
iters
,
int
batch
,
int
lstm_step
)
{
void
LSTMCell
(
int
iters
,
int
batch
,
int
input_size
,
int
hidden_units
)
{
mace
::
testing
::
StopTiming
();
OpsTestNet
net
;
// Add input data
if
(
D
==
DeviceType
::
GPU
)
{
net
.
AddRandomInput
<
D
,
T
>
(
"Input"
,
{
batch
,
lstm_step
});
net
.
AddRandomInput
<
D
,
T
>
(
"PreOutput"
,
{
batch
,
lstm_step
});
net
.
AddRandomInput
<
D
,
T
>
(
"Weight"
,
{
2
*
lstm_step
,
4
*
lstm_step
});
net
.
AddRandomInput
<
D
,
T
>
(
"Bias"
,
{
4
*
lstm_step
});
net
.
AddRandomInput
<
D
,
T
>
(
"PreCell"
,
{
batch
,
lstm_step
});
}
else
{
MACE_NOT_IMPLEMENTED
;
}
net
.
AddRandomInput
<
D
,
float
>
(
"Input"
,
{
batch
,
input_size
});
net
.
AddRandomInput
<
D
,
float
>
(
"PreOutput"
,
{
batch
,
hidden_units
});
net
.
AddRandomInput
<
D
,
float
>
(
"Weight"
,
{
input_size
+
hidden_units
,
4
*
hidden_units
});
net
.
AddRandomInput
<
D
,
float
>
(
"Bias"
,
{
4
*
hidden_units
});
net
.
AddRandomInput
<
D
,
float
>
(
"PreCell"
,
{
batch
,
hidden_units
});
const
float
&
forget_add
=
0.0
f
;
if
(
D
==
DeviceType
::
CPU
)
{
net
.
CopyData
<
DeviceType
::
CPU
,
float
>
(
"Input"
,
"InputCPU"
);
net
.
CopyData
<
DeviceType
::
CPU
,
float
>
(
"PreOutput"
,
"PreOutputCPU"
);
net
.
CopyData
<
DeviceType
::
CPU
,
float
>
(
"Weight"
,
"WeightCPU"
);
net
.
CopyData
<
DeviceType
::
CPU
,
float
>
(
"Bias"
,
"BiasCPU"
);
net
.
CopyData
<
DeviceType
::
CPU
,
float
>
(
"PreCell"
,
"PreCellCPU"
);
if
(
D
==
DeviceType
::
GPU
)
{
LSTMCellCPU
<
float
>
(
&
net
,
"InputCPU"
,
"PreOutputCPU"
,
"WeightCPU"
,
"BiasCPU"
,
"PreCellCPU"
,
forget_add
,
"CellCPU"
,
"OutputCPU"
);
}
else
if
(
D
==
DeviceType
::
GPU
)
{
BufferToImage
<
D
,
T
>
(
&
net
,
"Input"
,
"InputImage"
,
kernels
::
BufferType
::
IN_OUT_CHANNEL
);
BufferToImage
<
D
,
T
>
(
&
net
,
"PreOutput"
,
"PreOutputImage"
,
...
...
@@ -57,7 +66,7 @@ void LSTMCell(int iters, int batch, int lstm_step) {
.
Input
(
"WeightImage"
)
.
Input
(
"BiasImage"
)
.
Input
(
"PreCellImage"
)
.
AddFloatArg
(
"
forget_add"
,
0.0
f
)
.
AddFloatArg
(
"
scalar_input"
,
forget_add
)
.
Output
(
"CellImage"
)
.
Output
(
"OutputImage"
)
.
Finalize
(
net
.
NewOperatorDef
());
...
...
@@ -79,27 +88,30 @@ void LSTMCell(int iters, int batch, int lstm_step) {
}
}
// namespace
#define MACE_BM_LSTMCELL_MACRO(N, LSTM_STEP, TYPE, DEVICE) \
static void MACE_BM_LSTMCELL_##N##_##LSTM_STEP##_##TYPE##_##DEVICE( \
#define MACE_BM_LSTMCELL_MACRO(N, INPUT_SIZE, HIDDEN_UNITS, TYPE, DEVICE) \
static void \
MACE_BM_LSTMCELL_##N##_##INPUT_SIZE##_##HIDDEN_UNITS##_##TYPE##_##DEVICE(\
int iters) { \
const int64_t macc = \
static_cast<int64_t>(iters) * N * 2 * LSTM_STEP * 4 * LSTM_STEP; \
const int64_t tot = static_cast<int64_t>(iters) * N * LSTM_STEP; \
static_cast<int64_t>( \
iters) * N * (INPUT_SIZE + HIDDEN_UNITS) * 4 * HIDDEN_UNITS; \
const int64_t tot = static_cast<int64_t>(iters) * N * INPUT_SIZE; \
mace::testing::MaccProcessed(macc); \
mace::testing::BytesProcessed(tot *
(sizeof(TYPE)));
\
LSTMCell<DEVICE, TYPE>(iters, N,
LSTM_STEP);
\
mace::testing::BytesProcessed(tot *
(sizeof(TYPE)));
\
LSTMCell<DEVICE, TYPE>(iters, N,
INPUT_SIZE, HIDDEN_UNITS);
\
} \
MACE_BENCHMARK(MACE_BM_LSTMCELL_##N##_##LSTM_STEP##_##TYPE##_##DEVICE)
MACE_BENCHMARK( \
MACE_BM_LSTMCELL_##N##_##INPUT_SIZE##_##HIDDEN_UNITS##_##TYPE##_##DEVICE)
#define MACE_BM_LSTMCELL(N, LSTM_STEP) \
MACE_BM_LSTMCELL_MACRO(N, LSTM_STEP, float, GPU); \
MACE_BM_LSTMCELL_MACRO(N, LSTM_STEP, half, GPU);
#define MACE_BM_LSTMCELL(N, INPUT_SIZE, HIDDEN_UNITS) \
MACE_BM_LSTMCELL_MACRO(N, INPUT_SIZE, HIDDEN_UNITS, float, CPU); \
MACE_BM_LSTMCELL_MACRO(N, INPUT_SIZE, HIDDEN_UNITS, float, GPU); \
MACE_BM_LSTMCELL_MACRO(N, INPUT_SIZE, HIDDEN_UNITS, half, GPU);
MACE_BM_LSTMCELL
(
1
,
200
);
MACE_BM_LSTMCELL
(
20
,
200
);
MACE_BM_LSTMCELL
(
20
,
320
);
MACE_BM_LSTMCELL
(
32
,
400
);
MACE_BM_LSTMCELL
(
32
,
640
);
MACE_BM_LSTMCELL
(
1
,
64
,
256
);
MACE_BM_LSTMCELL
(
30
,
64
,
256
);
MACE_BM_LSTMCELL
(
50
,
64
,
256
);
MACE_BM_LSTMCELL
(
80
,
64
,
256
);
}
// namespace test
}
// namespace ops
}
// namespace mace
mace/ops/lstmcell_test.cc
浏览文件 @
6e132756
...
...
@@ -14,6 +14,7 @@
#include "mace/core/operator.h"
#include "mace/kernels/eltwise.h"
#include "mace/ops/lstmcell_test_util.h"
#include "mace/ops/ops_test_util.h"
namespace
mace
{
...
...
@@ -23,128 +24,20 @@ namespace test {
class
LSTMCellTest
:
public
OpsTestBase
{};
namespace
{
template
<
typename
T
>
void
LSTMCellCPU
(
OpsTestNet
*
net
,
const
std
::
string
&
input_name
,
const
std
::
string
&
pre_output_name
,
const
std
::
string
&
weight_name
,
const
std
::
string
&
bias_name
,
const
std
::
string
&
pre_cell_name
,
const
float
&
forget_add_name
,
const
std
::
string
&
cell_name
,
const
std
::
string
&
output_name
)
{
OpDefBuilder
(
"Concat"
,
"Concat"
)
.
Input
(
input_name
)
.
Input
(
pre_output_name
)
.
AddIntArg
(
"axis"
,
1
)
.
Output
(
"ConcatOutput"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"MatMul"
,
"MatMul"
)
.
Input
(
"ConcatOutput"
)
.
Input
(
weight_name
)
.
Output
(
"MatMulOutput"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"BiasAdd"
,
"BiasAdd"
)
.
Input
(
"MatMulOutput"
)
.
Input
(
bias_name
)
.
Output
(
"BiasOutput"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Split"
,
"FCSplit"
)
.
Input
(
"BiasOutput"
)
.
AddIntArg
(
"axis"
,
1
)
.
Output
(
"SplitOutput0"
)
.
Output
(
"SplitOutput1"
)
.
Output
(
"SplitOutput2"
)
.
Output
(
"SplitOutput3"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Activation"
,
"InputSigmoid"
)
.
Input
(
"SplitOutput0"
)
.
AddStringArg
(
"activation"
,
"SIGMOID"
)
.
Output
(
"InputSigmoid"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Activation"
,
"NewInputTanh"
)
.
Input
(
"SplitOutput1"
)
.
AddStringArg
(
"activation"
,
"TANH"
)
.
Output
(
"NewInputTanh"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Eltwise"
,
"RememberMul"
)
.
Input
(
"InputSigmoid"
)
.
Input
(
"NewInputTanh"
)
.
AddIntArg
(
"T"
,
DataTypeToEnum
<
T
>::
v
())
.
AddIntArg
(
"type"
,
static_cast
<
int
>
(
kernels
::
EltwiseType
::
PROD
))
.
Output
(
"RememberMul"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Eltwise"
,
"ForgetAdd"
)
.
Input
(
"SplitOutput2"
)
.
AddFloatArg
(
"scalar_input"
,
forget_add_name
)
.
AddIntArg
(
"T"
,
DataTypeToEnum
<
T
>::
v
())
.
AddIntArg
(
"type"
,
static_cast
<
int
>
(
kernels
::
EltwiseType
::
SUM
))
.
Output
(
"ForgetAdd"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Activation"
,
"ForgetSigmoid"
)
.
Input
(
"ForgetAdd"
)
.
AddStringArg
(
"activation"
,
"SIGMOID"
)
.
Output
(
"ForgetSigmoid"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Eltwise"
,
"ForgetMul"
)
.
Input
(
"ForgetSigmoid"
)
.
Input
(
pre_cell_name
)
.
AddIntArg
(
"T"
,
DataTypeToEnum
<
T
>::
v
())
.
AddIntArg
(
"type"
,
static_cast
<
int
>
(
kernels
::
EltwiseType
::
PROD
))
.
Output
(
"ForgetMulPreCell"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Eltwise"
,
"Cell"
)
.
Input
(
"RememberMul"
)
.
Input
(
"ForgetMulPreCell"
)
.
AddIntArg
(
"T"
,
DataTypeToEnum
<
T
>::
v
())
.
AddIntArg
(
"type"
,
static_cast
<
int
>
(
kernels
::
EltwiseType
::
SUM
))
.
Output
(
cell_name
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Activation"
,
"CellTanh"
)
.
Input
(
cell_name
)
.
AddStringArg
(
"activation"
,
"TANH"
)
.
Output
(
"CellTanh"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Activation"
,
"OutputSigmoid"
)
.
Input
(
"SplitOutput3"
)
.
AddStringArg
(
"activation"
,
"SIGMOID"
)
.
Output
(
"OutputSigmoid"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Eltwise"
,
"FinalMul"
)
.
Input
(
"OutputSigmoid"
)
.
Input
(
"CellTanh"
)
.
AddIntArg
(
"T"
,
DataTypeToEnum
<
T
>::
v
())
.
AddIntArg
(
"type"
,
static_cast
<
int
>
(
kernels
::
EltwiseType
::
PROD
))
.
Output
(
output_name
)
.
Finalize
(
net
->
AddNewOperatorDef
());
}
template
<
DeviceType
D
,
typename
T
>
void
TestLSTMCell
(
const
uint32_t
&
batch
,
const
uint32_t
&
lstm_step
,
const
uint32_t
&
input_size
,
const
uint32_t
&
hidden_units
,
const
float
&
forget_add
)
{
// Construct graph
OpsTestNet
net
;
net
.
AddRandomInput
<
D
,
float
>
(
"Input"
,
{
batch
,
lstm_step
});
net
.
AddRandomInput
<
D
,
float
>
(
"PreOutput"
,
{
batch
,
lstm_step
});
net
.
AddRandomInput
<
D
,
float
>
(
"Weight"
,
{
2
*
lstm_step
,
4
*
lstm_step
});
net
.
AddRandomInput
<
D
,
float
>
(
"Bias"
,
{
4
*
lstm_step
});
net
.
AddRandomInput
<
D
,
float
>
(
"PreCell"
,
{
batch
,
lstm_step
});
net
.
AddRandomInput
<
D
,
float
>
(
"Input"
,
{
batch
,
input_size
});
net
.
AddRandomInput
<
D
,
float
>
(
"PreOutput"
,
{
batch
,
hidden_units
});
net
.
AddRandomInput
<
D
,
float
>
(
"Weight"
,
{
input_size
+
hidden_units
,
4
*
hidden_units
});
net
.
AddRandomInput
<
D
,
float
>
(
"Bias"
,
{
4
*
hidden_units
});
net
.
AddRandomInput
<
D
,
float
>
(
"PreCell"
,
{
batch
,
hidden_units
});
net
.
CopyData
<
DeviceType
::
CPU
,
float
>
(
"Input"
,
"InputCPU"
);
net
.
CopyData
<
DeviceType
::
CPU
,
float
>
(
"PreOutput"
,
"PreOutputCPU"
);
...
...
@@ -205,17 +98,17 @@ void TestLSTMCell(const uint32_t &batch,
}
// namespace
TEST_F
(
LSTMCellTest
,
OPENCLRandomHalf
)
{
TestLSTMCell
<
GPU
,
half
>
(
1
,
4
,
0.0
f
);
TestLSTMCell
<
GPU
,
half
>
(
2
,
16
,
0.0
f
);
TestLSTMCell
<
GPU
,
half
>
(
2
,
200
,
0.5
f
);
TestLSTMCell
<
GPU
,
half
>
(
20
,
320
,
0.5
f
);
TestLSTMCell
<
GPU
,
half
>
(
1
,
3
,
8
,
0.0
f
);
TestLSTMCell
<
GPU
,
half
>
(
2
,
16
,
24
,
0.0
f
);
TestLSTMCell
<
GPU
,
half
>
(
2
,
200
,
280
,
0.5
f
);
TestLSTMCell
<
GPU
,
half
>
(
20
,
320
,
512
,
0.5
f
);
}
TEST_F
(
LSTMCellTest
,
OPENCLRandomFloat
)
{
TestLSTMCell
<
GPU
,
float
>
(
1
,
4
,
0.0
f
);
TestLSTMCell
<
GPU
,
float
>
(
2
,
16
,
0.0
f
);
TestLSTMCell
<
GPU
,
float
>
(
2
,
200
,
0.5
f
);
TestLSTMCell
<
GPU
,
float
>
(
20
,
320
,
0.5
f
);
TestLSTMCell
<
GPU
,
float
>
(
1
,
3
,
8
,
0.0
f
);
TestLSTMCell
<
GPU
,
float
>
(
2
,
16
,
24
,
0.0
f
);
TestLSTMCell
<
GPU
,
float
>
(
2
,
200
,
280
,
0.5
f
);
TestLSTMCell
<
GPU
,
float
>
(
20
,
320
,
512
,
0.5
f
);
}
}
// namespace test
...
...
mace/ops/lstmcell_test_util.h
0 → 100644
浏览文件 @
6e132756
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_LSTMCELL_TEST_UTIL_H_
#define MACE_OPS_LSTMCELL_TEST_UTIL_H_
#include <string>
#include "mace/core/operator.h"
#include "mace/kernels/eltwise.h"
#include "mace/ops/ops_test_util.h"
namespace
mace
{
namespace
ops
{
namespace
test
{
template
<
typename
T
>
void
LSTMCellCPU
(
OpsTestNet
*
net
,
const
std
::
string
&
input_name
,
const
std
::
string
&
pre_output_name
,
const
std
::
string
&
weight_name
,
const
std
::
string
&
bias_name
,
const
std
::
string
&
pre_cell_name
,
const
float
&
forget_add_name
,
const
std
::
string
&
cell_name
,
const
std
::
string
&
output_name
)
{
OpDefBuilder
(
"Concat"
,
"Concat"
)
.
Input
(
input_name
)
.
Input
(
pre_output_name
)
.
AddIntArg
(
"axis"
,
1
)
.
Output
(
"ConcatOutput"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"MatMul"
,
"MatMul"
)
.
Input
(
"ConcatOutput"
)
.
Input
(
weight_name
)
.
Output
(
"MatMulOutput"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"BiasAdd"
,
"BiasAdd"
)
.
Input
(
"MatMulOutput"
)
.
Input
(
bias_name
)
.
Output
(
"BiasOutput"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Split"
,
"FCSplit"
)
.
Input
(
"BiasOutput"
)
.
AddIntArg
(
"axis"
,
1
)
.
Output
(
"SplitOutput0"
)
.
Output
(
"SplitOutput1"
)
.
Output
(
"SplitOutput2"
)
.
Output
(
"SplitOutput3"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Activation"
,
"InputSigmoid"
)
.
Input
(
"SplitOutput0"
)
.
AddStringArg
(
"activation"
,
"SIGMOID"
)
.
Output
(
"InputSigmoid"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Activation"
,
"NewInputTanh"
)
.
Input
(
"SplitOutput1"
)
.
AddStringArg
(
"activation"
,
"TANH"
)
.
Output
(
"NewInputTanh"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Eltwise"
,
"RememberMul"
)
.
Input
(
"InputSigmoid"
)
.
Input
(
"NewInputTanh"
)
.
AddIntArg
(
"T"
,
DataTypeToEnum
<
T
>::
v
())
.
AddIntArg
(
"type"
,
static_cast
<
int
>
(
kernels
::
EltwiseType
::
PROD
))
.
Output
(
"RememberMul"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Eltwise"
,
"ForgetAdd"
)
.
Input
(
"SplitOutput2"
)
.
AddFloatArg
(
"scalar_input"
,
forget_add_name
)
.
AddIntArg
(
"T"
,
DataTypeToEnum
<
T
>::
v
())
.
AddIntArg
(
"type"
,
static_cast
<
int
>
(
kernels
::
EltwiseType
::
SUM
))
.
Output
(
"ForgetAdd"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Activation"
,
"ForgetSigmoid"
)
.
Input
(
"ForgetAdd"
)
.
AddStringArg
(
"activation"
,
"SIGMOID"
)
.
Output
(
"ForgetSigmoid"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Eltwise"
,
"ForgetMul"
)
.
Input
(
"ForgetSigmoid"
)
.
Input
(
pre_cell_name
)
.
AddIntArg
(
"T"
,
DataTypeToEnum
<
T
>::
v
())
.
AddIntArg
(
"type"
,
static_cast
<
int
>
(
kernels
::
EltwiseType
::
PROD
))
.
Output
(
"ForgetMulPreCell"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Eltwise"
,
"Cell"
)
.
Input
(
"RememberMul"
)
.
Input
(
"ForgetMulPreCell"
)
.
AddIntArg
(
"T"
,
DataTypeToEnum
<
T
>::
v
())
.
AddIntArg
(
"type"
,
static_cast
<
int
>
(
kernels
::
EltwiseType
::
SUM
))
.
Output
(
cell_name
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Activation"
,
"CellTanh"
)
.
Input
(
cell_name
)
.
AddStringArg
(
"activation"
,
"TANH"
)
.
Output
(
"CellTanh"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Activation"
,
"OutputSigmoid"
)
.
Input
(
"SplitOutput3"
)
.
AddStringArg
(
"activation"
,
"SIGMOID"
)
.
Output
(
"OutputSigmoid"
)
.
Finalize
(
net
->
AddNewOperatorDef
());
OpDefBuilder
(
"Eltwise"
,
"FinalMul"
)
.
Input
(
"OutputSigmoid"
)
.
Input
(
"CellTanh"
)
.
AddIntArg
(
"T"
,
DataTypeToEnum
<
T
>::
v
())
.
AddIntArg
(
"type"
,
static_cast
<
int
>
(
kernels
::
EltwiseType
::
PROD
))
.
Output
(
output_name
)
.
Finalize
(
net
->
AddNewOperatorDef
());
}
}
// namespace test
}
// namespace ops
}
// namespace mace
#endif // MACE_OPS_LSTMCELL_TEST_UTIL_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录