Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
8bc63815
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8bc63815
编写于
2月 26, 2019
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix jitcodekey and refine test
test=develop
上级
7044cfa7
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
113 addition
and
158 deletion
+113
-158
paddle/fluid/operators/jit/kernel_key.cc
paddle/fluid/operators/jit/kernel_key.cc
+22
-5
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+91
-153
未找到文件。
paddle/fluid/operators/jit/kernel_key.cc
浏览文件 @
8bc63815
...
...
@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -23,14 +24,30 @@ size_t JitCodeKey<int>(const int& d) {
return
d
;
}
// TODO(TJ): refine and benchmark JitCodeKey generatation
constexpr
int
act_type_shift
=
3
;
// suppot 2^3 act types
static
inline
int
act_type_convert
(
KernelType
type
)
{
if
(
type
==
kVIdentity
)
{
return
0
;
}
else
if
(
type
==
kVExp
)
{
return
1
;
}
else
if
(
type
==
kVRelu
)
{
return
2
;
}
else
if
(
type
==
kVSigmoid
)
{
return
3
;
}
else
if
(
type
==
kVTanh
)
{
return
4
;
}
PADDLE_THROW
(
"Unsupported act type %d"
,
type
);
return
0
;
}
template
<
>
size_t
JitCodeKey
<
lstm_attr_t
>
(
const
lstm_attr_t
&
attr
)
{
size_t
key
=
attr
.
d
;
int
gate_key
=
static_cast
<
int
>
(
attr
.
act_gate
)
<<
1
;
int
cand_key
=
static_cast
<
int
>
(
attr
.
act_cand
)
<<
(
1
+
act_type_shift
);
int
cell_key
=
static_cast
<
int
>
(
attr
.
act_cell
)
<<
(
1
+
act_type_shift
*
2
);
int
gate_key
=
act_type_convert
(
attr
.
act_gate
)
<<
1
;
int
cand_key
=
act_type_convert
(
attr
.
act_cand
)
<<
(
1
+
act_type_shift
);
int
cell_key
=
act_type_convert
(
attr
.
act_cell
)
<<
(
1
+
act_type_shift
*
2
);
return
(
key
<<
(
1
+
act_type_shift
*
3
))
+
gate_key
+
cand_key
+
cell_key
+
attr
.
use_peephole
;
}
...
...
@@ -38,8 +55,8 @@ size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
template
<
>
size_t
JitCodeKey
<
gru_attr_t
>
(
const
gru_attr_t
&
attr
)
{
size_t
key
=
attr
.
d
;
return
(
key
<<
(
act_type_shift
*
2
))
+
static_cast
<
int
>
(
attr
.
act_gate
)
+
(
static_cast
<
int
>
(
attr
.
act_cand
)
<<
act_type_shift
);
return
(
key
<<
(
act_type_shift
*
2
))
+
act_type_convert
(
attr
.
act_gate
)
+
(
act_type_convert
(
attr
.
act_cand
)
<<
act_type_shift
);
}
template
<
>
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
8bc63815
...
...
@@ -40,11 +40,11 @@ template <typename T>
void
ExpectEQ
(
const
T
*
target
,
const
T
*
refer
,
size_t
n
)
{
if
(
std
::
is_floating_point
<
T
>::
value
)
{
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
EXPECT_NEAR
(
target
[
i
],
refer
[
i
],
FLAGS_acc
);
EXPECT_NEAR
(
target
[
i
],
refer
[
i
],
FLAGS_acc
)
<<
" at index : "
<<
i
;
}
}
else
{
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
EXPECT_EQ
(
target
[
i
],
refer
[
i
]);
EXPECT_EQ
(
target
[
i
],
refer
[
i
])
<<
" at index : "
<<
i
;
}
}
}
...
...
@@ -447,7 +447,7 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
XYZNKernel
()
{
void
Test
KernelXYZNTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
XYZNTuples
<
T
>>
();
...
...
@@ -480,7 +480,7 @@ void TestXYZNKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
AXYNKernel
()
{
void
Test
KernelAXYNTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
AXYNTuples
<
T
>>
();
...
...
@@ -506,7 +506,7 @@ void TestAXYNKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
XRNKernel
()
{
void
Test
KernelXRNTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
auto
last_acc
=
FLAGS_acc
;
FLAGS_acc
=
1e-4
;
...
...
@@ -524,7 +524,7 @@ void TestXRNKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
XYNKernel
()
{
void
Test
KernelXYNTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
XYNTuples
<
T
>>
();
...
...
@@ -549,10 +549,12 @@ void TestXYNKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
LSTMKernel
()
{
void
Test
KernelLSTMTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
std
::
vector
<
std
::
string
>
all_acts
=
{
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
};
for
(
int
d
:
TestSizes
())
{
auto
test_sizes
=
TestSizes
();
test_sizes
.
erase
(
std
::
remove
(
test_sizes
.
begin
(),
test_sizes
.
end
(),
1000
));
for
(
int
d
:
test_sizes
)
{
for
(
bool
use_peephole
:
{
true
,
false
})
{
for
(
auto
&
act_gate
:
all_acts
)
{
for
(
auto
&
act_cand
:
all_acts
)
{
...
...
@@ -599,10 +601,12 @@ void TestLSTMKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
GRUKernel
()
{
void
Test
KernelGRUTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
std
::
vector
<
std
::
string
>
all_acts
=
{
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
};
for
(
int
d
:
TestSizes
())
{
auto
test_sizes
=
TestSizes
();
test_sizes
.
erase
(
std
::
remove
(
test_sizes
.
begin
(),
test_sizes
.
end
(),
1000
));
for
(
int
d
:
test_sizes
)
{
for
(
auto
&
act_gate
:
all_acts
)
{
for
(
auto
&
act_cand
:
all_acts
)
{
const
jit
::
gru_attr_t
attr
(
d
,
jit
::
to_kerneltype
(
act_gate
),
...
...
@@ -633,14 +637,16 @@ void TestGRUKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
SeqPoolKernel
()
{
void
Test
KernelSeqPoolTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
jit
::
SeqPoolType
::
kSum
,
jit
::
SeqPoolType
::
kAvg
,
jit
::
SeqPoolType
::
kSqrt
};
auto
test_sizes
=
TestSizes
();
test_sizes
.
erase
(
std
::
remove
(
test_sizes
.
begin
(),
test_sizes
.
end
(),
1000
));
for
(
auto
type
:
pool_types
)
{
for
(
int
w
:
TestSizes
()
)
{
for
(
int
w
:
test_sizes
)
{
jit
::
seq_pool_attr_t
attr
(
w
,
type
);
for
(
int
h
:
TestSizes
()
)
{
for
(
int
h
:
test_sizes
)
{
attr
.
h
=
h
;
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
SeqPoolTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
...
...
@@ -658,11 +664,11 @@ void TestSeqPoolKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
MatMulKernel
()
{
void
Test
KernelMatMulTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
auto
last_acc
=
FLAGS_acc
;
//
TODO(intel): fix MKL acc issue
//
https://github.com/PaddlePaddle/Paddle/issues/15447
//
export MKL_CBWR=AVX would make MKL force to use AVX
//
export KMP_DETERMINISTIC_REDUCTION=yes would make the result deterministic
FLAGS_acc
=
1e-3
;
for
(
int
m
:
{
1
,
2
,
3
,
4
})
{
for
(
int
n
:
{
1
,
2
,
3
,
4
})
{
...
...
@@ -686,7 +692,7 @@ void TestMatMulKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
SoftmaxKernel
()
{
void
Test
KernelSoftmaxTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
for
(
int
bs
:
{
1
,
2
,
10
})
{
for
(
int
n
:
TestSizes
())
{
...
...
@@ -711,12 +717,14 @@ void TestSoftmaxKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
EmbSeqPoolKernel
()
{
void
Test
KernelEmbSeqPoolTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
int64_t
tbl_h
=
1e4
;
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
jit
::
SeqPoolType
::
kSum
};
// only support sum yet
for
(
int
tbl_w
:
TestSizes
())
{
auto
test_sizes
=
TestSizes
();
test_sizes
.
erase
(
std
::
remove
(
test_sizes
.
begin
(),
test_sizes
.
end
(),
1000
));
for
(
int
tbl_w
:
test_sizes
)
{
std
::
vector
<
T
>
table
(
tbl_h
*
tbl_w
);
RandomVec
<
T
>
(
tbl_h
*
tbl_w
,
table
.
data
(),
-
2.
f
,
2.
f
);
const
T
*
table_data
=
table
.
data
();
...
...
@@ -745,7 +753,7 @@ void TestEmbSeqPoolKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
SgdKernel
()
{
void
Test
KernelSgdTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
const
T
lr
=
0.1
;
auto
UnDuplicatedRandomVec
=
[](
int
n
,
const
int64_t
lower
,
...
...
@@ -799,7 +807,7 @@ void TestSgdKernel() {
}
template
<
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
NCHW16CMulNCKernel
()
{
void
Test
KernelNCHW16CMulNCTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
const
int
n
=
3
,
c
=
16
*
4
,
h
=
10
,
w
=
10
;
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
NCHW16CMulNCTuples
<
T
>>
();
...
...
@@ -852,7 +860,7 @@ void TestNCHW16CMulNCKernel() {
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
LayerNormKernel
()
{
void
Test
KernelLayerNormTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
const
T
epsilon
=
9.99999975e-06
;
for
(
int
n
:
{
1
,
2
,
10
})
{
...
...
@@ -891,11 +899,13 @@ void TestLayerNormKernel() {
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
Test
CRFDecodingKernel
()
{
void
Test
KernelCRFDecodingTuples
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
constexpr
int
state_trans_base_idx
=
2
;
auto
test_sizes
=
TestSizes
();
test_sizes
.
erase
(
std
::
remove
(
test_sizes
.
begin
(),
test_sizes
.
end
(),
1000
));
for
(
int
seq_len
:
{
1
,
11
,
17
,
50
})
{
for
(
int
tag_num
:
TestSizes
()
)
{
for
(
int
tag_num
:
test_sizes
)
{
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
CRFDecodingTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
int
x_sz
=
seq_len
*
tag_num
;
...
...
@@ -916,148 +926,76 @@ void TestCRFDecodingKernel() {
}
}
// XYZNTuple
TEST
(
JITKernel
,
kVMul
)
{
TestXYZNKernel
<
jit
::
kVMul
,
float
,
CPUPlace
>
();
TestXYZNKernel
<
jit
::
kVMul
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVAdd
)
{
TestXYZNKernel
<
jit
::
kVAdd
,
float
,
CPUPlace
>
();
TestXYZNKernel
<
jit
::
kVAdd
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVAddRelu
)
{
TestXYZNKernel
<
jit
::
kVAddRelu
,
float
,
CPUPlace
>
();
TestXYZNKernel
<
jit
::
kVAddRelu
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVSub
)
{
TestXYZNKernel
<
jit
::
kVSub
,
float
,
CPUPlace
>
();
TestXYZNKernel
<
jit
::
kVSub
,
double
,
CPUPlace
>
();
}
// AXYNTuples
TEST
(
JITKernel
,
kVScal
)
{
TestAXYNKernel
<
jit
::
kVScal
,
float
,
CPUPlace
>
();
TestAXYNKernel
<
jit
::
kVScal
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVAddBias
)
{
TestAXYNKernel
<
jit
::
kVAddBias
,
float
,
CPUPlace
>
();
TestAXYNKernel
<
jit
::
kVAddBias
,
double
,
CPUPlace
>
();
}
// XRNTuples
TEST
(
JITKernel
,
kHMax
)
{
TestXRNKernel
<
jit
::
kHMax
,
float
,
CPUPlace
>
();
TestXRNKernel
<
jit
::
kHMax
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kHSum
)
{
TestXRNKernel
<
jit
::
kHSum
,
float
,
CPUPlace
>
();
TestXRNKernel
<
jit
::
kHSum
,
double
,
CPUPlace
>
();
}
// XYNTuples
TEST
(
JITKernel
,
kVRelu
)
{
TestXYNKernel
<
jit
::
kVRelu
,
float
,
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVRelu
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVIdentity
)
{
TestXYNKernel
<
jit
::
kVIdentity
,
float
,
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVIdentity
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVSquare
)
{
TestXYNKernel
<
jit
::
kVSquare
,
float
,
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVSquare
,
double
,
CPUPlace
>
();
}
#define TEST_CPU_KERNEL(test_tuple, kernel_type) \
TEST(JITKernel, kernel_type) { \
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
}
TEST
(
JITKernel
,
kVExp
)
{
TestXYNKernel
<
jit
::
kVExp
,
float
,
CPUPlace
>
(
);
TestXYNKernel
<
jit
::
kVExp
,
double
,
CPUPlace
>
(
);
}
TEST
_CPU_KERNEL
(
XYZNTuples
,
kVMul
);
TEST_CPU_KERNEL
(
XYZNTuples
,
kVAdd
);
TEST_CPU_KERNEL
(
XYZNTuples
,
kVAddRelu
);
TEST_CPU_KERNEL
(
XYZNTuples
,
kVSub
);
TEST
(
JITKernel
,
kVSigmoid
)
{
TestXYNKernel
<
jit
::
kVSigmoid
,
float
,
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVSigmoid
,
double
,
CPUPlace
>
();
}
TEST_CPU_KERNEL
(
AXYNTuples
,
kVScal
);
TEST_CPU_KERNEL
(
AXYNTuples
,
kVAddBias
);
TEST
(
JITKernel
,
kVTanh
)
{
TestXYNKernel
<
jit
::
kVTanh
,
float
,
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVTanh
,
double
,
CPUPlace
>
();
}
TEST_CPU_KERNEL
(
XRNTuples
,
kHMax
);
TEST_CPU_KERNEL
(
XRNTuples
,
kHSum
);
// LSTM
TEST
(
JITKernel
,
kLSTMCtHt
)
{
TestLSTMKernel
<
jit
::
kLSTMCtHt
,
float
,
CPUPlace
>
();
TestLSTMKernel
<
jit
::
kLSTMCtHt
,
double
,
CPUPlace
>
();
}
TEST_CPU_KERNEL
(
XYNTuples
,
kVRelu
);
TEST_CPU_KERNEL
(
XYNTuples
,
kVIdentity
);
TEST_CPU_KERNEL
(
XYNTuples
,
kVSquare
);
TEST_CPU_KERNEL
(
XYNTuples
,
kVExp
);
TEST_CPU_KERNEL
(
XYNTuples
,
kVSigmoid
);
TEST_CPU_KERNEL
(
XYNTuples
,
kVTanh
);
TEST
(
JITKernel
,
kLSTMC1H1
)
{
TestLSTMKernel
<
jit
::
kLSTMC1H1
,
float
,
CPUPlace
>
();
TestLSTMKernel
<
jit
::
kLSTMC1H1
,
double
,
CPUPlace
>
();
}
TEST_CPU_KERNEL
(
LSTMTuples
,
kLSTMCtHt
);
TEST_CPU_KERNEL
(
LSTMTuples
,
kLSTMC1H1
);
// GRU
TEST
(
JITKernel
,
kGRUH1
)
{
TestGRUKernel
<
jit
::
kGRUH1
,
float
,
CPUPlace
>
();
TestGRUKernel
<
jit
::
kGRUH1
,
double
,
CPUPlace
>
();
}
TEST_CPU_KERNEL
(
GRUTuples
,
kGRUH1
);
TEST_CPU_KERNEL
(
GRUTuples
,
kGRUHtPart1
);
TEST_CPU_KERNEL
(
GRUTuples
,
kGRUHtPart2
);
TEST
(
JITKernel
,
kGRUHtPart1
)
{
TestGRUKernel
<
jit
::
kGRUHtPart1
,
float
,
CPUPlace
>
();
TestGRUKernel
<
jit
::
kGRUHtPart1
,
double
,
CPUPlace
>
();
}
TEST_CPU_KERNEL
(
NCHW16CMulNCTuples
,
kNCHW16CMulNC
);
TEST
(
JITKernel
,
kGRUHtPart2
)
{
TestGRUKernel
<
jit
::
kGRUHtPart2
,
float
,
CPUPlace
>
();
TestGRUKernel
<
jit
::
kGRUHtPart2
,
double
,
CPUPlace
>
();
}
TEST_CPU_KERNEL
(
SeqPoolTuples
,
kSeqPool
);
TEST_CPU_KERNEL
(
MatMulTuples
,
kMatMul
);
TEST_CPU_KERNEL
(
SoftmaxTuples
,
kSoftmax
);
TEST_CPU_KERNEL
(
EmbSeqPoolTuples
,
kEmbSeqPool
);
TEST_CPU_KERNEL
(
SgdTuples
,
kSgd
);
TEST_CPU_KERNEL
(
LayerNormTuples
,
kLayerNorm
);
TEST_CPU_KERNEL
(
CRFDecodingTuples
,
kCRFDecoding
);
TEST
(
JITKernel
,
kSeqPool
)
{
TestSeqPoolKernel
<
jit
::
kSeqPool
,
float
,
CPUPlace
>
();
TestSeqPoolKernel
<
jit
::
kSeqPool
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kMatMul
)
{
TestMatMulKernel
<
jit
::
kMatMul
,
float
,
CPUPlace
>
();
TestMatMulKernel
<
jit
::
kMatMul
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel
,
kSoftmax
)
{
TestSoftmaxKernel
<
jit
::
kSoftmax
,
float
,
CPUPlace
>
();
TestSoftmaxKernel
<
jit
::
kSoftmax
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel_key
,
lstm
)
{
jit
::
lstm_attr_t
attr1
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr2
(
9
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr3
(
9
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr4
(
9
,
jit
::
kVRelu
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
TEST
(
JITKernel
,
kEmbSeqPool
)
{
TestEmbSeqPoolKernel
<
jit
::
kEmbSeqPool
,
float
,
CPUPlace
>
(
);
TestEmbSeqPoolKernel
<
jit
::
kEmbSeqPool
,
double
,
CPUPlace
>
(
);
}
auto
key1
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr4
);
TEST
(
JITKernel
,
kSgd
)
{
TestSgdKernel
<
jit
::
kSgd
,
float
,
CPUPlace
>
(
);
TestSgdKernel
<
jit
::
kSgd
,
double
,
CPUPlace
>
(
);
EXPECT_TRUE
(
key1
!=
key2
);
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key3
!=
key4
);
}
TEST
(
JITKernel
,
kNCHW16CMulNC
)
{
TestNCHW16CMulNCKernel
<
jit
::
kNCHW16CMulNC
,
float
,
CPUPlace
>
();
TestNCHW16CMulNCKernel
<
jit
::
kNCHW16CMulNC
,
double
,
CPUPlace
>
();
}
TEST
(
JITKernel_key
,
gru
)
{
jit
::
gru_attr_t
attr1
(
8
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr2
(
9
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr3
(
9
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr4
(
9
,
jit
::
kVSigmoid
,
jit
::
kVIdentity
);
TEST
(
JITKernel
,
kLayerNorm
)
{
TestLayerNormKernel
<
jit
::
kLayerNorm
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestLayerNormKernel
<
jit
::
kLayerNorm
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kCRFDecoding
)
{
TestCRFDecodingKernel
<
jit
::
kCRFDecoding
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestCRFDecodingKernel
<
jit
::
kCRFDecoding
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
auto
key1
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr4
);
TEST
(
JITKernel
,
pool
)
{
// TODO(TJ): add some test
EXPECT_TRUE
(
key1
!=
key2
);
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key3
!=
key4
);
}
// TODO(TJ): add more test about key and pool
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录