Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
cfc83c14
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cfc83c14
编写于
3月 11, 2019
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine jitcodekey and enhance unit tests
test=develop
上级
6ff230a6
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
211 addition
and
91 deletion
+211
-91
paddle/fluid/operators/jit/gen/act.cc
paddle/fluid/operators/jit/gen/act.cc
+1
-0
paddle/fluid/operators/jit/gen/blas.cc
paddle/fluid/operators/jit/gen/blas.cc
+1
-0
paddle/fluid/operators/jit/gen/embseqpool.cc
paddle/fluid/operators/jit/gen/embseqpool.cc
+1
-0
paddle/fluid/operators/jit/gen/gru.cc
paddle/fluid/operators/jit/gen/gru.cc
+1
-0
paddle/fluid/operators/jit/gen/hopv.cc
paddle/fluid/operators/jit/gen/hopv.cc
+1
-0
paddle/fluid/operators/jit/gen/lstm.cc
paddle/fluid/operators/jit/gen/lstm.cc
+1
-0
paddle/fluid/operators/jit/gen/matmul.cc
paddle/fluid/operators/jit/gen/matmul.cc
+1
-1
paddle/fluid/operators/jit/gen/seqpool.cc
paddle/fluid/operators/jit/gen/seqpool.cc
+1
-0
paddle/fluid/operators/jit/gen/sgd.cc
paddle/fluid/operators/jit/gen/sgd.cc
+1
-0
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+1
-1
paddle/fluid/operators/jit/kernel_key.cc
paddle/fluid/operators/jit/kernel_key.cc
+17
-44
paddle/fluid/operators/jit/kernel_key.h
paddle/fluid/operators/jit/kernel_key.h
+1
-1
paddle/fluid/operators/jit/kernel_pool.h
paddle/fluid/operators/jit/kernel_pool.h
+3
-3
paddle/fluid/operators/jit/registry.h
paddle/fluid/operators/jit/registry.h
+1
-0
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+179
-41
未找到文件。
paddle/fluid/operators/jit/gen/act.cc
浏览文件 @
cfc83c14
...
...
@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/act.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/gen/blas.cc
浏览文件 @
cfc83c14
...
...
@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/blas.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/gen/embseqpool.cc
浏览文件 @
cfc83c14
...
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/embseqpool.h"
#include <stddef.h> // offsetof
#include <memory>
#include <vector>
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h"
...
...
paddle/fluid/operators/jit/gen/gru.cc
浏览文件 @
cfc83c14
...
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/gru.h"
#include <stddef.h> // offsetof
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/gen/hopv.cc
浏览文件 @
cfc83c14
...
...
@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/hopv.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/gen/lstm.cc
浏览文件 @
cfc83c14
...
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/lstm.h"
#include <stddef.h> // offsetof
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/gen/matmul.cc
浏览文件 @
cfc83c14
...
...
@@ -14,8 +14,8 @@
#include "paddle/fluid/operators/jit/gen/matmul.h"
#include <stddef.h> // offsetof
#include <memory>
#include <vector>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/gen/seqpool.cc
浏览文件 @
cfc83c14
...
...
@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h"
#include <memory>
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/gen/sgd.cc
浏览文件 @
cfc83c14
...
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/sgd.h"
#include <stddef.h> // offsetof
#include <memory>
#include <vector>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/helper.h
浏览文件 @
cfc83c14
...
...
@@ -36,7 +36,7 @@ inline typename std::enable_if<
const
Kernel
*>::
type
GetJitCode
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
using
Attr
=
typename
KernelTuple
::
attr_type
;
size
_t
key
=
JitCodeKey
<
Attr
>
(
attr
);
int64
_t
key
=
JitCodeKey
<
Attr
>
(
attr
);
auto
&
codes
=
JitCodePool
<
KernelTuple
::
kernel_type
>::
Instance
();
if
(
codes
.
Has
(
key
))
{
return
codes
.
AllKernels
().
at
(
key
).
get
();
...
...
paddle/fluid/operators/jit/kernel_key.cc
浏览文件 @
cfc83c14
...
...
@@ -13,7 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
#include <xxhash.h>
#include <xxhash.h>
// XXH64: 13.8 GB/s
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
...
...
@@ -21,73 +21,46 @@ namespace operators {
namespace
jit
{
template
<
>
size
_t
JitCodeKey
<
int
>
(
const
int
&
d
)
{
int64
_t
JitCodeKey
<
int
>
(
const
int
&
d
)
{
return
d
;
}
template
<
>
size
_t
JitCodeKey
<
int64_t
>
(
const
int64_t
&
d
)
{
int64
_t
JitCodeKey
<
int64_t
>
(
const
int64_t
&
d
)
{
return
d
;
}
// TODO(TJ): refine and benchmark JitCodeKey generatation
constexpr
int
act_type_shift
=
3
;
// suppot 2^3 act types
static
inline
int
act_type_convert
(
KernelType
type
)
{
if
(
type
==
kVIdentity
)
{
return
0
;
}
else
if
(
type
==
kVExp
)
{
return
1
;
}
else
if
(
type
==
kVRelu
)
{
return
2
;
}
else
if
(
type
==
kVSigmoid
)
{
return
3
;
}
else
if
(
type
==
kVTanh
)
{
return
4
;
}
PADDLE_THROW
(
"Unsupported act type %d"
,
type
);
return
0
;
}
template
<
>
size_t
JitCodeKey
<
lstm_attr_t
>
(
const
lstm_attr_t
&
attr
)
{
// XXH64: 13.8 GB/s
size_t
key
=
attr
.
d
;
int
gate_key
=
act_type_convert
(
attr
.
act_gate
)
<<
1
;
int
cand_key
=
act_type_convert
(
attr
.
act_cand
)
<<
(
1
+
act_type_shift
);
int
cell_key
=
act_type_convert
(
attr
.
act_cell
)
<<
(
1
+
act_type_shift
*
2
);
return
(
key
<<
(
1
+
act_type_shift
*
3
))
+
gate_key
+
cand_key
+
cell_key
+
attr
.
use_peephole
;
int64_t
JitCodeKey
<
gru_attr_t
>
(
const
gru_attr_t
&
attr
)
{
return
XXH64
(
&
attr
,
sizeof
(
gru_attr_t
),
0
);
}
template
<
>
size_t
JitCodeKey
<
gru_attr_t
>
(
const
gru_attr_t
&
attr
)
{
size_t
key
=
attr
.
d
;
return
(
key
<<
(
act_type_shift
*
2
))
+
act_type_convert
(
attr
.
act_gate
)
+
(
act_type_convert
(
attr
.
act_cand
)
<<
act_type_shift
);
int64_t
JitCodeKey
<
lstm_attr_t
>
(
const
lstm_attr_t
&
attr
)
{
int
keys
[
5
]
=
{
attr
.
d
,
static_cast
<
int
>
(
attr
.
act_gate
),
static_cast
<
int
>
(
attr
.
act_cand
),
static_cast
<
int
>
(
attr
.
act_cell
),
static_cast
<
int
>
(
attr
.
use_peephole
)};
return
XXH64
(
keys
,
sizeof
(
int
)
*
5
,
0
);
}
template
<
>
size_t
JitCodeKey
<
seq_pool_attr_t
>
(
const
seq_pool_attr_t
&
attr
)
{
size_t
key
=
attr
.
w
;
constexpr
int
pool_type_shift
=
3
;
return
(
key
<<
pool_type_shift
)
+
static_cast
<
int
>
(
attr
.
type
);
int64_t
JitCodeKey
<
seq_pool_attr_t
>
(
const
seq_pool_attr_t
&
attr
)
{
int
keys
[
2
]
=
{
attr
.
w
,
static_cast
<
int
>
(
attr
.
type
)};
return
XXH64
(
keys
,
sizeof
(
int
)
*
2
,
0
);
}
template
<
>
size_t
JitCodeKey
<
matmul_attr_t
>
(
const
matmul_attr_t
&
attr
)
{
size_t
key
=
attr
.
m
;
constexpr
int
shift
=
21
;
return
(
key
<<
shift
*
2
)
+
((
static_cast
<
size_t
>
(
attr
.
n
))
<<
shift
)
+
attr
.
k
;
int64_t
JitCodeKey
<
matmul_attr_t
>
(
const
matmul_attr_t
&
attr
)
{
return
XXH64
(
&
attr
,
sizeof
(
int
)
*
3
,
0
);
// m, n, k
}
template
<
>
size
_t
JitCodeKey
<
emb_seq_pool_attr_t
>
(
const
emb_seq_pool_attr_t
&
attr
)
{
int64
_t
JitCodeKey
<
emb_seq_pool_attr_t
>
(
const
emb_seq_pool_attr_t
&
attr
)
{
return
attr
.
table_width
;
}
template
<
>
size
_t
JitCodeKey
<
sgd_attr_t
>
(
const
sgd_attr_t
&
attr
)
{
int64
_t
JitCodeKey
<
sgd_attr_t
>
(
const
sgd_attr_t
&
attr
)
{
return
attr
.
grad_width
;
}
...
...
paddle/fluid/operators/jit/kernel_key.h
浏览文件 @
cfc83c14
...
...
@@ -46,7 +46,7 @@ struct KernelKey {
// Every JitCode should have a method to get the key from attribution
template
<
typename
Attr
>
size
_t
JitCodeKey
(
const
Attr
&
attr
);
int64
_t
JitCodeKey
(
const
Attr
&
attr
);
}
// namespace jit
}
// namespace operators
...
...
paddle/fluid/operators/jit/kernel_pool.h
浏览文件 @
cfc83c14
...
...
@@ -30,7 +30,7 @@ namespace jit {
template
<
KernelType
KT
>
class
JitCodePool
{
typedef
std
::
unique_ptr
<
GenBase
>
GenBasePtr
;
typedef
std
::
unordered_map
<
size
_t
,
GenBasePtr
>
JitCodeMap
;
typedef
std
::
unordered_map
<
int64
_t
,
GenBasePtr
>
JitCodeMap
;
public:
JitCodePool
()
=
default
;
...
...
@@ -41,9 +41,9 @@ class JitCodePool {
const
JitCodeMap
&
AllKernels
()
{
return
codes_
;
}
bool
Has
(
size
_t
key
)
const
{
return
codes_
.
find
(
key
)
!=
codes_
.
end
();
}
bool
Has
(
int64
_t
key
)
const
{
return
codes_
.
find
(
key
)
!=
codes_
.
end
();
}
void
Insert
(
size
_t
key
,
GenBasePtr
value
)
{
void
Insert
(
int64
_t
key
,
GenBasePtr
value
)
{
codes_
.
emplace
(
key
,
std
::
move
(
value
));
}
...
...
paddle/fluid/operators/jit/registry.h
浏览文件 @
cfc83c14
...
...
@@ -17,6 +17,7 @@
#include <memory>
#include <tuple>
#include <type_traits>
#include <utility> // for std::move
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
cfc83c14
...
...
@@ -886,7 +886,11 @@ void TestKernelVBroadcast() {
// test pool
TEST
(
JITKernel_pool
,
jitcreator
)
{
const
auto
&
jitcreators
=
jit
::
JitCodeCreatorPool
::
Instance
().
AllCreators
();
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_EQ
(
jitcreators
.
size
(),
0UL
);
#else
EXPECT_EQ
(
jitcreators
.
size
(),
25UL
);
#endif
}
TEST
(
JITKernel_pool
,
jitpool
)
{
...
...
@@ -894,13 +898,25 @@ TEST(JITKernel_pool, jitpool) {
const
auto
&
kers
=
jit
::
JitCodePool
<
jit
::
kVAdd
>
().
Instance
().
AllKernels
();
EXPECT_EQ
(
kers
.
size
(),
0UL
);
jit
::
GetAllCandidateKernels
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>
(
3
);
// after call GetAllCandidateKernels, it will create jitcode Automatically
// after call GetAllCandidateKernels, it will create jitcode Automatically
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_EQ
(
kers
.
size
(),
0UL
);
#else
EXPECT_EQ
(
kers
.
size
(),
1UL
);
#endif
}
TEST
(
JITKernel_pool
,
more
)
{
const
auto
&
kers
=
jit
::
KernelPool
::
Instance
().
AllKernels
();
#if defined(__APPLE__) || defined(__OSX__)
EXPECT_EQ
(
kers
.
size
(),
10UL
);
#else
#ifdef PADDLE_WITH_MKLML
EXPECT_EQ
(
kers
.
size
(),
21UL
);
#else
EXPECT_EQ
(
kers
.
size
(),
8UL
);
#endif
#endif
}
TEST
(
JITKernel_pool
,
refer
)
{
...
...
@@ -915,7 +931,11 @@ TEST(JITKernel_helper, GetAllCandidateKernels) {
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE
(
fp_kers
.
size
(),
1UL
);
// refer
#else
#ifdef PADDLE_WITH_MKLML
EXPECT_GE
(
fp_kers
.
size
(),
3UL
);
// jitcode, mkl, refer
#else
EXPECT_GE
(
fp_kers
.
size
(),
2UL
);
// jitcode, refer
#endif
#endif
auto
db_kers
=
...
...
@@ -923,18 +943,48 @@ TEST(JITKernel_helper, GetAllCandidateKernels) {
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE
(
db_kers
.
size
(),
1UL
);
// refer
#else
#ifdef PADDLE_WITH_MKLML
EXPECT_GE
(
db_kers
.
size
(),
2UL
);
// mkl, refer
#else
EXPECT_GE
(
db_kers
.
size
(),
1UL
);
// refer
#endif
#endif
}
TEST
(
JITKernel_helper
,
GetAllCandidateFuncsWithTypes
)
{
auto
fp_kers
=
jit
::
GetAllCandidateFuncsWithTypes
<
jit
::
VExpTuple
<
float
>
,
CPUPlace
>
(
10
);
#if defined(__APPLE__) || defined(__OSX__)
EXPECT_GE
(
fp_kers
.
size
(),
1UL
);
// refer
#else
#if !defined(PADDLE_WITH_MKLML) || defined(_WIN32)
EXPECT_GE
(
fp_kers
.
size
(),
2UL
);
// jitcode/mkl, refer
#else
EXPECT_GE
(
fp_kers
.
size
(),
3UL
);
// jitcode, mkl, refer
#endif
#endif
auto
db_kers
=
jit
::
GetAllCandidateFuncsWithTypes
<
jit
::
VExpTuple
<
double
>
,
CPUPlace
>
(
10
);
#if defined(__APPLE__) || defined(__OSX__) || !defined(PADDLE_WITH_MKLML)
EXPECT_GE
(
db_kers
.
size
(),
1UL
);
// refer
#else
EXPECT_GE
(
db_kers
.
size
(),
2UL
);
// mkl, refer
#endif
}
TEST
(
JITKernel_helper
,
KernelFuncs
)
{
auto
f1
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
().
At
(
3
);
auto
f2
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
()[
3
];
EXPECT_TRUE
(
f1
!=
nullptr
);
EXPECT_TRUE
(
f1
==
f2
);
auto
f3
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
()[
5
];
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_TRUE
(
f2
==
f3
);
#else
EXPECT_TRUE
(
f2
!=
f3
);
#endif
}
TEST
(
JITKernel_helper
,
GetAllCandidateFuncs
)
{
...
...
@@ -1011,6 +1061,134 @@ TEST(JITKernel_helper, attr) {
EXPECT_EQ
(
out
.
str
().
size
(),
14
);
}
// test keys
TEST
(
JITKernel_key
,
int
)
{
EXPECT_TRUE
(
jit
::
JitCodeKey
<
int
>
(
2
)
==
jit
::
JitCodeKey
<
int
>
(
2
));
EXPECT_TRUE
(
jit
::
JitCodeKey
<
int
>
(
2
)
==
jit
::
JitCodeKey
<
int64_t
>
(
2
));
EXPECT_TRUE
(
jit
::
JitCodeKey
<
int
>
(
2
)
!=
jit
::
JitCodeKey
<
int
>
(
3
));
}
TEST
(
JITKernel_key
,
gru
)
{
jit
::
gru_attr_t
attr1
(
8
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr2
(
8
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr3
(
9
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr4
(
9
,
jit
::
kVSigmoid
,
jit
::
kVIdentity
);
jit
::
gru_attr_t
attr5
(
9
,
jit
::
kVTanh
,
jit
::
kVIdentity
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr4
);
auto
key5
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr5
);
EXPECT_TRUE
(
key1
==
key2
);
EXPECT_TRUE
(
key2
!=
key3
);
EXPECT_TRUE
(
key2
!=
key4
);
EXPECT_TRUE
(
key2
!=
key5
);
EXPECT_TRUE
(
key3
!=
key4
);
EXPECT_TRUE
(
key3
!=
key5
);
EXPECT_TRUE
(
key4
!=
key5
);
}
TEST
(
JITKernel_key
,
lstm
)
{
jit
::
lstm_attr_t
attr1
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr2
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr3
(
9
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr4
(
9
,
jit
::
kVRelu
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr5
(
9
,
jit
::
kVRelu
,
jit
::
kVSigmoid
,
jit
::
kVTanh
,
true
);
jit
::
lstm_attr_t
attr6
(
9
,
jit
::
kVRelu
,
jit
::
kVSigmoid
,
jit
::
kVTanh
,
true
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr4
);
auto
key5
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr5
);
auto
key6
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr6
);
EXPECT_TRUE
(
key1
==
key2
);
EXPECT_TRUE
(
key2
!=
key3
);
EXPECT_TRUE
(
key2
!=
key4
);
EXPECT_TRUE
(
key2
!=
key5
);
EXPECT_TRUE
(
key3
!=
key4
);
EXPECT_TRUE
(
key3
!=
key5
);
EXPECT_TRUE
(
key4
!=
key5
);
EXPECT_TRUE
(
key5
==
key6
);
}
TEST
(
JITKernel_key
,
seq_pool
)
{
jit
::
seq_pool_attr_t
attr1
(
2
,
jit
::
SeqPoolType
::
kSum
,
1
);
jit
::
seq_pool_attr_t
attr2
(
2
,
jit
::
SeqPoolType
::
kSum
,
3
);
jit
::
seq_pool_attr_t
attr3
(
3
,
jit
::
SeqPoolType
::
kSum
,
3
);
jit
::
seq_pool_attr_t
attr4
(
3
,
jit
::
SeqPoolType
::
kAvg
,
3
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
seq_pool_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
seq_pool_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
seq_pool_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
seq_pool_attr_t
>
(
attr4
);
EXPECT_TRUE
(
key1
==
key2
);
EXPECT_TRUE
(
key2
!=
key3
);
EXPECT_TRUE
(
key2
!=
key4
);
EXPECT_TRUE
(
key3
!=
key4
);
}
TEST
(
JITKernel_key
,
matmul
)
{
jit
::
matmul_attr_t
attr1
(
1
,
2
,
3
);
jit
::
matmul_attr_t
attr2
(
1
,
2
,
3
);
jit
::
matmul_attr_t
attr3
(
1
,
3
,
3
);
jit
::
matmul_attr_t
attr4
(
2
,
3
,
4
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
matmul_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
matmul_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
matmul_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
matmul_attr_t
>
(
attr4
);
EXPECT_TRUE
(
key1
==
key2
);
EXPECT_TRUE
(
key2
!=
key3
);
EXPECT_TRUE
(
key2
!=
key4
);
EXPECT_TRUE
(
key3
!=
key4
);
}
TEST
(
JITKernel_key
,
emb_seq_pool
)
{
jit
::
emb_seq_pool_attr_t
attr1
(
1
,
2
,
3
,
4
,
5
,
jit
::
SeqPoolType
::
kSum
);
jit
::
emb_seq_pool_attr_t
attr2
(
1
,
2
,
3
,
4
,
5
,
jit
::
SeqPoolType
::
kSum
);
jit
::
emb_seq_pool_attr_t
attr3
(
10
,
2
,
9
,
8
,
7
,
jit
::
SeqPoolType
::
kAvg
);
jit
::
emb_seq_pool_attr_t
attr4
(
10
,
3
,
9
,
8
,
7
,
jit
::
SeqPoolType
::
kSum
);
jit
::
emb_seq_pool_attr_t
attr5
(
1
,
6
,
3
,
4
,
5
,
jit
::
SeqPoolType
::
kSum
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
emb_seq_pool_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
emb_seq_pool_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
emb_seq_pool_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
emb_seq_pool_attr_t
>
(
attr4
);
auto
key5
=
jit
::
JitCodeKey
<
jit
::
emb_seq_pool_attr_t
>
(
attr5
);
EXPECT_TRUE
(
key1
==
key2
);
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key2
!=
key4
);
EXPECT_TRUE
(
key2
!=
key5
);
EXPECT_TRUE
(
key4
!=
key5
);
}
TEST
(
JITKernel_key
,
sgd
)
{
jit
::
sgd_attr_t
attr1
(
1
,
2
,
3
,
4
,
5
);
jit
::
sgd_attr_t
attr2
(
1
,
2
,
3
,
4
,
5
);
jit
::
sgd_attr_t
attr3
(
9
,
8
,
7
,
4
,
6
);
jit
::
sgd_attr_t
attr4
(
1
,
2
,
3
,
6
,
5
);
jit
::
sgd_attr_t
attr5
(
10
,
9
,
8
,
7
,
6
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
sgd_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
sgd_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
sgd_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
sgd_attr_t
>
(
attr4
);
auto
key5
=
jit
::
JitCodeKey
<
jit
::
sgd_attr_t
>
(
attr5
);
EXPECT_TRUE
(
key1
==
key2
);
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key3
!=
key4
);
EXPECT_TRUE
(
key3
!=
key5
);
EXPECT_TRUE
(
key4
!=
key5
);
}
// test kernerls
#define TestKernelVMul TestKernelXYZN
#define TestKernelVAdd TestKernelXYZN
...
...
@@ -1080,43 +1258,3 @@ TEST_CPU_KERNEL(MatMul);
TEST_CPU_KERNEL
(
Softmax
);
TEST_CPU_KERNEL
(
Sgd
);
TEST_CPU_KERNEL
(
VBroadcast
);
TEST
(
JITKernel
,
kernel_func
)
{
auto
f1
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
().
At
(
3
);
auto
f2
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
()[
3
];
EXPECT_TRUE
(
f1
!=
nullptr
);
EXPECT_TRUE
(
f1
==
f2
);
// TODO(TJ): check not equal
}
TEST
(
JITKernel_key
,
lstm
)
{
jit
::
lstm_attr_t
attr1
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr2
(
9
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr3
(
9
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr4
(
9
,
jit
::
kVRelu
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr4
);
EXPECT_TRUE
(
key1
!=
key2
);
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key3
!=
key4
);
}
TEST
(
JITKernel_key
,
gru
)
{
jit
::
gru_attr_t
attr1
(
8
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr2
(
9
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr3
(
9
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr4
(
9
,
jit
::
kVSigmoid
,
jit
::
kVIdentity
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr4
);
EXPECT_TRUE
(
key1
!=
key2
);
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key3
!=
key4
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录