Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
cfc83c14
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cfc83c14
编写于
3月 11, 2019
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine jitcodekey and enhance unit tests
test=develop
上级
6ff230a6
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
211 addition
and
91 deletion
+211
-91
paddle/fluid/operators/jit/gen/act.cc
paddle/fluid/operators/jit/gen/act.cc
+1
-0
paddle/fluid/operators/jit/gen/blas.cc
paddle/fluid/operators/jit/gen/blas.cc
+1
-0
paddle/fluid/operators/jit/gen/embseqpool.cc
paddle/fluid/operators/jit/gen/embseqpool.cc
+1
-0
paddle/fluid/operators/jit/gen/gru.cc
paddle/fluid/operators/jit/gen/gru.cc
+1
-0
paddle/fluid/operators/jit/gen/hopv.cc
paddle/fluid/operators/jit/gen/hopv.cc
+1
-0
paddle/fluid/operators/jit/gen/lstm.cc
paddle/fluid/operators/jit/gen/lstm.cc
+1
-0
paddle/fluid/operators/jit/gen/matmul.cc
paddle/fluid/operators/jit/gen/matmul.cc
+1
-1
paddle/fluid/operators/jit/gen/seqpool.cc
paddle/fluid/operators/jit/gen/seqpool.cc
+1
-0
paddle/fluid/operators/jit/gen/sgd.cc
paddle/fluid/operators/jit/gen/sgd.cc
+1
-0
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+1
-1
paddle/fluid/operators/jit/kernel_key.cc
paddle/fluid/operators/jit/kernel_key.cc
+17
-44
paddle/fluid/operators/jit/kernel_key.h
paddle/fluid/operators/jit/kernel_key.h
+1
-1
paddle/fluid/operators/jit/kernel_pool.h
paddle/fluid/operators/jit/kernel_pool.h
+3
-3
paddle/fluid/operators/jit/registry.h
paddle/fluid/operators/jit/registry.h
+1
-0
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+179
-41
未找到文件。
paddle/fluid/operators/jit/gen/act.cc
浏览文件 @
cfc83c14
...
...
@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/act.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/gen/blas.cc
浏览文件 @
cfc83c14
...
...
@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/blas.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/gen/embseqpool.cc
浏览文件 @
cfc83c14
...
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/embseqpool.h"
#include <stddef.h> // offsetof
#include <memory>
#include <vector>
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h"
...
...
paddle/fluid/operators/jit/gen/gru.cc
浏览文件 @
cfc83c14
...
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/gru.h"
#include <stddef.h> // offsetof
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/gen/hopv.cc
浏览文件 @
cfc83c14
...
...
@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/hopv.h"
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/gen/lstm.cc
浏览文件 @
cfc83c14
...
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/lstm.h"
#include <stddef.h> // offsetof
#include <memory>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/gen/matmul.cc
浏览文件 @
cfc83c14
...
...
@@ -14,8 +14,8 @@
#include "paddle/fluid/operators/jit/gen/matmul.h"
#include <stddef.h> // offsetof
#include <memory>
#include <vector>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/gen/seqpool.cc
浏览文件 @
cfc83c14
...
...
@@ -13,6 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h"
#include <memory>
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/gen/sgd.cc
浏览文件 @
cfc83c14
...
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/jit/gen/sgd.h"
#include <stddef.h> // offsetof
#include <memory>
#include <vector>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
...
...
paddle/fluid/operators/jit/helper.h
浏览文件 @
cfc83c14
...
...
@@ -36,7 +36,7 @@ inline typename std::enable_if<
const
Kernel
*>::
type
GetJitCode
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
using
Attr
=
typename
KernelTuple
::
attr_type
;
size
_t
key
=
JitCodeKey
<
Attr
>
(
attr
);
int64
_t
key
=
JitCodeKey
<
Attr
>
(
attr
);
auto
&
codes
=
JitCodePool
<
KernelTuple
::
kernel_type
>::
Instance
();
if
(
codes
.
Has
(
key
))
{
return
codes
.
AllKernels
().
at
(
key
).
get
();
...
...
paddle/fluid/operators/jit/kernel_key.cc
浏览文件 @
cfc83c14
...
...
@@ -13,7 +13,7 @@
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
#include <xxhash.h>
#include <xxhash.h>
// XXH64: 13.8 GB/s
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
...
...
@@ -21,73 +21,46 @@ namespace operators {
namespace
jit
{
template
<
>
size
_t
JitCodeKey
<
int
>
(
const
int
&
d
)
{
int64
_t
JitCodeKey
<
int
>
(
const
int
&
d
)
{
return
d
;
}
template
<
>
size
_t
JitCodeKey
<
int64_t
>
(
const
int64_t
&
d
)
{
int64
_t
JitCodeKey
<
int64_t
>
(
const
int64_t
&
d
)
{
return
d
;
}
// TODO(TJ): refine and benchmark JitCodeKey generatation
constexpr
int
act_type_shift
=
3
;
// suppot 2^3 act types
static
inline
int
act_type_convert
(
KernelType
type
)
{
if
(
type
==
kVIdentity
)
{
return
0
;
}
else
if
(
type
==
kVExp
)
{
return
1
;
}
else
if
(
type
==
kVRelu
)
{
return
2
;
}
else
if
(
type
==
kVSigmoid
)
{
return
3
;
}
else
if
(
type
==
kVTanh
)
{
return
4
;
}
PADDLE_THROW
(
"Unsupported act type %d"
,
type
);
return
0
;
}
template
<
>
size_t
JitCodeKey
<
lstm_attr_t
>
(
const
lstm_attr_t
&
attr
)
{
// XXH64: 13.8 GB/s
size_t
key
=
attr
.
d
;
int
gate_key
=
act_type_convert
(
attr
.
act_gate
)
<<
1
;
int
cand_key
=
act_type_convert
(
attr
.
act_cand
)
<<
(
1
+
act_type_shift
);
int
cell_key
=
act_type_convert
(
attr
.
act_cell
)
<<
(
1
+
act_type_shift
*
2
);
return
(
key
<<
(
1
+
act_type_shift
*
3
))
+
gate_key
+
cand_key
+
cell_key
+
attr
.
use_peephole
;
int64_t
JitCodeKey
<
gru_attr_t
>
(
const
gru_attr_t
&
attr
)
{
return
XXH64
(
&
attr
,
sizeof
(
gru_attr_t
),
0
);
}
template
<
>
size_t
JitCodeKey
<
gru_attr_t
>
(
const
gru_attr_t
&
attr
)
{
size_t
key
=
attr
.
d
;
return
(
key
<<
(
act_type_shift
*
2
))
+
act_type_convert
(
attr
.
act_gate
)
+
(
act_type_convert
(
attr
.
act_cand
)
<<
act_type_shift
);
int64_t
JitCodeKey
<
lstm_attr_t
>
(
const
lstm_attr_t
&
attr
)
{
int
keys
[
5
]
=
{
attr
.
d
,
static_cast
<
int
>
(
attr
.
act_gate
),
static_cast
<
int
>
(
attr
.
act_cand
),
static_cast
<
int
>
(
attr
.
act_cell
),
static_cast
<
int
>
(
attr
.
use_peephole
)};
return
XXH64
(
keys
,
sizeof
(
int
)
*
5
,
0
);
}
template
<
>
size_t
JitCodeKey
<
seq_pool_attr_t
>
(
const
seq_pool_attr_t
&
attr
)
{
size_t
key
=
attr
.
w
;
constexpr
int
pool_type_shift
=
3
;
return
(
key
<<
pool_type_shift
)
+
static_cast
<
int
>
(
attr
.
type
);
int64_t
JitCodeKey
<
seq_pool_attr_t
>
(
const
seq_pool_attr_t
&
attr
)
{
int
keys
[
2
]
=
{
attr
.
w
,
static_cast
<
int
>
(
attr
.
type
)};
return
XXH64
(
keys
,
sizeof
(
int
)
*
2
,
0
);
}
template
<
>
size_t
JitCodeKey
<
matmul_attr_t
>
(
const
matmul_attr_t
&
attr
)
{
size_t
key
=
attr
.
m
;
constexpr
int
shift
=
21
;
return
(
key
<<
shift
*
2
)
+
((
static_cast
<
size_t
>
(
attr
.
n
))
<<
shift
)
+
attr
.
k
;
int64_t
JitCodeKey
<
matmul_attr_t
>
(
const
matmul_attr_t
&
attr
)
{
return
XXH64
(
&
attr
,
sizeof
(
int
)
*
3
,
0
);
// m, n, k
}
template
<
>
size
_t
JitCodeKey
<
emb_seq_pool_attr_t
>
(
const
emb_seq_pool_attr_t
&
attr
)
{
int64
_t
JitCodeKey
<
emb_seq_pool_attr_t
>
(
const
emb_seq_pool_attr_t
&
attr
)
{
return
attr
.
table_width
;
}
template
<
>
size
_t
JitCodeKey
<
sgd_attr_t
>
(
const
sgd_attr_t
&
attr
)
{
int64
_t
JitCodeKey
<
sgd_attr_t
>
(
const
sgd_attr_t
&
attr
)
{
return
attr
.
grad_width
;
}
...
...
paddle/fluid/operators/jit/kernel_key.h
浏览文件 @
cfc83c14
...
...
@@ -46,7 +46,7 @@ struct KernelKey {
// Every JitCode should have a method to get the key from attribution
template
<
typename
Attr
>
size
_t
JitCodeKey
(
const
Attr
&
attr
);
int64
_t
JitCodeKey
(
const
Attr
&
attr
);
}
// namespace jit
}
// namespace operators
...
...
paddle/fluid/operators/jit/kernel_pool.h
浏览文件 @
cfc83c14
...
...
@@ -30,7 +30,7 @@ namespace jit {
template
<
KernelType
KT
>
class
JitCodePool
{
typedef
std
::
unique_ptr
<
GenBase
>
GenBasePtr
;
typedef
std
::
unordered_map
<
size
_t
,
GenBasePtr
>
JitCodeMap
;
typedef
std
::
unordered_map
<
int64
_t
,
GenBasePtr
>
JitCodeMap
;
public:
JitCodePool
()
=
default
;
...
...
@@ -41,9 +41,9 @@ class JitCodePool {
const
JitCodeMap
&
AllKernels
()
{
return
codes_
;
}
bool
Has
(
size
_t
key
)
const
{
return
codes_
.
find
(
key
)
!=
codes_
.
end
();
}
bool
Has
(
int64
_t
key
)
const
{
return
codes_
.
find
(
key
)
!=
codes_
.
end
();
}
void
Insert
(
size
_t
key
,
GenBasePtr
value
)
{
void
Insert
(
int64
_t
key
,
GenBasePtr
value
)
{
codes_
.
emplace
(
key
,
std
::
move
(
value
));
}
...
...
paddle/fluid/operators/jit/registry.h
浏览文件 @
cfc83c14
...
...
@@ -17,6 +17,7 @@
#include <memory>
#include <tuple>
#include <type_traits>
#include <utility> // for std::move
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
cfc83c14
...
...
@@ -886,7 +886,11 @@ void TestKernelVBroadcast() {
// test pool
TEST
(
JITKernel_pool
,
jitcreator
)
{
const
auto
&
jitcreators
=
jit
::
JitCodeCreatorPool
::
Instance
().
AllCreators
();
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_EQ
(
jitcreators
.
size
(),
0UL
);
#else
EXPECT_EQ
(
jitcreators
.
size
(),
25UL
);
#endif
}
TEST
(
JITKernel_pool
,
jitpool
)
{
...
...
@@ -894,13 +898,25 @@ TEST(JITKernel_pool, jitpool) {
const
auto
&
kers
=
jit
::
JitCodePool
<
jit
::
kVAdd
>
().
Instance
().
AllKernels
();
EXPECT_EQ
(
kers
.
size
(),
0UL
);
jit
::
GetAllCandidateKernels
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>
(
3
);
// after call GetAllCandidateKernels, it will create jitcode Automatically
// after call GetAllCandidateKernels, it will create jitcode Automatically
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_EQ
(
kers
.
size
(),
0UL
);
#else
EXPECT_EQ
(
kers
.
size
(),
1UL
);
#endif
}
TEST
(
JITKernel_pool
,
more
)
{
const
auto
&
kers
=
jit
::
KernelPool
::
Instance
().
AllKernels
();
#if defined(__APPLE__) || defined(__OSX__)
EXPECT_EQ
(
kers
.
size
(),
10UL
);
#else
#ifdef PADDLE_WITH_MKLML
EXPECT_EQ
(
kers
.
size
(),
21UL
);
#else
EXPECT_EQ
(
kers
.
size
(),
8UL
);
#endif
#endif
}
TEST
(
JITKernel_pool
,
refer
)
{
...
...
@@ -915,7 +931,11 @@ TEST(JITKernel_helper, GetAllCandidateKernels) {
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE
(
fp_kers
.
size
(),
1UL
);
// refer
#else
#ifdef PADDLE_WITH_MKLML
EXPECT_GE
(
fp_kers
.
size
(),
3UL
);
// jitcode, mkl, refer
#else
EXPECT_GE
(
fp_kers
.
size
(),
2UL
);
// jitcode, refer
#endif
#endif
auto
db_kers
=
...
...
@@ -923,18 +943,48 @@ TEST(JITKernel_helper, GetAllCandidateKernels) {
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE
(
db_kers
.
size
(),
1UL
);
// refer
#else
#ifdef PADDLE_WITH_MKLML
EXPECT_GE
(
db_kers
.
size
(),
2UL
);
// mkl, refer
#else
EXPECT_GE
(
db_kers
.
size
(),
1UL
);
// refer
#endif
#endif
}
TEST
(
JITKernel_helper
,
GetAllCandidateFuncsWithTypes
)
{
auto
fp_kers
=
jit
::
GetAllCandidateFuncsWithTypes
<
jit
::
VExpTuple
<
float
>
,
CPUPlace
>
(
10
);
#if defined(__APPLE__) || defined(__OSX__)
EXPECT_GE
(
fp_kers
.
size
(),
1UL
);
// refer
#else
#if !defined(PADDLE_WITH_MKLML) || defined(_WIN32)
EXPECT_GE
(
fp_kers
.
size
(),
2UL
);
// jitcode/mkl, refer
#else
EXPECT_GE
(
fp_kers
.
size
(),
3UL
);
// jitcode, mkl, refer
#endif
#endif
auto
db_kers
=
jit
::
GetAllCandidateFuncsWithTypes
<
jit
::
VExpTuple
<
double
>
,
CPUPlace
>
(
10
);
#if defined(__APPLE__) || defined(__OSX__) || !defined(PADDLE_WITH_MKLML)
EXPECT_GE
(
db_kers
.
size
(),
1UL
);
// refer
#else
EXPECT_GE
(
db_kers
.
size
(),
2UL
);
// mkl, refer
#endif
}
TEST
(
JITKernel_helper
,
KernelFuncs
)
{
auto
f1
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
().
At
(
3
);
auto
f2
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
()[
3
];
EXPECT_TRUE
(
f1
!=
nullptr
);
EXPECT_TRUE
(
f1
==
f2
);
auto
f3
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
()[
5
];
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_TRUE
(
f2
==
f3
);
#else
EXPECT_TRUE
(
f2
!=
f3
);
#endif
}
TEST
(
JITKernel_helper
,
GetAllCandidateFuncs
)
{
...
...
@@ -1011,6 +1061,134 @@ TEST(JITKernel_helper, attr) {
EXPECT_EQ
(
out
.
str
().
size
(),
14
);
}
// test keys
TEST
(
JITKernel_key
,
int
)
{
EXPECT_TRUE
(
jit
::
JitCodeKey
<
int
>
(
2
)
==
jit
::
JitCodeKey
<
int
>
(
2
));
EXPECT_TRUE
(
jit
::
JitCodeKey
<
int
>
(
2
)
==
jit
::
JitCodeKey
<
int64_t
>
(
2
));
EXPECT_TRUE
(
jit
::
JitCodeKey
<
int
>
(
2
)
!=
jit
::
JitCodeKey
<
int
>
(
3
));
}
TEST
(
JITKernel_key
,
gru
)
{
jit
::
gru_attr_t
attr1
(
8
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr2
(
8
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr3
(
9
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr4
(
9
,
jit
::
kVSigmoid
,
jit
::
kVIdentity
);
jit
::
gru_attr_t
attr5
(
9
,
jit
::
kVTanh
,
jit
::
kVIdentity
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr4
);
auto
key5
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr5
);
EXPECT_TRUE
(
key1
==
key2
);
EXPECT_TRUE
(
key2
!=
key3
);
EXPECT_TRUE
(
key2
!=
key4
);
EXPECT_TRUE
(
key2
!=
key5
);
EXPECT_TRUE
(
key3
!=
key4
);
EXPECT_TRUE
(
key3
!=
key5
);
EXPECT_TRUE
(
key4
!=
key5
);
}
TEST
(
JITKernel_key
,
lstm
)
{
jit
::
lstm_attr_t
attr1
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr2
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr3
(
9
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr4
(
9
,
jit
::
kVRelu
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr5
(
9
,
jit
::
kVRelu
,
jit
::
kVSigmoid
,
jit
::
kVTanh
,
true
);
jit
::
lstm_attr_t
attr6
(
9
,
jit
::
kVRelu
,
jit
::
kVSigmoid
,
jit
::
kVTanh
,
true
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr4
);
auto
key5
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr5
);
auto
key6
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr6
);
EXPECT_TRUE
(
key1
==
key2
);
EXPECT_TRUE
(
key2
!=
key3
);
EXPECT_TRUE
(
key2
!=
key4
);
EXPECT_TRUE
(
key2
!=
key5
);
EXPECT_TRUE
(
key3
!=
key4
);
EXPECT_TRUE
(
key3
!=
key5
);
EXPECT_TRUE
(
key4
!=
key5
);
EXPECT_TRUE
(
key5
==
key6
);
}
TEST
(
JITKernel_key
,
seq_pool
)
{
jit
::
seq_pool_attr_t
attr1
(
2
,
jit
::
SeqPoolType
::
kSum
,
1
);
jit
::
seq_pool_attr_t
attr2
(
2
,
jit
::
SeqPoolType
::
kSum
,
3
);
jit
::
seq_pool_attr_t
attr3
(
3
,
jit
::
SeqPoolType
::
kSum
,
3
);
jit
::
seq_pool_attr_t
attr4
(
3
,
jit
::
SeqPoolType
::
kAvg
,
3
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
seq_pool_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
seq_pool_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
seq_pool_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
seq_pool_attr_t
>
(
attr4
);
EXPECT_TRUE
(
key1
==
key2
);
EXPECT_TRUE
(
key2
!=
key3
);
EXPECT_TRUE
(
key2
!=
key4
);
EXPECT_TRUE
(
key3
!=
key4
);
}
TEST
(
JITKernel_key
,
matmul
)
{
jit
::
matmul_attr_t
attr1
(
1
,
2
,
3
);
jit
::
matmul_attr_t
attr2
(
1
,
2
,
3
);
jit
::
matmul_attr_t
attr3
(
1
,
3
,
3
);
jit
::
matmul_attr_t
attr4
(
2
,
3
,
4
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
matmul_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
matmul_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
matmul_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
matmul_attr_t
>
(
attr4
);
EXPECT_TRUE
(
key1
==
key2
);
EXPECT_TRUE
(
key2
!=
key3
);
EXPECT_TRUE
(
key2
!=
key4
);
EXPECT_TRUE
(
key3
!=
key4
);
}
TEST
(
JITKernel_key
,
emb_seq_pool
)
{
jit
::
emb_seq_pool_attr_t
attr1
(
1
,
2
,
3
,
4
,
5
,
jit
::
SeqPoolType
::
kSum
);
jit
::
emb_seq_pool_attr_t
attr2
(
1
,
2
,
3
,
4
,
5
,
jit
::
SeqPoolType
::
kSum
);
jit
::
emb_seq_pool_attr_t
attr3
(
10
,
2
,
9
,
8
,
7
,
jit
::
SeqPoolType
::
kAvg
);
jit
::
emb_seq_pool_attr_t
attr4
(
10
,
3
,
9
,
8
,
7
,
jit
::
SeqPoolType
::
kSum
);
jit
::
emb_seq_pool_attr_t
attr5
(
1
,
6
,
3
,
4
,
5
,
jit
::
SeqPoolType
::
kSum
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
emb_seq_pool_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
emb_seq_pool_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
emb_seq_pool_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
emb_seq_pool_attr_t
>
(
attr4
);
auto
key5
=
jit
::
JitCodeKey
<
jit
::
emb_seq_pool_attr_t
>
(
attr5
);
EXPECT_TRUE
(
key1
==
key2
);
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key2
!=
key4
);
EXPECT_TRUE
(
key2
!=
key5
);
EXPECT_TRUE
(
key4
!=
key5
);
}
TEST
(
JITKernel_key
,
sgd
)
{
jit
::
sgd_attr_t
attr1
(
1
,
2
,
3
,
4
,
5
);
jit
::
sgd_attr_t
attr2
(
1
,
2
,
3
,
4
,
5
);
jit
::
sgd_attr_t
attr3
(
9
,
8
,
7
,
4
,
6
);
jit
::
sgd_attr_t
attr4
(
1
,
2
,
3
,
6
,
5
);
jit
::
sgd_attr_t
attr5
(
10
,
9
,
8
,
7
,
6
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
sgd_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
sgd_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
sgd_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
sgd_attr_t
>
(
attr4
);
auto
key5
=
jit
::
JitCodeKey
<
jit
::
sgd_attr_t
>
(
attr5
);
EXPECT_TRUE
(
key1
==
key2
);
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key3
!=
key4
);
EXPECT_TRUE
(
key3
!=
key5
);
EXPECT_TRUE
(
key4
!=
key5
);
}
// test kernerls
#define TestKernelVMul TestKernelXYZN
#define TestKernelVAdd TestKernelXYZN
...
...
@@ -1080,43 +1258,3 @@ TEST_CPU_KERNEL(MatMul);
TEST_CPU_KERNEL
(
Softmax
);
TEST_CPU_KERNEL
(
Sgd
);
TEST_CPU_KERNEL
(
VBroadcast
);
TEST
(
JITKernel
,
kernel_func
)
{
auto
f1
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
().
At
(
3
);
auto
f2
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
()[
3
];
EXPECT_TRUE
(
f1
!=
nullptr
);
EXPECT_TRUE
(
f1
==
f2
);
// TODO(TJ): check not equal
}
TEST
(
JITKernel_key
,
lstm
)
{
jit
::
lstm_attr_t
attr1
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr2
(
9
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr3
(
9
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr4
(
9
,
jit
::
kVRelu
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
lstm_attr_t
>
(
attr4
);
EXPECT_TRUE
(
key1
!=
key2
);
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key3
!=
key4
);
}
TEST
(
JITKernel_key
,
gru
)
{
jit
::
gru_attr_t
attr1
(
8
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr2
(
9
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr3
(
9
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
gru_attr_t
attr4
(
9
,
jit
::
kVSigmoid
,
jit
::
kVIdentity
);
auto
key1
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr1
);
auto
key2
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr2
);
auto
key3
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr3
);
auto
key4
=
jit
::
JitCodeKey
<
jit
::
gru_attr_t
>
(
attr4
);
EXPECT_TRUE
(
key1
!=
key2
);
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key3
!=
key4
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录