Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d53c4756
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d53c4756
编写于
12月 19, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean code and remove unused files
test=develop
上级
95fb3128
变更
28
展开全部
隐藏空白更改
内联
并排
Showing
28 changed file
with
73 addition
and
3608 deletion
+73
-3608
paddle/fluid/operators/jit/README.md
paddle/fluid/operators/jit/README.md
+22
-12
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+3
-2
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+1
-1
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+6
-7
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc
+1
-1
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h
+4
-2
paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
+1
-1
paddle/fluid/operators/jit/more/intrinsic/layer_norm.h
paddle/fluid/operators/jit/more/intrinsic/layer_norm.h
+3
-2
paddle/fluid/operators/jit/more/mix/mix.cc
paddle/fluid/operators/jit/more/mix/mix.cc
+7
-7
paddle/fluid/operators/jit/more/mix/mix.h
paddle/fluid/operators/jit/more/mix/mix.h
+6
-5
paddle/fluid/operators/jit/more/mkl/mkl.cc
paddle/fluid/operators/jit/more/mkl/mkl.cc
+10
-10
paddle/fluid/operators/jit/more/mkl/mkl.h
paddle/fluid/operators/jit/more/mkl/mkl.h
+7
-6
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+2
-2
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+0
-334
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+0
-532
paddle/fluid/operators/math/jit_gen.cc
paddle/fluid/operators/math/jit_gen.cc
+0
-90
paddle/fluid/operators/math/jit_gen.h
paddle/fluid/operators/math/jit_gen.h
+0
-80
paddle/fluid/operators/math/jit_kernel.cc
paddle/fluid/operators/math/jit_kernel.cc
+0
-39
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+0
-157
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+0
-346
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
+0
-291
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+0
-195
paddle/fluid/operators/math/jit_kernel_impl.h
paddle/fluid/operators/math/jit_kernel_impl.h
+0
-34
paddle/fluid/operators/math/jit_kernel_layer_norm.cc
paddle/fluid/operators/math/jit_kernel_layer_norm.cc
+0
-239
paddle/fluid/operators/math/jit_kernel_macro.h
paddle/fluid/operators/math/jit_kernel_macro.h
+0
-179
paddle/fluid/operators/math/jit_kernel_refer.h
paddle/fluid/operators/math/jit_kernel_refer.h
+0
-29
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+0
-263
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+0
-742
未找到文件。
paddle/fluid/operators/jit/README.md
浏览文件 @
d53c4756
# JIT Kernel
# JIT Kernel
结合函数模板和JIT生成需要的kernel函数。
结合函数模板和JIT生成需要的kernel函数。
这里的kernel是比Operator中kernel更小级别的算子单元,更侧重的是在不同硬件上的性能。
这里的kernel是比Operator中kernel更小级别的算子单元,更侧重的是在不同硬件上的性能。可以有多重第三方库的实现,每种实现有自己的
`UseMe`
函数负责什么条件下可以被调用。
这里实现的函数可以非常细粒度的函数方法,比如Vector mul, 也可以是一个复杂的逻辑比如LSTM等。复杂的逻辑也可以由自己的底层函数拼接而成。
目前仅支持CPU上的高性能计算。
目前仅支持CPU上的高性能计算。
## 目录结构
## 目录结构
...
@@ -21,6 +22,8 @@ PaddlePaddle/Paddle/paddle/fluid/
...
@@ -21,6 +22,8 @@ PaddlePaddle/Paddle/paddle/fluid/
│ │ └── ...
│ │ └── ...
│ ├── mkldnn/
│ ├── mkldnn/
│ │ └── ...
│ │ └── ...
│ ├── mix/
│ │ └── ...
│ ├── intrinsic/
│ ├── intrinsic/
│ │ └── ...
│ │ └── ...
│ └── openblas/
│ └── openblas/
...
@@ -29,28 +32,35 @@ PaddlePaddle/Paddle/paddle/fluid/
...
@@ -29,28 +32,35 @@ PaddlePaddle/Paddle/paddle/fluid/
└── ...
└── ...
```
```
基
础class都的根目录下,根目录下包括jitcode,more和refer。每个目录下都是一种实现,每种kernel算子都需要有reference的实现,其他的
都是可选的。
基
本类的定义都放在根目录下,根目录下包括gen,more和refer三个目录。每个目录下都是一种或者多种实现,每种kernel算子都需要有reference的实现,用作单元测试的基准,其他的实现
都是可选的。
-
jitcode: 代表使用jit生成的code,需要依赖xbyak。他关心的
是性能。
-
gen: 代表使用jit生成的code,需要依赖xbyak库。该实现最关心的就
是性能。
-
refer
:代表reference的实现,每种kernel算子都需要有在CPU上的reference的实现,他主要关心的算法逻辑
。
-
refer
: 代表reference的实现,每种kernel算子都需要有在CPU上的reference的实现,他主要关心的算法逻辑的正确性
。
-
more
: 下面可以放入跟多实现,包括mkl,mkldnn
,openblas等,也可以是自身已有的kernel组合。
-
more
: 下面可以放入跟多实现,可以包括mkl,mkldnn,intrinsic
,openblas等,也可以是自身已有的kernel组合。
## 动态获取
## 动态获取
提供一个
get
方法,根据kernel类别获取,每种实现都有自己的使用范围,根据范围动态和当前条件选择需要的kernel函数。
提供一个
`jit::Get`
方法,根据kernel类别获取,每种实现都有自己的使用范围,根据范围动态和当前条件选择需要的kernel函数。
## 测试
## 测试
-
逻辑测试
-
逻辑测试
所有实现都要与refer的code对比,需要满足精度要求, 包括float和double的数据类型
所有实现都要与refer的code对比,需要满足精度要求, 包括float和double的数据类型
-
性能测试
-
性能测试
所有实现的性能对比,并且与最终的
`jit::Get`
方法对比,该方法拿到的性能需要是最好的。
所有实现的性能对比,并且与最终的
`jit::Get`
方法对比,该方法拿到的性能需要
在各种条件下都
是最好的。
# 如何添加新的算子
# 如何添加新的算子
-
在
`KernelType`
中添加
`your_key`
.
-
在
`KernelType`
中添加
`your_key`
.
-
实现Reference 的逻辑,
每个jitkernel的Reference 实现是必须的。不要依赖任何第三方库。并在
`refer/CmakeLists.txt`
中
`USE_JITKERNEL_REFER(your_key)`
.
-
实现Reference 的逻辑,
这个是必须是在CPU上的实现,并且不能依赖任何第三方库。实现后在
`refer/CmakeLists.txt`
中添加
`USE_JITKERNEL_REFER(your_key)`
来使用该kernel
.
-
(optional) 实现更多的算法在
`more`
目录下,可以依赖mkl,
openblas,
或者mkldnn等第三方库。
-
(optional) 实现更多的算法在
`more`
目录下,可以依赖mkl,
intrinsic
或者mkldnn等第三方库。
-
(optional) 实现基于Xbyak的生成code,在
`gen`
目下。 jitcode需要实现自己的
`JitCodeCreator`
,并注册在
KernelType
上。
-
(optional) 实现基于Xbyak的生成code,在
`gen`
目下。 jitcode需要实现自己的
`JitCodeCreator`
,并注册在
与refer相同的
`KernelType`
上。
-
必要时可以添加新的
`KernelTuples`
,可以参考
`XYZNTuples`
,新加的Attr类型需要特例化
`JitCodeKey`
方法。
-
必要时可以添加新的
`KernelTuples`
,可以参考
`XYZNTuples`
,新加的Attr类型需要特例化
`JitCodeKey`
方法。
-
添加unit test,需要测试float和double
-
在
`test.cc`
中添加unit test,至少需要测试
`float`
和
`double`
两种数据类型,如有必要需要支持额外的数据类型,比如
`int8`
的相关函数。
-
添加benchmark确保get得到的速度是最快。
-
在
`benchmark.cc`
中添加相应的性能对比,同一种kernel需要对比所有实现,并且确保
`jit::Get`
得到的实现一直是速度最快的。
# 优点
-
统一的Get方法,接口简单。
-
同一套逻辑可以有多套实现,可以依赖多套第三方库,互不影响。
-
目录结构清晰,不会在某个文件中有多个宏定义,导致的可读性差问题。
-
优化方便,可以直接针对某种属性针对性优化,并不影响其他属性下的性能。
-
可以支持多种平台,包括Linux,Mac 和 Windows,至少可以保证每种平台都可以正常work。后期也可以针对不同平台有针对的优化。框架层面可以使用统一接口,不必关心底层实现。
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
d53c4756
...
@@ -93,10 +93,11 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
...
@@ -93,10 +93,11 @@ void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
if
(
iter
!=
pool
.
end
())
{
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
jit
::
Kernel
Impl
<
KernelTuples
>*>
(
impl
.
get
());
auto
i
=
dynamic_cast
<
const
jit
::
Kernel
More
<
KernelTuples
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
if
(
i
&&
i
->
UseMe
(
attr
))
{
auto
more
=
i
->
GetFunc
();
auto
more
=
i
->
GetFunc
();
infos
.
push_back
(
std
::
make_pair
(
"More"
,
benchmark
(
more
,
args
...)));
infos
.
push_back
(
std
::
make_pair
(
i
->
ImplType
(),
benchmark
(
more
,
args
...)));
}
}
}
}
}
}
...
...
paddle/fluid/operators/jit/helper.h
浏览文件 @
d53c4756
...
@@ -107,7 +107,7 @@ typename KernelTuples::func_type Get(
...
@@ -107,7 +107,7 @@ typename KernelTuples::func_type Get(
if
(
iter
!=
pool
.
end
())
{
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
Kernel
Impl
<
KernelTuples
>*>
(
impl
.
get
());
auto
i
=
dynamic_cast
<
const
Kernel
More
<
KernelTuples
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
if
(
i
&&
i
->
UseMe
(
attr
))
{
return
i
->
GetFunc
();
return
i
->
GetFunc
();
}
}
...
...
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
d53c4756
...
@@ -144,28 +144,27 @@ class Kernel {
...
@@ -144,28 +144,27 @@ class Kernel {
};
};
template
<
typename
KernelTuples
>
template
<
typename
KernelTuples
>
class
KernelImpl
:
public
Kernel
{
class
KernelMore
:
public
Kernel
{
// TODO(TJ): rename KernelImpl to KernelMore which seems only used in more
// and add name interface for more implements easy for debug
public:
public:
using
T
=
typename
KernelTuples
::
data_type
;
using
T
=
typename
KernelTuples
::
data_type
;
using
Func
=
typename
KernelTuples
::
func_type
;
using
Func
=
typename
KernelTuples
::
func_type
;
using
Attr
=
typename
KernelTuples
::
attr_type
;
using
Attr
=
typename
KernelTuples
::
attr_type
;
virtual
Func
GetFunc
()
const
{
return
func
;
}
virtual
Func
GetFunc
()
const
{
return
func
;
}
// TODO(TJ): const &attr
virtual
bool
UseMe
(
const
Attr
&
attr
)
const
=
0
;
virtual
bool
UseMe
(
Attr
attr
)
const
=
0
;
virtual
const
char
*
ImplType
(
)
const
=
0
;
protected:
protected:
Func
func
{
nullptr
};
Func
func
{
nullptr
};
};
};
template
<
typename
KernelTuples
>
template
<
typename
KernelTuples
>
class
ReferKernel
:
public
Kernel
Impl
<
KernelTuples
>
{
class
ReferKernel
:
public
Kernel
More
<
KernelTuples
>
{
public:
public:
// Refer code can always be used
// Refer code can always be used
bool
UseMe
(
typename
KernelTuples
::
attr_type
attr
)
const
override
{
bool
UseMe
(
const
typename
KernelTuples
::
attr_type
&
attr
)
const
override
{
return
true
;
return
true
;
}
}
const
char
*
ImplType
()
const
override
{
return
"Refer"
;
}
};
};
}
// namespace jit
}
// namespace jit
...
...
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc
浏览文件 @
d53c4756
...
@@ -156,7 +156,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
...
@@ -156,7 +156,7 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
}
}
}
}
bool
CRFDecodingKernel
::
UseMe
(
int
d
)
const
{
bool
CRFDecodingKernel
::
UseMe
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
...
...
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h
浏览文件 @
d53c4756
...
@@ -26,10 +26,12 @@ namespace intrinsic {
...
@@ -26,10 +26,12 @@ namespace intrinsic {
void
CRFDecoding
(
const
int
seq_len
,
const
float
*
x
,
const
float
*
w
,
void
CRFDecoding
(
const
int
seq_len
,
const
float
*
x
,
const
float
*
w
,
float
*
alpha
,
int
*
track
,
int
tag_num
);
float
*
alpha
,
int
*
track
,
int
tag_num
);
class
CRFDecodingKernel
:
public
Kernel
Impl
<
CRFDecodingTuples
<
float
>>
{
class
CRFDecodingKernel
:
public
Kernel
More
<
CRFDecodingTuples
<
float
>>
{
public:
public:
CRFDecodingKernel
()
{
this
->
func
=
CRFDecoding
;
}
CRFDecodingKernel
()
{
this
->
func
=
CRFDecoding
;
}
bool
UseMe
(
typename
CRFDecodingTuples
<
float
>::
attr_type
)
const
override
;
bool
UseMe
(
const
typename
CRFDecodingTuples
<
float
>::
attr_type
&
)
const
override
;
const
char
*
ImplType
()
const
override
{
return
"Intrinsic"
;
}
};
};
}
// namespace intrinsic
}
// namespace intrinsic
...
...
paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
浏览文件 @
d53c4756
...
@@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
...
@@ -153,7 +153,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
}
}
}
}
bool
LayerNormKernel
::
UseMe
(
int
d
)
const
{
bool
LayerNormKernel
::
UseMe
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
...
...
paddle/fluid/operators/jit/more/intrinsic/layer_norm.h
浏览文件 @
d53c4756
...
@@ -27,10 +27,11 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
...
@@ -27,10 +27,11 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
const
float
*
scale
,
const
float
*
bias
,
int
height
,
const
float
*
scale
,
const
float
*
bias
,
int
height
,
const
float
epsilon
,
int
right
);
const
float
epsilon
,
int
right
);
class
LayerNormKernel
:
public
Kernel
Impl
<
LayerNormTuples
<
float
>>
{
class
LayerNormKernel
:
public
Kernel
More
<
LayerNormTuples
<
float
>>
{
public:
public:
LayerNormKernel
()
{
this
->
func
=
LayerNorm
;
}
LayerNormKernel
()
{
this
->
func
=
LayerNorm
;
}
bool
UseMe
(
typename
LayerNormTuples
<
float
>::
attr_type
)
const
override
;
bool
UseMe
(
const
typename
LayerNormTuples
<
float
>::
attr_type
&
)
const
override
;
const
char
*
ImplType
()
const
override
{
return
"Intrinsic"
;
}
};
};
}
// namespace intrinsic
}
// namespace intrinsic
...
...
paddle/fluid/operators/jit/more/mix/mix.cc
浏览文件 @
d53c4756
...
@@ -180,19 +180,19 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
...
@@ -180,19 +180,19 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
}
}
// TODO(TJ): tuning me
// TODO(TJ): tuning me
bool
VSigmoidKernel
::
UseMe
(
int
d
)
const
{
return
true
;
}
bool
VSigmoidKernel
::
UseMe
(
const
int
&
d
)
const
{
return
true
;
}
bool
VTanhKernel
::
UseMe
(
int
d
)
const
{
return
true
;
}
bool
VTanhKernel
::
UseMe
(
const
int
&
d
)
const
{
return
true
;
}
bool
LSTMCtHtKernel
::
UseMe
(
lstm_attr_t
attr
)
const
{
return
true
;
}
bool
LSTMCtHtKernel
::
UseMe
(
const
lstm_attr_t
&
attr
)
const
{
return
true
;
}
bool
LSTMC1H1Kernel
::
UseMe
(
lstm_attr_t
attr
)
const
{
return
true
;
}
bool
LSTMC1H1Kernel
::
UseMe
(
const
lstm_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUH1Kernel
::
UseMe
(
gru_attr_t
attr
)
const
{
return
true
;
}
bool
GRUH1Kernel
::
UseMe
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUHtPart1Kernel
::
UseMe
(
gru_attr_t
attr
)
const
{
return
true
;
}
bool
GRUHtPart1Kernel
::
UseMe
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUHtPart2Kernel
::
UseMe
(
gru_attr_t
attr
)
const
{
return
true
;
}
bool
GRUHtPart2Kernel
::
UseMe
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
}
// namespace mix
}
// namespace mix
}
// namespace more
}
// namespace more
...
...
paddle/fluid/operators/jit/more/mix/mix.h
浏览文件 @
d53c4756
...
@@ -33,11 +33,12 @@ void GRUH1(gru_t* step, const gru_attr_t* attr);
...
@@ -33,11 +33,12 @@ void GRUH1(gru_t* step, const gru_attr_t* attr);
void
GRUHtPart1
(
gru_t
*
step
,
const
gru_attr_t
*
attr
);
void
GRUHtPart1
(
gru_t
*
step
,
const
gru_attr_t
*
attr
);
void
GRUHtPart2
(
gru_t
*
step
,
const
gru_attr_t
*
attr
);
void
GRUHtPart2
(
gru_t
*
step
,
const
gru_attr_t
*
attr
);
#define DECLARE_MORE_KERNEL(name, tuples) \
#define DECLARE_MORE_KERNEL(name, tuples) \
class name##Kernel : public KernelImpl<tuples<T>> { \
class name##Kernel : public KernelMore<tuples<T>> { \
public: \
public: \
name##Kernel() { this->func = name; } \
name##Kernel() { this->func = name; } \
bool UseMe(typename tuples<T>::attr_type) const override; \
bool UseMe(const typename tuples<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \
}
}
// XYN
// XYN
...
...
paddle/fluid/operators/jit/more/mkl/mkl.cc
浏览文件 @
d53c4756
...
@@ -74,39 +74,39 @@ void VExp<double>(const double* x, double* y, int n) {
...
@@ -74,39 +74,39 @@ void VExp<double>(const double* x, double* y, int n) {
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template
<
>
template
<
>
bool
VMulKernel
<
float
>::
UseMe
(
int
d
)
const
{
bool
VMulKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
}
template
<
>
template
<
>
bool
VAddKernel
<
float
>::
UseMe
(
int
d
)
const
{
bool
VAddKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
}
template
<
>
template
<
>
bool
VScalKernel
<
float
>::
UseMe
(
int
d
)
const
{
bool
VScalKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
}
template
<
>
template
<
>
bool
VExpKernel
<
float
>::
UseMe
(
int
d
)
const
{
bool
VExpKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
d
>
7
;
return
d
>
7
;
}
}
template
<
>
template
<
>
bool
VSigmoidKernel
<
float
>::
UseMe
(
int
d
)
const
{
bool
VSigmoidKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
d
>
7
;
return
d
>
7
;
}
}
template
<
>
template
<
>
bool
VTanhKernel
<
float
>::
UseMe
(
int
d
)
const
{
bool
VTanhKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
d
>
7
;
return
d
>
7
;
}
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
#define AWALYS_USE_ME_WITH_DOUBLE(func)
\
template <> \
template <>
\
bool func##Kernel<double>::UseMe(
int
d) const { \
bool func##Kernel<double>::UseMe(
const int&
d) const { \
return true; \
return true;
\
}
}
AWALYS_USE_ME_WITH_DOUBLE
(
VMul
);
AWALYS_USE_ME_WITH_DOUBLE
(
VMul
);
...
...
paddle/fluid/operators/jit/more/mkl/mkl.h
浏览文件 @
d53c4756
...
@@ -60,12 +60,13 @@ void VTanh(const T* x, T* y, int n) {
...
@@ -60,12 +60,13 @@ void VTanh(const T* x, T* y, int n) {
}
}
}
}
#define DECLARE_MKL_KERNEL(name, tuples) \
#define DECLARE_MKL_KERNEL(name, tuples) \
template <typename T> \
template <typename T> \
class name##Kernel : public KernelImpl<tuples<T>> { \
class name##Kernel : public KernelMore<tuples<T>> { \
public: \
public: \
name##Kernel() { this->func = name<T>; } \
name##Kernel() { this->func = name<T>; } \
bool UseMe(typename tuples<T>::attr_type) const override; \
bool UseMe(const typename tuples<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \
}
}
// XYZN
// XYZN
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
d53c4756
...
@@ -228,10 +228,10 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
...
@@ -228,10 +228,10 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
if
(
iter
!=
pool
.
end
())
{
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
jit
::
Kernel
Impl
<
KernelTuples
>*>
(
impl
.
get
());
auto
i
=
dynamic_cast
<
const
jit
::
Kernel
More
<
KernelTuples
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
if
(
i
&&
i
->
UseMe
(
attr
))
{
auto
more
=
i
->
GetFunc
();
auto
more
=
i
->
GetFunc
();
VLOG
(
10
)
<<
"Test More Kernel
"
;
VLOG
(
10
)
<<
"Test More Kernel
: "
<<
i
->
ImplType
()
;
test
(
more
,
args
...);
test
(
more
,
args
...);
}
}
}
}
...
...
paddle/fluid/operators/math/jit_code.cc
已删除
100644 → 0
浏览文件 @
95fb3128
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_code.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
gen
{
using
namespace
platform
;
// NOLINT
bool
VXXJitCode
::
init
(
int
d
,
int
scalar_index
)
{
// It's not necessary to use avx512 since it would slow down the frequency
// and this kernel is not compute bound.
return
MayIUse
(
avx
)
&&
scalar_index
>=
0
&&
scalar_index
<=
2
;
}
void
VXXJitCode
::
generate
()
{
// do not need push stack, and do not need save avx512reg if do not use avx512
int
offset
=
0
;
if
(
with_relu_
)
{
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
}
if
(
scalar_index_
==
1
)
{
vbroadcastss
(
ymm_src1
,
ptr
[
param1
]);
}
else
if
(
scalar_index_
==
2
)
{
vbroadcastss
(
ymm_src2
,
ptr
[
param2
]);
}
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
if
(
scalar_index_
!=
1
)
{
vmovups
(
ymm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
ymm_src2
,
ptr
[
param2
+
offset
]);
}
if
(
type_
==
operand_type
::
mul
)
{
vmulps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
}
else
if
(
type_
==
operand_type
::
add
)
{
vaddps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
}
if
(
with_relu_
)
{
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_dst
);
}
vmovups
(
ptr
[
param3
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
int
rest
=
num_
%
YMM_FLOAT_BLOCK
;
while
(
rest
>
0
)
{
int
block
=
XMM_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
block
=
4
;
if
(
scalar_index_
!=
1
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
}
else
if
(
rest
>=
2
)
{
block
=
2
;
if
(
scalar_index_
!=
1
)
{
vmovq
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovq
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
}
else
{
block
=
1
;
if
(
scalar_index_
!=
1
)
{
vmovss
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovss
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
}
switch
(
type_
)
{
case
operand_type
::
mul
:
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
break
;
case
operand_type
::
add
:
vaddps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
break
;
default:
break
;
}
if
(
with_relu_
)
{
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_dst
);
}
if
(
rest
>=
4
)
{
vmovups
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
else
if
(
rest
>=
2
)
{
vmovq
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
else
{
vmovss
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
offset
+=
sizeof
(
float
)
*
block
;
rest
-=
block
;
}
ret
();
}
const
float
ALIGN32_BEG
exp_float_consts
[]
ALIGN32_END
=
{
REPEAT_8TIMES
(
1.
f
),
REPEAT_8TIMES
(
2.
f
),
REPEAT_8TIMES
(
0.5
f
),
REPEAT_8TIMES
(
EXP_HIG
),
REPEAT_8TIMES
(
EXP_LOW
),
REPEAT_8TIMES
(
CEPHES_LOG2EF
),
REPEAT_8TIMES
(
CEPHES_EXP_C1
),
REPEAT_8TIMES
(
CEPHES_EXP_C2
),
REPEAT_8TIMES
(
CEPHES_EXP_P0
),
REPEAT_8TIMES
(
CEPHES_EXP_P1
),
REPEAT_8TIMES
(
CEPHES_EXP_P2
),
REPEAT_8TIMES
(
CEPHES_EXP_P3
),
REPEAT_8TIMES
(
CEPHES_EXP_P4
),
REPEAT_8TIMES
(
CEPHES_EXP_P5
),
REPEAT_8TIMES
(
EXP_MAX_INPUT
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MAX
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MIN
)};
const
int
ALIGN32_BEG
exp_int_0x7f
[]
ALIGN32_END
=
{
REPEAT_8TIMES
(
0x7f
)};
int
ALIGN32_BEG
g_tmp_mem
[
16
]
ALIGN32_END
=
{
0
};
bool
VActJitCode
::
init
(
int
d
,
operand_type
type
)
{
// TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
return
MayIUse
(
avx
);
}
void
VActJitCode
::
generate
()
{
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
act
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
type_
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
int
rest
=
num_
%
YMM_FLOAT_BLOCK
;
while
(
rest
>
0
)
{
int
block
=
XMM_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
block
=
4
;
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
}
else
if
(
rest
>=
2
)
{
block
=
2
;
vmovq
(
xmm_src
,
ptr
[
param1
+
offset
]);
}
else
{
block
=
1
;
vmovss
(
xmm_src
,
ptr
[
param1
+
offset
]);
}
act
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
type_
);
if
(
rest
>=
4
)
{
vmovups
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
else
if
(
rest
>=
2
)
{
vmovq
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
else
{
vmovss
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
offset
+=
sizeof
(
float
)
*
block
;
rest
-=
block
;
}
ret
();
}
bool
LSTMJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
%
8
==
0
;
}
void
LSTMJitCode
::
generate
()
{
if
(
use_peephole_
)
{
preCode
();
}
reg64_t
reg_ptr_gates
=
rax
;
reg64_t
reg_ptr_ct_1
=
r9
;
reg64_t
reg_ptr_ct
=
r10
;
reg64_t
reg_ptr_ht
=
r11
;
reg64_t
reg_ptr_wp
=
r12
;
mov
(
reg_ptr_gates
,
ptr
[
param1
+
offsetof
(
lstm_t
,
gates
)]);
mov
(
reg_ptr_ct_1
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ct_1
)]);
mov
(
reg_ptr_ct
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ct
)]);
mov
(
reg_ptr_ht
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ht
)]);
if
(
use_peephole_
)
{
mov
(
reg_ptr_wp
,
ptr
[
param1
+
offsetof
(
lstm_t
,
wp
)]);
}
int
offset
=
0
;
int
d
=
num_
*
sizeof
(
float
);
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
/* gates: W_ch, W_ih, W_fh, W_oh */
ymm_t
ymm_c
=
ymm_t
(
0
);
ymm_t
ymm_i
=
ymm_t
(
1
);
ymm_t
ymm_f
=
ymm_t
(
2
);
ymm_t
ymm_o
=
ymm_t
(
3
);
ymm_t
ymm_ct_1
=
ymm_t
(
4
);
ymm_t
ymm_wp0
=
ymm_t
(
5
);
ymm_t
ymm_wp1
=
ymm_t
(
6
);
ymm_t
ymm_wp2
=
ymm_t
(
7
);
vmovups
(
ymm_c
,
ptr
[
reg_ptr_gates
+
offset
]);
vmovups
(
ymm_i
,
ptr
[
reg_ptr_gates
+
offset
+
d
]);
vmovups
(
ymm_f
,
ptr
[
reg_ptr_gates
+
offset
+
2
*
d
]);
vmovups
(
ymm_o
,
ptr
[
reg_ptr_gates
+
offset
+
3
*
d
]);
if
(
!
compute_c1h1_
)
{
vmovups
(
ymm_ct_1
,
ptr
[
reg_ptr_ct_1
+
offset
]);
}
if
(
use_peephole_
)
{
vmovups
(
ymm_wp0
,
ptr
[
reg_ptr_wp
+
offset
]);
vmovups
(
ymm_wp1
,
ptr
[
reg_ptr_wp
+
offset
+
d
]);
vmovups
(
ymm_wp2
,
ptr
[
reg_ptr_wp
+
offset
+
2
*
d
]);
}
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
// act_cand(c)
act
<
ymm_t
>
(
ymm_c
,
ymm_c
,
act_cand_
);
// act_gate(i) or act_gate(ct_1 * wp0 + i)
if
(
!
compute_c1h1_
&&
use_peephole_
)
{
vmulps
(
ymm_wp0
,
ymm_ct_1
,
ymm_wp0
);
vaddps
(
ymm_i
,
ymm_i
,
ymm_wp0
);
}
act
<
ymm_t
>
(
ymm_i
,
ymm_i
,
act_gate_
);
vmulps
(
ymm_c
,
ymm_c
,
ymm_i
);
if
(
!
compute_c1h1_
)
{
// act_gate(f) or act_gate(ct_1 * wp1 + f)
if
(
use_peephole_
)
{
vmulps
(
ymm_wp1
,
ymm_ct_1
,
ymm_wp1
);
vaddps
(
ymm_f
,
ymm_f
,
ymm_wp1
);
}
act
<
ymm_t
>
(
ymm_f
,
ymm_f
,
act_gate_
);
// ct
vmulps
(
ymm_f
,
ymm_f
,
ymm_ct_1
);
vaddps
(
ymm_f
,
ymm_f
,
ymm_c
);
}
/* H_t = act_cell(C_t) * act_gate(o) */
// act_cell(C_t)
ymm_t
ymm_ct
=
compute_c1h1_
?
ymm_c
:
ymm_f
;
ymm_t
ymm_tmp
=
ymm_i
;
act
<
ymm_t
>
(
ymm_tmp
,
ymm_ct
,
act_cell_
);
// act_gate(o) or act_gate(ct * wp2 + o)
if
(
use_peephole_
)
{
vmulps
(
ymm_wp2
,
ymm_ct
,
ymm_wp2
);
vaddps
(
ymm_o
,
ymm_o
,
ymm_wp2
);
}
act
<
ymm_t
>
(
ymm_o
,
ymm_o
,
act_gate_
);
// ht
vmulps
(
ymm_o
,
ymm_o
,
ymm_tmp
);
// save ct and ht
vmovups
(
ptr
[
reg_ptr_ct
+
offset
],
ymm_ct
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_o
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
if
(
use_peephole_
)
{
postCode
();
}
else
{
ret
();
}
}
bool
GRUJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
%
8
==
0
;
}
void
GRUJitCode
::
generate
()
{
reg64_t
reg_ptr_gates
=
rax
;
reg64_t
reg_ptr_ht_1
=
r9
;
reg64_t
reg_ptr_ht
=
r10
;
mov
(
reg_ptr_gates
,
ptr
[
param1
+
offsetof
(
gru_t
,
gates
)]);
mov
(
reg_ptr_ht_1
,
ptr
[
param1
+
offsetof
(
gru_t
,
ht_1
)]);
mov
(
reg_ptr_ht
,
ptr
[
param1
+
offsetof
(
gru_t
,
ht
)]);
ymm_t
ymm_one
=
ymm_t
(
0
);
if
(
id_
==
2
)
{
reg64_t
reg_ptr_tmp
=
r11
;
mov
(
reg_ptr_tmp
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_one
,
ptr
[
reg_ptr_tmp
+
OFFSET_EXP_ONE
]);
}
int
offset
=
0
;
int
d
=
num_
*
sizeof
(
float
);
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
ymm_t
ymm_u
=
ymm_t
(
1
);
ymm_t
ymm_r
=
ymm_t
(
2
);
ymm_t
ymm_s
=
ymm_t
(
3
);
ymm_t
ymm_ht_1
=
ymm_t
(
4
);
// W: {W_update, W_reset; W_state}
if
(
id_
==
0
||
id_
==
2
)
{
vmovups
(
ymm_u
,
ptr
[
reg_ptr_gates
+
offset
]);
vmovups
(
ymm_s
,
ptr
[
reg_ptr_gates
+
offset
+
2
*
d
]);
}
if
(
id_
==
1
)
{
vmovups
(
ymm_r
,
ptr
[
reg_ptr_gates
+
offset
+
d
]);
}
if
(
id_
==
1
||
id_
==
2
)
{
vmovups
(
ymm_ht_1
,
ptr
[
reg_ptr_ht_1
+
offset
]);
}
if
(
id_
==
0
)
{
// ht = act_gate(u) * act_cand(s)
act
<
ymm_t
>
(
ymm_u
,
ymm_u
,
act_gate_
);
act
<
ymm_t
>
(
ymm_s
,
ymm_s
,
act_cand_
);
vmulps
(
ymm_s
,
ymm_s
,
ymm_u
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_s
);
}
else
if
(
id_
==
1
)
{
// ht = act_gate(r) * ht_1
act
<
ymm_t
>
(
ymm_r
,
ymm_r
,
act_gate_
);
vmulps
(
ymm_r
,
ymm_r
,
ymm_ht_1
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_r
);
}
else
if
(
id_
==
2
)
{
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
ymm_t
ymm_one_inner
=
ymm_t
(
ymm_one
.
getIdx
());
act
<
ymm_t
>
(
ymm_u
,
ymm_u
,
act_gate_
);
act
<
ymm_t
>
(
ymm_s
,
ymm_s
,
act_cand_
);
vmulps
(
ymm_s
,
ymm_s
,
ymm_u
);
vsubps
(
ymm_u
,
ymm_one_inner
,
ymm_u
);
vmulps
(
ymm_u
,
ymm_ht_1
,
ymm_u
);
vaddps
(
ymm_u
,
ymm_s
,
ymm_u
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_u
);
}
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
ret
();
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_code.h
已删除
100644 → 0
浏览文件 @
95fb3128
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/operators/math/jit_gen.h"
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
gen
{
using
reg64_t
=
const
Xbyak
::
Reg64
;
using
reg32_t
=
const
Xbyak
::
Reg32
;
using
xmm_t
=
const
Xbyak
::
Xmm
;
using
ymm_t
=
const
Xbyak
::
Ymm
;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
Label
=
Xbyak
::
Label
;
typedef
enum
{
mul
=
0
,
add
,
sub
,
relu
,
exp
,
sigmoid
,
tanh
,
identity
}
operand_type
;
extern
const
float
exp_float_consts
[];
extern
const
int
exp_int_0x7f
[];
extern
int
g_tmp_mem
[];
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
#define CEPHES_LOG2EF 1.44269504088896341
#define CEPHES_EXP_C1 0.693359375
#define CEPHES_EXP_C2 -2.12194440e-4
#define CEPHES_EXP_P0 1.9875691500E-4
#define CEPHES_EXP_P1 1.3981999507E-3
#define CEPHES_EXP_P2 8.3334519073E-3
#define CEPHES_EXP_P3 4.1665795894E-2
#define CEPHES_EXP_P4 1.6666665459E-1
#define CEPHES_EXP_P5 5.0000001201E-1
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class
VXXJitCode
:
public
JitCode
{
public:
const
char
*
name
()
const
override
{
std
::
string
base
=
"VXXJitCode"
;
if
(
scalar_index_
==
1
)
{
base
+=
"_Scalar"
;
}
else
{
base
+=
"_Vec"
;
}
if
(
type_
==
operand_type
::
mul
)
{
base
+=
"_Mul"
;
}
else
if
(
type_
==
operand_type
::
add
)
{
base
+=
"_Add"
;
}
if
(
scalar_index_
==
2
)
{
base
+=
"_Scalar"
;
}
else
{
base
+=
"_Vec"
;
}
base
+=
(
with_relu_
?
"_Relu"
:
""
);
return
base
.
c_str
();
}
explicit
VXXJitCode
(
int
d
,
operand_type
type
,
int
scalar_index
,
bool
with_relu
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
),
scalar_index_
(
scalar_index
),
with_relu_
(
with_relu
)
{}
static
bool
init
(
int
d
,
int
scalar_index
=
0
);
void
generate
()
override
;
private:
int
num_
;
operand_type
type_
;
int
scalar_index_
;
bool
with_relu_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
reg64_t
param3
{
abi_param3
};
xmm_t
xmm_src1
=
xmm_t
(
0
);
xmm_t
xmm_src2
=
xmm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
2
);
xmm_t
xmm_zero
=
xmm_t
(
3
);
ymm_t
ymm_src1
=
ymm_t
(
0
);
ymm_t
ymm_src2
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
3
);
};
class
VActJitCode
:
public
JitCode
{
public:
const
char
*
name
()
const
override
{
std
::
string
base
=
"VActJitCode"
;
switch
(
type_
)
{
case
operand_type
::
relu
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
exp
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
sigmoid
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
tanh
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
identity
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
return
base
.
c_str
();
}
explicit
VActJitCode
(
int
d
,
operand_type
type
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
)
{}
static
bool
init
(
int
d
,
operand_type
type
);
void
generate
()
override
;
protected:
// compute relu with ymm, xmm
template
<
typename
JMM
>
void
relu_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
zero_idx
=
15
)
{
// NOLINT
JMM
zero
=
JMM
(
zero_idx
);
vxorps
(
zero
,
zero
,
zero
);
vmaxps
(
dst
,
src
,
zero
);
}
// compute exp with ymm, xmm
template
<
typename
JMM
>
void
exp_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
int
fx_idx
=
12
,
// NOLINT
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
tmp_idx
=
15
)
{
using
namespace
platform
;
// NOLINT
// check all idx can not equal
JMM
jmm_src
=
JMM
(
src_idx
);
JMM
jmm_fx
=
JMM
(
fx_idx
);
JMM
jmm_fy
=
JMM
(
fy_idx
);
JMM
jmm_mask
=
JMM
(
mask_idx
);
JMM
jmm_tmp
=
JMM
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
vmovaps
(
jmm_src
,
src
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
vminps
(
jmm_src
,
jmm_src
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOW
]);
vmaxps
(
jmm_src
,
jmm_src
,
jmm_tmp
);
// express exp(x) as exp(g + n*log(2))
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOG2EF
]);
vmulps
(
jmm_fx
,
jmm_src
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_0P5
]);
vaddps
(
jmm_fx
,
jmm_fx
,
jmm_tmp
);
vroundps
(
jmm_fy
,
jmm_fx
,
0x01
);
// if greater, substract 1
vcmpgtps
(
jmm_mask
,
jmm_fy
,
jmm_fx
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
]);
vandps
(
jmm_mask
,
jmm_mask
,
jmm_tmp
);
vsubps
(
jmm_fx
,
jmm_fy
,
jmm_mask
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C1
]);
vmulps
(
jmm_fy
,
jmm_fx
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C2
]);
JMM
ymm_z
=
JMM
(
jmm_mask
.
getIdx
());
vmulps
(
ymm_z
,
jmm_fx
,
jmm_tmp
);
vsubps
(
jmm_src
,
jmm_src
,
jmm_fy
);
vsubps
(
jmm_src
,
jmm_src
,
ymm_z
);
vmulps
(
ymm_z
,
jmm_src
,
jmm_src
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P0
]);
vmulps
(
dst
,
jmm_src
,
jmm_tmp
);
for
(
size_t
i
=
OFFSET_EXP_P1
;
i
<
OFFSET_EXP_P5
;
i
+=
(
YMM_FLOAT_BLOCK
*
sizeof
(
float
)))
{
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
i
]);
// P1~P4
vaddps
(
dst
,
dst
,
jmm_tmp
);
vmulps
(
dst
,
dst
,
jmm_src
);
}
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P5
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vmulps
(
dst
,
dst
,
ymm_z
);
vaddps
(
dst
,
dst
,
jmm_src
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
// build 2^n
JMM
ymm_int
=
jmm_fx
;
vcvttps2dq
(
ymm_int
,
jmm_fx
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_int_0x7f
));
vmovdqa
(
jmm_tmp
,
ptr
[
reg_ptr_global
]);
if
(
MayIUse
(
avx2
)
||
std
::
is_same
<
JMM
,
xmm_t
>::
value
)
{
vpaddd
(
ymm_int
,
ymm_int
,
jmm_tmp
);
vpslld
(
ymm_int
,
ymm_int
,
23
);
}
else
if
(
MayIUse
(
avx
))
{
xmm_t
xtmp1
=
xmm_t
(
ymm_int
.
getIdx
());
xmm_t
xtmp2
=
xmm_t
(
jmm_tmp
.
getIdx
());
reg64_t
reg_ptr_tmp
=
reg_ptr_global
;
mov
(
reg_ptr_tmp
,
reinterpret_cast
<
size_t
>
(
g_tmp_mem
));
vmovdqa
(
ptr
[
reg_ptr_tmp
],
ymm_int
);
vmovdqa
(
ptr
[
reg_ptr_tmp
+
YMM_FLOAT_BLOCK
*
sizeof
(
float
)],
jmm_tmp
);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_tmp
],
xtmp1
);
// next 128bits
vmovdqa
(
xtmp1
,
ptr
[
reg_ptr_tmp
+
XMM_FLOAT_BLOCK
*
sizeof
(
float
)]);
vmovdqa
(
xtmp2
,
ptr
[
reg_ptr_tmp
+
(
YMM_FLOAT_BLOCK
+
XMM_FLOAT_BLOCK
)
*
sizeof
(
float
)]);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_tmp
+
XMM_FLOAT_BLOCK
*
sizeof
(
float
)],
xtmp1
);
// load out
vmovdqa
(
ymm_int
,
ptr
[
reg_ptr_tmp
]);
}
vmulps
(
dst
,
dst
,
ymm_int
);
pop
(
reg_ptr_global
);
}
// compute sigmoid with ymm, xmm
template
<
typename
JMM
>
void
sigmoid_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
// NOLINT
int
fx_idx
=
12
,
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
tmp_idx
=
15
)
{
// y = 1 / (1 + e^-x)
JMM
jmm_tmp
=
JMM
(
tmp_idx
);
JMM
jmm_src
=
JMM
(
src_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
vmovaps
(
jmm_src
,
src
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MAX
]);
vminps
(
jmm_src
,
jmm_src
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MIN
]);
vmaxps
(
jmm_src
,
jmm_src
,
jmm_tmp
);
vxorps
(
jmm_tmp
,
jmm_tmp
,
jmm_tmp
);
vsubps
(
jmm_src
,
jmm_tmp
,
jmm_src
);
exp_jmm
<
JMM
>
(
dst
,
jmm_src
,
src_idx
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vdivps
(
dst
,
jmm_tmp
,
dst
);
pop
(
reg_ptr_global
);
}
// compute tanh with ymm, xmm
template
<
typename
JMM
>
void
tanh_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
// NOLINT
int
fx_idx
=
12
,
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
tmp_idx
=
15
)
{
// y = 2 / (1 + e^(-2x)) - 1
JMM
jmm_src
=
JMM
(
src_idx
);
JMM
jmm_tmp
=
JMM
(
tmp_idx
);
JMM
jmm_zero
=
JMM
(
mask_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
vmovaps
(
jmm_src
,
src
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vxorps
(
jmm_zero
,
jmm_zero
,
jmm_zero
);
vsubps
(
jmm_tmp
,
jmm_zero
,
jmm_tmp
);
vmulps
(
jmm_src
,
jmm_src
,
jmm_tmp
);
exp_jmm
<
JMM
>
(
dst
,
jmm_src
,
src_idx
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vdivps
(
dst
,
jmm_tmp
,
dst
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vsubps
(
dst
,
dst
,
jmm_tmp
);
pop
(
reg_ptr_global
);
}
template
<
typename
JMM
>
void
act
(
JMM
&
dst
,
JMM
&
src
,
operand_type
type
)
{
// NOLINT
// use 11~15
switch
(
type
)
{
case
operand_type
::
relu
:
relu_jmm
<
JMM
>
(
dst
,
src
,
15
);
break
;
case
operand_type
::
exp
:
exp_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
break
;
case
operand_type
::
sigmoid
:
sigmoid_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
break
;
case
operand_type
::
tanh
:
tanh_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
break
;
case
operand_type
::
identity
:
break
;
default:
// throw error
break
;
}
}
protected:
int
num_
;
operand_type
type_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
xmm_t
xmm_src
=
xmm_t
(
0
);
ymm_t
ymm_src
=
ymm_t
(
0
);
xmm_t
xmm_dst
=
xmm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
class
LSTMJitCode
:
public
VActJitCode
{
public:
const
char
*
name
()
const
override
{
std
::
string
base
=
"LSTMJitCode"
;
if
(
use_peephole_
)
{
base
+=
"_Peephole"
;
}
if
(
compute_c1h1_
)
{
base
+=
"_C1H1"
;
}
auto
AddTypeStr
=
[
&
](
operand_type
type
)
{
switch
(
type
)
{
case
operand_type
::
relu
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
exp
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
sigmoid
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
tanh
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
identity
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
};
AddTypeStr
(
act_gate_
);
AddTypeStr
(
act_cand_
);
AddTypeStr
(
act_cell_
);
return
base
.
c_str
();
}
explicit
LSTMJitCode
(
bool
compute_c1h1
,
const
lstm_attr_t
&
attr
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
VActJitCode
(
attr
.
d
,
operand_type
::
sigmoid
/* this is bugy*/
,
code_size
,
code_ptr
),
compute_c1h1_
(
compute_c1h1
)
{
auto
typeExchange
=
[](
const
std
::
string
&
type
)
->
gen
::
operand_type
{
if
(
type
==
"sigmoid"
)
{
return
operand_type
::
sigmoid
;
}
else
if
(
type
==
"relu"
)
{
return
operand_type
::
relu
;
}
else
if
(
type
==
"tanh"
)
{
return
operand_type
::
tanh
;
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
operand_type
::
identity
;
}
// else throw error
return
operand_type
::
identity
;
};
num_
=
attr
.
d
;
use_peephole_
=
attr
.
use_peephole
;
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
act_cell_
=
typeExchange
(
attr
.
act_cell
);
}
static
bool
init
(
int
d
);
void
generate
()
override
;
protected:
int
num_
;
bool
compute_c1h1_
;
bool
use_peephole_
;
operand_type
act_gate_
;
operand_type
act_cand_
;
operand_type
act_cell_
;
reg64_t
param1
{
abi_param1
};
};
class
GRUJitCode
:
public
VActJitCode
{
public:
const
char
*
name
()
const
override
{
std
::
string
base
=
"GRUJitCode"
;
if
(
id_
==
0
)
{
base
+=
"_H1"
;
}
else
if
(
id_
==
1
)
{
base
+=
"_HtPart1"
;
}
else
if
(
id_
==
2
)
{
base
+=
"_HtPart2"
;
}
auto
AddTypeStr
=
[
&
](
operand_type
type
)
{
switch
(
type
)
{
case
operand_type
::
relu
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
exp
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
sigmoid
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
tanh
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
identity
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
};
AddTypeStr
(
act_gate_
);
AddTypeStr
(
act_cand_
);
return
base
.
c_str
();
}
explicit
GRUJitCode
(
int
id
,
const
gru_attr_t
&
attr
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
VActJitCode
(
attr
.
d
,
operand_type
::
sigmoid
/* this is bugy*/
,
code_size
,
code_ptr
),
id_
(
id
)
{
auto
typeExchange
=
[](
const
std
::
string
&
type
)
->
gen
::
operand_type
{
if
(
type
==
"sigmoid"
)
{
return
operand_type
::
sigmoid
;
}
else
if
(
type
==
"relu"
)
{
return
operand_type
::
relu
;
}
else
if
(
type
==
"tanh"
)
{
return
operand_type
::
tanh
;
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
operand_type
::
identity
;
}
// else throw error
return
operand_type
::
identity
;
};
num_
=
attr
.
d
;
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
}
static
bool
init
(
int
d
);
void
generate
()
override
;
protected:
int
id_
;
int
num_
;
operand_type
act_gate_
;
operand_type
act_cand_
;
reg64_t
param1
{
abi_param1
};
};
#ifdef PADDLE_WITH_MKLDNN
struct
EltwiseMulnChw16cNC
:
public
Xbyak
::
CodeGenerator
{
explicit
EltwiseMulnChw16cNC
(
size_t
code_size
=
256
*
1024
)
:
Xbyak
::
CodeGenerator
(
code_size
)
{
// RDI is ptr x_input
// RSI is ptr y_input
// RDX is ptr output
// RCX is height
// r8 is width
push
(
rbx
);
xor_
(
rax
,
rax
);
xor_
(
r10
,
r10
);
vmovups
(
zmm3
,
ptr
[
rsi
]);
L
(
"h_loop"
);
xor_
(
rbx
,
rbx
);
L
(
"w_loop"
);
vmovups
(
zmm2
,
ptr
[
rdi
+
rax
]);
vmulps
(
zmm1
,
zmm2
,
zmm3
);
vmovups
(
ptr
[
rdx
+
rax
],
zmm1
);
add
(
rax
,
64
);
inc
(
rbx
);
cmp
(
r8
,
rbx
);
jnz
(
"w_loop"
);
inc
(
r10
);
cmp
(
r10
,
rcx
);
jnz
(
"h_loop"
);
pop
(
rbx
);
ret
();
}
};
#endif
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_gen.cc
已删除
100644 → 0
浏览文件 @
95fb3128
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_gen.h"
#include <fstream>
#include <iostream>
#include <sstream>
#include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool
(
dump_jitcode
,
false
,
"Whether to dump the jitcode to file"
);
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
gen
{
constexpr
Xbyak
::
Operand
::
Code
g_abi_regs
[]
=
{
Xbyak
::
Operand
::
RBX
,
Xbyak
::
Operand
::
RBP
,
Xbyak
::
Operand
::
R12
,
Xbyak
::
Operand
::
R13
,
Xbyak
::
Operand
::
R14
,
Xbyak
::
Operand
::
R15
};
constexpr
int
num_g_abi_regs
=
sizeof
(
g_abi_regs
)
/
sizeof
(
g_abi_regs
[
0
]);
void
JitCode
::
preCode
()
{
for
(
int
i
=
0
;
i
<
num_g_abi_regs
;
++
i
)
{
push
(
Xbyak
::
Reg64
(
g_abi_regs
[
i
]));
}
if
(
platform
::
MayIUse
(
platform
::
avx512f
))
{
mov
(
reg_EVEX_max_8b_offt
,
2
*
EVEX_max_8b_offt
);
}
}
void
JitCode
::
postCode
()
{
for
(
int
i
=
0
;
i
<
num_g_abi_regs
;
++
i
)
{
pop
(
Xbyak
::
Reg64
(
g_abi_regs
[
num_g_abi_regs
-
1
-
i
]));
}
ret
();
}
void
JitCode
::
dumpCode
(
const
Xbyak
::
uint8
*
code
)
const
{
if
(
code
)
{
static
int
counter
=
0
;
std
::
ostringstream
filename
;
filename
<<
"paddle_jitcode_"
<<
name
()
<<
"."
<<
counter
<<
".bin"
;
counter
++
;
std
::
ofstream
fout
(
filename
.
str
(),
std
::
ios
::
out
);
if
(
fout
.
is_open
())
{
fout
.
write
(
reinterpret_cast
<
const
char
*>
(
code
),
getSize
());
fout
.
close
();
}
}
}
Xbyak
::
Address
JitCode
::
EVEX_compress_addr
(
Xbyak
::
Reg64
base
,
int
offt
,
bool
bcast
)
{
int
scale
=
0
;
if
(
EVEX_max_8b_offt
<=
offt
&&
offt
<
3
*
EVEX_max_8b_offt
)
{
offt
=
offt
-
2
*
EVEX_max_8b_offt
;
scale
=
1
;
}
else
if
(
3
*
EVEX_max_8b_offt
<=
offt
&&
offt
<
5
*
EVEX_max_8b_offt
)
{
offt
=
offt
-
4
*
EVEX_max_8b_offt
;
scale
=
2
;
}
auto
re
=
Xbyak
::
RegExp
()
+
base
+
offt
;
if
(
scale
)
{
re
=
re
+
reg_EVEX_max_8b_offt
*
scale
;
}
if
(
bcast
)
{
return
zword_b
[
re
];
}
else
{
return
zword
[
re
];
}
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_gen.h
已删除
100644 → 0
浏览文件 @
95fb3128
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <type_traits>
#include "paddle/fluid/platform/macros.h"
#define XBYAK_USE_MMAP_ALLOCATOR
#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"
DECLARE_bool
(
dump_jitcode
);
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
gen
{
#define DECLARE_JIT_CODE(codename) \
const char *name() const override { return #codename; }
// Application Binary Interface
constexpr
Xbyak
::
Operand
::
Code
abi_param1
(
Xbyak
::
Operand
::
RDI
),
abi_param2
(
Xbyak
::
Operand
::
RSI
),
abi_param3
(
Xbyak
::
Operand
::
RDX
),
abi_param4
(
Xbyak
::
Operand
::
RCX
),
abi_not_param1
(
Xbyak
::
Operand
::
RCX
);
class
JitCode
:
public
Xbyak
::
CodeGenerator
{
public:
explicit
JitCode
(
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
Xbyak
::
CodeGenerator
(
code_size
,
code_ptr
)
{}
virtual
~
JitCode
()
{}
virtual
const
char
*
name
()
const
=
0
;
virtual
void
generate
()
=
0
;
template
<
typename
FUNC
>
const
FUNC
getCode
()
{
this
->
generate
();
const
Xbyak
::
uint8
*
code
=
CodeGenerator
::
getCode
();
if
(
FLAGS_dump_jitcode
)
{
this
->
dumpCode
(
code
);
}
return
reinterpret_cast
<
const
FUNC
>
(
code
);
}
DISABLE_COPY_AND_ASSIGN
(
JitCode
);
protected:
Xbyak
::
Reg64
param1
{
abi_param1
};
const
int
EVEX_max_8b_offt
=
0x200
;
const
Xbyak
::
Reg64
reg_EVEX_max_8b_offt
=
rbp
;
void
preCode
();
void
postCode
();
void
dumpCode
(
const
Xbyak
::
uint8
*
code
)
const
;
void
L
(
const
char
*
label
)
{
Xbyak
::
CodeGenerator
::
L
(
label
);
}
void
L
(
const
Xbyak
::
Label
&
label
)
{
Xbyak
::
CodeGenerator
::
L
(
label
);
}
// Enhanced vector extension
Xbyak
::
Address
EVEX_compress_addr
(
Xbyak
::
Reg64
base
,
int
offt
,
bool
bcast
=
false
);
};
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel.cc
已删除
100644 → 0
浏览文件 @
95fb3128
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <iostream>
#include <string>
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
KernelPool
&
KernelPool
::
Instance
()
{
static
thread_local
KernelPool
g_jit_kernels
;
return
g_jit_kernels
;
}
std
::
shared_ptr
<
const
Kernel
>
KernelPool
::
Get
(
const
std
::
string
&
key
)
const
{
if
(
kers_
.
find
(
key
)
==
kers_
.
end
())
{
return
nullptr
;
}
return
kers_
.
at
(
key
);
}
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel.h
已删除
100644 → 0
浏览文件 @
95fb3128
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <memory> // for shared_ptr
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/macros.h"
// Note: Only support on CPU yet.
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
// TODO(TJ): remove me
typedef
enum
{
kLT8
,
kEQ8
,
kGT8LT16
,
kEQ16
,
kGT16
}
jit_block
;
class
Kernel
{
public:
Kernel
()
=
default
;
virtual
~
Kernel
()
=
default
;
// TODO(TJ): below members should be deprecated.
int
num_
{
0
};
int
end_
{
0
};
int
rest_
{
0
};
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
};
class
KernelPool
{
public:
static
KernelPool
&
Instance
();
template
<
typename
Ker
,
typename
...
ARGS
>
std
::
shared_ptr
<
const
Ker
>
Get
(
ARGS
...
args
);
std
::
shared_ptr
<
const
Kernel
>
Get
(
const
std
::
string
&
key
)
const
;
private:
KernelPool
()
=
default
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
const
Kernel
>>
kers_
;
DISABLE_COPY_AND_ASSIGN
(
KernelPool
);
};
template
<
typename
T
>
class
VMulKernel
:
public
Kernel
{
public:
void
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
);
};
template
<
typename
T
>
class
VAddKernel
:
public
Kernel
{
public:
void
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
);
};
template
<
typename
T
>
class
VAddReluKernel
:
public
Kernel
{
public:
void
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
);
};
template
<
typename
T
>
class
VScalKernel
:
public
Kernel
{
public:
// y = a.*x
void
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
);
};
template
<
typename
T
>
class
VAddBiasKernel
:
public
Kernel
{
public:
// y = a.+x
void
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
);
};
#ifdef PADDLE_WITH_MKLDNN
template
<
typename
T
>
class
EltwiseMulnChw16cNCKernel
:
public
Kernel
{
public:
// nChw16c = nChw16c .* NC
void
(
*
Compute
)(
const
float
*
,
const
float
*
,
float
*
,
int
,
int
);
};
#endif
template
<
typename
T
>
class
VActKernel
:
public
Kernel
{
public:
void
(
*
Compute
)(
const
T
*
,
T
*
,
int
);
};
template
<
typename
T
>
class
VReluKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VIdentityKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VExpKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VSigmoidKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VTanhKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
LSTMKernel
:
public
Kernel
{
public:
// compute c1 and h1 without c0 or h0
void
(
*
ComputeC1H1
)(
lstm_t
*
,
const
lstm_attr_t
*
);
void
(
*
ComputeCtHt
)(
lstm_t
*
,
const
lstm_attr_t
*
);
};
template
<
typename
T
>
class
GRUKernel
:
public
Kernel
{
public:
// compute h1 without h0
void
(
*
ComputeH1
)(
gru_t
*
,
const
gru_attr_t
*
);
void
(
*
ComputeHtPart1
)(
gru_t
*
,
const
gru_attr_t
*
);
void
(
*
ComputeHtPart2
)(
gru_t
*
,
const
gru_attr_t
*
);
};
template
<
typename
T
>
class
CRFDecodeKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
seq_len
,
const
T
*
x
,
const
T
*
w
,
T
*
alpha
,
int
*
track
)
const
=
0
;
};
template
<
typename
T
>
class
LayerNormKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
T
*
x
,
T
*
out
,
T
*
mean
,
T
*
var
,
const
T
*
scale
,
const
T
*
bias
,
int
height
,
const
float
epsilon
)
const
=
0
;
};
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_blas.cc
已删除
100644 → 0
浏览文件 @
95fb3128
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
/* VMUL JitKernel */
template
<
typename
T
>
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VMulKernelImpl
(
int
d
)
:
VMulKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
// roughly estimate the size of code
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
mul
,
0
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
#ifdef PADDLE_WITH_MKLML
if
(
useMKL
(
d
))
{
this
->
Compute
=
VMulMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
refer
::
VMul
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VXXJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VMulKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VXXJitCode
::
init
(
d
);
}
#endif
#ifdef PADDLE_WITH_MKLML
template
<
>
bool
VMulKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
template
<
>
bool
VMulKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
/* VAdd JitKernel */
template
<
typename
T
>
class
VAddKernelImpl
:
public
VAddKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VAddKernelImpl
(
int
d
)
:
VAddKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
#ifdef PADDLE_WITH_MKLML
if
(
useMKL
(
d
))
{
this
->
Compute
=
VAddMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
refer
::
VAdd
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VXXJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VAddKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VXXJitCode
::
init
(
d
);
}
#endif
#ifdef PADDLE_WITH_MKLML
template
<
>
bool
VAddKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
template
<
>
bool
VAddKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
#ifdef PADDLE_WITH_MKLDNN
/* EltwiseMul for nChw16c & NC inputs JitKernel */
template
<
typename
T
>
class
EltwiseMulnChw16cNCKernelImpl
:
public
math
::
jitkernel
::
EltwiseMulnChw16cNCKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
EltwiseMulnChw16cNCKernelImpl
(
int
d
)
:
EltwiseMulnChw16cNCKernel
<
T
>
()
{
using
mul_func_t
=
void
(
*
)(
const
float
*
,
const
float
*
,
float
*
,
int
,
int
);
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
// roughly estimate the size of code
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
4
*
8
;
sz
=
sz
>
4096
?
sz
:
4096
;
jitcode_
.
reset
(
new
gen
::
EltwiseMulnChw16cNC
(
sz
));
this
->
Compute
=
(
mul_func_t
)
jitcode_
->
getCode
();
return
;
}
#endif
PADDLE_THROW
(
"This kernel shouldn't be used in Non-Xbyak, Non-MKL-DNN "
"environemnt"
);
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
EltwiseMulnChw16cNC
>
jitcode_
{
nullptr
};
};
template
<
>
bool
EltwiseMulnChw16cNCKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
true
;
}
#endif
#endif
/* VAddRelu JitKernel */
template
<
typename
T
>
class
VAddReluKernelImpl
:
public
VAddReluKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VAddReluKernelImpl
(
int
d
)
:
VAddReluKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
true
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
this
->
Compute
=
refer
::
VAddRelu
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VXXJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VAddReluKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VXXJitCode
::
init
(
d
);
}
#endif
/* VScal JitKernel */
template
<
typename
T
>
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VScalKernelImpl
(
int
d
)
:
VScalKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
mul
,
1
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
#ifdef PADDLE_WITH_MKLML
if
(
useMKL
(
d
))
{
this
->
Compute
=
VScalMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
refer
::
VScal
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VXXJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VScalKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VXXJitCode
::
init
(
d
,
1
);
}
#endif
#ifdef PADDLE_WITH_MKLML
template
<
>
bool
VScalKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
template
<
>
bool
VScalKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
/* VAddBias JitKernel */
template
<
typename
T
>
class
VAddBiasKernelImpl
:
public
VAddBiasKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VAddBiasKernelImpl
(
int
d
)
:
VAddBiasKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
1
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
this
->
Compute
=
refer
::
VAddBias
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VXXJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VAddBiasKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VXXJitCode
::
init
(
d
,
1
);
}
#endif
/* VRelu JitKernel */
template
<
typename
T
>
class
VReluKernelImpl
:
public
VReluKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VReluKernelImpl
(
int
d
)
:
VReluKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
/* init size */
+
d
/
YMM_FLOAT_BLOCK
*
4
/* instructions */
*
8
/* average bytes for each instruction */
;
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
relu
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
this
->
Compute
=
refer
::
VRelu
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VActJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VReluKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VActJitCode
::
init
(
d
,
gen
::
operand_type
::
relu
);
}
#endif
/* An empty JitKernel */
template
<
typename
T
>
class
VIdentityKernelImpl
:
public
VIdentityKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VIdentityKernelImpl
(
int
d
)
:
VIdentityKernel
<
T
>
()
{
this
->
Compute
=
refer
::
VIdentity
<
T
>
;
}
};
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL
(
vaddrelu
,
VAddReluKernel
);
REGISTER_JITKERNEL
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL
(
vaddbias
,
VAddBiasKernel
);
REGISTER_JITKERNEL
(
vrelu
,
VReluKernel
);
REGISTER_JITKERNEL
(
videntity
,
VIdentityKernel
);
#ifdef PADDLE_WITH_MKLDNN
REGISTER_JITKERNEL
(
eltwise_mul_nchw16c
,
EltwiseMulnChw16cNCKernel
);
#endif
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
已删除
100644 → 0
浏览文件 @
95fb3128
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <limits>
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
/* CRF Decode JitKernel */
template
<
typename
T
,
platform
::
cpu_isa_t
isa
,
jit_block
>
class
CRFDecodeKernelImpl
:
public
CRFDecodeKernel
<
T
>
{
public:
explicit
CRFDecodeKernelImpl
(
int
tag_num
)
:
CRFDecodeKernel
<
T
>
()
{
this
->
num_
=
tag_num
;
}
void
Compute
(
const
int
seq_len
,
const
T
*
x
,
const
T
*
w
,
T
*
alpha
,
int
*
track
)
const
override
{
constexpr
int
state_trans_base_idx
=
2
;
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
alpha
[
i
]
=
w
[
i
]
+
x
[
i
];
}
for
(
int
k
=
1
;
k
<
seq_len
;
++
k
)
{
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
T
max_score
=
-
std
::
numeric_limits
<
T
>::
max
();
int
max_j
=
0
;
for
(
int
j
=
0
;
j
<
this
->
num_
;
++
j
)
{
T
score
=
alpha
[(
k
-
1
)
*
this
->
num_
+
j
]
+
w
[(
j
+
state_trans_base_idx
)
*
this
->
num_
+
i
];
if
(
score
>
max_score
)
{
max_score
=
score
;
max_j
=
j
;
}
}
alpha
[
k
*
this
->
num_
+
i
]
=
max_score
+
x
[
k
*
this
->
num_
+
i
];
track
[
k
*
this
->
num_
+
i
]
=
max_j
;
}
}
}
};
#define INIT_ALPHA(step_size) \
/* Setup the alpha initial value.*/
\
int i_offset = 0; \
int last_offset = this->rest_ - step_size; \
for (int i = 0; i <= this->end_; ++i) { \
/* weights, input and alpha values. */
\
__m256 w_content, x_content, alpha_content; \
/* Load the relevant data into the variables from un-aligned address.*/
\
w_content = _mm256_loadu_ps(w + i_offset); \
x_content = _mm256_loadu_ps(x + i_offset); \
alpha_content = _mm256_add_ps(w_content, x_content); \
_mm256_storeu_ps(alpha + i_offset, alpha_content); \
i_offset += step_size; \
if (i == this->end_ - 1) { \
if (this->rest_ > 0) { \
i_offset += last_offset; \
} else { \
break; \
} \
} \
}
#define UPDATE_ALPHA(step_size) \
/* Update the alpha and track values. */
\
__m256 x_content = _mm256_loadu_ps(x + seq_offset + this->num_ + j_offset); \
max_score = _mm256_add_ps(max_score, x_content); \
_mm256_storeu_ps(alpha + seq_offset + this->num_ + j_offset, max_score); \
_mm256_storeu_si256( \
reinterpret_cast<__m256i*>(track + seq_offset + this->num_ + j_offset), \
max_j); \
/* Calculate the offset of next step*/
\
j_offset += step_size; \
if (j == this->end_ - 1) { \
if (this->rest_ > 0) { \
j_offset += last_offset; \
} else { \
break; \
} \
}
#define INTRIAVX_FLOAT(block) \
template <> \
CRFDecodeKernelImpl<float, platform::avx, block>::CRFDecodeKernelImpl( \
int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ / YMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, platform::avx, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(YMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
for (int k = 1; k < seq_len; ++k) { \
int j_offset = 0; \
for (int j = 0; j <= this->end_; ++j) { \
/* Initialize the variables of maximum score and location.*/
\
__m256 max_score = _mm256_set1_ps(-std::numeric_limits<float>::max()); \
__m256i max_j = _mm256_set1_epi32(0); \
/* Calculate the offset of transition_weights.*/
\
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
for (int i = 0; i < this->num_; ++i) { \
/* Initalize the content of alpha variable with related offset.*/
\
__m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i); \
/* Obtain the content of weights from un-aligned address.*/
\
__m256 w_content = _mm256_loadu_ps(w + trans_offset); \
__m256 score_v = _mm256_add_ps(alpha_content, w_content); \
__m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS); \
/* According to the mask value, update the index of the max_score.*/
\
/* AVX instructions.*/
\
__m128i lo_max_j = _mm256_extractf128_si256(max_j, 0); \
__m128i hi_max_j = _mm256_extractf128_si256(max_j, 1); \
__m128i lo_mask = _mm256_extractf128_si256(*(__m256i*)&mask, 0); \
__m128i hi_mask = _mm256_extractf128_si256(*(__m256i*)&mask, 1); \
lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j); \
hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j); \
lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i)); \
hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i)); \
lo_max_j = _mm_or_si128(lo_mask, lo_max_j); \
hi_max_j = _mm_or_si128(hi_mask, hi_max_j); \
max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0); \
max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1); \
/* AVX done*/
\
/* Update the max_score value.*/
\
max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \
} \
UPDATE_ALPHA(YMM_FLOAT_BLOCK) \
} \
seq_offset += this->num_; \
} \
}
#define INTRIAVX2_FLOAT(isa, block) \
template <> \
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ / YMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, isa, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(YMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
for (int k = 1; k < seq_len; ++k) { \
int j_offset = 0; \
for (int j = 0; j <= this->end_; ++j) { \
/* Initialize the variables of maximum score and location.*/
\
__m256 max_score = _mm256_set1_ps(-std::numeric_limits<float>::max()); \
__m256i max_j = _mm256_set1_epi32(0); \
/* Calculate the offset of transition_weights.*/
\
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
for (int i = 0; i < this->num_; ++i) { \
/* Initalize the content of alpha variable with related offset.*/
\
__m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i); \
/* Obtain the content of weights from un-aligned address.*/
\
__m256 w_content = _mm256_loadu_ps(w + trans_offset); \
__m256 score_v = _mm256_add_ps(alpha_content, w_content); \
__m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS); \
/* According to the mask value, update the index of the max_score.*/
\
/* AVX2 instructions.*/
\
max_j = _mm256_or_si256( \
_mm256_andnot_si256((__m256i)mask, max_j), \
_mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i))); \
/* Update the max_score value.*/
\
max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \
} \
UPDATE_ALPHA(YMM_FLOAT_BLOCK) \
} \
seq_offset += this->num_; \
} \
}
#define INTRIAVX512_FLOAT(block) \
template <> \
CRFDecodeKernelImpl<float, platform::avx512f, block>::CRFDecodeKernelImpl( \
int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ / ZMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % ZMM_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, platform::avx512f, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(ZMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
for (int k = 1; k < seq_len; ++k) { \
int j_offset = 0; \
for (int j = 0; j <= this->end_; ++j) { \
/* Initialize the variables of maximum score and location.*/
\
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<float>::max()); \
__m512i max_j = _mm512_setzero_si512(); \
/* Calculate the offset of transition_weights.*/
\
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
for (int i = 0; i < this->num_; ++i) { \
/* Initalize the content of alpha variable with related offset.*/
\
__m512 alpha_content = _mm512_set1_ps(*(alpha + seq_offset + i)); \
/* Obtain the content of weights from un-aligned address.*/
\
__m512 w_content = _mm512_loadu_ps(w + trans_offset); \
__m512 score_v = _mm512_add_ps(alpha_content, w_content); \
__mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS); \
/* AVX512 instructions.*/
\
max_j = _mm512_mask_set1_epi32(max_j, mask, i); \
/* Update the max_score value.*/
\
max_score = _mm512_max_ps(max_score, score_v); \
trans_offset += this->num_; \
} \
/* Update the alpha and track values.*/
\
__m512 x_content = \
_mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); \
max_score = _mm512_add_ps(max_score, x_content); \
_mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, \
max_score); \
_mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + \
this->num_ + j_offset), \
max_j); \
/* Calculate the offset of next step*/
\
j_offset += ZMM_FLOAT_BLOCK; \
if (j == this->end_ - 1) { \
if (this->rest_ > 0) { \
j_offset += last_offset; \
} else { \
break; \
} \
} \
} \
seq_offset += this->num_; \
} \
}
#ifdef __AVX__
INTRIAVX_FLOAT
(
kEQ8
);
INTRIAVX_FLOAT
(
kGT8LT16
);
INTRIAVX_FLOAT
(
kEQ16
);
INTRIAVX_FLOAT
(
kGT16
);
#endif
#ifdef __AVX2__
INTRIAVX2_FLOAT
(
platform
::
avx2
,
kEQ8
);
INTRIAVX2_FLOAT
(
platform
::
avx2
,
kGT8LT16
);
INTRIAVX2_FLOAT
(
platform
::
avx2
,
kEQ16
);
INTRIAVX2_FLOAT
(
platform
::
avx2
,
kGT16
);
#endif
#ifdef __AVX512F__
INTRIAVX2_FLOAT
(
platform
::
avx512f
,
kEQ8
);
INTRIAVX2_FLOAT
(
platform
::
avx512f
,
kGT8LT16
);
INTRIAVX512_FLOAT
(
kEQ16
);
INTRIAVX512_FLOAT
(
kGT16
);
#endif
#undef INTRIAVX512_FLOAT
#undef INTRIAVX2_FLOAT
#undef INTRIAVX_FLOAT
#undef INIT_ALPHA
#undef UPDATE_ALPHA
REGISTER_JITKERNEL_DEPRECATED
(
crf_decode
,
CRFDecodeKernel
);
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_exp.cc
已删除
100644 → 0
浏览文件 @
95fb3128
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
/* VExp JitKernel */
template
<
typename
T
>
class
VExpKernelImpl
:
public
VExpKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VExpKernelImpl
(
int
d
)
:
VExpKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
70
*
8
;
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
exp
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
#ifdef PADDLE_WITH_MKLML
if
(
useMKL
(
d
))
{
this
->
Compute
=
VExpMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
refer
::
VExp
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VActJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VExpKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VActJitCode
::
init
(
d
,
gen
::
operand_type
::
exp
);
}
#endif
#ifdef PADDLE_WITH_MKLML
template
<
>
bool
VExpKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
template
<
>
bool
VExpKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
/* VSigmoid JitKernel */
template
<
typename
T
>
class
VSigmoidKernelImpl
:
public
VSigmoidKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VSigmoidKernelImpl
(
int
d
)
:
VSigmoidKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
82
*
8
;
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
sigmoid
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
#ifdef PADDLE_WITH_MKLML
// strictly it's a better impl with MKL, then is refer
if
(
useMKL
(
d
))
{
this
->
Compute
=
VSigmoidMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
refer
::
VSigmoid
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VActJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VSigmoidKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VActJitCode
::
init
(
d
,
gen
::
operand_type
::
sigmoid
);
}
#endif
#ifdef PADDLE_WITH_MKLML
template
<
>
bool
VSigmoidKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
template
<
>
bool
VSigmoidKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
/* VTanh JitKernel */
template
<
typename
T
>
class
VTanhKernelImpl
:
public
VTanhKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VTanhKernelImpl
(
int
d
)
:
VTanhKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
84
*
8
;
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
tanh
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
#ifdef PADDLE_WITH_MKLML
// strictly it's a better impl with MKL, then is refer
if
(
useMKL
(
d
))
{
this
->
Compute
=
VTanhMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
refer
::
VTanh
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VActJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VTanhKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VActJitCode
::
init
(
d
,
gen
::
operand_type
::
tanh
);
}
#endif
#ifdef PADDLE_WITH_MKLML
template
<
>
bool
VTanhKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
template
<
>
bool
VTanhKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
REGISTER_JITKERNEL
(
vexp
,
VExpKernel
);
REGISTER_JITKERNEL
(
vsigmoid
,
VSigmoidKernel
);
REGISTER_JITKERNEL
(
vtanh
,
VTanhKernel
);
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_impl.h
已删除
100644 → 0
浏览文件 @
95fb3128
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <type_traits>
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_layer_norm.cc
已删除
100644 → 0
浏览文件 @
95fb3128
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <math.h>
#include <limits>
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
/* Layer Norm JitKernel */
template
<
typename
T
,
platform
::
cpu_isa_t
isa
,
jit_block
>
class
LayerNormKernelImpl
:
public
LayerNormKernel
<
T
>
{
public:
explicit
LayerNormKernelImpl
(
int
right
)
:
LayerNormKernel
<
T
>
()
{
this
->
num_
=
right
;
}
void
Compute
(
T
*
x
,
T
*
out
,
T
*
mean
,
T
*
var
,
const
T
*
scale
,
const
T
*
bias
,
int
height
,
const
float
epsilon
)
const
override
{
// get mean
for
(
int
i
=
0
;
i
<
height
;
i
++
)
{
T
sum
=
0.0
;
int
offset
=
i
*
this
->
num_
;
for
(
int
j
=
0
;
j
<
this
->
num_
;
j
++
)
{
sum
+=
x
[
offset
+
j
];
}
mean
[
i
]
=
sum
/
this
->
num_
;
}
// get variance
for
(
int
i
=
0
;
i
<
height
;
i
++
)
{
T
sum
=
0.0
;
int
offset
=
i
*
this
->
num_
;
for
(
int
j
=
0
;
j
<
this
->
num_
;
j
++
)
{
sum
+=
(
x
[
offset
+
j
]
-
mean
[
i
])
*
(
x
[
offset
+
j
]
-
mean
[
i
]);
}
var
[
i
]
=
sum
/
this
->
num_
;
}
for
(
int
i
=
0
;
i
<
height
;
i
++
)
{
int
offset
=
i
*
this
->
num_
;
T
sqrt_var
=
sqrt
(
var
[
i
]
+
(
T
)
epsilon
);
for
(
int
j
=
0
;
j
<
this
->
num_
;
j
++
)
{
out
[
offset
+
j
]
=
(
x
[
offset
+
j
]
-
mean
[
i
])
/
sqrt_var
;
}
}
if
(
scale
)
{
for
(
int
i
=
0
;
i
<
height
;
i
++
)
{
int
offset
=
i
*
this
->
num_
;
for
(
int
j
=
0
;
j
<
this
->
num_
;
j
++
)
{
out
[
offset
+
j
]
*=
scale
[
j
];
}
}
}
if
(
bias
)
{
for
(
int
i
=
0
;
i
<
height
;
i
++
)
{
int
offset
=
i
*
this
->
num_
;
for
(
int
j
=
0
;
j
<
this
->
num_
;
j
++
)
{
out
[
offset
+
j
]
+=
bias
[
j
];
}
}
}
}
};
#define INTRIAVX_FLOAT(isa, jit_block) \
template <> \
LayerNormKernelImpl<float, isa, jit_block>::LayerNormKernelImpl(int right) \
: LayerNormKernel<float>() { \
this->num_ = right; \
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
this->end_ = this->num_ - this->rest_; \
} \
template <> \
void LayerNormKernelImpl<float, isa, jit_block>::Compute( \
float* x, float* out, float* mean, float* var, const float* scale, \
const float* bias, int height, const float epsilon) const { \
__m256 sum; \
__m256 mean_vec, var_vec; \
__m128 hi, lo; \
__m256 tmp; \
size_t offset; \
size_t j; \
size_t block = YMM_FLOAT_BLOCK; \
__m256 reverse_num_vec = \
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \
__m256 epsilon_vec = _mm256_set1_ps(epsilon); \
int rest_mask = \
((-1) & (~((~0U) >> (sizeof(int) * 8 - (YMM_FLOAT_BLOCK - rest_))))) & \
0x0ff; \
__m256i mask_vec = _mm256_set_epi32( \
rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0, \
rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0, \
rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0, \
rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0); \
\
for (int i = 0; i < height; ++i) { \
offset = i * this->num_; \
\
/* get mean */
\
sum = _mm256_setzero_ps(); \
for (j = offset; j < end_ + offset; j += block) { \
sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j)); \
} \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
tmp = _mm256_loadu_ps((const float*)x + j); \
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, *(__m256*)&mask_vec); \
sum = _mm256_add_ps(sum, tmp); \
} \
hi = _mm256_extractf128_ps(sum, 1); \
lo = _mm256_extractf128_ps(sum, 0); \
sum = _mm256_add_ps( \
sum, _mm256_insertf128_ps( \
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \
sum = _mm256_hadd_ps(sum, sum); \
sum = _mm256_hadd_ps(sum, sum); \
mean_vec = _mm256_mul_ps(sum, reverse_num_vec); \
mean[i] = *reinterpret_cast<float*>(&mean_vec); \
\
/* get variance */
\
sum = _mm256_setzero_ps(); \
for (j = offset; j < end_ + offset; j += block) { \
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
tmp = _mm256_mul_ps(tmp, tmp); \
sum = _mm256_add_ps(sum, tmp); \
} \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
tmp = _mm256_mul_ps(tmp, tmp); \
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, *(__m256*)&mask_vec); \
sum = _mm256_add_ps(sum, tmp); \
} \
hi = _mm256_extractf128_ps(sum, 1); \
lo = _mm256_extractf128_ps(sum, 0); \
sum = _mm256_add_ps( \
sum, _mm256_insertf128_ps( \
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \
sum = _mm256_hadd_ps(sum, sum); \
sum = _mm256_hadd_ps(sum, sum); \
var_vec = _mm256_mul_ps(sum, reverse_num_vec); \
var[i] = *reinterpret_cast<float*>(&var_vec); \
\
/* get x_norm and calculate output*/
\
for (j = offset; j < end_ + offset; j += block) { \
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
tmp = _mm256_div_ps( \
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp); \
} \
if (rest_ != 0) { \
j = offset + num_ - block; \
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
tmp = _mm256_div_ps( \
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp); \
} \
\
if (scale) { \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
tmp = _mm256_loadu_ps((const float*)out + j); \
} \
for (j = offset; j < end_ + offset; j += block) { \
_mm256_storeu_ps( \
reinterpret_cast<float*>(out) + j, \
_mm256_mul_ps( \
_mm256_loadu_ps((const float*)out + j), \
_mm256_loadu_ps((const float*)scale + j - offset))); \
} \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
_mm256_storeu_ps( \
reinterpret_cast<float*>(out) + j, \
_mm256_mul_ps( \
tmp, _mm256_loadu_ps((const float*)scale + j - offset))); \
} \
} \
\
if (bias) { \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
tmp = _mm256_loadu_ps((const float*)out + j); \
} \
for (j = offset; j < end_ + offset; j += block) { \
_mm256_storeu_ps( \
reinterpret_cast<float*>(out) + j, \
_mm256_add_ps( \
_mm256_loadu_ps((const float*)out + j), \
_mm256_loadu_ps((const float*)bias + j - offset))); \
} \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
_mm256_storeu_ps( \
reinterpret_cast<float*>(out) + j, \
_mm256_add_ps( \
tmp, _mm256_loadu_ps((const float*)bias + j - offset))); \
} \
} \
} \
}
#ifdef __AVX__
INTRIAVX_FLOAT
(
platform
::
avx
,
kEQ8
);
INTRIAVX_FLOAT
(
platform
::
avx
,
kGT8LT16
);
INTRIAVX_FLOAT
(
platform
::
avx
,
kEQ16
);
INTRIAVX_FLOAT
(
platform
::
avx
,
kGT16
);
INTRIAVX_FLOAT
(
platform
::
avx2
,
kEQ8
);
INTRIAVX_FLOAT
(
platform
::
avx2
,
kGT8LT16
);
INTRIAVX_FLOAT
(
platform
::
avx2
,
kEQ16
);
INTRIAVX_FLOAT
(
platform
::
avx2
,
kGT16
);
INTRIAVX_FLOAT
(
platform
::
avx512f
,
kEQ8
);
INTRIAVX_FLOAT
(
platform
::
avx512f
,
kGT8LT16
);
INTRIAVX_FLOAT
(
platform
::
avx512f
,
kEQ16
);
INTRIAVX_FLOAT
(
platform
::
avx512f
,
kGT16
);
#endif
#undef INTRIAVX_FLOAT
REGISTER_JITKERNEL_DEPRECATED
(
layer_norm
,
LayerNormKernel
);
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_macro.h
已删除
100644 → 0
浏览文件 @
95fb3128
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
#define JITKERNEL_DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(int d) { \
std::string key(#ker_key "f"); \
if (useJIT(d)) { \
/* only jit code need record d*/
\
return key + "jit" + std::to_string(d); \
} else if (useMKL(d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
} \
template <> \
std::string ker_class##Impl<double>::name(int d) { \
std::string key(#ker_key "d"); \
/* jit code do not support double yet*/
\
if (useMKL(d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
}
#define JITKERNEL_DECLARE(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, int>(int d)
#define JITKERNEL_FIND_KEY(ker_class, ker_dtype) \
std::string key = ker_class##Impl<ker_dtype>::name(d)
#define JITKERNEL_IMPL(ker_class, ker_dtype) \
p = std::dynamic_pointer_cast<ker_class<ker_dtype>>( \
std::make_shared<ker_class##Impl<ker_dtype>>(d))
#define REGISTER_JITKERNEL_WITH_DTYPE(ker_class, ker_dtype, marco_declare, \
macro_find_key, macro_impl) \
marco_declare(ker_class, ker_dtype) { \
macro_find_key(ker_class, ker_dtype); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
macro_impl(ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
}
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \
marco_declare, macro_find_key, macro_impl) \
marco_define_name(ker_key, ker_class); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, marco_declare, \
macro_find_key, macro_impl); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, marco_declare, \
macro_find_key, macro_impl)
#define REGISTER_JITKERNEL(ker_key, ker_class) \
REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \
JITKERNEL_DECLARE, JITKERNEL_FIND_KEY, \
JITKERNEL_IMPL)
// TODO(TJ): below defines are deprecated, would be remove recently
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d < YMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kLT8); \
} else if (d == YMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ8); \
} else if (d > YMM_FLOAT_BLOCK && d < ZMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kGT8LT16); \
} else if (d == ZMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ16); \
} else { \
macro_(ker, dtype, isa, kGT16); \
}
#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \
if (platform::MayIUse(platform::avx512f)) { \
SEARCH_BLOCK(macro_, ker, dtype, platform::avx512f); \
} else if (platform::MayIUse(platform::avx2)) { \
SEARCH_BLOCK(macro_, ker, dtype, platform::avx2); \
} else if (platform::MayIUse(platform::avx)) { \
SEARCH_BLOCK(macro_, ker, dtype, platform::avx); \
} else { \
SEARCH_BLOCK(macro_, ker, dtype, platform::isa_any); \
}
#define JITKERNEL_KEY(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d)
#define JITKERNEL_NEW_IMPL_DEPRECATED(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(d))
#define JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, ker_dtype, \
dtype_key, marco_declare, macro_key, \
macro_impl) \
marco_declare(ker_class, ker_dtype) { \
std::string key = macro_key(ker_key, dtype_key); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
SEARCH_ISA_BLOCK(macro_impl, ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
}
#define REGISTER_JITKERNEL_DEPRECATED(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, \
JITKERNEL_DECLARE, JITKERNEL_KEY, \
JITKERNEL_NEW_IMPL_DEPRECATED); \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \
JITKERNEL_DECLARE, JITKERNEL_KEY, \
JITKERNEL_NEW_IMPL_DEPRECATED)
#define REGISTER_JITKERNEL_ARGS_DEPRECATED(ker_key, ker_class, marco_declare, \
macro_key, macro_impl) \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, marco_declare, \
macro_key, macro_impl); \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \
marco_declare, macro_key, macro_impl)
#define FOR_EACH_ISA(macro_, block) \
macro_(platform::avx512f, block); \
macro_(platform::avx2, block); \
macro_(platform::avx, block); \
macro_(platform::isa_any, block)
#define FOR_EACH_BLOCK(macro_, isa) \
macro_(isa, kLT8); \
macro_(isa, kEQ8); \
macro_(isa, kGT8LT16); \
macro_(isa, kEQ16); \
macro_(isa, kGT16)
#define FOR_EACH_ISA_BLOCK(macro_) \
FOR_EACH_BLOCK(macro_, platform::avx512f); \
FOR_EACH_BLOCK(macro_, platform::avx2); \
FOR_EACH_BLOCK(macro_, platform::avx); \
FOR_EACH_BLOCK(macro_, platform::isa_any)
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_refer.h
已删除
100644 → 0
浏览文件 @
95fb3128
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
refer
{}
// namespace refer
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_rnn.cc
已删除
100644 → 0
浏览文件 @
95fb3128
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
/* LSTM JitKernel */
template
<
typename
T
>
class
LSTMKernelImpl
:
public
LSTMKernel
<
T
>
{
public:
static
inline
std
::
string
name
(
const
lstm_attr_t
&
attr
)
{
PADDLE_THROW
(
"DType should be either float or double"
);
}
static
inline
bool
useJIT
(
int
d
)
{
return
false
;
}
static
inline
bool
useMKL
(
int
d
)
{
return
false
;
}
explicit
LSTMKernelImpl
(
const
lstm_attr_t
&
attr
)
:
LSTMKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
attr
.
d
))
{
size_t
sz
=
96
+
attr
.
d
/
YMM_FLOAT_BLOCK
*
90
*
4
*
8
;
jitcode0_
.
reset
(
new
gen
::
LSTMJitCode
(
false
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeCtHt
=
jitcode0_
->
getCode
<
void
(
*
)(
lstm_t
*
,
const
lstm_attr_t
*
)
>
();
jitcode1_
.
reset
(
new
gen
::
LSTMJitCode
(
true
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeC1H1
=
jitcode1_
->
getCode
<
void
(
*
)(
lstm_t
*
,
const
lstm_attr_t
*
)
>
();
return
;
}
#endif
this
->
ComputeCtHt
=
refer
::
LSTMCtHt
<
T
>
;
this
->
ComputeC1H1
=
refer
::
LSTMC1H1
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
LSTMJitCode
>
jitcode0_
{
nullptr
},
jitcode1_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
LSTMKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
LSTMJitCode
::
init
(
d
);
}
#endif
/* Peephole JitKernel */
template
<
typename
T
>
class
PeepholeKernelImpl
:
public
LSTMKernel
<
T
>
{
public:
static
inline
std
::
string
name
(
const
lstm_attr_t
&
attr
)
{
PADDLE_THROW
(
"DType should be either float or double"
);
}
static
inline
bool
useJIT
(
int
d
)
{
return
false
;
}
static
inline
bool
useMKL
(
int
d
)
{
return
false
;
}
explicit
PeepholeKernelImpl
(
const
lstm_attr_t
&
attr
)
:
LSTMKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
attr
.
d
))
{
size_t
sz
=
96
+
attr
.
d
/
YMM_FLOAT_BLOCK
*
96
*
4
*
8
;
jitcode0_
.
reset
(
new
gen
::
LSTMJitCode
(
false
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeCtHt
=
jitcode0_
->
getCode
<
void
(
*
)(
lstm_t
*
,
const
lstm_attr_t
*
)
>
();
jitcode1_
.
reset
(
new
gen
::
LSTMJitCode
(
true
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeC1H1
=
jitcode1_
->
getCode
<
void
(
*
)(
lstm_t
*
,
const
lstm_attr_t
*
)
>
();
return
;
}
#endif
this
->
ComputeCtHt
=
refer
::
LSTMCtHt
<
T
>
;
this
->
ComputeC1H1
=
refer
::
LSTMC1H1
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
LSTMJitCode
>
jitcode0_
{
nullptr
},
jitcode1_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
PeepholeKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
LSTMJitCode
::
init
(
d
);
}
#endif
#define JITKERNEL_DEFINE_NAME_LSTM(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(const lstm_attr_t& attr) { \
std::string key(#ker_key "f"); \
key += (attr.act_gate + attr.act_cand + attr.act_cell + \
(attr.use_peephole ? "p" : "n")); \
if (useJIT(attr.d)) { \
/* only jit code need record d*/
\
return key + "jit" + std::to_string(attr.d); \
} else if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
} \
template <> \
std::string ker_class##Impl<double>::name(const lstm_attr_t& attr) { \
std::string key(#ker_key "d"); \
/* jit code do not support double yet*/
\
if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
}
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const LSTMKernel<ker_dtype>> \
KernelPool::Get<LSTMKernel<ker_dtype>, const lstm_attr_t&>( \
const lstm_attr_t& attr)
#define JITKERNEL_FIND_KEY_LSTM(ker_class, ker_dtype) \
std::string key = ker_class##Impl<ker_dtype>::name(attr)
#define JITKERNEL_LSTM_IMPL(ker, dtype) \
if (attr.use_peephole) { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<PeepholeKernelImpl<dtype>>(attr)); \
} else { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype>>(attr)); \
}
REGISTER_JITKERNEL_ARGS
(
lstm
,
LSTMKernel
,
JITKERNEL_DEFINE_NAME_LSTM
,
JITKERNEL_DECLARE_LSTM
,
JITKERNEL_FIND_KEY_LSTM
,
JITKERNEL_LSTM_IMPL
);
#undef JITKERNEL_LSTM_IMPL
#undef JITKERNEL_FIND_KEY_LSTM
#undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_DEFINE_NAME_LSTM
/* GRU JitKernel */
template
<
typename
T
>
class
GRUKernelImpl
:
public
GRUKernel
<
T
>
{
public:
static
inline
std
::
string
name
(
const
gru_attr_t
&
attr
)
{
PADDLE_THROW
(
"DType should be either float or double"
);
}
static
inline
bool
useJIT
(
int
d
)
{
return
false
;
}
static
inline
bool
useMKL
(
int
d
)
{
return
false
;
}
explicit
GRUKernelImpl
(
const
gru_attr_t
&
attr
)
:
GRUKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
attr
.
d
))
{
size_t
sz
=
96
+
attr
.
d
/
YMM_FLOAT_BLOCK
*
96
*
2
*
8
;
jitcode0_
.
reset
(
new
gen
::
GRUJitCode
(
0
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeH1
=
jitcode0_
->
getCode
<
void
(
*
)(
gru_t
*
,
const
gru_attr_t
*
)
>
();
jitcode1_
.
reset
(
new
gen
::
GRUJitCode
(
1
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeHtPart1
=
jitcode1_
->
getCode
<
void
(
*
)(
gru_t
*
,
const
gru_attr_t
*
)
>
();
jitcode2_
.
reset
(
new
gen
::
GRUJitCode
(
2
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeHtPart2
=
jitcode2_
->
getCode
<
void
(
*
)(
gru_t
*
,
const
gru_attr_t
*
)
>
();
return
;
}
#endif
this
->
ComputeH1
=
refer
::
GRUH1
<
T
>
;
this
->
ComputeHtPart1
=
refer
::
GRUHtPart1
<
T
>
;
this
->
ComputeHtPart2
=
refer
::
GRUHtPart2
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
GRUJitCode
>
jitcode0_
{
nullptr
},
jitcode1_
{
nullptr
},
jitcode2_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
GRUKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
GRUJitCode
::
init
(
d
);
}
#endif
#define JITKERNEL_DEFINE_NAME_GRU(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(const gru_attr_t& attr) { \
std::string key(#ker_key "f"); \
key += (attr.act_gate + attr.act_cand); \
if (useJIT(attr.d)) { \
/* only jit code need record d*/
\
return key + "jit" + std::to_string(attr.d); \
} else if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
} \
template <> \
std::string ker_class##Impl<double>::name(const gru_attr_t& attr) { \
std::string key(#ker_key "d"); \
/* jit code do not support double yet*/
\
if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
}
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, const gru_attr_t&>( \
const gru_attr_t& attr)
#define JITKERNEL_FIND_KEY_GRU(ker_class, ker_dtype) \
std::string key = ker_class##Impl<ker_dtype>::name(attr)
#define JITKERNEL_GRU_IMPL(ker, dtype) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype>>(attr));
REGISTER_JITKERNEL_ARGS
(
gru
,
GRUKernel
,
JITKERNEL_DEFINE_NAME_GRU
,
JITKERNEL_DECLARE_GRU
,
JITKERNEL_FIND_KEY_GRU
,
JITKERNEL_GRU_IMPL
);
#undef JITKERNEL_GRU_IMPL
#undef JITKERNEL_FIND_KEY_GRU
#undef JITKERNEL_DECLARE_GRU
#undef JITKERNEL_DEFINE_NAME_GRU
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_test.cc
已删除
100644 → 0
浏览文件 @
95fb3128
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录