Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
45bdd84d
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
45bdd84d
编写于
3月 10, 2019
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enhance the jitkernel helper and add unit tests
test=develop
上级
14a764c9
变更
27
显示空白变更内容
内联
并排
Showing
27 changed file
with
328 addition
and
182 deletion
+328
-182
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+3
-25
paddle/fluid/operators/jit/gen/act.cc
paddle/fluid/operators/jit/gen/act.cc
+7
-7
paddle/fluid/operators/jit/gen/blas.cc
paddle/fluid/operators/jit/gen/blas.cc
+2
-2
paddle/fluid/operators/jit/gen/embseqpool.cc
paddle/fluid/operators/jit/gen/embseqpool.cc
+1
-1
paddle/fluid/operators/jit/gen/gru.cc
paddle/fluid/operators/jit/gen/gru.cc
+1
-1
paddle/fluid/operators/jit/gen/hopv.cc
paddle/fluid/operators/jit/gen/hopv.cc
+1
-1
paddle/fluid/operators/jit/gen/jitcode.h
paddle/fluid/operators/jit/gen/jitcode.h
+1
-1
paddle/fluid/operators/jit/gen/lstm.cc
paddle/fluid/operators/jit/gen/lstm.cc
+1
-1
paddle/fluid/operators/jit/gen/matmul.cc
paddle/fluid/operators/jit/gen/matmul.cc
+1
-1
paddle/fluid/operators/jit/gen/seqpool.cc
paddle/fluid/operators/jit/gen/seqpool.cc
+1
-1
paddle/fluid/operators/jit/gen/sgd.cc
paddle/fluid/operators/jit/gen/sgd.cc
+1
-1
paddle/fluid/operators/jit/gen/vbroadcast.cc
paddle/fluid/operators/jit/gen/vbroadcast.cc
+1
-1
paddle/fluid/operators/jit/gen_base.cc
paddle/fluid/operators/jit/gen_base.cc
+1
-1
paddle/fluid/operators/jit/gen_base.h
paddle/fluid/operators/jit/gen_base.h
+4
-3
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+82
-34
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+4
-3
paddle/fluid/operators/jit/kernel_key.cc
paddle/fluid/operators/jit/kernel_key.cc
+3
-0
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc
+1
-1
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h
+2
-1
paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
+1
-1
paddle/fluid/operators/jit/more/intrinsic/layer_norm.h
paddle/fluid/operators/jit/more/intrinsic/layer_norm.h
+2
-1
paddle/fluid/operators/jit/more/mix/mix.cc
paddle/fluid/operators/jit/more/mix/mix.cc
+8
-8
paddle/fluid/operators/jit/more/mix/mix.h
paddle/fluid/operators/jit/more/mix/mix.h
+6
-6
paddle/fluid/operators/jit/more/mkl/mkl.cc
paddle/fluid/operators/jit/more/mkl/mkl.cc
+24
-23
paddle/fluid/operators/jit/more/mkl/mkl.h
paddle/fluid/operators/jit/more/mkl/mkl.h
+7
-7
paddle/fluid/operators/jit/registry.h
paddle/fluid/operators/jit/registry.h
+2
-2
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+160
-48
未找到文件。
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
45bdd84d
...
@@ -111,33 +111,11 @@ template <typename KernelTuple, typename PlaceType, typename... Args>
...
@@ -111,33 +111,11 @@ template <typename KernelTuple, typename PlaceType, typename... Args>
void
BenchAllImpls
(
const
typename
KernelTuple
::
attr_type
&
attr
,
Args
...
args
)
{
void
BenchAllImpls
(
const
typename
KernelTuple
::
attr_type
&
attr
,
Args
...
args
)
{
BenchFunc
<
KernelTuple
,
Args
...
>
benchmark
;
BenchFunc
<
KernelTuple
,
Args
...
>
benchmark
;
std
::
vector
<
std
::
pair
<
std
::
string
,
double
>>
infos
;
std
::
vector
<
std
::
pair
<
std
::
string
,
double
>>
infos
;
// test refer
auto
funcs
=
jit
::
GetAllCandidateFuncsWithTypes
<
KernelTuple
,
PlaceType
>
(
attr
);
auto
refer
=
jit
::
GetRefer
<
KernelTuple
>
();
for
(
auto
f
:
funcs
)
{
if
(
!
refer
)
{
infos
.
push_back
(
std
::
make_pair
(
f
.
first
,
benchmark
(
f
.
second
,
args
...)));
LOG
(
FATAL
)
<<
"Refer can not be empty!"
;
}
}
infos
.
push_back
(
std
::
make_pair
(
"Refer"
,
benchmark
(
refer
,
args
...)));
// test jitcode
auto
jitcode
=
jit
::
GetJitCode
<
KernelTuple
,
PlaceType
>
(
attr
);
if
(
jitcode
)
{
infos
.
push_back
(
std
::
make_pair
(
"JitCode"
,
benchmark
(
jitcode
,
args
...)));
}
// test all impls in more
jit
::
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
PlaceType
());
auto
&
pool
=
jit
::
KernelPool
().
Instance
().
AllKernels
();
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
jit
::
KernelMore
<
KernelTuple
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
auto
more
=
i
->
GetFunc
();
infos
.
push_back
(
std
::
make_pair
(
i
->
ImplType
(),
benchmark
(
more
,
args
...)));
}
}
}
// Test result from Get function
// Test result from Get function
auto
tgt
=
jit
::
KernelFuncs
<
KernelTuple
,
PlaceType
>::
Cache
().
At
(
attr
);
auto
tgt
=
jit
::
KernelFuncs
<
KernelTuple
,
PlaceType
>::
Cache
().
At
(
attr
);
if
(
!
tgt
)
{
if
(
!
tgt
)
{
...
...
paddle/fluid/operators/jit/gen/act.cc
浏览文件 @
45bdd84d
...
@@ -81,7 +81,7 @@ void VActJitCode::genCode() {
...
@@ -81,7 +81,7 @@ void VActJitCode::genCode() {
#define DECLARE_ACT_CREATOR(name) \
#define DECLARE_ACT_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
class name##Creator : public JitCodeCreator<int> { \
public: \
public: \
bool
UseMe(const int& attr) const override;
\
bool
CanBeUsed(const int& attr) const override;
\
size_t CodeSize(const int& d) const override; \
size_t CodeSize(const int& d) const override; \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
...
@@ -96,27 +96,27 @@ DECLARE_ACT_CREATOR(VSigmoid);
...
@@ -96,27 +96,27 @@ DECLARE_ACT_CREATOR(VSigmoid);
DECLARE_ACT_CREATOR
(
VTanh
);
DECLARE_ACT_CREATOR
(
VTanh
);
// TODO(TJ): tuning use me
// TODO(TJ): tuning use me
bool
VReluCreator
::
UseMe
(
const
int
&
d
)
const
{
bool
VReluCreator
::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
bool
VSquareCreator
::
UseMe
(
const
int
&
d
)
const
{
bool
VSquareCreator
::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
bool
VIdentityCreator
::
UseMe
(
const
int
&
d
)
const
{
bool
VIdentityCreator
::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
bool
VExpCreator
::
UseMe
(
const
int
&
d
)
const
{
bool
VExpCreator
::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
<
32
;
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
<
32
;
}
}
bool
VSigmoidCreator
::
UseMe
(
const
int
&
d
)
const
{
bool
VSigmoidCreator
::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
bool
VTanhCreator
::
UseMe
(
const
int
&
d
)
const
{
bool
VTanhCreator
::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
...
...
paddle/fluid/operators/jit/gen/blas.cc
浏览文件 @
45bdd84d
...
@@ -142,7 +142,7 @@ void NCHW16CMulNCJitCode::genCode() {
...
@@ -142,7 +142,7 @@ void NCHW16CMulNCJitCode::genCode() {
class
NCHW16CMulNCCreator
:
public
JitCodeCreator
<
int
>
{
class
NCHW16CMulNCCreator
:
public
JitCodeCreator
<
int
>
{
public:
public:
bool
UseMe
(
const
int
&
attr
)
const
override
{
bool
CanBeUsed
(
const
int
&
attr
)
const
override
{
return
platform
::
MayIUse
(
platform
::
avx512f
);
return
platform
::
MayIUse
(
platform
::
avx512f
);
}
}
size_t
CodeSize
(
const
int
&
d
)
const
override
{
return
256
*
1024
;
}
size_t
CodeSize
(
const
int
&
d
)
const
override
{
return
256
*
1024
;
}
...
@@ -154,7 +154,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
...
@@ -154,7 +154,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
#define DECLARE_BLAS_CREATOR(name) \
#define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
class name##Creator : public JitCodeCreator<int> { \
public: \
public: \
bool
UseMe(const int& attr) const override {
\
bool
CanBeUsed(const int& attr) const override {
\
return platform::MayIUse(platform::avx) && attr <= 1024; \
return platform::MayIUse(platform::avx) && attr <= 1024; \
} \
} \
size_t CodeSize(const int& d) const override { \
size_t CodeSize(const int& d) const override { \
...
...
paddle/fluid/operators/jit/gen/embseqpool.cc
浏览文件 @
45bdd84d
...
@@ -121,7 +121,7 @@ void EmbSeqPoolJitCode::genCode() {
...
@@ -121,7 +121,7 @@ void EmbSeqPoolJitCode::genCode() {
class
EmbSeqPoolCreator
:
public
JitCodeCreator
<
emb_seq_pool_attr_t
>
{
class
EmbSeqPoolCreator
:
public
JitCodeCreator
<
emb_seq_pool_attr_t
>
{
public:
public:
bool
UseMe
(
const
emb_seq_pool_attr_t
&
attr
)
const
override
{
bool
CanBeUsed
(
const
emb_seq_pool_attr_t
&
attr
)
const
override
{
return
platform
::
MayIUse
(
platform
::
avx
)
&&
return
platform
::
MayIUse
(
platform
::
avx
)
&&
attr
.
table_width
%
YMM_FLOAT_BLOCK
==
0
;
attr
.
table_width
%
YMM_FLOAT_BLOCK
==
0
;
}
}
...
...
paddle/fluid/operators/jit/gen/gru.cc
浏览文件 @
45bdd84d
...
@@ -86,7 +86,7 @@ void GRUJitCode::genCode() {
...
@@ -86,7 +86,7 @@ void GRUJitCode::genCode() {
class name##Creator : public JitCodeCreator<gru_attr_t> { \
class name##Creator : public JitCodeCreator<gru_attr_t> { \
public: \
public: \
/* TODO(TJ): enable more */
\
/* TODO(TJ): enable more */
\
bool
UseMe(const gru_attr_t& attr) const override {
\
bool
CanBeUsed(const gru_attr_t& attr) const override {
\
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
} \
size_t CodeSize(const gru_attr_t& attr) const override { \
size_t CodeSize(const gru_attr_t& attr) const override { \
...
...
paddle/fluid/operators/jit/gen/hopv.cc
浏览文件 @
45bdd84d
...
@@ -76,7 +76,7 @@ void HOPVJitCode::genCode() {
...
@@ -76,7 +76,7 @@ void HOPVJitCode::genCode() {
#define DECLARE_HOP_CREATOR(name) \
#define DECLARE_HOP_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
class name##Creator : public JitCodeCreator<int> { \
public: \
public: \
bool
UseMe(const int& attr) const override {
\
bool
CanBeUsed(const int& attr) const override {
\
return platform::MayIUse(platform::avx); \
return platform::MayIUse(platform::avx); \
} \
} \
size_t CodeSize(const int& d) const override { \
size_t CodeSize(const int& d) const override { \
...
...
paddle/fluid/operators/jit/gen/jitcode.h
浏览文件 @
45bdd84d
...
@@ -73,7 +73,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
...
@@ -73,7 +73,7 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
virtual
void
genCode
()
=
0
;
virtual
void
genCode
()
=
0
;
size_t
getSize
()
const
override
{
return
CodeGenerator
::
getSize
();
}
size_t
getSize
()
const
override
{
return
CodeGenerator
::
getSize
();
}
const
unsigned
char
*
getCodeInternal
()
override
{
const
unsigned
char
*
getCodeInternal
()
const
override
{
const
Xbyak
::
uint8
*
code
=
CodeGenerator
::
getCode
();
const
Xbyak
::
uint8
*
code
=
CodeGenerator
::
getCode
();
return
code
;
return
code
;
}
}
...
...
paddle/fluid/operators/jit/gen/lstm.cc
浏览文件 @
45bdd84d
...
@@ -114,7 +114,7 @@ void LSTMJitCode::genCode() {
...
@@ -114,7 +114,7 @@ void LSTMJitCode::genCode() {
class name##Creator : public JitCodeCreator<lstm_attr_t> { \
class name##Creator : public JitCodeCreator<lstm_attr_t> { \
public: \
public: \
/* TODO(TJ): enable more */
\
/* TODO(TJ): enable more */
\
bool
UseMe(const lstm_attr_t& attr) const override {
\
bool
CanBeUsed(const lstm_attr_t& attr) const override {
\
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
} \
size_t CodeSize(const lstm_attr_t& attr) const override { \
size_t CodeSize(const lstm_attr_t& attr) const override { \
...
...
paddle/fluid/operators/jit/gen/matmul.cc
浏览文件 @
45bdd84d
...
@@ -98,7 +98,7 @@ void MatMulJitCode::genCode() {
...
@@ -98,7 +98,7 @@ void MatMulJitCode::genCode() {
class
MatMulCreator
:
public
JitCodeCreator
<
matmul_attr_t
>
{
class
MatMulCreator
:
public
JitCodeCreator
<
matmul_attr_t
>
{
public:
public:
bool
UseMe
(
const
matmul_attr_t
&
attr
)
const
override
{
bool
CanBeUsed
(
const
matmul_attr_t
&
attr
)
const
override
{
return
attr
.
m
==
1
&&
platform
::
MayIUse
(
platform
::
avx512f
)
&&
return
attr
.
m
==
1
&&
platform
::
MayIUse
(
platform
::
avx512f
)
&&
attr
.
n
%
ZMM_FLOAT_BLOCK
==
0
&&
attr
.
k
<
512
;
attr
.
n
%
ZMM_FLOAT_BLOCK
==
0
&&
attr
.
k
<
512
;
}
}
...
...
paddle/fluid/operators/jit/gen/seqpool.cc
浏览文件 @
45bdd84d
...
@@ -57,7 +57,7 @@ void SeqPoolJitCode::genCode() {
...
@@ -57,7 +57,7 @@ void SeqPoolJitCode::genCode() {
class
SeqPoolCreator
:
public
JitCodeCreator
<
seq_pool_attr_t
>
{
class
SeqPoolCreator
:
public
JitCodeCreator
<
seq_pool_attr_t
>
{
public:
public:
bool
UseMe
(
const
seq_pool_attr_t
&
attr
)
const
override
{
bool
CanBeUsed
(
const
seq_pool_attr_t
&
attr
)
const
override
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
size_t
CodeSize
(
const
seq_pool_attr_t
&
attr
)
const
override
{
size_t
CodeSize
(
const
seq_pool_attr_t
&
attr
)
const
override
{
...
...
paddle/fluid/operators/jit/gen/sgd.cc
浏览文件 @
45bdd84d
...
@@ -104,7 +104,7 @@ void SgdJitCode::genCode() {
...
@@ -104,7 +104,7 @@ void SgdJitCode::genCode() {
class
SgdCreator
:
public
JitCodeCreator
<
sgd_attr_t
>
{
class
SgdCreator
:
public
JitCodeCreator
<
sgd_attr_t
>
{
public:
public:
bool
UseMe
(
const
sgd_attr_t
&
attr
)
const
override
{
bool
CanBeUsed
(
const
sgd_attr_t
&
attr
)
const
override
{
return
platform
::
MayIUse
(
platform
::
avx
)
&&
return
platform
::
MayIUse
(
platform
::
avx
)
&&
attr
.
grad_width
%
YMM_FLOAT_BLOCK
==
0
;
attr
.
grad_width
%
YMM_FLOAT_BLOCK
==
0
;
}
}
...
...
paddle/fluid/operators/jit/gen/vbroadcast.cc
浏览文件 @
45bdd84d
...
@@ -69,7 +69,7 @@ void VBroadcastJitCode::genCode() {
...
@@ -69,7 +69,7 @@ void VBroadcastJitCode::genCode() {
class
VBroadcastCreator
:
public
JitCodeCreator
<
int64_t
>
{
class
VBroadcastCreator
:
public
JitCodeCreator
<
int64_t
>
{
public:
public:
bool
UseMe
(
const
int64_t
&
w
)
const
override
{
bool
CanBeUsed
(
const
int64_t
&
w
)
const
override
{
return
platform
::
MayIUse
(
platform
::
avx
)
&&
w
%
YMM_FLOAT_BLOCK
==
0
;
return
platform
::
MayIUse
(
platform
::
avx
)
&&
w
%
YMM_FLOAT_BLOCK
==
0
;
}
}
size_t
CodeSize
(
const
int64_t
&
w
)
const
override
{
size_t
CodeSize
(
const
int64_t
&
w
)
const
override
{
...
...
paddle/fluid/operators/jit/gen_base.cc
浏览文件 @
45bdd84d
...
@@ -31,7 +31,7 @@ namespace paddle {
...
@@ -31,7 +31,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
jit
{
namespace
jit
{
// refer do not need
useme
, it would be the last one.
// refer do not need
CanBeUsed
, it would be the last one.
void
GenBase
::
dumpCode
(
const
unsigned
char
*
code
)
const
{
void
GenBase
::
dumpCode
(
const
unsigned
char
*
code
)
const
{
if
(
code
)
{
if
(
code
)
{
static
int
counter
=
0
;
static
int
counter
=
0
;
...
...
paddle/fluid/operators/jit/gen_base.h
浏览文件 @
45bdd84d
...
@@ -31,9 +31,10 @@ class GenBase : public Kernel {
...
@@ -31,9 +31,10 @@ class GenBase : public Kernel {
virtual
~
GenBase
()
=
default
;
virtual
~
GenBase
()
=
default
;
virtual
std
::
string
name
()
const
=
0
;
virtual
std
::
string
name
()
const
=
0
;
virtual
size_t
getSize
()
const
=
0
;
virtual
size_t
getSize
()
const
=
0
;
virtual
const
unsigned
char
*
getCodeInternal
()
=
0
;
virtual
const
unsigned
char
*
getCodeInternal
()
const
=
0
;
const
char
*
ImplType
()
const
override
{
return
"JitCode"
;
}
template
<
typename
Func
>
template
<
typename
Func
>
Func
getCode
()
{
Func
getCode
()
const
{
const
unsigned
char
*
code
=
this
->
getCodeInternal
();
const
unsigned
char
*
code
=
this
->
getCodeInternal
();
if
(
FLAGS_dump_jitcode
)
{
if
(
FLAGS_dump_jitcode
)
{
this
->
dumpCode
(
code
);
this
->
dumpCode
(
code
);
...
@@ -65,7 +66,7 @@ class JitCodeCreator : public GenCreator {
...
@@ -65,7 +66,7 @@ class JitCodeCreator : public GenCreator {
virtual
~
JitCodeCreator
()
=
default
;
virtual
~
JitCodeCreator
()
=
default
;
// condition when this jit code can be used.
// condition when this jit code can be used.
virtual
bool
UseMe
(
const
Attr
&
attr
)
const
=
0
;
virtual
bool
CanBeUsed
(
const
Attr
&
attr
)
const
=
0
;
// estimate this code size
// estimate this code size
virtual
size_t
CodeSize
(
const
Attr
&
attr
)
const
=
0
;
virtual
size_t
CodeSize
(
const
Attr
&
attr
)
const
=
0
;
...
...
paddle/fluid/operators/jit/helper.h
浏览文件 @
45bdd84d
...
@@ -14,9 +14,6 @@
...
@@ -14,9 +14,6 @@
#pragma once
#pragma once
extern
"C"
{
#include <xxhash.h>
}
#include <iostream>
#include <iostream>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
...
@@ -36,31 +33,30 @@ template <typename KernelTuple, typename PlaceType>
...
@@ -36,31 +33,30 @@ template <typename KernelTuple, typename PlaceType>
inline
typename
std
::
enable_if
<
inline
typename
std
::
enable_if
<
std
::
is_same
<
typename
KernelTuple
::
data_type
,
float
>::
value
&&
std
::
is_same
<
typename
KernelTuple
::
data_type
,
float
>::
value
&&
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
,
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
,
typename
KernelTuple
::
func_type
>::
type
const
Kernel
*
>::
type
GetJitCode
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
GetJitCode
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
using
Func
=
typename
KernelTuple
::
func_type
;
using
Attr
=
typename
KernelTuple
::
attr_type
;
using
Attr
=
typename
KernelTuple
::
attr_type
;
size_t
key
=
JitCodeKey
<
Attr
>
(
attr
);
size_t
key
=
JitCodeKey
<
Attr
>
(
attr
);
auto
&
codes
=
JitCodePool
<
KernelTuple
::
kernel_type
>
().
Instance
();
auto
&
codes
=
JitCodePool
<
KernelTuple
::
kernel_type
>
::
Instance
();
if
(
codes
.
Has
(
key
))
{
if
(
codes
.
Has
(
key
))
{
return
codes
.
AllKernels
().
at
(
key
)
->
template
getCode
<
Func
>
();
return
codes
.
AllKernels
().
at
(
key
)
.
get
();
}
}
// creator is not related with attr, so can use KernelKey as key
// creator is not related with attr, so can use KernelKey as key
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
PlaceType
());
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
PlaceType
());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto
&
creator_map
=
JitCodeCreatorPool
().
Instance
().
AllCreators
();
auto
&
creator_map
=
JitCodeCreatorPool
::
Instance
().
AllCreators
();
auto
iter
=
creator_map
.
find
(
kkey
);
auto
iter
=
creator_map
.
find
(
kkey
);
if
(
iter
!=
creator_map
.
end
())
{
if
(
iter
!=
creator_map
.
end
())
{
auto
&
creators
=
iter
->
second
;
auto
&
creators
=
iter
->
second
;
for
(
auto
&
cur
:
creators
)
{
for
(
auto
&
cur
:
creators
)
{
auto
i
=
dynamic_cast
<
const
JitCodeCreator
<
Attr
>*>
(
cur
.
get
());
auto
i
=
dynamic_cast
<
const
JitCodeCreator
<
Attr
>*>
(
cur
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
if
(
i
&&
i
->
CanBeUsed
(
attr
))
{
auto
p
=
i
->
CreateJitCode
(
attr
);
auto
p
=
i
->
CreateJitCode
(
attr
);
if
(
p
)
{
if
(
p
)
{
auto
f
=
p
->
template
getCode
<
Func
>
();
auto
res
=
p
.
get
();
codes
.
Insert
(
key
,
std
::
move
(
p
));
codes
.
Insert
(
key
,
std
::
move
(
p
));
return
f
;
return
res
;
}
}
}
}
}
}
...
@@ -72,7 +68,7 @@ template <typename KernelTuple, typename PlaceType>
...
@@ -72,7 +68,7 @@ template <typename KernelTuple, typename PlaceType>
inline
typename
std
::
enable_if
<
inline
typename
std
::
enable_if
<
!
std
::
is_same
<
typename
KernelTuple
::
data_type
,
float
>::
value
||
!
std
::
is_same
<
typename
KernelTuple
::
data_type
,
float
>::
value
||
!
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
,
!
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
,
typename
KernelTuple
::
func_type
>::
type
const
Kernel
*
>::
type
GetJitCode
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
GetJitCode
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
return
nullptr
;
return
nullptr
;
}
}
...
@@ -80,8 +76,8 @@ GetJitCode(const typename KernelTuple::attr_type& attr) {
...
@@ -80,8 +76,8 @@ GetJitCode(const typename KernelTuple::attr_type& attr) {
// Refer code do not related with attr, which is just for cast
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
// Refer is always on CPUPlace
template
<
typename
KernelTuple
>
template
<
typename
KernelTuple
>
inline
typename
KernelTuple
::
func_type
GetRefer
()
{
inline
const
Kernel
*
GetReferKernel
()
{
auto
&
ref_pool
=
ReferKernelPool
().
Instance
().
AllKernels
();
auto
&
ref_pool
=
ReferKernelPool
::
Instance
().
AllKernels
();
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
platform
::
CPUPlace
());
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
platform
::
CPUPlace
());
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
PADDLE_ENFORCE
(
ref_iter
!=
ref_pool
.
end
(),
PADDLE_ENFORCE
(
ref_iter
!=
ref_pool
.
end
(),
...
@@ -90,36 +86,93 @@ inline typename KernelTuple::func_type GetRefer() {
...
@@ -90,36 +86,93 @@ inline typename KernelTuple::func_type GetRefer() {
for
(
auto
&
impl
:
ref_impls
)
{
for
(
auto
&
impl
:
ref_impls
)
{
auto
i
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
impl
.
get
());
auto
i
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
impl
.
get
());
if
(
i
)
{
if
(
i
)
{
return
i
->
GetFunc
()
;
return
i
;
}
}
}
}
return
nullptr
;
return
nullptr
;
}
}
template
<
typename
KernelTuple
,
typename
PlaceType
=
platform
::
CPUPlace
>
template
<
typename
KernelTuple
>
typename
KernelTuple
::
func_type
Get
(
inline
typename
KernelTuple
::
func_type
GetReferFunc
()
{
auto
ker
=
GetReferKernel
<
KernelTuple
>
();
auto
p
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
ker
);
PADDLE_ENFORCE
(
p
,
"The Refer kernel should exsit"
);
return
p
->
GetFunc
();
}
// Return all Kernels that can be used
template
<
typename
KernelTuple
,
typename
PlaceType
>
std
::
vector
<
const
Kernel
*>
GetAllCandidateKernels
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
const
typename
KernelTuple
::
attr_type
&
attr
)
{
auto
jitfunc
=
GetJitCode
<
KernelTuple
,
PlaceType
>
(
attr
);
// the search order shoudl be jitcode > more > refer
if
(
jitfunc
)
{
std
::
vector
<
const
Kernel
*>
res
;
return
jitfunc
;
auto
jitker
=
GetJitCode
<
KernelTuple
,
PlaceType
>
(
attr
);
if
(
jitker
)
{
res
.
emplace_back
(
jitker
);
}
}
// pool: (KernelKey(type, place), vector<KernelPtr>)
//
more kernel
pool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
PlaceType
());
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
PlaceType
());
auto
&
pool
=
KernelPool
().
Instance
().
AllKernels
();
auto
&
pool
=
KernelPool
::
Instance
().
AllKernels
();
auto
iter
=
pool
.
find
(
kkey
);
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
KernelMore
<
KernelTuple
>*>
(
impl
.
get
());
auto
i
=
dynamic_cast
<
const
KernelMore
<
KernelTuple
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
if
(
i
&&
i
->
CanBeUsed
(
attr
))
{
re
turn
i
->
GetFunc
(
);
re
s
.
emplace_back
(
i
);
}
}
}
}
}
}
// The last implementation should be reference function on CPUPlace.
// The last implementation should be reference function on CPUPlace.
return
GetRefer
<
KernelTuple
>
();
auto
ref
=
GetReferKernel
<
KernelTuple
>
();
PADDLE_ENFORCE
(
ref
!=
nullptr
,
"Refer Kernel can not be empty."
);
res
.
emplace_back
(
ref
);
return
res
;
}
template
<
typename
KernelTuple
,
typename
PlaceType
=
platform
::
CPUPlace
>
std
::
vector
<
std
::
pair
<
std
::
string
,
typename
KernelTuple
::
func_type
>>
GetAllCandidateFuncsWithTypes
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
using
Func
=
typename
KernelTuple
::
func_type
;
auto
kers
=
GetAllCandidateKernels
<
KernelTuple
,
PlaceType
>
(
attr
);
std
::
vector
<
std
::
pair
<
std
::
string
,
Func
>>
res
;
for
(
auto
k
:
kers
)
{
std
::
string
name
=
k
->
ImplType
();
if
(
name
==
"JitCode"
)
{
auto
i
=
dynamic_cast
<
const
GenBase
*>
(
k
);
PADDLE_ENFORCE
(
i
,
"jitcode kernel cast can not fail."
);
res
.
emplace_back
(
std
::
make_pair
(
name
,
i
->
template
getCode
<
Func
>()));
}
else
{
auto
i
=
dynamic_cast
<
const
KernelMore
<
KernelTuple
>*>
(
k
);
PADDLE_ENFORCE
(
i
,
"kernel cast can not fail."
);
res
.
emplace_back
(
std
::
make_pair
(
name
,
i
->
GetFunc
()));
}
}
return
res
;
}
template
<
typename
KernelTuple
,
typename
PlaceType
=
platform
::
CPUPlace
>
std
::
vector
<
typename
KernelTuple
::
func_type
>
GetAllCandidateFuncs
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
auto
funcs
=
GetAllCandidateFuncsWithTypes
<
KernelTuple
,
PlaceType
>
(
attr
);
std
::
vector
<
typename
KernelTuple
::
func_type
>
res
;
for
(
auto
&
i
:
funcs
)
{
res
.
emplace_back
(
i
.
second
);
}
return
res
;
}
template
<
typename
KernelTuple
,
typename
PlaceType
=
platform
::
CPUPlace
>
typename
KernelTuple
::
func_type
GetDefaultBestFunc
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
auto
funcs
=
GetAllCandidateFuncs
<
KernelTuple
,
PlaceType
>
(
attr
);
PADDLE_ENFORCE_GE
(
funcs
.
size
(),
1UL
);
// Here could do some runtime benchmark of this attr and return the best one.
// But yet just get the first one as the default best one,
// which is searched in order and tuned by offline.
return
funcs
[
0
];
}
}
template
<
typename
KernelTuple
,
typename
PlaceType
>
template
<
typename
KernelTuple
,
typename
PlaceType
>
...
@@ -134,17 +187,13 @@ class KernelFuncs {
...
@@ -134,17 +187,13 @@ class KernelFuncs {
// the exposed interface to use
// the exposed interface to use
typename
KernelTuple
::
func_type
At
(
typename
KernelTuple
::
func_type
At
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
const
typename
KernelTuple
::
attr_type
&
attr
)
{
// XXH64: 13.8 GB/s
// Maybe here is not good enough, not all kernels should have jitcode
// TODO(TJ): change me, maybe not all attr change need one key, should be
int64_t
key
=
JitCodeKey
<
typename
KernelTuple
::
attr_type
>
(
attr
);
// attrkey
int64_t
key
=
XXH64
(
&
attr
,
sizeof
(
typename
KernelTuple
::
attr_type
),
0
);
if
(
Has
(
key
))
{
if
(
Has
(
key
))
{
return
funcs_
.
at
(
key
);
return
funcs_
.
at
(
key
);
}
}
// If do not have this attr in cache,
// If do not have this attr in cache then get the default best
// then could run some runtime benchmark of this attr and save the best one.
auto
func
=
GetDefaultBestFunc
<
KernelTuple
,
PlaceType
>
(
attr
);
// Here just get the offline benchmarked best one.
auto
func
=
Get
<
KernelTuple
,
PlaceType
>
(
attr
);
Insert
(
key
,
func
);
Insert
(
key
,
func
);
return
func
;
return
func
;
}
}
...
@@ -156,7 +205,6 @@ class KernelFuncs {
...
@@ -156,7 +205,6 @@ class KernelFuncs {
protected:
protected:
bool
Has
(
int64_t
key
)
const
{
return
funcs_
.
find
(
key
)
!=
funcs_
.
end
();
}
bool
Has
(
int64_t
key
)
const
{
return
funcs_
.
find
(
key
)
!=
funcs_
.
end
();
}
void
Insert
(
int64_t
key
,
typename
KernelTuple
::
func_type
func
)
{
void
Insert
(
int64_t
key
,
typename
KernelTuple
::
func_type
func
)
{
funcs_
.
emplace
(
key
,
func
);
funcs_
.
emplace
(
key
,
func
);
}
}
...
...
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
45bdd84d
...
@@ -302,6 +302,7 @@ class Kernel {
...
@@ -302,6 +302,7 @@ class Kernel {
public:
public:
Kernel
()
=
default
;
Kernel
()
=
default
;
virtual
~
Kernel
()
=
default
;
virtual
~
Kernel
()
=
default
;
virtual
const
char
*
ImplType
()
const
=
0
;
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
};
};
...
@@ -312,8 +313,8 @@ class KernelMore : public Kernel {
...
@@ -312,8 +313,8 @@ class KernelMore : public Kernel {
using
Func
=
typename
KernelTuple
::
func_type
;
using
Func
=
typename
KernelTuple
::
func_type
;
using
Attr
=
typename
KernelTuple
::
attr_type
;
using
Attr
=
typename
KernelTuple
::
attr_type
;
virtual
Func
GetFunc
()
const
{
return
func
;
}
virtual
Func
GetFunc
()
const
{
return
func
;
}
virtual
bool
UseMe
(
const
Attr
&
attr
)
const
=
0
;
// specify this kernel can be used, means it should not fail if use it.
virtual
const
char
*
ImplType
(
)
const
=
0
;
virtual
bool
CanBeUsed
(
const
Attr
&
attr
)
const
=
0
;
protected:
protected:
Func
func
{
nullptr
};
Func
func
{
nullptr
};
...
@@ -323,7 +324,7 @@ template <typename KernelTuple>
...
@@ -323,7 +324,7 @@ template <typename KernelTuple>
class
ReferKernel
:
public
KernelMore
<
KernelTuple
>
{
class
ReferKernel
:
public
KernelMore
<
KernelTuple
>
{
public:
public:
// Refer code can always be used
// Refer code can always be used
bool
UseMe
(
const
typename
KernelTuple
::
attr_type
&
attr
)
const
override
{
bool
CanBeUsed
(
const
typename
KernelTuple
::
attr_type
&
attr
)
const
override
{
return
true
;
return
true
;
}
}
const
char
*
ImplType
()
const
override
{
return
"Refer"
;
}
const
char
*
ImplType
()
const
override
{
return
"Refer"
;
}
...
...
paddle/fluid/operators/jit/kernel_key.cc
浏览文件 @
45bdd84d
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
* limitations under the License. */
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/operators/jit/kernel_key.h"
#include <xxhash.h>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -49,6 +50,8 @@ static inline int act_type_convert(KernelType type) {
...
@@ -49,6 +50,8 @@ static inline int act_type_convert(KernelType type) {
template
<
>
template
<
>
size_t
JitCodeKey
<
lstm_attr_t
>
(
const
lstm_attr_t
&
attr
)
{
size_t
JitCodeKey
<
lstm_attr_t
>
(
const
lstm_attr_t
&
attr
)
{
// XXH64: 13.8 GB/s
size_t
key
=
attr
.
d
;
size_t
key
=
attr
.
d
;
int
gate_key
=
act_type_convert
(
attr
.
act_gate
)
<<
1
;
int
gate_key
=
act_type_convert
(
attr
.
act_gate
)
<<
1
;
int
cand_key
=
act_type_convert
(
attr
.
act_cand
)
<<
(
1
+
act_type_shift
);
int
cand_key
=
act_type_convert
(
attr
.
act_cand
)
<<
(
1
+
act_type_shift
);
...
...
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc
浏览文件 @
45bdd84d
...
@@ -161,7 +161,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
...
@@ -161,7 +161,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
}
}
}
}
bool
CRFDecodingKernel
::
UseMe
(
const
int
&
d
)
const
{
bool
CRFDecodingKernel
::
CanBeUsed
(
const
int
&
d
)
const
{
#ifdef __AVX512F__
#ifdef __AVX512F__
constexpr
int
block
=
ZMM_FLOAT_BLOCK
;
constexpr
int
block
=
ZMM_FLOAT_BLOCK
;
#else
#else
...
...
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h
浏览文件 @
45bdd84d
...
@@ -29,7 +29,8 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
...
@@ -29,7 +29,8 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
class
CRFDecodingKernel
:
public
KernelMore
<
CRFDecodingTuple
<
float
>>
{
class
CRFDecodingKernel
:
public
KernelMore
<
CRFDecodingTuple
<
float
>>
{
public:
public:
CRFDecodingKernel
()
{
this
->
func
=
CRFDecoding
;
}
CRFDecodingKernel
()
{
this
->
func
=
CRFDecoding
;
}
bool
UseMe
(
const
typename
CRFDecodingTuple
<
float
>::
attr_type
&
)
const
override
;
bool
CanBeUsed
(
const
typename
CRFDecodingTuple
<
float
>::
attr_type
&
)
const
override
;
const
char
*
ImplType
()
const
override
{
return
"Intrinsic"
;
}
const
char
*
ImplType
()
const
override
{
return
"Intrinsic"
;
}
};
};
...
...
paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
浏览文件 @
45bdd84d
...
@@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
...
@@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
}
}
}
}
bool
LayerNormKernel
::
UseMe
(
const
int
&
d
)
const
{
bool
LayerNormKernel
::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
>=
YMM_FLOAT_BLOCK
;
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
>=
YMM_FLOAT_BLOCK
;
}
}
...
...
paddle/fluid/operators/jit/more/intrinsic/layer_norm.h
浏览文件 @
45bdd84d
...
@@ -30,7 +30,8 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
...
@@ -30,7 +30,8 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
class
LayerNormKernel
:
public
KernelMore
<
LayerNormTuple
<
float
>>
{
class
LayerNormKernel
:
public
KernelMore
<
LayerNormTuple
<
float
>>
{
public:
public:
LayerNormKernel
()
{
this
->
func
=
LayerNorm
;
}
LayerNormKernel
()
{
this
->
func
=
LayerNorm
;
}
bool
UseMe
(
const
typename
LayerNormTuple
<
float
>::
attr_type
&
)
const
override
;
bool
CanBeUsed
(
const
typename
LayerNormTuple
<
float
>::
attr_type
&
)
const
override
;
const
char
*
ImplType
()
const
override
{
return
"Intrinsic"
;
}
const
char
*
ImplType
()
const
override
{
return
"Intrinsic"
;
}
};
};
...
...
paddle/fluid/operators/jit/more/mix/mix.cc
浏览文件 @
45bdd84d
...
@@ -204,21 +204,21 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
...
@@ -204,21 +204,21 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
}
}
// TODO(TJ): tuning me
// TODO(TJ): tuning me
bool
VSigmoidKernel
::
UseMe
(
const
int
&
d
)
const
{
return
true
;
}
bool
VSigmoidKernel
::
CanBeUsed
(
const
int
&
d
)
const
{
return
true
;
}
bool
VTanhKernel
::
UseMe
(
const
int
&
d
)
const
{
return
true
;
}
bool
VTanhKernel
::
CanBeUsed
(
const
int
&
d
)
const
{
return
true
;
}
bool
SoftmaxKernel
::
UseMe
(
const
int
&
d
)
const
{
return
true
;
}
bool
SoftmaxKernel
::
CanBeUsed
(
const
int
&
d
)
const
{
return
true
;
}
bool
LSTMCtHtKernel
::
UseMe
(
const
lstm_attr_t
&
attr
)
const
{
return
true
;
}
bool
LSTMCtHtKernel
::
CanBeUsed
(
const
lstm_attr_t
&
attr
)
const
{
return
true
;
}
bool
LSTMC1H1Kernel
::
UseMe
(
const
lstm_attr_t
&
attr
)
const
{
return
true
;
}
bool
LSTMC1H1Kernel
::
CanBeUsed
(
const
lstm_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUH1Kernel
::
UseMe
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUH1Kernel
::
CanBeUsed
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUHtPart1Kernel
::
UseMe
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUHtPart1Kernel
::
CanBeUsed
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUHtPart2Kernel
::
UseMe
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUHtPart2Kernel
::
CanBeUsed
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
}
// namespace mix
}
// namespace mix
}
// namespace more
}
// namespace more
...
...
paddle/fluid/operators/jit/more/mix/mix.h
浏览文件 @
45bdd84d
...
@@ -38,7 +38,7 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr);
...
@@ -38,7 +38,7 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr);
class name##Kernel : public KernelMore<name##Tuple<T>> { \
class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \
public: \
name##Kernel() { this->func = name; } \
name##Kernel() { this->func = name; } \
bool
UseMe
(const typename name##Tuple<T>::attr_type&) const override; \
bool
CanBeUsed
(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \
const char* ImplType() const override { return "Mixed"; } \
}
}
...
...
paddle/fluid/operators/jit/more/mkl/mkl.cc
浏览文件 @
45bdd84d
...
@@ -130,104 +130,105 @@ void ASum<double>(const double* x, double* res, int n) {
...
@@ -130,104 +130,105 @@ void ASum<double>(const double* x, double* res, int n) {
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template
<
>
template
<
>
bool
VMulKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VMulKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
}
template
<
>
template
<
>
bool
VAddKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VAddKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
>
512
;
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
>
512
;
}
}
template
<
>
template
<
>
bool
VScalKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VScalKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
}
template
<
>
template
<
>
bool
VExpKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VExpKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
d
>
7
;
return
d
>
7
;
}
}
template
<
>
template
<
>
bool
VSquareKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VSquareKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
d
>
7
;
return
d
>
7
;
}
}
template
<
>
template
<
>
bool
VCopyKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VCopyKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
d
>
15
;
return
d
>
15
;
}
}
template
<
>
template
<
>
bool
VBroadcastKernel
<
float
>::
UseMe
(
const
int64_t
&
d
)
const
{
bool
VBroadcastKernel
<
float
>::
CanBeUsed
(
const
int64_t
&
d
)
const
{
return
d
>
127
;
return
d
>
127
;
}
}
template
<
>
template
<
>
bool
VBroadcastKernel
<
double
>::
UseMe
(
const
int64_t
&
attr
)
const
{
bool
VBroadcastKernel
<
double
>::
CanBeUsed
(
const
int64_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
VSigmoidKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VSigmoidKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
d
>
7
;
return
d
>
7
;
}
}
template
<
>
template
<
>
bool
VTanhKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
VTanhKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
return
d
>
7
;
return
d
>
7
;
}
}
template
<
>
template
<
>
bool
SeqPoolKernel
<
float
>::
UseMe
(
const
seq_pool_attr_t
&
attr
)
const
{
bool
SeqPoolKernel
<
float
>::
CanBeUsed
(
const
seq_pool_attr_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
SeqPoolKernel
<
double
>::
UseMe
(
const
seq_pool_attr_t
&
attr
)
const
{
bool
SeqPoolKernel
<
double
>::
CanBeUsed
(
const
seq_pool_attr_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
EmbSeqPoolKernel
<
float
>::
UseMe
(
const
emb_seq_pool_attr_t
&
attr
)
const
{
bool
EmbSeqPoolKernel
<
float
>::
CanBeUsed
(
const
emb_seq_pool_attr_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
EmbSeqPoolKernel
<
double
>::
UseMe
(
const
emb_seq_pool_attr_t
&
attr
)
const
{
bool
EmbSeqPoolKernel
<
double
>::
CanBeUsed
(
const
emb_seq_pool_attr_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
SgdKernel
<
float
>::
UseMe
(
const
sgd_attr_t
&
attr
)
const
{
bool
SgdKernel
<
float
>::
CanBeUsed
(
const
sgd_attr_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
SgdKernel
<
double
>::
UseMe
(
const
sgd_attr_t
&
attr
)
const
{
bool
SgdKernel
<
double
>::
CanBeUsed
(
const
sgd_attr_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
MatMulKernel
<
float
>::
UseMe
(
const
matmul_attr_t
&
attr
)
const
{
bool
MatMulKernel
<
float
>::
CanBeUsed
(
const
matmul_attr_t
&
attr
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
template
<
>
template
<
>
bool
MatMulKernel
<
double
>::
UseMe
(
const
matmul_attr_t
&
attr
)
const
{
bool
MatMulKernel
<
double
>::
CanBeUsed
(
const
matmul_attr_t
&
attr
)
const
{
return
true
;
return
true
;
}
}
template
<
>
template
<
>
bool
SoftmaxKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
bool
SoftmaxKernel
<
float
>::
CanBeUsed
(
const
int
&
d
)
const
{
// tuned on avx2
// tuned on avx2
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
<
60
;
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
<
60
;
}
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \
template <> \
bool func##Kernel<double>::
UseMe
(const int& d) const { \
bool func##Kernel<double>::
CanBeUsed
(const int& d) const { \
return true; \
return true; \
}
}
...
...
paddle/fluid/operators/jit/more/mkl/mkl.h
浏览文件 @
45bdd84d
...
@@ -180,7 +180,7 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
...
@@ -180,7 +180,7 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
class name##Kernel : public KernelMore<name##Tuple<T>> { \
class name##Kernel : public KernelMore<name##Tuple<T>> { \
public: \
public: \
name##Kernel() { this->func = name<T>; } \
name##Kernel() { this->func = name<T>; } \
bool
UseMe
(const typename name##Tuple<T>::attr_type&) const override; \
bool
CanBeUsed
(const typename name##Tuple<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \
const char* ImplType() const override { return "MKL"; } \
}
}
...
...
paddle/fluid/operators/jit/registry.h
浏览文件 @
45bdd84d
...
@@ -49,7 +49,7 @@ struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
...
@@ -49,7 +49,7 @@ struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
void
operator
()(
KernelType
kt
)
const
{
void
operator
()(
KernelType
kt
)
const
{
KernelKey
kkey
(
kt
,
PlaceType
());
KernelKey
kkey
(
kt
,
PlaceType
());
Pool
().
Instance
().
Insert
(
kkey
,
Pool
::
Instance
().
Insert
(
kkey
,
std
::
move
(
make_unique
<
const
KERNEL_IMPL_TYPE
>
()));
std
::
move
(
make_unique
<
const
KERNEL_IMPL_TYPE
>
()));
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelImpls
...
>>::
value
;
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelImpls
...
>>::
value
;
JitKernelRegistrarFunctor
<
Pool
,
PlaceType
,
I
+
1
==
size
,
I
+
1
,
JitKernelRegistrarFunctor
<
Pool
,
PlaceType
,
I
+
1
==
size
,
I
+
1
,
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
45bdd84d
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <algorithm>
#include <algorithm>
#include <iostream>
#include <random>
#include <random>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -68,31 +69,11 @@ template <typename KernelTuple, typename PlaceType, typename Tester,
...
@@ -68,31 +69,11 @@ template <typename KernelTuple, typename PlaceType, typename Tester,
typename
...
Args
>
typename
...
Args
>
void
TestAllImpls
(
const
typename
KernelTuple
::
attr_type
&
attr
,
void
TestAllImpls
(
const
typename
KernelTuple
::
attr_type
&
attr
,
const
Tester
&
verifier
,
const
Args
&
...
args
)
{
const
Tester
&
verifier
,
const
Args
&
...
args
)
{
// test jitcode
auto
funcs
=
jit
::
GetAllCandidateFuncsWithTypes
<
KernelTuple
,
PlaceType
>
(
attr
);
auto
jitcode
=
jit
::
GetJitCode
<
KernelTuple
,
PlaceType
>
(
attr
);
for
(
auto
f
:
funcs
)
{
if
(
jitcode
)
{
VLOG
(
10
)
<<
"Test Kernel "
<<
f
.
first
;
VLOG
(
10
)
<<
"Test Jitcode Kernel "
;
verifier
(
f
.
second
,
args
...);
verifier
(
jitcode
,
args
...);
}
}
// test all impls in more
jit
::
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
PlaceType
());
auto
&
pool
=
jit
::
KernelPool
().
Instance
().
AllKernels
();
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
jit
::
KernelMore
<
KernelTuple
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
auto
more
=
i
->
GetFunc
();
VLOG
(
10
)
<<
"Test More Kernel : "
<<
i
->
ImplType
();
verifier
(
more
,
args
...);
}
}
}
// test result from Get function
VLOG
(
10
)
<<
"Test final get function "
;
auto
tgt
=
jit
::
KernelFuncs
<
KernelTuple
,
PlaceType
>::
Cache
().
At
(
attr
);
verifier
(
tgt
,
args
...);
}
}
template
<
typename
KernelTuple
,
typename
PlaceType
>
template
<
typename
KernelTuple
,
typename
PlaceType
>
...
@@ -100,7 +81,7 @@ void TestKernelXYZN() {
...
@@ -100,7 +81,7 @@ void TestKernelXYZN() {
using
T
=
typename
KernelTuple
::
data_type
;
using
T
=
typename
KernelTuple
::
data_type
;
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
for
(
int
d
:
TestSizes
())
{
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
d
),
y
(
d
),
zref
(
d
);
std
::
vector
<
T
>
x
(
d
),
y
(
d
),
zref
(
d
);
...
@@ -159,7 +140,7 @@ void TestKernelAXYN() {
...
@@ -159,7 +140,7 @@ void TestKernelAXYN() {
using
T
=
typename
KernelTuple
::
data_type
;
using
T
=
typename
KernelTuple
::
data_type
;
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
for
(
int
d
:
TestSizes
())
{
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
const
T
a
=
static_cast
<
T
>
(
3
);
const
T
a
=
static_cast
<
T
>
(
3
);
...
@@ -202,7 +183,7 @@ void TestKernelXYN() {
...
@@ -202,7 +183,7 @@ void TestKernelXYN() {
using
T
=
typename
KernelTuple
::
data_type
;
using
T
=
typename
KernelTuple
::
data_type
;
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
for
(
int
d
:
TestSizes
())
{
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
d
),
yref
(
d
);
std
::
vector
<
T
>
x
(
d
),
yref
(
d
);
...
@@ -245,7 +226,7 @@ void TestKernelXRN() {
...
@@ -245,7 +226,7 @@ void TestKernelXRN() {
auto
last_acc
=
FLAGS_acc
;
auto
last_acc
=
FLAGS_acc
;
FLAGS_acc
=
1e-4
;
FLAGS_acc
=
1e-4
;
for
(
int
d
:
TestSizes
())
{
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
d
);
std
::
vector
<
T
>
x
(
d
);
RandomVec
<
T
>
(
d
,
x
.
data
());
RandomVec
<
T
>
(
d
,
x
.
data
());
...
@@ -279,7 +260,7 @@ void TestKernelLSTM() {
...
@@ -279,7 +260,7 @@ void TestKernelLSTM() {
const
jit
::
lstm_attr_t
attr
(
const
jit
::
lstm_attr_t
attr
(
d
,
jit
::
to_kerneltype
(
act_gate
),
jit
::
to_kerneltype
(
act_cand
),
d
,
jit
::
to_kerneltype
(
act_gate
),
jit
::
to_kerneltype
(
act_cand
),
jit
::
to_kerneltype
(
act_cell
),
use_peephole
);
jit
::
to_kerneltype
(
act_cell
),
use_peephole
);
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
xsrc
(
4
*
d
),
wp
(
3
*
d
),
ct_1
(
d
);
std
::
vector
<
T
>
xsrc
(
4
*
d
),
wp
(
3
*
d
),
ct_1
(
d
);
std
::
vector
<
T
>
ct_ref
(
d
),
ht_ref
(
d
),
checked
(
2
*
d
);
std
::
vector
<
T
>
ct_ref
(
d
),
ht_ref
(
d
),
checked
(
2
*
d
);
...
@@ -370,7 +351,7 @@ void TestKernelGRU() {
...
@@ -370,7 +351,7 @@ void TestKernelGRU() {
for
(
auto
&
act_cand
:
all_acts
)
{
for
(
auto
&
act_cand
:
all_acts
)
{
const
jit
::
gru_attr_t
attr
(
d
,
jit
::
to_kerneltype
(
act_gate
),
const
jit
::
gru_attr_t
attr
(
d
,
jit
::
to_kerneltype
(
act_gate
),
jit
::
to_kerneltype
(
act_cand
));
jit
::
to_kerneltype
(
act_cand
));
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
xsrc
(
3
*
d
),
ht_1
(
d
),
ht_ref
(
d
);
std
::
vector
<
T
>
xsrc
(
3
*
d
),
ht_1
(
d
),
ht_ref
(
d
);
RandomVec
<
T
>
(
3
*
d
,
xsrc
.
data
());
RandomVec
<
T
>
(
3
*
d
,
xsrc
.
data
());
...
@@ -423,7 +404,7 @@ void TestKernelNCHW16CMulNC() {
...
@@ -423,7 +404,7 @@ void TestKernelNCHW16CMulNC() {
using
T
=
typename
KernelTuple
::
data_type
;
using
T
=
typename
KernelTuple
::
data_type
;
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
const
int
n
=
3
,
c
=
16
*
4
,
h
=
10
,
w
=
10
;
const
int
n
=
3
,
c
=
16
*
4
,
h
=
10
,
w
=
10
;
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
int
sz
=
n
*
c
*
h
*
w
;
int
sz
=
n
*
c
*
h
*
w
;
std
::
vector
<
T
>
x
(
sz
),
y
(
n
*
c
),
zref
(
sz
);
std
::
vector
<
T
>
x
(
sz
),
y
(
n
*
c
),
zref
(
sz
);
...
@@ -439,7 +420,9 @@ void TestKernelNCHW16CMulNC() {
...
@@ -439,7 +420,9 @@ void TestKernelNCHW16CMulNC() {
constexpr
int
simd_width
=
ZMM_FLOAT_BLOCK
;
constexpr
int
simd_width
=
ZMM_FLOAT_BLOCK
;
int
C
=
c
/
simd_width
;
int
C
=
c
/
simd_width
;
auto
tgt
=
jit
::
KernelFuncs
<
KernelTuple
,
PlaceType
>::
Cache
().
At
(
0
);
auto
tgt
=
jit
::
KernelFuncs
<
KernelTuple
,
PlaceType
>::
Cache
().
At
(
0
);
auto
jitcode
=
jit
::
GetJitCode
<
KernelTuple
,
PlaceType
>
(
0
);
auto
funcs
=
jit
::
GetAllCandidateFuncs
<
KernelTuple
,
PlaceType
>
(
0
);
EXPECT_GT
(
funcs
.
size
(),
0UL
);
auto
jitcode
=
funcs
[
0
];
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_TRUE
(
tgt
!=
nullptr
);
if
(
std
::
is_same
<
T
,
float
>::
value
&&
if
(
std
::
is_same
<
T
,
float
>::
value
&&
...
@@ -482,7 +465,7 @@ void TestKernelLayerNorm() {
...
@@ -482,7 +465,7 @@ void TestKernelLayerNorm() {
int
left
=
n
*
x_dim_0
;
int
left
=
n
*
x_dim_0
;
for
(
int
x_dim_1
:
TestSizes
())
{
for
(
int
x_dim_1
:
TestSizes
())
{
int
right
=
x_dim_1
;
int
right
=
x_dim_1
;
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
int
sz
=
left
*
right
;
int
sz
=
left
*
right
;
std
::
vector
<
T
>
x
(
sz
),
mean
(
left
),
var
(
left
),
scale
(
right
),
bias
(
right
),
std
::
vector
<
T
>
x
(
sz
),
mean
(
left
),
var
(
left
),
scale
(
right
),
bias
(
right
),
...
@@ -555,7 +538,7 @@ void TestKernelCRFDecoding() {
...
@@ -555,7 +538,7 @@ void TestKernelCRFDecoding() {
test_sizes
.
erase
(
std
::
remove
(
test_sizes
.
begin
(),
test_sizes
.
end
(),
2000
));
test_sizes
.
erase
(
std
::
remove
(
test_sizes
.
begin
(),
test_sizes
.
end
(),
2000
));
for
(
int
seq_len
:
{
1
,
11
,
17
,
50
})
{
for
(
int
seq_len
:
{
1
,
11
,
17
,
50
})
{
for
(
int
tag_num
:
test_sizes
)
{
for
(
int
tag_num
:
test_sizes
)
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
int
x_sz
=
seq_len
*
tag_num
;
int
x_sz
=
seq_len
*
tag_num
;
int
w_sz
=
(
tag_num
+
state_trans_base_idx
)
*
tag_num
;
int
w_sz
=
(
tag_num
+
state_trans_base_idx
)
*
tag_num
;
...
@@ -606,7 +589,7 @@ void TestKernelSeqPool() {
...
@@ -606,7 +589,7 @@ void TestKernelSeqPool() {
jit
::
seq_pool_attr_t
attr
(
w
,
type
);
jit
::
seq_pool_attr_t
attr
(
w
,
type
);
for
(
int
h
:
test_sizes
)
{
for
(
int
h
:
test_sizes
)
{
attr
.
h
=
h
;
attr
.
h
=
h
;
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
h
*
w
),
yref
(
w
);
std
::
vector
<
T
>
x
(
h
*
w
),
yref
(
w
);
RandomVec
<
T
>
(
h
*
w
,
x
.
data
());
RandomVec
<
T
>
(
h
*
w
,
x
.
data
());
...
@@ -649,7 +632,7 @@ void TestKernelEmbSeqPool() {
...
@@ -649,7 +632,7 @@ void TestKernelEmbSeqPool() {
for
(
auto
type
:
pool_types
)
{
for
(
auto
type
:
pool_types
)
{
for
(
int
idx_w
:
{
1
,
2
,
10
,
16
})
{
for
(
int
idx_w
:
{
1
,
2
,
10
,
16
})
{
for
(
int
idx_h
:
{
1
,
2
,
9
,
13
,
16
})
{
for
(
int
idx_h
:
{
1
,
2
,
9
,
13
,
16
})
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
int64_t
>
idx
(
idx_h
*
idx_w
);
std
::
vector
<
int64_t
>
idx
(
idx_h
*
idx_w
);
RandomVec
<
int64_t
>
(
idx_h
*
idx_w
,
idx
.
data
(),
0
,
tbl_h
-
1
);
RandomVec
<
int64_t
>
(
idx_h
*
idx_w
,
idx
.
data
(),
0
,
tbl_h
-
1
);
...
@@ -701,7 +684,7 @@ void TestKernelMatMul() {
...
@@ -701,7 +684,7 @@ void TestKernelMatMul() {
for
(
int
m
:
{
1
,
2
,
3
,
4
})
{
for
(
int
m
:
{
1
,
2
,
3
,
4
})
{
for
(
int
n
:
{
1
,
2
,
3
,
4
})
{
for
(
int
n
:
{
1
,
2
,
3
,
4
})
{
for
(
int
k
:
TestSizes
())
{
for
(
int
k
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
a
(
m
*
k
),
b
(
k
*
n
),
c
(
m
*
n
);
std
::
vector
<
T
>
a
(
m
*
k
),
b
(
k
*
n
),
c
(
m
*
n
);
RandomVec
<
T
>
(
m
*
k
,
a
.
data
());
RandomVec
<
T
>
(
m
*
k
,
a
.
data
());
...
@@ -740,7 +723,7 @@ void TestKernelSoftmax() {
...
@@ -740,7 +723,7 @@ void TestKernelSoftmax() {
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
VLOG
(
10
)
<<
"Test JITKernel: "
<<
jit
::
to_string
(
KernelTuple
::
kernel_type
);
for
(
int
bs
:
{
1
,
2
,
10
})
{
for
(
int
bs
:
{
1
,
2
,
10
})
{
for
(
int
n
:
TestSizes
())
{
for
(
int
n
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
bs
*
n
),
y
(
bs
*
n
);
std
::
vector
<
T
>
x
(
bs
*
n
),
y
(
bs
*
n
);
RandomVec
<
T
>
(
bs
*
n
,
x
.
data
());
RandomVec
<
T
>
(
bs
*
n
,
x
.
data
());
...
@@ -808,7 +791,7 @@ void TestKernelSgd() {
...
@@ -808,7 +791,7 @@ void TestKernelSgd() {
RandomVec
<
T
>
(
rows_size
*
grad_w
,
grad
.
data
());
RandomVec
<
T
>
(
rows_size
*
grad_w
,
grad
.
data
());
const
int64_t
*
rows_data
=
rows
.
data
();
const
int64_t
*
rows_data
=
rows
.
data
();
const
T
*
grad_data
=
grad
.
data
();
const
T
*
grad_data
=
grad
.
data
();
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
jit
::
sgd_attr_t
attr
(
param_h
,
grad_w
,
rows_size
,
grad_w
,
rows_size
);
jit
::
sgd_attr_t
attr
(
param_h
,
grad_w
,
rows_size
,
grad_w
,
rows_size
);
ref
(
&
lr
,
param_data
,
grad_data
,
rows_data
,
out_data
,
&
attr
);
ref
(
&
lr
,
param_data
,
grad_data
,
rows_data
,
out_data
,
&
attr
);
...
@@ -874,7 +857,7 @@ void TestKernelVBroadcast() {
...
@@ -874,7 +857,7 @@ void TestKernelVBroadcast() {
RandomVec
<
T
>
(
w
,
x
.
data
());
RandomVec
<
T
>
(
w
,
x
.
data
());
const
T
*
x_data
=
x
.
data
();
const
T
*
x_data
=
x
.
data
();
for
(
int64_t
h
:
{
1
,
2
,
6
})
{
for
(
int64_t
h
:
{
1
,
2
,
6
})
{
auto
ref
=
jit
::
GetRefer
<
KernelTuple
>
();
auto
ref
=
jit
::
GetRefer
Func
<
KernelTuple
>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
y
(
w
*
h
);
std
::
vector
<
T
>
y
(
w
*
h
);
T
*
y_data
=
y
.
data
();
T
*
y_data
=
y
.
data
();
...
@@ -900,6 +883,135 @@ void TestKernelVBroadcast() {
...
@@ -900,6 +883,135 @@ void TestKernelVBroadcast() {
}
}
}
}
// test pool
TEST
(
JITKernel_pool
,
jitcreator
)
{
const
auto
&
jitcreators
=
jit
::
JitCodeCreatorPool
::
Instance
().
AllCreators
();
EXPECT_EQ
(
jitcreators
.
size
(),
25UL
);
}
TEST
(
JITKernel_pool
,
jitpool
)
{
// jitpool is related with attr
const
auto
&
kers
=
jit
::
JitCodePool
<
jit
::
kVAdd
>
().
Instance
().
AllKernels
();
EXPECT_EQ
(
kers
.
size
(),
0UL
);
jit
::
GetAllCandidateKernels
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>
(
3
);
// after call GetAllCandidateKernels, it will create jitcode Automatically
EXPECT_EQ
(
kers
.
size
(),
1UL
);
}
TEST
(
JITKernel_pool
,
more
)
{
const
auto
&
kers
=
jit
::
KernelPool
::
Instance
().
AllKernels
();
EXPECT_EQ
(
kers
.
size
(),
21UL
);
}
TEST
(
JITKernel_pool
,
refer
)
{
const
auto
&
kers
=
jit
::
ReferKernelPool
::
Instance
().
AllKernels
();
EXPECT_EQ
(
kers
.
size
(),
29UL
);
}
// test helper
TEST
(
JITKernel_helper
,
GetAllCandidateKernels
)
{
auto
fp_kers
=
jit
::
GetAllCandidateKernels
<
jit
::
VExpTuple
<
float
>
,
CPUPlace
>
(
10
);
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE
(
fp_kers
.
size
(),
1UL
);
// refer
#else
EXPECT_GE
(
fp_kers
.
size
(),
3UL
);
// jitcode, mkl, refer
#endif
auto
db_kers
=
jit
::
GetAllCandidateKernels
<
jit
::
VExpTuple
<
double
>
,
CPUPlace
>
(
10
);
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_GE
(
db_kers
.
size
(),
1UL
);
// refer
#else
EXPECT_GE
(
db_kers
.
size
(),
2UL
);
// mkl, refer
#endif
}
TEST
(
JITKernel_helper
,
GetAllCandidateFuncsWithTypes
)
{
auto
fp_kers
=
jit
::
GetAllCandidateFuncsWithTypes
<
jit
::
VExpTuple
<
float
>
,
CPUPlace
>
(
10
);
EXPECT_GE
(
fp_kers
.
size
(),
3UL
);
// jitcode, mkl, refer
auto
db_kers
=
jit
::
GetAllCandidateFuncsWithTypes
<
jit
::
VExpTuple
<
double
>
,
CPUPlace
>
(
10
);
EXPECT_GE
(
db_kers
.
size
(),
2UL
);
// mkl, refer
}
TEST
(
JITKernel_helper
,
GetAllCandidateFuncs
)
{
auto
funcs
=
jit
::
GetAllCandidateFuncs
<
jit
::
VExpTuple
<
float
>
,
CPUPlace
>
(
10
);
auto
kers
=
jit
::
GetAllCandidateKernels
<
jit
::
VExpTuple
<
float
>
,
CPUPlace
>
(
10
);
EXPECT_EQ
(
funcs
.
size
(),
kers
.
size
());
std
::
vector
<
float
>
x
(
10
),
tgt
(
10
);
RandomVec
<
float
>
(
10
,
x
.
data
());
auto
best
=
jit
::
GetDefaultBestFunc
<
jit
::
VExpTuple
<
float
>
,
CPUPlace
>
(
10
);
best
(
x
.
data
(),
tgt
.
data
(),
10
);
for
(
auto
f
:
funcs
)
{
std
::
vector
<
float
>
y
(
10
);
f
(
x
.
data
(),
y
.
data
(),
10
);
ExpectEQ
<
float
>
(
y
.
data
(),
tgt
.
data
(),
10
);
}
}
TEST
(
JITKernel_helper
,
attr
)
{
std
::
ostringstream
out
;
// KernelTypes
out
<<
jit
::
to_string
(
jit
::
kNone
)
<<
jit
::
to_string
(
jit
::
kCRFDecoding
)
<<
jit
::
to_string
(
jit
::
kEmbSeqPool
)
<<
jit
::
to_string
(
jit
::
kGRUH1
)
<<
jit
::
to_string
(
jit
::
kGRUHtPart1
)
<<
jit
::
to_string
(
jit
::
kGRUHtPart2
)
<<
jit
::
to_string
(
jit
::
kHSum
)
<<
jit
::
to_string
(
jit
::
kHMax
)
<<
jit
::
to_string
(
jit
::
kLSTMCtHt
)
<<
jit
::
to_string
(
jit
::
kLSTMC1H1
)
<<
jit
::
to_string
(
jit
::
kLayerNorm
)
<<
jit
::
to_string
(
jit
::
kMatMul
)
<<
jit
::
to_string
(
jit
::
kNCHW16CMulNC
)
<<
jit
::
to_string
(
jit
::
kSeqPool
)
<<
jit
::
to_string
(
jit
::
kSoftmax
)
<<
jit
::
to_string
(
jit
::
kVAdd
)
<<
jit
::
to_string
(
jit
::
kVAddBias
)
<<
jit
::
to_string
(
jit
::
kVAddRelu
)
<<
jit
::
to_string
(
jit
::
kVBroadcast
)
<<
jit
::
to_string
(
jit
::
kVCopy
)
<<
jit
::
to_string
(
jit
::
kVExp
)
<<
jit
::
to_string
(
jit
::
kVIdentity
)
<<
jit
::
to_string
(
jit
::
kVMul
)
<<
jit
::
to_string
(
jit
::
kVRelu
)
<<
jit
::
to_string
(
jit
::
kVScal
)
<<
jit
::
to_string
(
jit
::
kSgd
)
<<
jit
::
to_string
(
jit
::
kVSigmoid
)
<<
jit
::
to_string
(
jit
::
kVSquare
)
<<
jit
::
to_string
(
jit
::
kVSub
)
<<
jit
::
to_string
(
jit
::
kVTanh
);
EXPECT_EQ
(
out
.
str
().
size
(),
234
);
// SeqPoolTypes
out
.
str
(
""
);
out
<<
jit
::
to_string
(
jit
::
kSum
)
<<
jit
::
to_string
(
jit
::
kAvg
)
<<
jit
::
to_string
(
jit
::
kSqrt
);
EXPECT_EQ
(
out
.
str
().
size
(),
13
);
EXPECT_EQ
(
jit
::
to_kerneltype
(
"relu"
),
jit
::
kVRelu
);
EXPECT_EQ
(
jit
::
to_kerneltype
(
"Identity"
),
jit
::
kVIdentity
);
EXPECT_EQ
(
jit
::
to_kerneltype
(
"VEXP"
),
jit
::
kVExp
);
EXPECT_EQ
(
jit
::
to_kerneltype
(
"SigmoiD"
),
jit
::
kVSigmoid
);
EXPECT_EQ
(
jit
::
to_kerneltype
(
"VTanh"
),
jit
::
kVTanh
);
out
.
str
(
""
);
out
<<
jit
::
lstm_attr_t
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
EXPECT_EQ
(
out
.
str
().
size
(),
89
);
out
.
str
(
""
);
out
<<
jit
::
gru_attr_t
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
);
EXPECT_EQ
(
out
.
str
().
size
(),
52
);
out
.
str
(
""
);
out
<<
jit
::
seq_pool_attr_t
(
8
,
jit
::
SeqPoolType
::
kSum
);
EXPECT_EQ
(
out
.
str
().
size
(),
44
);
out
.
str
(
""
);
out
<<
jit
::
emb_seq_pool_attr_t
(
1
,
2
,
3
,
4
,
5
,
jit
::
SeqPoolType
::
kAvg
);
EXPECT_EQ
(
out
.
str
().
size
(),
93
);
out
.
str
(
""
);
out
<<
jit
::
sgd_attr_t
(
1
,
2
,
3
,
4
,
5
);
EXPECT_EQ
(
out
.
str
().
size
(),
81
);
out
.
str
(
""
);
out
<<
jit
::
matmul_attr_t
(
1
,
2
,
3
);
EXPECT_EQ
(
out
.
str
().
size
(),
14
);
}
// test kernerls
#define TestKernelVMul TestKernelXYZN
#define TestKernelVMul TestKernelXYZN
#define TestKernelVAdd TestKernelXYZN
#define TestKernelVAdd TestKernelXYZN
#define TestKernelVAddRelu TestKernelXYZN
#define TestKernelVAddRelu TestKernelXYZN
...
@@ -969,6 +1081,14 @@ TEST_CPU_KERNEL(Softmax);
...
@@ -969,6 +1081,14 @@ TEST_CPU_KERNEL(Softmax);
TEST_CPU_KERNEL
(
Sgd
);
TEST_CPU_KERNEL
(
Sgd
);
TEST_CPU_KERNEL
(
VBroadcast
);
TEST_CPU_KERNEL
(
VBroadcast
);
TEST
(
JITKernel
,
kernel_func
)
{
auto
f1
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
().
At
(
3
);
auto
f2
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
()[
3
];
EXPECT_TRUE
(
f1
!=
nullptr
);
EXPECT_TRUE
(
f1
==
f2
);
// TODO(TJ): check not equal
}
TEST
(
JITKernel_key
,
lstm
)
{
TEST
(
JITKernel_key
,
lstm
)
{
jit
::
lstm_attr_t
attr1
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr1
(
8
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr2
(
9
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
jit
::
lstm_attr_t
attr2
(
9
,
jit
::
kVIdentity
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
...
@@ -1000,11 +1120,3 @@ TEST(JITKernel_key, gru) {
...
@@ -1000,11 +1120,3 @@ TEST(JITKernel_key, gru) {
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key2
==
key3
);
EXPECT_TRUE
(
key3
!=
key4
);
EXPECT_TRUE
(
key3
!=
key4
);
}
}
TEST
(
JITKernel
,
kernel_func
)
{
auto
f1
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
().
At
(
3
);
auto
f2
=
jit
::
KernelFuncs
<
jit
::
VAddTuple
<
float
>
,
CPUPlace
>::
Cache
()[
3
];
EXPECT_TRUE
(
f1
!=
nullptr
);
EXPECT_TRUE
(
f1
==
f2
);
// TODO(TJ): check not equal
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录