Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
45bdd84d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
45bdd84d
编写于
3月 10, 2019
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enhance the jitkernel helper and add unit tests
test=develop
上级
14a764c9
变更
27
隐藏空白更改
内联
并排
Showing
27 changed file
with
328 addition
and
182 deletion
+328
-182
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+3
-25
paddle/fluid/operators/jit/gen/act.cc
paddle/fluid/operators/jit/gen/act.cc
+7
-7
paddle/fluid/operators/jit/gen/blas.cc
paddle/fluid/operators/jit/gen/blas.cc
+2
-2
paddle/fluid/operators/jit/gen/embseqpool.cc
paddle/fluid/operators/jit/gen/embseqpool.cc
+1
-1
paddle/fluid/operators/jit/gen/gru.cc
paddle/fluid/operators/jit/gen/gru.cc
+1
-1
paddle/fluid/operators/jit/gen/hopv.cc
paddle/fluid/operators/jit/gen/hopv.cc
+1
-1
paddle/fluid/operators/jit/gen/jitcode.h
paddle/fluid/operators/jit/gen/jitcode.h
+1
-1
paddle/fluid/operators/jit/gen/lstm.cc
paddle/fluid/operators/jit/gen/lstm.cc
+1
-1
paddle/fluid/operators/jit/gen/matmul.cc
paddle/fluid/operators/jit/gen/matmul.cc
+1
-1
paddle/fluid/operators/jit/gen/seqpool.cc
paddle/fluid/operators/jit/gen/seqpool.cc
+1
-1
paddle/fluid/operators/jit/gen/sgd.cc
paddle/fluid/operators/jit/gen/sgd.cc
+1
-1
paddle/fluid/operators/jit/gen/vbroadcast.cc
paddle/fluid/operators/jit/gen/vbroadcast.cc
+1
-1
paddle/fluid/operators/jit/gen_base.cc
paddle/fluid/operators/jit/gen_base.cc
+1
-1
paddle/fluid/operators/jit/gen_base.h
paddle/fluid/operators/jit/gen_base.h
+4
-3
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+82
-34
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+4
-3
paddle/fluid/operators/jit/kernel_key.cc
paddle/fluid/operators/jit/kernel_key.cc
+3
-0
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc
+1
-1
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h
+2
-1
paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
+1
-1
paddle/fluid/operators/jit/more/intrinsic/layer_norm.h
paddle/fluid/operators/jit/more/intrinsic/layer_norm.h
+2
-1
paddle/fluid/operators/jit/more/mix/mix.cc
paddle/fluid/operators/jit/more/mix/mix.cc
+8
-8
paddle/fluid/operators/jit/more/mix/mix.h
paddle/fluid/operators/jit/more/mix/mix.h
+6
-6
paddle/fluid/operators/jit/more/mkl/mkl.cc
paddle/fluid/operators/jit/more/mkl/mkl.cc
+24
-23
paddle/fluid/operators/jit/more/mkl/mkl.h
paddle/fluid/operators/jit/more/mkl/mkl.h
+7
-7
paddle/fluid/operators/jit/registry.h
paddle/fluid/operators/jit/registry.h
+2
-2
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+160
-48
未找到文件。
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
45bdd84d
...
@@ -111,33 +111,11 @@ template <typename KernelTuple, typename PlaceType, typename... Args>
...
@@ -111,33 +111,11 @@ template <typename KernelTuple, typename PlaceType, typename... Args>
void
BenchAllImpls
(
const
typename
KernelTuple
::
attr_type
&
attr
,
Args
...
args
)
{
void
BenchAllImpls
(
const
typename
KernelTuple
::
attr_type
&
attr
,
Args
...
args
)
{
BenchFunc
<
KernelTuple
,
Args
...
>
benchmark
;
BenchFunc
<
KernelTuple
,
Args
...
>
benchmark
;
std
::
vector
<
std
::
pair
<
std
::
string
,
double
>>
infos
;
std
::
vector
<
std
::
pair
<
std
::
string
,
double
>>
infos
;
// test refer
auto
funcs
=
jit
::
GetAllCandidateFuncsWithTypes
<
KernelTuple
,
PlaceType
>
(
attr
);
auto
refer
=
jit
::
GetRefer
<
KernelTuple
>
();
for
(
auto
f
:
funcs
)
{
if
(
!
refer
)
{
infos
.
push_back
(
std
::
make_pair
(
f
.
first
,
benchmark
(
f
.
second
,
args
...)));
LOG
(
FATAL
)
<<
"Refer can not be empty!"
;
}
}
infos
.
push_back
(
std
::
make_pair
(
"Refer"
,
benchmark
(
refer
,
args
...)));
// test jitcode
auto
jitcode
=
jit
::
GetJitCode
<
KernelTuple
,
PlaceType
>
(
attr
);
if
(
jitcode
)
{
infos
.
push_back
(
std
::
make_pair
(
"JitCode"
,
benchmark
(
jitcode
,
args
...)));
}
// test all impls in more
jit
::
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
PlaceType
());
auto
&
pool
=
jit
::
KernelPool
().
Instance
().
AllKernels
();
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
jit
::
KernelMore
<
KernelTuple
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
auto
more
=
i
->
GetFunc
();
infos
.
push_back
(
std
::
make_pair
(
i
->
ImplType
(),
benchmark
(
more
,
args
...)));
}
}
}
// Test result from Get function
// Test result from Get function
auto
tgt
=
jit
::
KernelFuncs
<
KernelTuple
,
PlaceType
>::
Cache
().
At
(
attr
);
auto
tgt
=
jit
::
KernelFuncs
<
KernelTuple
,
PlaceType
>::
Cache
().
At
(
attr
);
if
(
!
tgt
)
{
if
(
!
tgt
)
{
...
...
paddle/fluid/operators/jit/gen/act.cc
浏览文件 @
45bdd84d
...
@@ -81,7 +81,7 @@ void VActJitCode::genCode() {
...
@@ -81,7 +81,7 @@ void VActJitCode::genCode() {
#define DECLARE_ACT_CREATOR(name) \
#define DECLARE_ACT_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
class name##Creator : public JitCodeCreator<int> { \
public: \
public: \
bool
UseMe(const int& attr) const override;
\
bool
CanBeUsed(const int& attr) const override;
\
size_t CodeSize(const int& d) const override; \
size_t CodeSize(const int& d) const override; \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
...
@@ -96,27 +96,27 @@ DECLARE_ACT_CREATOR(VSigmoid);
...
@@ -96,27 +96,27 @@ DECLARE_ACT_CREATOR(VSigmoid);
DECLARE_ACT_CREATOR
(
VTanh
);
DECLARE_ACT_CREATOR
(
VTanh
);
// TODO(TJ): tuning use me
// TODO(TJ): tuning use me
bool
VReluCreator
::
UseMe
(
const
int
&
d
)
const
{
bool
VReluCreator
::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
bool
VSquareCreator
::
UseMe
(
const
int
&
d
)
const
{
bool
VSquareCreator
::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
bool
VIdentityCreator
::
UseMe
(
const
int
&
d
)
const
{
bool
VIdentityCreator
::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
bool
VExpCreator
::
UseMe
(
const
int
&
d
)
const
{
bool
VExpCreator
::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
<
32
;
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
<
32
;
}
}
bool
VSigmoidCreator
::
UseMe
(
const
int
&
d
)
const
{
bool
VSigmoidCreator
::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
bool
VTanhCreator
::
UseMe
(
const
int
&
d
)
const
{
bool
VTanhCreator
::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
...
...
paddle/fluid/operators/jit/gen/blas.cc
浏览文件 @
45bdd84d
...
@@ -142,7 +142,7 @@ void NCHW16CMulNCJitCode::genCode() {
...
@@ -142,7 +142,7 @@ void NCHW16CMulNCJitCode::genCode() {
class
NCHW16CMulNCCreator
:
public
JitCodeCreator
<
int
>
{
class
NCHW16CMulNCCreator
:
public
JitCodeCreator
<
int
>
{
public:
public:
bool
UseMe
(
const
int
&
attr
)
const
override
{
bool
CanBeUsed
(
const
int
&
attr
)
const
override
{
return
platform
::
MayIUse
(
platform
::
avx512f
);
return
platform
::
MayIUse
(
platform
::
avx512f
);
}
}
size_t
CodeSize
(
const
int
&
d
)
const
override
{
return
256
*
1024
;
}
size_t
CodeSize
(
const
int
&
d
)
const
override
{
return
256
*
1024
;
}
...
@@ -154,7 +154,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
...
@@ -154,7 +154,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
#define DECLARE_BLAS_CREATOR(name) \
#define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
class name##Creator : public JitCodeCreator<int> { \
public: \
public: \
bool
UseMe(const int& attr) const override {
\
bool
CanBeUsed(const int& attr) const override {
\
return platform::MayIUse(platform::avx) && attr <= 1024; \
return platform::MayIUse(platform::avx) && attr <= 1024; \
} \
} \
size_t CodeSize(const int& d) const override { \
size_t CodeSize(const int& d) const override { \
...
...
paddle/fluid/operators/jit/gen/embseqpool.cc
浏览文件 @
45bdd84d
...
@@ -121,7 +121,7 @@ void EmbSeqPoolJitCode::genCode() {
...
@@ -121,7 +121,7 @@ void EmbSeqPoolJitCode::genCode() {
class
EmbSeqPoolCreator
:
public
JitCodeCreator
<
emb_seq_pool_attr_t
>
{
class
EmbSeqPoolCreator
:
public
JitCodeCreator
<
emb_seq_pool_attr_t
>
{
public:
public:
bool
UseMe
(
const
emb_seq_pool_attr_t
&
attr
)
const
override
{
bool
CanBeUsed
(
const
emb_seq_pool_attr_t
&
attr
)
const
override
{
return
platform
::
MayIUse
(
platform
::
avx
)
&&
return
platform
::
MayIUse
(
platform
::
avx
)
&&
attr
.
table_width
%
YMM_FLOAT_BLOCK
==
0
;
attr
.
table_width
%
YMM_FLOAT_BLOCK
==
0
;
}
}
...
...
paddle/fluid/operators/jit/gen/gru.cc
浏览文件 @
45bdd84d
...
@@ -86,7 +86,7 @@ void GRUJitCode::genCode() {
...
@@ -86,7 +86,7 @@ void GRUJitCode::genCode() {
class name##Creator : public JitCodeCreator<gru_attr_t> { \
class name##Creator : public JitCodeCreator<gru_attr_t> { \
public: \
public: \
/* TODO(TJ): enable more */
\
/* TODO(TJ): enable more */
\
bool
UseMe(const gru_attr_t& attr) const override {
\
bool
CanBeUsed(const gru_attr_t& attr) const override {
\
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
} \
size_t CodeSize(const gru_attr_t& attr) const override { \
size_t CodeSize(const gru_attr_t& attr) const override { \
...
...
paddle/fluid/operators/jit/gen/hopv.cc
浏览文件 @
45bdd84d
...
@@ -76,7 +76,7 @@ void HOPVJitCode::genCode() {
...
@@ -76,7 +76,7 @@ void HOPVJitCode::genCode() {
#define DECLARE_HOP_CREATOR(name) \
#define DECLARE_HOP_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
class name##Creator : public JitCodeCreator<int> { \
public: \
public: \
bool
UseMe(const int& attr) const override {
\
bool
CanBeUsed(const int& attr) const override {
\
return platform::MayIUse(platform::avx); \
return platform::MayIUse(platform::avx); \
} \
} \
size_t CodeSize(const int& d) const override { \
size_t CodeSize(const int& d) const override { \
...
...
paddle/fluid/operators/jit/gen/jitcode.h
浏览文件 @
45bdd84d
...
@@ -73,7 +73,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
...
@@ -73,7 +73,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
virtual
void
genCode
()
=
0
;
virtual
void
genCode
()
=
0
;
size_t
getSize
()
const
override
{
return
CodeGenerator
::
getSize
();
}
size_t
getSize
()
const
override
{
return
CodeGenerator
::
getSize
();
}
const
unsigned
char
*
getCodeInternal
()
override
{
const
unsigned
char
*
getCodeInternal
()
const
override
{
const
Xbyak
::
uint8
*
code
=
CodeGenerator
::
getCode
();
const
Xbyak
::
uint8
*
code
=
CodeGenerator
::
getCode
();
return
code
;
return
code
;
}
}
...
...
paddle/fluid/operators/jit/gen/lstm.cc
浏览文件 @
45bdd84d
...
@@ -114,7 +114,7 @@ void LSTMJitCode::genCode() {
...
@@ -114,7 +114,7 @@ void LSTMJitCode::genCode() {
class name##Creator : public JitCodeCreator<lstm_attr_t> { \
class name##Creator : public JitCodeCreator<lstm_attr_t> { \
public: \
public: \
/* TODO(TJ): enable more */
\
/* TODO(TJ): enable more */
\
bool
UseMe(const lstm_attr_t& attr) const override {
\
bool
CanBeUsed(const lstm_attr_t& attr) const override {
\
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
} \
size_t CodeSize(const lstm_attr_t& attr) const override { \
size_t CodeSize(const lstm_attr_t& attr) const override { \
...
...
paddle/fluid/operators/jit/gen/matmul.cc
浏览文件 @
45bdd84d
...
@@ -98,7 +98,7 @@ void MatMulJitCode::genCode() {
...
@@ -98,7 +98,7 @@ void MatMulJitCode::genCode() {
class
MatMulCreator
:
public
JitCodeCreator
<
matmul_attr_t
>
{
class
MatMulCreator
:
public
JitCodeCreator
<
matmul_attr_t
>
{
public:
public:
bool
UseMe
(
const
matmul_attr_t
&
attr
)
const
override
{
bool
CanBeUsed
(
const
matmul_attr_t
&
attr
)
const
override
{
return
attr
.
m
==
1
&&
platform
::
MayIUse
(
platform
::
avx512f
)
&&
return
attr
.
m
==
1
&&
platform
::
MayIUse
(
platform
::
avx512f
)
&&
attr
.
n
%
ZMM_FLOAT_BLOCK
==
0
&&
attr
.
k
<
512
;
attr
.
n
%
ZMM_FLOAT_BLOCK
==
0
&&
attr
.
k
<
512
;
}
}
...
...
paddle/fluid/operators/jit/gen/seqpool.cc
浏览文件 @
45bdd84d
...
@@ -57,7 +57,7 @@ void SeqPoolJitCode::genCode() {
...
@@ -57,7 +57,7 @@ void SeqPoolJitCode::genCode() {
class
SeqPoolCreator
:
public
JitCodeCreator
<
seq_pool_attr_t
>
{
class
SeqPoolCreator
:
public
JitCodeCreator
<
seq_pool_attr_t
>
{
public:
public:
bool
UseMe
(
const
seq_pool_attr_t
&
attr
)
const
override
{
bool
CanBeUsed
(
const
seq_pool_attr_t
&
attr
)
const
override
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
size_t
CodeSize
(
const
seq_pool_attr_t
&
attr
)
const
override
{
size_t
CodeSize
(
const
seq_pool_attr_t
&
attr
)
const
override
{
...
...
paddle/fluid/operators/jit/gen/sgd.cc
浏览文件 @
45bdd84d
...
@@ -104,7 +104,7 @@ void SgdJitCode::genCode() {
...
@@ -104,7 +104,7 @@ void SgdJitCode::genCode() {
class
SgdCreator
:
public
JitCodeCreator
<
sgd_attr_t
>
{
class
SgdCreator
:
public
JitCodeCreator
<
sgd_attr_t
>
{
public:
public:
bool
UseMe
(
const
sgd_attr_t
&
attr
)
const
override
{
bool
CanBeUsed
(
const
sgd_attr_t
&
attr
)
const
override
{
return
platform
::
MayIUse
(
platform
::
avx
)
&&
return
platform
::
MayIUse
(
platform
::
avx
)
&&
attr
.
grad_width
%
YMM_FLOAT_BLOCK
==
0
;
attr
.
grad_width
%
YMM_FLOAT_BLOCK
==
0
;
}
}
...
...
paddle/fluid/operators/jit/gen/vbroadcast.cc
浏览文件 @
45bdd84d
...
@@ -69,7 +69,7 @@ void VBroadcastJitCode::genCode() {
...
@@ -69,7 +69,7 @@ void VBroadcastJitCode::genCode() {
class
VBroadcastCreator
:
public
JitCodeCreator
<
int64_t
>
{
class
VBroadcastCreator
:
public
JitCodeCreator
<
int64_t
>
{
public:
public:
bool
UseMe
(
const
int64_t
&
w
)
const
override
{
bool
CanBeUsed
(
const
int64_t
&
w
)
const
override
{
return
platform
::
MayIUse
(
platform
::
avx
)
&&
w
%
YMM_FLOAT_BLOCK
==
0
;
return
platform
::
MayIUse
(
platform
::
avx
)
&&
w
%
YMM_FLOAT_BLOCK
==
0
;
}
}
size_t
CodeSize
(
const
int64_t
&
w
)
const
override
{
size_t
CodeSize
(
const
int64_t
&
w
)
const
override
{
...
...
paddle/fluid/operators/jit/gen_base.cc
浏览文件 @
45bdd84d
...
@@ -31,7 +31,7 @@ namespace paddle {
...
@@ -31,7 +31,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
jit
{
namespace
jit
{
// refer do not need
useme
, it would be the last one.
// refer do not need
CanBeUsed
, it would be the last one.
void
GenBase
::
dumpCode
(
const
unsigned
char
*
code
)
const
{
void
GenBase
::
dumpCode
(
const
unsigned
char
*
code
)
const
{
if
(
code
)
{
if
(
code
)
{
static
int
counter
=
0
;
static
int
counter
=
0
;
...
...
paddle/fluid/operators/jit/gen_base.h
浏览文件 @
45bdd84d
...
@@ -31,9 +31,10 @@ class GenBase : public Kernel {
...
@@ -31,9 +31,10 @@ class GenBase : public Kernel {
virtual
~
GenBase
()
=
default
;
virtual
~
GenBase
()
=
default
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
size_t
getSize
()
const
=
0
;
virtual
size_t
getSize
()
const
=
0
;
virtual
const
unsigned
char
*
getCodeInternal
()
=
0
;
virtual
const
unsigned
char
*
getCodeInternal
()
const
=
0
;
const
char
*
ImplType
()
const
override
{
return
"JitCode"
;
}
template
<
typename
Func
>
template
<
typename
Func
>
Func
getCode
()
{
Func
getCode
()
const
{
const
unsigned
char
*
code
=
this
->
getCodeInternal
();
const
unsigned
char
*
code
=
this
->
getCodeInternal
();
if
(
FLAGS_dump_jitcode
)
{
if
(
FLAGS_dump_jitcode
)
{
this
->
dumpCode
(
code
);
this
->
dumpCode
(
code
);
...
@@ -65,7 +66,7 @@ class JitCodeCreator : public GenCreator {
...
@@ -65,7 +66,7 @@ class JitCodeCreator : public GenCreator {
virtual
~
JitCodeCreator
()
=
default
;
virtual
~
JitCodeCreator
()
=
default
;
// condition when this jit code can be used.
// condition when this jit code can be used.
virtual
bool
UseMe
(
const
Attr
&
attr
)
const
=
0
;
virtual
bool
CanBeUsed
(
const
Attr
&
attr
)
const
=
0
;
// estimate this code size
// estimate this code size
virtual
size_t
CodeSize
(
const
Attr
&
attr
)
const
=
0
;
virtual
size_t
CodeSize
(
const
Attr
&
attr
)
const
=
0
;
...
...
paddle/fluid/operators/jit/helper.h
浏览文件 @
45bdd84d
...
@@ -14,9 +14,6 @@
...
@@ -14,9 +14,6 @@
#pragma once
#pragma once
extern
"C"
{
#include <xxhash.h>
}
#include <iostream>
#include <iostream>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
...
@@ -36,31 +33,30 @@ template <typename KernelTuple, typename PlaceType>
...
@@ -36,31 +33,30 @@ template <typename KernelTuple, typename PlaceType>
inline
typename
std
::
enable_if
<
inline
typename
std
::
enable_if
<
std
::
is_same
<
typename
KernelTuple
::
data_type
,
float
>::
value
&&
std
::
is_same
<
typename
KernelTuple
::
data_type
,
float
>::
value
&&
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
,
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
,
typename
KernelTuple
::
func_type
>::
type
const
Kernel
*
>::
type
GetJitCode
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
GetJitCode
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
using
Func
=
typename
KernelTuple
::
func_type
;
using
Attr
=
typename
KernelTuple
::
attr_type
;
using
Attr
=
typename
KernelTuple
::
attr_type
;
size_t
key
=
JitCodeKey
<
Attr
>
(
attr
);
size_t
key
=
JitCodeKey
<
Attr
>
(
attr
);
auto
&
codes
=
JitCodePool
<
KernelTuple
::
kernel_type
>
().
Instance
();
auto
&
codes
=
JitCodePool
<
KernelTuple
::
kernel_type
>
::
Instance
();
if
(
codes
.
Has
(
key
))
{
if
(
codes
.
Has
(
key
))
{
return
codes
.
AllKernels
().
at
(
key
)
->
template
getCode
<
Func
>
();
return
codes
.
AllKernels
().
at
(
key
)
.
get
();
}
}
// creator is not related with attr, so can use KernelKey as key
// creator is not related with attr, so can use KernelKey as key
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
PlaceType
());
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
PlaceType
());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto
&
creator_map
=
JitCodeCreatorPool
().
Instance
().
AllCreators
();
auto
&
creator_map
=
JitCodeCreatorPool
::
Instance
().
AllCreators
();
auto
iter
=
creator_map
.
find
(
kkey
);
auto
iter
=
creator_map
.
find
(
kkey
);
if
(
iter
!=
creator_map
.
end
())
{
if
(
iter
!=
creator_map
.
end
())
{
auto
&
creators
=
iter
->
second
;
auto
&
creators
=
iter
->
second
;
for
(
auto
&
cur
:
creators
)
{
for
(
auto
&
cur
:
creators
)
{
auto
i
=
dynamic_cast
<
const
JitCodeCreator
<
Attr
>*>
(
cur
.
get
());
auto
i
=
dynamic_cast
<
const
JitCodeCreator
<
Attr
>*>
(
cur
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
if
(
i
&&
i
->
CanBeUsed
(
attr
))
{
auto
p
=
i
->
CreateJitCode
(
attr
);
auto
p
=
i
->
CreateJitCode
(
attr
);
if
(
p
)
{
if
(
p
)
{
auto
f
=
p
->
template
getCode
<
Func
>
();
auto
res
=
p
.
get
();
codes
.
Insert
(
key
,
std
::
move
(
p
));
codes
.
Insert
(
key
,
std
::
move
(
p
));
return
f
;
return
res
;
}
}
}
}
}
}
...
@@ -72,7 +68,7 @@ template <typename KernelTuple, typename PlaceType>
...
@@ -72,7 +68,7 @@ template <typename KernelTuple, typename PlaceType>
inline
typename
std
::
enable_if
<
inline
typename
std
::
enable_if
<
!
std
::
is_same
<
typename
KernelTuple
::
data_type
,
float
>::
value
||
!
std
::
is_same
<
typename
KernelTuple
::
data_type
,
float
>::
value
||
!
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
,
!
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
,
typename
KernelTuple
::
func_type
>::
type
const
Kernel
*
>::
type
GetJitCode
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
GetJitCode
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
return
nullptr
;
return
nullptr
;
}
}
...
@@ -80,8 +76,8 @@ GetJitCode(const typename KernelTuple::attr_type& attr) {
...
@@ -80,8 +76,8 @@ GetJitCode(const typename KernelTuple::attr_type& attr) {
// Refer code do not related with attr, which is just for cast
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
// Refer is always on CPUPlace
template
<
typename
KernelTuple
>
template
<
typename
KernelTuple
>
inline
typename
KernelTuple
::
func_type
GetRefer
()
{
inline
const
Kernel
*
GetReferKernel
()
{
auto
&
ref_pool
=
ReferKernelPool
().
Instance
().
AllKernels
();
auto
&
ref_pool
=
ReferKernelPool
::
Instance
().
AllKernels
();
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
platform
::
CPUPlace
());
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
platform
::
CPUPlace
());
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
PADDLE_ENFORCE
(
ref_iter
!=
ref_pool
.
end
(),
PADDLE_ENFORCE
(
ref_iter
!=
ref_pool
.
end
(),
...
@@ -90,36 +86,93 @@ inline typename KernelTuple::func_type GetRefer() {
...
@@ -90,36 +86,93 @@ inline typename KernelTuple::func_type GetRefer() {
for
(
auto
&
impl
:
ref_impls
)
{
for
(
auto
&
impl
:
ref_impls
)
{
auto
i
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
impl
.
get
());
auto
i
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
impl
.
get
());
if
(
i
)
{
if
(
i
)
{
return
i
->
GetFunc
()
;
return
i
;
}
}
}
}
return
nullptr
;
return
nullptr
;
}
}
template
<
typename
KernelTuple
,
typename
PlaceType
=
platform
::
CPUPlace
>
template
<
typename
KernelTuple
>
typename
KernelTuple
::
func_type
Get
(
inline
typename
KernelTuple
::
func_type
GetReferFunc
()
{
auto
ker
=
GetReferKernel
<
KernelTuple
>
();
auto
p
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
ker
);
PADDLE_ENFORCE
(
p
,
"The Refer kernel should exsit"
);
return
p
->
GetFunc
();
}
// Return all Kernels that can be used
template
<
typename
KernelTuple
,
typename
PlaceType
>
std
::
vector
<
const
Kernel
*>
GetAllCandidateKernels
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
const
typename
KernelTuple
::
attr_type
&
attr
)
{
auto
jitfunc
=
GetJitCode
<
KernelTuple
,
PlaceType
>
(
attr
);
// the search order shoudl be jitcode > more > refer
if
(
jitfunc
)
{
std
::
vector
<
const
Kernel
*>
res
;
return
jitfunc
;
auto
jitker
=
GetJitCode
<
KernelTuple
,
PlaceType
>
(
attr
);
if
(
jitker
)
{
res
.
emplace_back
(
jitker
);
}
}
// pool: (KernelKey(type, place), vector<KernelPtr>)
//
more kernel
pool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
PlaceType
());
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
PlaceType
());
auto
&
pool
=
KernelPool
().
Instance
().
AllKernels
();
auto
&
pool
=
KernelPool
::
Instance
().
AllKernels
();
auto
iter
=
pool
.
find
(
kkey
);
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
KernelMore
<
KernelTuple
>*>
(
impl
.
get
());
auto
i
=
dynamic_cast
<
const
KernelMore
<
KernelTuple
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
if
(
i
&&
i
->
CanBeUsed
(
attr
))
{
re
turn
i
->
GetFunc
(
);
re
s
.
emplace_back
(
i
);
}
}
}
}
}
}
// The last implementation should be reference function on CPUPlace.
// The last implementation should be reference function on CPUPlace.
return
GetRefer
<
KernelTuple
>
();
auto
ref
=
GetReferKernel
<
KernelTuple
>
();
PADDLE_ENFORCE
(
ref
!=
nullptr
,
"Refer Kernel can not be empty."
);
res
.
emplace_back
(
ref
);
return
res
;
}
template
<
typename
KernelTuple
,
typename
PlaceType
=
platform
::
CPUPlace
>
std
::
vector
<
std
::
pair
<
std
::
string
,
typename
KernelTuple
::
func_type
>>
GetAllCandidateFuncsWithTypes
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
using
Func
=
typename
KernelTuple
::
func_type
;
auto
kers
=
GetAllCandidateKernels
<
KernelTuple
,
PlaceType
>
(
attr
);
std
::
vector
<
std
::
pair
<
std
::
string
,
Func
>>
res
;
for
(
auto
k
:
kers
)
{
std
::
string
name
=
k
->
ImplType
();
if
(
name
==
"JitCode"
)
{
auto
i
=
dynamic_cast
<
const
GenBase
*>
(
k
);
PADDLE_ENFORCE
(
i
,
"jitcode kernel cast can not fail."
);
res
.
emplace_back
(
std
::
make_pair
(
name
,
i
->
template
getCode
<
Func
>()));
}
else
{
auto
i
=
dynamic_cast
<
const
KernelMore
<
KernelTuple
>*>
(
k
);
PADDLE_ENFORCE
(
i
,
"kernel cast can not fail."
);
res
.
emplace_back
(
std
::
make_pair
(
name
,
i
->
GetFunc
()));
}
}
return
res
;
}
template
<
typename
KernelTuple
,
typename
PlaceType
=
platform
::
CPUPlace
>
std
::
vector
<
typename
KernelTuple
::
func_type
>
GetAllCandidateFuncs
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
auto
funcs
=
GetAllCandidateFuncsWithTypes
<
KernelTuple
,
PlaceType
>
(
attr
);
std
::
vector
<
typename
KernelTuple
::
func_type
>
res
;
for
(
auto
&
i
:
funcs
)
{
res
.
emplace_back
(
i
.
second
);
}
return
res
;
}
template
<
typename
KernelTuple
,
typename
PlaceType
=
platform
::
CPUPlace
>
typename
KernelTuple
::
func_type
GetDefaultBestFunc
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
auto
funcs
=
GetAllCandidateFuncs
<
KernelTuple
,
PlaceType
>
(
attr
);
PADDLE_ENFORCE_GE
(
funcs
.
size
(),
1UL
);
// Here could do some runtime benchmark of this attr and return the best one.
// But yet just get the first one as the default best one,
// which is searched in order and tuned by offline.
return
funcs
[
0
];
}
}
template
<
typename
KernelTuple
,
typename
PlaceType
>
template
<
typename
KernelTuple
,
typename
PlaceType
>
...
@@ -134,17 +187,13 @@ class KernelFuncs {
...
@@ -134,17 +187,13 @@ class KernelFuncs {
// the exposed interface to use
// the exposed interface to use
typename
KernelTuple
::
func_type
At
(
typename
KernelTuple
::
func_type
At
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
const
typename
KernelTuple
::
attr_type
&
attr
)
{
// XXH64: 13.8 GB/s
// Maybe here is not good enough, not all kernels should have jitcode
// TODO(TJ): change me, maybe not all attr change need one key, should be
int64_t
key
=
JitCodeKey
<
typename
KernelTuple
::
attr_type
>
(
attr
);
// attrkey
int64_t
key
=
XXH64
(
&
attr
,
sizeof
(
typename
KernelTuple
::
attr_type
),
0
);
if
(
Has
(
key
))
{
if
(
Has
(
key
))
{
return
funcs_
.
at
(
key
);
return
funcs_
.
at
(
key
);
}
}
// If do not have this attr in cache,
// If do not have this attr in cache then get the default best
// then could run some runtime benchmark of this attr and save the best one.
auto
func
=
GetDefaultBestFunc
<
KernelTuple
,
PlaceType
>
(
attr
);
// Here just get the offline benchmarked best one.
auto
func
=
Get
<
KernelTuple
,
PlaceType
>
(
attr
);
Insert
(
key
,
func
);
Insert
(
key
,
func
);
return
func
;
return
func
;
}
}
...
@@ -156,7 +205,6 @@ class KernelFuncs {
...
@@ -156,7 +205,6 @@ class KernelFuncs {
protected:
protected:
bool
Has
(
int64_t
key
)
const
{
return
funcs_
.
find
(
key
)
!=
funcs_
.
end
();
}
bool
Has
(
int64_t
key
)
const
{
return
funcs_
.
find
(
key
)
!=
funcs_
.
end
();
}
void
Insert
(
int64_t
key
,
typename
KernelTuple
::
func_type
func
)
{
void
Insert
(
int64_t
key
,
typename
KernelTuple
::
func_type
func
)
{
funcs_
.
emplace
(
key
,
func
);
funcs_
.
emplace
(
key
,
func
);
}
}
...
...
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
45bdd84d
...
@@ -302,6 +302,7 @@ class Kernel {
...
@@ -302,6 +302,7 @@ class Kernel {
public:
public:
Kernel
()
=
default
;
Kernel
()
=
default
;
virtual
~
Kernel
()
=
default
;
virtual
~
Kernel
()
=
default
;
virtual
const
char
*
ImplType
()
const
=
0
;
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
};
};
...
@@ -312,8 +313,8 @@ class KernelMore : public Kernel {
...
@@ -312,8 +313,8 @@ class KernelMore : public Kernel {
using
Func
=
typename
KernelTuple
::
func_type
;
using
Func
=
typename
KernelTuple
::
func_type
;
using
Attr
=
typename
KernelTuple
::
attr_type
;
using
Attr
=
typename
KernelTuple
::
attr_type
;
virtual
Func
GetFunc
()
const
{
return
func
;
}
virtual
Func
GetFunc
()
const
{
return
func
;
}
virtual
bool
UseMe
(
const
Attr
&
attr
)
const
=
0
;
// specify this kernel can be used, means it should not fail if use it.
virtual
const
char
*
ImplType
(
)
const
=
0
;
virtual
bool
CanBeUsed
(
const
Attr
&
attr
)
const
=
0
;
protected:
protected:
Func
func
{
nullptr
};
Func
func
{
nullptr
};
...
@@ -323,7 +324,7 @@ template <typename KernelTuple>
...
@@ -323,7 +324,7 @@ template <typename KernelTuple>
class
ReferKernel
:
public
KernelMore
<
KernelTuple
>
{
class
ReferKernel
:
public
KernelMore
<
KernelTuple
>
{
public:
public:
// Refer code can always be used
// Refer code can always be used
bool
UseMe
(
const
typename
KernelTuple
::
attr_type
&
attr
)
const
override
{
bool
CanBeUsed
(
const
typename
KernelTuple
::
attr_type
&
attr
)
const
override
{
return
true
;
return
true
;
}
}
const
char
*
ImplType
()
const
override
{
return
"Refer"
;
}
const
char
*
ImplType
()
const
override
{
return
"Refer"
;
}
...
...
paddle/fluid/operators/jit/kernel_key.cc
浏览文件 @
45bdd84d
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
* limitations under the License. */
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/operators/jit/kernel_key.h"
#include <xxhash.h>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -49,6 +50,8 @@ static inline int act_type_convert(KernelType type) {
...
@@ -49,6 +50,8 @@ static inline int act_type_convert(KernelType type) {
template
<
>
template
<
>
size_t
JitCodeKey
<
lstm_attr_t
>
(
const
lstm_attr_t
&
attr
)
{
size_t
JitCodeKey
<
lstm_attr_t
>
(
const
lstm_attr_t
&
attr
)
{
// XXH64: 13.8 GB/s
size_t
key
=
attr
.
d
;
size_t
key
=
attr
.
d
;
int
gate_key
=
act_type_convert
(
attr
.
act_gate
)
<<
1
;
int
gate_key
=
act_type_convert
(
attr
.
act_gate
)
<<
1
;
int
cand_key
=
act_type_convert
(
attr
.
act_cand
)
<<
(
1
+
act_type_shift
);
int
cand_key
=
act_type_convert
(
attr
.
act_cand
)
<<
(
1
+
act_type_shift
);
...
...
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc
浏览文件 @
45bdd84d
...
@@ -161,7 +161,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
...
@@ -161,7 +161,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
}
}
}
}
bool
CRFDecodingKernel
::
UseMe
(
const
int
&
d
)
const
{
bool
CRFDecodingKernel
::
CanBeUsed
(
const
int
&
d
)
const
{
#ifdef __AVX512F__
#ifdef __AVX512F__
constexpr
int
block
=
ZMM_FLOAT_BLOCK
;
constexpr
int
block
=
ZMM_FLOAT_BLOCK
;
#else
#else
...
...
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h
浏览文件 @
45bdd84d
...
@@ -29,7 +29,8 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
...
@@ -29,7 +29,8 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
class
CRFDecodingKernel
:
public
KernelMore
<
CRFDecodingTuple
<
float
>>
{
class
CRFDecodingKernel
:
public
KernelMore
<
CRFDecodingTuple
<
float
>>
{
public:
public:
CRFDecodingKernel
()
{
this
->
func
=
CRFDecoding
;
}
CRFDecodingKernel
()
{
this
->
func
=
CRFDecoding
;
}
bool
UseMe
(
const
typename
CRFDecodingTuple
<
float
>::
attr_type
&
)
const
override
;
bool
CanBeUsed
(
const
typename
CRFDecodingTuple
<
float
>::
attr_type
&
)
const
override
;
const
char
*
ImplType
()
const
override
{
return
"Intrinsic"
;
}
const
char
*
ImplType
()
const
override
{
return
"Intrinsic"
;
}
};
};
...
...
paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
浏览文件 @
45bdd84d
...
@@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
...
@@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
}
}
}
}
bool
LayerNormKernel
::
UseMe
(
const
int
&
d
)
const
{
bool
LayerNormKernel
::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
>=
YMM_FLOAT_BLOCK
;
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
>=
YMM_FLOAT_BLOCK
;
}
}
...
...
paddle/fluid/operators/jit/more/intrinsic/layer_norm.h
浏览文件 @
45bdd84d
...
@@ -30,7 +30,8 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
...
@@ -30,7 +30,8 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
class
LayerNormKernel
:
public
KernelMore
<
LayerNormTuple
<
float
>>
{
class
LayerNormKernel
:
public
KernelMore
<
LayerNormTuple
<
float
>>
{
public:
public:
LayerNormKernel
()
{
this
->
func
=
LayerNorm
;
}
LayerNormKernel
()
{
this
->
func
=
LayerNorm
;
}
bool
UseMe
(
const
typename
LayerNormTuple
<
float
>::
attr_type
&
)
const
override
;
bool
CanBeUsed
(
const
typename
LayerNormTuple
<
float
>::
attr_type
&
)
const
override
;
const
char
*
ImplType
()
const
override
{
return
"Intrinsic"
;
}
const
char
*
ImplType
()
const
override
{
return
"Intrinsic"
;
}
};
};
...
...
paddle/fluid/operators/jit/more/mix/mix.cc
浏览文件 @
45bdd84d
...
@@ -204,21 +204,21 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
...
@@ -204,21 +204,21 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
}
}
// TODO(TJ): tuning me
// TODO(TJ): tuning me
bool
VSigmoidKernel
::
UseMe
(
const
int
&
d
)
const
{
return
true
;
}
bool
VSigmoidKernel
::
CanBeUsed
(
const
int
&
d
)
const
{
return
true
;
}
bool
VTanhKernel
::
UseMe
(
const
int
&
d
)
const
{
return
true
;
}
bool
VTanhKernel
::
CanBeUsed
(
const
int
&
d
)
const
{
return
true
;
}
bool
SoftmaxKernel
::
UseMe
(
const
int
&
d
)
const
{
return
true
;
}
bool
SoftmaxKernel
::
CanBeUsed
(
const
int
&
d
)
const
{
return
true
;
}
bool
LSTMCtHtKernel
::
UseMe
(
const
lstm_attr_t
&
attr
)
const
{
return
true
;
}
bool
LSTMCtHtKernel
::
CanBeUsed
(
const
lstm_attr_t
&
attr
)
const
{
return
true
;
}
bool
LSTMC1H1Kernel
::
UseMe
(
const
lstm_attr_t
&
attr
)
const
{
return
true
;
}
bool
LSTMC1H1Kernel
::
CanBeUsed
(
const
lstm_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUH1Kernel
::
UseMe
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUH1Kernel
::
CanBeUsed
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUHtPart1Kernel
::
UseMe
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUHtPart1Kernel
::
CanBeUsed
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUHtPart2Kernel
::
UseMe
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUHtPart2Kernel
::
CanBeUsed
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
}
// namespace mix
}
// namespace mix
}
// namespace more
}
// namespace more
...
...
paddle/fluid/operators/jit/more/mix/mix.h
浏览文件 @
45bdd84d
...
@@ -34,12 +34,12 @@ void GRUH1(gru_t* step, const gru_attr_t* attr);
...
@@ -34,12 +34,12 @@ void GRUH1(gru_t* step, const gru_attr_t* attr);
void
GRUHtPart1
(
gru_t
*
step
,
const
gru_attr_t
*
attr
);
void
GRUHtPart1
(
gru_t
*
step
,
const
gru_attr_t
*
attr
);
void
GRUHtPart2
(
gru_t
*
step
,
const
gru_attr_t
*
attr
);
void
GRUHtPart2
(
gru_t
*
step
,
const
gru_attr_t
*
attr
);
#define DECLARE_MORE_KERNEL(name) \
#define DECLARE_MORE_KERNEL(name)
\
class name##Kernel : public KernelMore<name##Tuple<T>> { \
class name##Kernel : public KernelMore<name##Tuple<T>> {
\
public: \
public:
\
name##Kernel() { this->func = name; } \
name##Kernel() { this->func = name; }
\
bool
UseMe
(const typename name##Tuple<T>::attr_type&) const override; \
bool
CanBeUsed
(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \
const char* ImplType() const override { return "Mixed"; }
\
}
}
// XYN
// XYN
...
...
paddle/fluid/operators/jit/more/mkl/mkl.cc
浏览文件 @
45bdd84d
...
@@ -130,105 +130,106 @@ void ASum<double>(const double* x, double* res, int n) {
...
@@ -130,105 +130,106 @@ void ASum<double>(const double* x, double* res, int n) {
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template
<
>
template
<
>
bool
VMulKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VMulKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
}
template
<
>
template
<
>
bool
VAddKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VAddKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
>
512
;
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
>
512
;
}
}
template
<
>
template
<
>
bool
VScalKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VScalKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
}
template
<
>
template
<
>
bool
VExpKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VExpKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
d
>
7
;
return
d
>
7
;
}
}
template
<
>
template
<
>
bool
VSquareKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VSquareKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
d
>
7
;
return
d
>
7
;
}
}
template
<
>
template
<
>
bool
VCopyKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VCopyKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
d
>
15
;
return
d
>
15
;
}
}
template
<
>
template
<
>
bool
VBroadcastKernel
<
float
>::
UseMe
(
const
int64_t
&
d
)
const
{
bool
VBroadcastKernel
<
float
>::
CanBeUsed
(
const
int64_t
&
d
)
const
{
return
d
>
127
;
return
d
>
127
;
}
}
template
<
>
template
<
>
bool
VBroadcastKernel
<
double
>::
UseMe
(
const
int64_t
&
attr
)
const
{
bool
VBroadcastKernel
<
double
>::
CanBeUsed
(
const
int64_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
VSigmoidKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VSigmoidKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
d
>
7
;
return
d
>
7
;
}
}
template
<
>
template
<
>
bool
VTanhKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VTanhKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
d
>
7
;
return
d
>
7
;
}
}
template
<
>
template
<
>
bool
SeqPoolKernel
<
float
>::
UseMe
(
const
seq_pool_attr_t
&
attr
)
const
{
bool
SeqPoolKernel
<
float
>::
CanBeUsed
(
const
seq_pool_attr_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
SeqPoolKernel
<
double
>::
UseMe
(
const
seq_pool_attr_t
&
attr
)
const
{
bool
SeqPoolKernel
<
double
>::
CanBeUsed
(
const
seq_pool_attr_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
EmbSeqPoolKernel
<
float
>::
UseMe
(
const
emb_seq_pool_attr_t
&
attr
)
const
{
bool
EmbSeqPoolKernel
<
float
>::
CanBeUsed
(
const
emb_seq_pool_attr_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
EmbSeqPoolKernel
<
double
>::
UseMe
(
const
emb_seq_pool_attr_t
&
attr
)
const
{
bool
EmbSeqPoolKernel
<
double
>::
CanBeUsed
(
const
emb_seq_pool_attr_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
SgdKernel
<
float
>::
UseMe
(
const
sgd_attr_t
&
attr
)
const
{
bool
SgdKernel
<
float
>::
CanBeUsed
(
const
sgd_attr_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
SgdKernel
<
double
>::
UseMe
(
const
sgd_attr_t
&
attr
)
const
{
bool
SgdKernel
<
double
>::
CanBeUsed
(
const
sgd_attr_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
MatMulKernel
<
float
>::
UseMe
(
const
matmul_attr_t
&
attr
)
const
{
bool
MatMulKernel
<
float
>::
CanBeUsed
(
const
matmul_attr_t
&
attr
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
template
<
>
template
<
>
bool
MatMulKernel
<
double
>::
UseMe
(
const
matmul_attr_t
&
attr
)
const
{
bool
MatMulKernel
<
double
>::
CanBeUsed
(
const
matmul_attr_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
SoftmaxKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
SoftmaxKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
// tuned on avx2
// tuned on avx2
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
<
60
;
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
<
60
;
}
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
#define AWALYS_USE_ME_WITH_DOUBLE(func)
\
template <> \
template <>
\
bool func##Kernel<double>::
UseMe
(const int& d) const { \
bool func##Kernel<double>::
CanBeUsed
(const int& d) const { \
return true; \
return true;
\
}
}
AWALYS_USE_ME_WITH_DOUBLE
(
VMul
);
AWALYS_USE_ME_WITH_DOUBLE
(
VMul
);
...
...
paddle/fluid/operators/jit/more/mkl/mkl.h
浏览文件 @
45bdd84d
...
@@ -175,13 +175,13 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
...
@@ -175,13 +175,13 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
}
}
}
}
#define DECLARE_MKL_KERNEL(name) \
#define DECLARE_MKL_KERNEL(name)
\
template <typename T> \
template <typename T>
\
class name##Kernel : public KernelMore<name##Tuple<T>> { \
class name##Kernel : public KernelMore<name##Tuple<T>> {
\
public: \
public:
\
name##Kernel() { this->func = name<T>; } \
name##Kernel() { this->func = name<T>; }
\
bool
UseMe
(const typename name##Tuple<T>::attr_type&) const override; \
bool
CanBeUsed
(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \
const char* ImplType() const override { return "MKL"; }
\
}
}
// ABCMNK
// ABCMNK
...
...
paddle/fluid/operators/jit/registry.h
浏览文件 @
45bdd84d
...
@@ -49,8 +49,8 @@ struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
...
@@ -49,8 +49,8 @@ struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
void
operator
()(
KernelType
kt
)
const
{
void
operator
()(
KernelType
kt
)
const
{
KernelKey
kkey
(
kt
,
PlaceType
());
KernelKey
kkey
(
kt
,
PlaceType
());
Pool
().
Instance
().
Insert
(
kkey
,
Pool
::
Instance
().
Insert
(
kkey
,
std
::
move
(
make_unique
<
const
KERNEL_IMPL_TYPE
>
()));
std
::
move
(
make_unique
<
const
KERNEL_IMPL_TYPE
>
()));
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelImpls
...
>>::
value
;
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelImpls
...
>>::
value
;
JitKernelRegistrarFunctor
<
Pool
,
PlaceType
,
I
+
1
==
size
,
I
+
1
,
JitKernelRegistrarFunctor
<
Pool
,
PlaceType
,
I
+
1
==
size
,
I
+
1
,
KernelImpls
...
>
KernelImpls
...
>
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
45bdd84d
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <algorithm>
#include <algorithm>
#include <iostream>
#include <random>
#include <random>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -68,31 +69,11 @@ template <typename KernelTuple, typename PlaceType, typename Tester,
...
@@ -68,31 +69,11 @@ template <typename KernelTuple, typename PlaceType, typename Tester,
typename
...
Args
>
typename
...
Args
>
void
TestAllImpls
(
const
typename
KernelTuple
::
attr_type
&
attr
,
void
TestAllImpls
(
const
typename
KernelTuple
::
attr_type
&
attr
,
const
Tester
&
verifier
,
const
Args
&
...
args
)
{
const
Tester
&
verifier
,
const
Args
&
...
args
)
{
// test jitcode
auto
funcs
=
jit
::
GetAllCandidateFuncsWithTypes
<
KernelTuple
,
PlaceType
>
(
attr
);
auto
jitcode
=
jit
::
GetJitCode
<
KernelTuple
,
PlaceType
>
(
attr
);
for
(
auto
f
:
funcs
)
{
if
(
jitcode
)
{
VLOG
(
10
)
<<
"Test Kernel "
<<
f
.
first
;
VLOG
(
10
)
<<
"Test Jitcode Kernel "
;
verifier
(
f
.
second
,
args
...);
verifier
(
jitcode
,
args
...);
}
}
// test all impls in more
jit
::
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
PlaceType
());
auto
&
pool
=
jit
::
KernelPool
().
Instance
().
AllKernels
();
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
jit
::
KernelMore
<
KernelTuple
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
auto
more
=
i
->
GetFunc
();
VLOG
(
10
)
<<
"Test More Kernel : "
<<
i
->
ImplType
();
verifier
(
more
,
args
...);
}
}
}
// test result from Get function
VLOG
(
10
)
<<
"Test final get function "
;
auto
tgt
=
jit
::
KernelFuncs
<
KernelTuple
,
PlaceType
>::
Cache
().
At
(
attr
);
verifier
(
tgt
,
args
...);
}
}
template
<
typename
KernelTuple
,
typename
PlaceType
>
template
<
typename
KernelTuple
,
typename
PlaceType
>
...
@@ -100,7 +81,7 @@ void TestKernelXYZN() {
...
@@ -100,7 +81,7 @@ void TestKernelXYZN() {
using
T
=
typename
KernelTuple
::
data_type
;
using
T
=
typename
KernelTuple
::
data_type
;
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
for
(
int
d
:
TestSizes
())
{
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
d
),
y
(
d
),
zref
(
d
);
std
::
vector
<
T
>
x
(
d
),
y
(
d
),
zref
(
d
);
...
@@ -159,7 +140,7 @@ void TestKernelAXYN() {
...
@@ -159,7 +140,7 @@ void TestKernelAXYN() {
using
T
=
typename
KernelTuple
::
data_type
;
using
T
=
typename
KernelTuple
::
data_type
;
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
for
(
int
d
:
TestSizes
())
{
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
const
T
a
=
static_cast
<
T
>
(
3
);
const
T
a
=
static_cast
<
T
>
(
3
);
...
@@ -202,7 +183,7 @@ void TestKernelXYN() {
...
@@ -202,7 +183,7 @@ void TestKernelXYN() {
using
T
=
typename
KernelTuple
::
data_type
;
using
T
=
typename
KernelTuple
::
data_type
;
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
for
(
int
d
:
TestSizes
())
{
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
d
),
yref
(
d
);
std
::
vector
<
T
>
x
(
d
),
yref
(
d
);
...
@@ -245,7 +226,7 @@ void TestKernelXRN() {
...
@@ -245,7 +226,7 @@ void TestKernelXRN() {
auto
last_acc
=
FLAGS_acc
;
auto
last_acc
=
FLAGS_acc
;
FLAGS_acc
=
1e-4
;
FLAGS_acc
=
1e-4
;
for
(
int
d
:
TestSizes
())
{
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
d
);
std
::
vector
<
T
>
x
(
d
);
RandomVec
<
T
>
(
d
,
x
.
data
());
RandomVec
<
T
>
(
d
,
x
.
data
());
...
@@ -279,7 +260,7 @@ void TestKernelLSTM() {
...
@@ -279,7 +260,7 @@ void TestKernelLSTM() {
const
jit
::
lstm_attr_t
attr
(
const
jit
::
lstm_attr_t
attr
(
d
,
jit
::
to_kerneltype
(
act_gate
),
jit
::
to_kerneltype
(
act_cand
),
d
,
jit
::
to_kerneltype
(
act_gate
),
jit
::
to_kerneltype
(
act_cand
),
jit
::
to_kerneltype
(
act_cell
),
use_peephole
);
jit
::
to_kerneltype
(
act_cell
),
use_peephole
);
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
xsrc
(
4
*
d
),
wp
(
3
*
d
),
ct_1
(
d
);
std
::
vector
<
T
>
xsrc
(
4
*
d
),
wp
(
3
*
d
),
ct_1
(
d
);
std
::
vector
<
T
>
ct_ref
(
d
),
ht_ref
(
d
),
checked
(
2
*
d
);
std
::
vector
<
T
>
ct_ref
(
d
),
ht_ref
(
d
),
checked
(
2
*
d
);
...
@@ -370,7 +351,7 @@ void TestKernelGRU() {
...
@@ -370,7 +351,7 @@ void TestKernelGRU() {
for
(
auto
&
act_cand
:
all_acts
)
{
for
(
auto
&
act_cand
:
all_acts
)
{
const
jit
::
gru_attr_t
attr
(
d
,
jit
::
to_kerneltype
(
act_gate
),
const
jit
::
gru_attr_t
attr
(
d
,
jit
::
to_kerneltype
(
act_gate
),
jit
::
to_kerneltype
(
act_cand
));
jit
::
to_kerneltype
(
act_cand
));
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
xsrc
(
3
*
d
),
ht_1
(
d
),
ht_ref
(
d
);
std
::
vector
<
T
>
xsrc
(
3
*
d
),
ht_1
(
d
),
ht_ref
(
d
);
RandomVec
<
T
>
(
3
*
d
,
xsrc
.
data
());
RandomVec
<
T
>
(
3
*
d
,
xsrc
.
data
());
...
@@ -423,7 +404,7 @@ void TestKernelNCHW16CMulNC() {
...
@@ -423,7 +404,7 @@ void TestKernelNCHW16CMulNC() {
using
T
=
typename
KernelTuple
::
data_type
;
using
T
=
typename
KernelTuple
::
data_type
;
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
const
int
n
=
3
,
c
=
16
*
4
,
h
=
10
,
w
=
10
;
const
int
n
=
3
,
c
=
16
*
4
,
h
=
10
,
w
=
10
;
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
int
sz
=
n
*
c
*
h
*
w
;
int
sz
=
n
*
c
*
h
*
w
;
std
::
vector
<
T
>
x
(
sz
),
y
(
n
*
c
),
zref
(
sz
);
std
::
vector
<
T
>
x
(
sz
),
y
(
n
*
c
),
zref
(
sz
);
...
@@ -439,7 +420,9 @@ void TestKernelNCHW16CMulNC() {
...
@@ -439,7 +420,9 @@ void TestKernelNCHW16CMulNC() {
constexpr
int
simd_width
=
ZMM_FLOAT_BLOCK
;
constexpr
int
simd_width
=
ZMM_FLOAT_BLOCK
;
int
C
=
c
/
simd_width
;
int
C
=
c
/
simd_width
;
auto
tgt
=
jit
::
KernelFuncs
<
KernelTuple
,
PlaceType
>::
Cache
().
At
(
0
);
auto
tgt
=
jit
::
KernelFuncs
<
KernelTuple
,
PlaceType
>::
Cache
().
At
(
0
);
auto
jitcode
=
jit
::
GetJitCode
<
KernelTuple
,
PlaceType
>
(
0
);
auto
funcs
=
jit
::
GetAllCandidateFuncs
<
KernelTuple
,
PlaceType
>
(
0
);
EXPECT_GT
(
funcs
.
size
(),
0UL
);
auto
jitcode
=
funcs
[
0
];
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_TRUE
(
tgt
!=
nullptr
);
if
(
std
::
is_same
<
T
,
float
>::
value
&&
if
(
std
::
is_same
<
T
,
float
>::
value
&&
...
@@ -482,7 +465,7 @@ void TestKernelLayerNorm() {
...
@@ -482,7 +465,7 @@ void TestKernelLayerNorm() {
int
left
=
n
*
x_dim_0
;
int
left
=
n
*
x_dim_0
;
for
(
int
x_dim_1
:
TestSizes
())
{
for
(
int
x_dim_1
:
TestSizes
())
{
int
right
=
x_dim_1
;
int
right
=
x_dim_1
;
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
int
sz
=
left
*
right
;
int
sz
=
left
*
right
;
std
::
vector
<
T
>
x
(
sz
),
mean
(
left
),
var
(
left
),
scale
(
right
),
bias
(
right
),
std
::
vector
<
T
>
x
(
sz
),
mean
(
left
),
var
(
left
),
scale
(
right
),
bias
(
right
),
...
@@ -555,7 +538,7 @@ void TestKernelCRFDecoding() {
...
@@ -555,7 +538,7 @@ void TestKernelCRFDecoding() {
test_sizes
.
erase
(
std
::
remove
(
test_sizes
.
begin
(),
test_sizes
.
end
(),
2000
));
test_sizes
.
erase
(
std
::
remove
(
test_sizes
.
begin
(),
test_sizes
.
end
(),
2000
));
for
(
int
seq_len
:
{
1
,
11
,
17
,
50
})
{
for
(
int
seq_len
:
{
1
,
11
,
17
,
50
})
{
for
(
int
tag_num
:
test_sizes
)
{
for
(
int
tag_num
:
test_sizes
)
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
int
x_sz
=
seq_len
*
tag_num
;
int
x_sz
=
seq_len
*
tag_num
;
int
w_sz
=
(
tag_num
+
state_trans_base_idx
)
*
tag_num
;
int
w_sz
=
(
tag_num
+
state_trans_base_idx
)
*
tag_num
;
...
@@ -606,7 +589,7 @@ void TestKernelSeqPool() {
...
@@ -606,7 +589,7 @@ void TestKernelSeqPool() {
jit
::
seq_pool_attr_t
attr
(
w
,
type
);
jit
::
seq_pool_attr_t
attr
(
w
,
type
);
for
(
int
h
:
test_sizes
)
{
for
(
int
h
:
test_sizes
)
{
attr
.
h
=
h
;
attr
.
h
=
h
;
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
h
*
w
),
yref
(
w
);
std
::
vector
<
T
>
x
(
h
*
w
),
yref
(
w
);
RandomVec
<
T
>
(
h
*
w
,
x
.
data
());
RandomVec
<
T
>
(
h
*
w
,
x
.
data
());
...
@@ -649,7 +632,7 @@ void TestKernelEmbSeqPool() {
...
@@ -649,7 +632,7 @@ void TestKernelEmbSeqPool() {
for
(
auto
type
:
pool_types
)
{
for
(
auto
type
:
pool_types
)
{
for
(
int
idx_w
:
{
1
,
2
,
10
,
16
})
{
for
(
int
idx_w
:
{
1
,
2
,
10
,
16
})
{
for
(
int
idx_h
:
{
1
,
2
,
9
,
13
,
16
})
{
for
(
int
idx_h
:
{
1
,
2
,
9
,
13
,
16
})
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
int64_t
>
idx
(
idx_h
*
idx_w
);
std
::
vector
<
int64_t
>
idx
(
idx_h
*
idx_w
);
RandomVec
<
int64_t
>
(
idx_h
*
idx_w
,
idx
.
data
(),
0
,
tbl_h
-
1
);
RandomVec
<
int64_t
>
(
idx_h
*
idx_w
,
idx
.
data
(),
0
,
tbl_h
-
1
);
...
@@ -701,7 +684,7 @@ void TestKernelMatMul() {
...
@@ -701,7 +684,7 @@ void TestKernelMatMul() {
for
(
int
m
:
{
1
,
2
,
3
,
4
})
{
for
(
int
m
:
{
1
,
2
,
3
,
4
})
{
for
(
int
n
:
{
1
,
2
,
3
,
4
})
{
for
(
int
n
:
{
1
,
2
,
3
,
4
})
{
for
(
int
k
:
TestSizes
())
{
for
(
int
k
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
a
(
m
*
k
),
b
(
k
*
n
),
c
(
m
*
n
);
std
::
vector
<
T
>
a
(
m
*
k
),
b
(
k
*
n
),
c
(
m
*
n
);
RandomVec
<
T
>
(
m
*
k
,
a
.
data
());
RandomVec
<
T
>
(
m
*
k
,
a
.
data
());
...
@@ -740,7 +723,7 @@ void TestKernelSoftmax() {
...
@@ -740,7 +723,7 @@ void TestKernelSoftmax() {
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
for
(
int
bs
:
{
1
,
2
,
10
})
{
for
(
int
bs
:
{
1
,
2
,
10
})
{
for
(
int
n
:
TestSizes
())
{
for
(
int
n
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
bs
*
n
),
y
(
bs
*
n
);
std
::
vector
<
T
>
x
(
bs
*
n
),
y
(
bs
*
n
);
RandomVec
<
T
>
(
bs
*
n
,
x
.
data
());
RandomVec
<
T
>
(
bs
*
n
,
x
.
data
());
...
@@ -808,7 +791,7 @@ void TestKernelSgd() {
...
@@ -808,7 +791,7 @@ void TestKernelSgd() {
RandomVec
<
T
>
(
rows_size
*
grad_w
,
grad
.
data
());
RandomVec
<
T
>
(
rows_size
*
grad_w
,
grad
.
data
());
const
int64_t
*
rows_data
=
rows
.
data
();
const
int64_t
*
rows_data
=
rows
.
data
();
const
T
*
grad_data
=
grad
.
data
();
const
T
*
grad_data
=
grad
.
data
();
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
jit
::
sgd_attr_t
attr
(
param_h
,
grad_w
,
rows_size
,
grad_w
,
rows_size
);
jit
::
sgd_attr_t
attr
(
param_h
,
grad_w
,
rows_size
,
grad_w
,
rows_size
);
ref
(
&
lr
,
param_data
,
grad_data
,
rows_data
,
out_data
,
&
attr
);
ref
(
&
lr
,
param_data
,
grad_data
,
rows_data
,
out_data
,
&
attr
);
...
@@ -874,7 +857,7 @@ void TestKernelVBroadcast() {
...
@@ -874,7 +857,7 @@ void TestKernelVBroadcast() {
RandomVec
<
T
>
(
w
,
x
.
data
());
RandomVec
<
T
>
(
w
,
x
.
data
());
const
T
*
x_data
=
x
.
data
();
const
T
*
x_data
=
x
.
data
();
for
(
int64_t
h
:
{
1
,
2
,
6
})
{
for
(
int64_t
h
:
{
1
,
2
,
6
})
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
y
(
w
*
h
);
std
::
vector
<
T
>
y
(
w
*
h
);
T
*
y_data
=
y
.
data
();
T
*
y_data
=
y
.
data
();
...
@@ -900,6 +883,135 @@ void TestKernelVBroadcast() {
...
@@ -900,6 +883,135 @@ void TestKernelVBroadcast() {
}
}
}
}
// test pool
TEST
(
JITKernel_pool
,
jitcreator
)
{
const
auto
&
jitcreators
=
jit
::
JitCodeCreatorPool
::
Instance
().
AllCreators
();
EXPECT_EQ
(
jitcreators
.
size
(),
25UL
);
}
TEST
(
JITKernel_pool
,
jitpool
)
{
// jitpool is related with attr
const
auto
&
kers
=
jit
::
JitCodePool
<
jit
::
kVAdd
>
().
Instance
().
AllKernels
();
EXPECT_EQ
(
kers
.
size
(),
0UL
);
jit
::
GetAllCandidateKernels
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>
(
3
);
// after call GetAllCandidateKernels, it will create jitcode Automatically
EXPECT_EQ
(
kers
.
size
(),
1UL
);
}
TEST
(
JITKernel_pool
,
more
)
{
const
auto
&
kers
=
jit
::
KernelPool
::
Instance
().
AllKernels
();
EXPECT_EQ
(
kers
.
size
(),
21UL
);
}
TEST
(
JITKernel_pool
,
refer
)
{
const
auto
&
kers
=
jit
::
ReferKernelPool
::
Instance
().
AllKernels
();
EXPECT_EQ
(
kers
.
size
(),
29UL
);
}
// test helper
TEST
(
JITKernel_helper
,
GetAllCandidateKernels
)
{
auto
fp_kers
=
jit
::
GetAllCandidateKernels
<
jit
::
VExpTuple
<
float
>
,
CPUPlace
>
(
10
);
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE
(
fp_kers
.
size
(),
1UL
);
// refer
#else
EXPECT_GE
(
fp_kers
.
size
(),
3UL
);
// jitcode, mkl, refer
#endif
auto
db_kers
=
jit
::
GetAllCandidateKernels
<
jit
::
VExpTuple
<
double
>
,
CPUPlace
>
(
10
);
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE
(
db_kers
.
size
(),
1UL
);
// refer
#else
EXPECT_GE
(
db_kers
.
size
(),
2UL
);
// mkl, refer
#endif
}
TEST
(
JITKernel_helper
,
GetAllCandidateFuncsWithTypes
)
{
auto
fp_kers
=
jit
::
GetAllCandidateFuncsWithTypes
<
jit
::
VExpTuple
<
float
>
,
CPUPlace
>
(
10
);
EXPECT_GE
(
fp_kers
.
size
(),
3UL
);
// jitcode, mkl, refer
auto
db_kers
=
jit
::
GetAllCandidateFuncsWithTypes
<
jit
::
VExpTuple
<
double
>
,
CPUPlace
>
(
10
);
EXPECT_GE
(
db_kers
.
size
(),
2UL
);
// mkl, refer
}
TEST
(
JITKernel_helper
,
GetAllCandidateFuncs
)
{
auto
funcs
=
jit
::
GetAllCandidateFuncs
<
jit
::
VExpTuple
<
float
>
,
CPUPlace
>
(
10
);
auto
kers
=
jit
::
GetAllCandidateKernels
<
jit
::
VExpTuple
<
float
>
,
CPUPlace
>
(
10
);
EXPECT_EQ
(
funcs
.
size
(),
kers
.
size
());
std
::
vector
<
float
>
x
(
10
),
tgt
(
10
);
RandomVec
<
float
>
(
10
,
x
.
data
());
auto
best
=
jit
::
GetDefaultBestFunc
<
jit
::
VExpTuple
<
float
>
,
CPUPlace
>
(
10
);
best
(
x
.
data
(),
tgt
.
data
(),
10
);
for
(
auto
f
:
funcs
)
{
std
::
vector
<
float
>
y
(
10
);
f
(
x
.
data
(),
y
.
data
(),
10
);
ExpectEQ
<
float
>
(
y
.
data
(),
tgt
.
data
(),
10
);
}
}
TEST
(
JITKernel_helper
,
attr
)
{
std
::
ostringstream
out
;
// KernelTypes
out
<<
jit
::
to_string
(
jit
::
kNone
)
<<
jit
::
to_string
(
jit
::
kCRFDecoding
)
<<
jit
::
to_string
(
jit
::
kEmbSeqPool
)
<<
jit
::
to_string
(
jit
::
kGRUH1
)
<<
jit
::
to_string
(
jit
::
kGRUHtPart1
)
<<
jit
::
to_string
(
jit
::
kGRUHtPart2
)
<<
jit
::
to_string
(
jit
::
kHSum
)
<<
jit
::
to_string
(
jit
::
kHMax
)
<<
jit
::
to_string
(
jit
::
kLSTMCtHt
)
<<
jit
::
to_string
(
jit
::
kLSTMC1H1
)
<<
jit
::
to_string
(
jit
::
kLayerNorm
)
<<
jit
::
to_string
(
jit
::
kMatMul
)
<<
jit
::
to_string
(
jit
::
kNCHW16CMulNC
)
<<
jit
::
to_string
(
jit
::
kSeqPool
)
<<
jit
::
to_string
(
jit
::
kSoftmax
)
<<
jit
::
to_string
(
jit
::
kVAdd
)
<<
jit
::
to_string
(
jit
::
kVAddBias
)
<<
jit
::
to_string
(
jit
::
kVAddRelu
)
<<
jit
::
to_string
(
jit
::
kVBroadcast
)
<<
jit
::
to_string
(
jit
::
kVCopy
)
<<
jit
::
to_string
(
jit
::
kVExp
)
<<
jit
::
to_string
(
jit
::
kVIdentity
)
<<
jit
::
to_string
(
jit
::
kVMul
)
<<
jit
::
to_string
(
jit
::
kVRelu
)
<<
jit
::
to_string
(
jit
::
kVScal
)
<<
jit
::
to_string
(
jit
::
kSgd
)
<<
jit
::
to_string
(
jit
::
kVSigmoid
)
<<
jit
::
to_string
(
jit
::
kVSquare
)
<<
jit
::
to_string
(
jit
::
kVSub
)
<<
jit
::
to_string
(
jit
::
kVTanh
);
EXPECT_EQ
(
out
.
str
().
size
(),
234
);
// SeqPoolTypes
out
.
str
(
""
);
out
<<
jit
::
to_string
(
jit
::
kSum
)
<<
jit
::
to_string
(
jit
::
kAvg
)
<<
jit
::
to_string
(
jit
::
kSqrt
);
EXPECT_EQ
(
out
.
str
().
size
(),
13
);
EXPECT_EQ
(
jit
::
to_kerneltype
(
"relu"
),
jit
::
kVRelu
);
EXPECT_EQ
(
jit
::
to_kerneltype
(
"Identity"
),
jit
::
kVIdentity
);
EXPECT_EQ
(
jit
::
to_kerneltype
(
"VEXP"
),
jit
::
kVExp
);
EXPECT_EQ
(
jit
::
to_kerneltype
(
"SigmoiD"
),
jit
::
kVSigmoid
);
EXPECT_EQ
(
jit
::
to_kerneltype
(
"VTanh"
),
jit
::
kVTanh
);
out
.
str
(
""
);
out
<<
jit
::
lstm_attr_t
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
EXPECT_EQ
(
out
.
str
().
size
(),
89
);
out
.
str
(
""
);
out
<<
jit
::
gru_attr_t
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
);
EXPECT_EQ
(
out
.
str
().
size
(),
52
);
out
.
str
(
""
);
out
<<
jit
::
seq_pool_attr_t
(
8
,
jit
::
SeqPoolType
::
kSum
);
EXPECT_EQ
(
out
.
str
().
size
(),
44
);
out
.
str
(
""
);
out
<<
jit
::
emb_seq_pool_attr_t
(
1
,
2
,
3
,
4
,
5
,
jit
::
SeqPoolType
::
kAvg
);
EXPECT_EQ
(
out
.
str
().
size
(),
93
);
out
.
str
(
""
);
out
<<
jit
::
sgd_attr_t
(
1
,
2
,
3
,
4
,
5
);
EXPECT_EQ
(
out
.
str
().
size
(),
81
);
out
.
str
(
""
);
out
<<
jit
::
matmul_attr_t
(
1
,
2
,
3
);
EXPECT_EQ
(
out
.
str
().
size
(),
14
);
}
// test kernerls
#define TestKernelVMul TestKernelXYZN
#define TestKernelVMul TestKernelXYZN
#define TestKernelVAdd TestKernelXYZN
#define TestKernelVAdd TestKernelXYZN
#define TestKernelVAddRelu TestKernelXYZN
#define TestKernelVAddRelu TestKernelXYZN
...
@@ -969,6 +1081,14 @@ TEST_CPU_KERNEL(Softmax);
...
@@ -969,6 +1081,14 @@ TEST_CPU_KERNEL(Softmax);
TEST_CPU_KERNEL
(
Sgd
);
TEST_CPU_KERNEL
(
Sgd
);
TEST_CPU_KERNEL
(
VBroadcast
);
TEST_CPU_KERNEL
(
VBroadcast
);
TEST
(
JITKernel
,
kernel_func
)
{
auto
f1
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
().
At
(
3
);
auto
f2
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
()[
3
];
EXPECT_TRUE
(
f1
!=
nullptr
);
EXPECT_TRUE
(
f1
==
f2
);
// TODO(TJ): check not equal
}
TEST
(
JITKernel_key
,
lstm
)
{
TEST
(
JITKernel_key
,
lstm
)
{
jit
::
lstm_attr_t
attr1
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr1
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr2
(
9
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr2
(
9
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
...
@@ -1000,11 +1120,3 @@ TEST(JITKernel_key, gru) {
...
@@ -1000,11 +1120,3 @@ TEST(JITKernel_key, gru) {
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key3
!=
key4
);
EXPECT_TRUE
(
key3
!=
key4
);
}
}
TEST
(
JITKernel
,
kernel_func
)
{
auto
f1
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
().
At
(
3
);
auto
f2
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
()[
3
];
EXPECT_TRUE
(
f1
!=
nullptr
);
EXPECT_TRUE
(
f1
==
f2
);
// TODO(TJ): check not equal
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录