Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
0eefad0a
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0eefad0a
编写于
2月 26, 2019
作者:
T
tensor-tang
提交者:
ceci3
3月 04, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix jitcodekey and refine test
test=develop
上级
ce4cc482
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
113 addition
and
158 deletion
+113
-158
paddle/fluid/operators/jit/kernel_key.cc
paddle/fluid/operators/jit/kernel_key.cc
+22
-5
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+91
-153
未找到文件。
paddle/fluid/operators/jit/kernel_key.cc
浏览文件 @
0eefad0a
...
...
@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -23,14 +24,30 @@ size_t JitCodeKey<int>(const int& d) {
return
d
;
}
// TODO(TJ): refine and benchmark JitCodeKey generatation
constexpr
int
act_type_shift
=
3
;
// suppot 2^3 act types
static
inline
int
act_type_convert
(
KernelType
type
)
{
if
(
type
==
kVIdentity
)
{
return
0
;
}
else
if
(
type
==
kVExp
)
{
return
1
;
}
else
if
(
type
==
kVRelu
)
{
return
2
;
}
else
if
(
type
==
kVSigmoid
)
{
return
3
;
}
else
if
(
type
==
kVTanh
)
{
return
4
;
}
PADDLE_THROW
(
"Unsupported act type %d"
,
type
);
return
0
;
}
template
<
>
size_t
JitCodeKey
<
lstm_attr_t
>
(
const
lstm_attr_t
&
attr
)
{
size_t
key
=
attr
.
d
;
int
gate_key
=
static_cast
<
int
>
(
attr
.
act_gate
)
<<
1
;
int
cand_key
=
static_cast
<
int
>
(
attr
.
act_cand
)
<<
(
1
+
act_type_shift
);
int
cell_key
=
static_cast
<
int
>
(
attr
.
act_cell
)
<<
(
1
+
act_type_shift
*
2
);
int
gate_key
=
act_type_convert
(
attr
.
act_gate
)
<<
1
;
int
cand_key
=
act_type_convert
(
attr
.
act_cand
)
<<
(
1
+
act_type_shift
);
int
cell_key
=
act_type_convert
(
attr
.
act_cell
)
<<
(
1
+
act_type_shift
*
2
);
return
(
key
<<
(
1
+
act_type_shift
*
3
))
+
gate_key
+
cand_key
+
cell_key
+
attr
.
use_peephole
;
}
...
...
@@ -38,8 +55,8 @@ size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
template
<
>
size_t
JitCodeKey
<
gru_attr_t
>
(
const
gru_attr_t
&
attr
)
{
size_t
key
=
attr
.
d
;
return
(
key
<<
(
act_type_shift
*
2
))
+
static_cast
<
int
>
(
attr
.
act_gate
)
+
(
static_cast
<
int
>
(
attr
.
act_cand
)
<<
act_type_shift
);
return
(
key
<<
(
act_type_shift
*
2
))
+
act_type_convert
(
attr
.
act_gate
)
+
(
act_type_convert
(
attr
.
act_cand
)
<<
act_type_shift
);
}
template
<
>
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
0eefad0a
...
...
@@ -40,11 +40,11 @@ template <typename T>
void
ExpectEQ
(
const
T
*
target
,
const
T
*
refer
,
size_t
n
)
{
if
(
std
::
is_floating_point
<
T
>::
value
)
{
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
EXPECT_NEAR
(
target
[
i
],
refer
[
i
],
FLAGS_acc
);
EXPECT_NEAR
(
target
[
i
],
refer
[
i
],
FLAGS_acc
)
<<
" at index : "
<<
i
;
}
}
else
{
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
EXPECT_EQ
(
target
[
i
],
refer
[
i
]);
EXPECT_EQ
(
target
[
i
],
refer
[
i
])
<<
" at index : "
<<
i
;
}
}
}
...
...
@@ -447,7 +447,7 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
XYZNKernel
()
{
void
Test
KernelXYZNTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
XYZNTuples
<
T
>>
();
...
...
@@ -480,7 +480,7 @@ void TestXYZNKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
AXYNKernel
()
{
void
Test
KernelAXYNTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
AXYNTuples
<
T
>>
();
...
...
@@ -506,7 +506,7 @@ void TestAXYNKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
XRNKernel
()
{
void
Test
KernelXRNTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
auto
last_acc
=
FLAGS_acc
;
FLAGS_acc
=
1e-4
;
...
...
@@ -524,7 +524,7 @@ void TestXRNKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
XYNKernel
()
{
void
Test
KernelXYNTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
XYNTuples
<
T
>>
();
...
...
@@ -549,10 +549,12 @@ void TestXYNKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
LSTMKernel
()
{
void
Test
KernelLSTMTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
std
::
vector
<
std
::
string
>
all_acts
=
{
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
};
for
(
int
d
:
TestSizes
())
{
auto
test_sizes
=
TestSizes
();
test_sizes
.
erase
(
std
::
remove
(
test_sizes
.
begin
(),
test_sizes
.
end
(),
1000
));
for
(
int
d
:
test_sizes
)
{
for
(
bool
use_peephole
:
{
true
,
false
})
{
for
(
auto
&
act_gate
:
all_acts
)
{
for
(
auto
&
act_cand
:
all_acts
)
{
...
...
@@ -599,10 +601,12 @@ void TestLSTMKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
GRUKernel
()
{
void
Test
KernelGRUTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
std
::
vector
<
std
::
string
>
all_acts
=
{
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
};
for
(
int
d
:
TestSizes
())
{
auto
test_sizes
=
TestSizes
();
test_sizes
.
erase
(
std
::
remove
(
test_sizes
.
begin
(),
test_sizes
.
end
(),
1000
));
for
(
int
d
:
test_sizes
)
{
for
(
auto
&
act_gate
:
all_acts
)
{
for
(
auto
&
act_cand
:
all_acts
)
{
const
jit
::
gru_attr_t
attr
(
d
,
jit
::
to_kerneltype
(
act_gate
),
...
...
@@ -633,14 +637,16 @@ void TestGRUKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
SeqPoolKernel
()
{
void
Test
KernelSeqPoolTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
jit
::
SeqPoolType
::
kSum
,
jit
::
SeqPoolType
::
kAvg
,
jit
::
SeqPoolType
::
kSqrt
};
auto
test_sizes
=
TestSizes
();
test_sizes
.
erase
(
std
::
remove
(
test_sizes
.
begin
(),
test_sizes
.
end
(),
1000
));
for
(
auto
type
:
pool_types
)
{
for
(
int
w
:
TestSizes
()
)
{
for
(
int
w
:
test_sizes
)
{
jit
::
seq_pool_attr_t
attr
(
w
,
type
);
for
(
int
h
:
TestSizes
()
)
{
for
(
int
h
:
test_sizes
)
{
attr
.
h
=
h
;
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
SeqPoolTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
...
...
@@ -658,11 +664,11 @@ void TestSeqPoolKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
MatMulKernel
()
{
void
Test
KernelMatMulTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
auto
last_acc
=
FLAGS_acc
;
//
TODO(intel): fix MKL acc issue
//
https://github.com/PaddlePaddle/Paddle/issues/15447
//
export MKL_CBWR=AVX would make MKL force to use AVX
//
export KMP_DETERMINISTIC_REDUCTION=yes would make the result deterministic
FLAGS_acc
=
1e-3
;
for
(
int
m
:
{
1
,
2
,
3
,
4
})
{
for
(
int
n
:
{
1
,
2
,
3
,
4
})
{
...
...
@@ -686,7 +692,7 @@ void TestMatMulKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
SoftmaxKernel
()
{
void
Test
KernelSoftmaxTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
for
(
int
bs
:
{
1
,
2
,
10
})
{
for
(
int
n
:
TestSizes
())
{
...
...
@@ -711,12 +717,14 @@ void TestSoftmaxKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
EmbSeqPoolKernel
()
{
void
Test
KernelEmbSeqPoolTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
int64_t
tbl_h
=
1e4
;
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
jit
::
SeqPoolType
::
kSum
};
// only support sum yet
for
(
int
tbl_w
:
TestSizes
())
{
auto
test_sizes
=
TestSizes
();
test_sizes
.
erase
(
std
::
remove
(
test_sizes
.
begin
(),
test_sizes
.
end
(),
1000
));
for
(
int
tbl_w
:
test_sizes
)
{
std
::
vector
<
T
>
table
(
tbl_h
*
tbl_w
);
RandomVec
<
T
>
(
tbl_h
*
tbl_w
,
table
.
data
(),
-
2.
f
,
2.
f
);
const
T
*
table_data
=
table
.
data
();
...
...
@@ -745,7 +753,7 @@ void TestEmbSeqPoolKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
SgdKernel
()
{
void
Test
KernelSgdTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
const
T
lr
=
0.1
;
auto
UnDuplicatedRandomVec
=
[](
int
n
,
const
int64_t
lower
,
...
...
@@ -799,7 +807,7 @@ void TestSgdKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
NCHW16CMulNCKernel
()
{
void
Test
KernelNCHW16CMulNCTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
const
int
n
=
3
,
c
=
16
*
4
,
h
=
10
,
w
=
10
;
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
NCHW16CMulNCTuples
<
T
>>
();
...
...
@@ -852,7 +860,7 @@ void TestNCHW16CMulNCKernel() {
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
LayerNormKernel
()
{
void
Test
KernelLayerNormTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
const
T
epsilon
=
9.99999975e-06
;
for
(
int
n
:
{
1
,
2
,
10
})
{
...
...
@@ -891,11 +899,13 @@ void TestLayerNormKernel() {
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
CRFDecodingKernel
()
{
void
Test
KernelCRFDecodingTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
constexpr
int
state_trans_base_idx
=
2
;
auto
test_sizes
=
TestSizes
();
test_sizes
.
erase
(
std
::
remove
(
test_sizes
.
begin
(),
test_sizes
.
end
(),
1000
));
for
(
int
seq_len
:
{
1
,
11
,
17
,
50
})
{
for
(
int
tag_num
:
TestSizes
()
)
{
for
(
int
tag_num
:
test_sizes
)
{
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
CRFDecodingTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
int
x_sz
=
seq_len
*
tag_num
;
...
...
@@ -916,148 +926,76 @@ void TestCRFDecodingKernel() {
}
}
// XYZNTuple
TEST
(
JITKernel
,
kVMul
)
{
TestXYZNKernel
<
jit
::
kVMul
,
float
,
CPUPlace
>
();
TestXYZNKernel
<
jit
::
kVMul
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVAdd
)
{
TestXYZNKernel
<
jit
::
kVAdd
,
float
,
CPUPlace
>
();
TestXYZNKernel
<
jit
::
kVAdd
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVAddRelu
)
{
TestXYZNKernel
<
jit
::
kVAddRelu
,
float
,
CPUPlace
>
();
TestXYZNKernel
<
jit
::
kVAddRelu
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVSub
)
{
TestXYZNKernel
<
jit
::
kVSub
,
float
,
CPUPlace
>
();
TestXYZNKernel
<
jit
::
kVSub
,
double
,
CPUPlace
>
();
}
// AXYNTuples
TEST
(
JITKernel
,
kVScal
)
{
TestAXYNKernel
<
jit
::
kVScal
,
float
,
CPUPlace
>
();
TestAXYNKernel
<
jit
::
kVScal
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVAddBias
)
{
TestAXYNKernel
<
jit
::
kVAddBias
,
float
,
CPUPlace
>
();
TestAXYNKernel
<
jit
::
kVAddBias
,
double
,
CPUPlace
>
();
}
// XRNTuples
TEST
(
JITKernel
,
kHMax
)
{
TestXRNKernel
<
jit
::
kHMax
,
float
,
CPUPlace
>
();
TestXRNKernel
<
jit
::
kHMax
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kHSum
)
{
TestXRNKernel
<
jit
::
kHSum
,
float
,
CPUPlace
>
();
TestXRNKernel
<
jit
::
kHSum
,
double
,
CPUPlace
>
();
}
// XYNTuples
TEST
(
JITKernel
,
kVRelu
)
{
TestXYNKernel
<
jit
::
kVRelu
,
float
,
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVRelu
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVIdentity
)
{
TestXYNKernel
<
jit
::
kVIdentity
,
float
,
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVIdentity
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVSquare
)
{
TestXYNKernel
<
jit
::
kVSquare
,
float
,
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVSquare
,
double
,
CPUPlace
>
();
}
#define TEST_CPU_KERNEL(test_tuple, kernel_type) \
TEST(JITKernel, kernel_type) { \
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
}
TEST
(
JITKernel
,
kVExp
)
{
TestXYNKernel
<
jit
::
kVExp
,
float
,
CPUPlace
>
(
);
TestXYNKernel
<
jit
::
kVExp
,
double
,
CPUPlace
>
(
);
}
TEST
_CPU_KERNEL
(
XYZNTuples
,
kVMul
);
TEST_CPU_KERNEL
(
XYZNTuples
,
kVAdd
);
TEST_CPU_KERNEL
(
XYZNTuples
,
kVAddRelu
);
TEST_CPU_KERNEL
(
XYZNTuples
,
kVSub
);
TEST
(
JITKernel
,
kVSigmoid
)
{
TestXYNKernel
<
jit
::
kVSigmoid
,
float
,
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVSigmoid
,
double
,
CPUPlace
>
();
}
TEST_CPU_KERNEL
(
AXYNTuples
,
kVScal
);
TEST_CPU_KERNEL
(
AXYNTuples
,
kVAddBias
);
TEST
(
JITKernel
,
kVTanh
)
{
TestXYNKernel
<
jit
::
kVTanh
,
float
,
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVTanh
,
double
,
CPUPlace
>
();
}
TEST_CPU_KERNEL
(
XRNTuples
,
kHMax
);
TEST_CPU_KERNEL
(
XRNTuples
,
kHSum
);
// LSTM
TEST
(
JITKernel
,
kLSTMCtHt
)
{
TestLSTMKernel
<
jit
::
kLSTMCtHt
,
float
,
CPUPlace
>
();
TestLSTMKernel
<
jit
::
kLSTMCtHt
,
double
,
CPUPlace
>
();
}
TEST_CPU_KERNEL
(
XYNTuples
,
kVRelu
);
TEST_CPU_KERNEL
(
XYNTuples
,
kVIdentity
);
TEST_CPU_KERNEL
(
XYNTuples
,
kVSquare
);
TEST_CPU_KERNEL
(
XYNTuples
,
kVExp
);
TEST_CPU_KERNEL
(
XYNTuples
,
kVSigmoid
);
TEST_CPU_KERNEL
(
XYNTuples
,
kVTanh
);
TEST
(
JITKernel
,
kLSTMC1H1
)
{
TestLSTMKernel
<
jit
::
kLSTMC1H1
,
float
,
CPUPlace
>
();
TestLSTMKernel
<
jit
::
kLSTMC1H1
,
double
,
CPUPlace
>
();
}
TEST_CPU_KERNEL
(
LSTMTuples
,
kLSTMCtHt
);
TEST_CPU_KERNEL
(
LSTMTuples
,
kLSTMC1H1
);
// GRU
TEST
(
JITKernel
,
kGRUH1
)
{
TestGRUKernel
<
jit
::
kGRUH1
,
float
,
CPUPlace
>
();
TestGRUKernel
<
jit
::
kGRUH1
,
double
,
CPUPlace
>
();
}
TEST_CPU_KERNEL
(
GRUTuples
,
kGRUH1
);
TEST_CPU_KERNEL
(
GRUTuples
,
kGRUHtPart1
);
TEST_CPU_KERNEL
(
GRUTuples
,
kGRUHtPart2
);
TEST
(
JITKernel
,
kGRUHtPart1
)
{
TestGRUKernel
<
jit
::
kGRUHtPart1
,
float
,
CPUPlace
>
();
TestGRUKernel
<
jit
::
kGRUHtPart1
,
double
,
CPUPlace
>
();
}
TEST_CPU_KERNEL
(
NCHW16CMulNCTuples
,
kNCHW16CMulNC
);
TEST
(
JITKernel
,
kGRUHtPart2
)
{
TestGRUKernel
<
jit
::
kGRUHtPart2
,
float
,
CPUPlace
>
();
TestGRUKernel
<
jit
::
kGRUHtPart2
,
double
,
CPUPlace
>
();
}
TEST_CPU_KERNEL
(
SeqPoolTuples
,
kSeqPool
);
TEST_CPU_KERNEL
(
MatMulTuples
,
kMatMul
);
TEST_CPU_KERNEL
(
SoftmaxTuples
,
kSoftmax
);
TEST_CPU_KERNEL
(
EmbSeqPoolTuples
,
kEmbSeqPool
);
TEST_CPU_KERNEL
(
SgdTuples
,
kSgd
);
TEST_CPU_KERNEL
(
LayerNormTuples
,
kLayerNorm
);
TEST_CPU_KERNEL
(
CRFDecodingTuples
,
kCRFDecoding
);
TEST
(
JITKernel
,
kSeqPool
)
{
TestSeqPoolKernel
<
jit
::
kSeqPool
,
float
,
CPUPlace
>
();
TestSeqPoolKernel
<
jit
::
kSeqPool
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kMatMul
)
{
TestMatMulKernel
<
jit
::
kMatMul
,
float
,
CPUPlace
>
();
TestMatMulKernel
<
jit
::
kMatMul
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kSoftmax
)
{
TestSoftmaxKernel
<
jit
::
kSoftmax
,
float
,
CPUPlace
>
();
TestSoftmaxKernel
<
jit
::
kSoftmax
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel_key
,
lstm
)
{
jit
::
lstm_attr_t
attr1
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr2
(
9
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr3
(
9
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr4
(
9
,
jit
::
kVRelu
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
TEST
(
JITKernel
,
kEmbSeqPool
)
{
TestEmbSeqPoolKernel
<
jit
::
kEmbSeqPool
,
float
,
CPUPlace
>
(
);
TestEmbSeqPoolKernel
<
jit
::
kEmbSeqPool
,
double
,
CPUPlace
>
(
);
}
auto
key1
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr4
);
TEST
(
JITKernel
,
kSgd
)
{
TestSgdKernel
<
jit
::
kSgd
,
float
,
CPUPlace
>
(
);
TestSgdKernel
<
jit
::
kSgd
,
double
,
CPUPlace
>
(
);
EXPECT_TRUE
(
key1
!=
key2
);
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key3
!=
key4
);
}
TEST
(
JITKernel
,
kNCHW16CMulNC
)
{
TestNCHW16CMulNCKernel
<
jit
::
kNCHW16CMulNC
,
float
,
CPUPlace
>
();
TestNCHW16CMulNCKernel
<
jit
::
kNCHW16CMulNC
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel_key
,
gru
)
{
jit
::
gru_attr_t
attr1
(
8
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr2
(
9
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr3
(
9
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr4
(
9
,
jit
::
kVSigmoid
,
jit
::
kVIdentity
);
TEST
(
JITKernel
,
kLayerNorm
)
{
TestLayerNormKernel
<
jit
::
kLayerNorm
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestLayerNormKernel
<
jit
::
kLayerNorm
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kCRFDecoding
)
{
TestCRFDecodingKernel
<
jit
::
kCRFDecoding
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestCRFDecodingKernel
<
jit
::
kCRFDecoding
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
auto
key1
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr4
);
TEST
(
JITKernel
,
pool
)
{
// TODO(TJ): add some test
EXPECT_TRUE
(
key1
!=
key2
);
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key3
!=
key4
);
}
// TODO(TJ): add more test about key and pool
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录