Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
693e5e65
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
693e5e65
编写于
12月 21, 2018
作者:
T
tensor-tang
提交者:
GitHub
12月 21, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14958 from tensor-tang/refine/jit
enhance jit
上级
95cbe07c
1aaec571
变更
61
隐藏空白更改
内联
并排
Showing
61 changed file
with
4353 addition
and
3541 deletion
+4353
-3541
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+2
-1
paddle/fluid/operators/crf_decoding_op.h
paddle/fluid/operators/crf_decoding_op.h
+4
-5
paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc
.../fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc
+4
-6
paddle/fluid/operators/fused/fusion_gru_op.cc
paddle/fluid/operators/fused/fusion_gru_op.cc
+30
-28
paddle/fluid/operators/fused/fusion_lstm_op.cc
paddle/fluid/operators/fused/fusion_lstm_op.cc
+31
-30
paddle/fluid/operators/jit/CMakeLists.txt
paddle/fluid/operators/jit/CMakeLists.txt
+25
-0
paddle/fluid/operators/jit/README.md
paddle/fluid/operators/jit/README.md
+66
-0
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+231
-0
paddle/fluid/operators/jit/gen/CMakeLists.txt
paddle/fluid/operators/jit/gen/CMakeLists.txt
+28
-0
paddle/fluid/operators/jit/gen/act.cc
paddle/fluid/operators/jit/gen/act.cc
+135
-0
paddle/fluid/operators/jit/gen/act.h
paddle/fluid/operators/jit/gen/act.h
+90
-303
paddle/fluid/operators/jit/gen/blas.cc
paddle/fluid/operators/jit/gen/blas.cc
+186
-0
paddle/fluid/operators/jit/gen/blas.h
paddle/fluid/operators/jit/gen/blas.h
+117
-0
paddle/fluid/operators/jit/gen/gru.cc
paddle/fluid/operators/jit/gen/gru.cc
+116
-0
paddle/fluid/operators/jit/gen/gru.h
paddle/fluid/operators/jit/gen/gru.h
+113
-0
paddle/fluid/operators/jit/gen/jitcode.h
paddle/fluid/operators/jit/gen/jitcode.h
+126
-0
paddle/fluid/operators/jit/gen/lstm.cc
paddle/fluid/operators/jit/gen/lstm.cc
+142
-0
paddle/fluid/operators/jit/gen/lstm.h
paddle/fluid/operators/jit/gen/lstm.h
+118
-0
paddle/fluid/operators/jit/gen_base.cc
paddle/fluid/operators/jit/gen_base.cc
+43
-0
paddle/fluid/operators/jit/gen_base.h
paddle/fluid/operators/jit/gen_base.h
+70
-0
paddle/fluid/operators/jit/helper.cc
paddle/fluid/operators/jit/helper.cc
+76
-0
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+140
-0
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+172
-0
paddle/fluid/operators/jit/kernel_key.cc
paddle/fluid/operators/jit/kernel_key.cc
+47
-0
paddle/fluid/operators/jit/kernel_key.h
paddle/fluid/operators/jit/kernel_key.h
+53
-0
paddle/fluid/operators/jit/kernel_pool.cc
paddle/fluid/operators/jit/kernel_pool.cc
+41
-0
paddle/fluid/operators/jit/kernel_pool.h
paddle/fluid/operators/jit/kernel_pool.h
+119
-0
paddle/fluid/operators/jit/macro.h
paddle/fluid/operators/jit/macro.h
+32
-0
paddle/fluid/operators/jit/more/CMakeLists.txt
paddle/fluid/operators/jit/more/CMakeLists.txt
+17
-0
paddle/fluid/operators/jit/more/intrinsic/CMakeLists.txt
paddle/fluid/operators/jit/more/intrinsic/CMakeLists.txt
+9
-0
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc
+181
-0
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h
+41
-0
paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
+168
-0
paddle/fluid/operators/jit/more/intrinsic/layer_norm.h
paddle/fluid/operators/jit/more/intrinsic/layer_norm.h
+41
-0
paddle/fluid/operators/jit/more/mix/CMakeLists.txt
paddle/fluid/operators/jit/more/mix/CMakeLists.txt
+14
-0
paddle/fluid/operators/jit/more/mix/mix.cc
paddle/fluid/operators/jit/more/mix/mix.cc
+216
-0
paddle/fluid/operators/jit/more/mix/mix.h
paddle/fluid/operators/jit/more/mix/mix.h
+61
-0
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
+11
-0
paddle/fluid/operators/jit/more/mkl/mkl.cc
paddle/fluid/operators/jit/more/mkl/mkl.cc
+139
-0
paddle/fluid/operators/jit/more/mkl/mkl.h
paddle/fluid/operators/jit/more/mkl/mkl.h
+90
-0
paddle/fluid/operators/jit/refer/CMakeLists.txt
paddle/fluid/operators/jit/refer/CMakeLists.txt
+28
-0
paddle/fluid/operators/jit/refer/refer.cc
paddle/fluid/operators/jit/refer/refer.cc
+50
-0
paddle/fluid/operators/jit/refer/refer.h
paddle/fluid/operators/jit/refer/refer.h
+165
-25
paddle/fluid/operators/jit/registry.h
paddle/fluid/operators/jit/registry.h
+167
-0
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+584
-0
paddle/fluid/operators/layer_norm_op.h
paddle/fluid/operators/layer_norm_op.h
+7
-7
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+0
-9
paddle/fluid/operators/math/fc_compute.h
paddle/fluid/operators/math/fc_compute.h
+7
-8
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+0
-334
paddle/fluid/operators/math/jit_gen.cc
paddle/fluid/operators/math/jit_gen.cc
+0
-90
paddle/fluid/operators/math/jit_gen.h
paddle/fluid/operators/math/jit_gen.h
+0
-80
paddle/fluid/operators/math/jit_kernel.cc
paddle/fluid/operators/math/jit_kernel.cc
+0
-39
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+0
-157
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+0
-396
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
+0
-291
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+0
-236
paddle/fluid/operators/math/jit_kernel_impl.h
paddle/fluid/operators/math/jit_kernel_impl.h
+0
-73
paddle/fluid/operators/math/jit_kernel_layer_norm.cc
paddle/fluid/operators/math/jit_kernel_layer_norm.cc
+0
-239
paddle/fluid/operators/math/jit_kernel_macro.h
paddle/fluid/operators/math/jit_kernel_macro.h
+0
-179
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+0
-263
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+0
-742
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
693e5e65
...
@@ -16,6 +16,7 @@ add_subdirectory(metrics)
...
@@ -16,6 +16,7 @@ add_subdirectory(metrics)
add_subdirectory
(
optimizers
)
add_subdirectory
(
optimizers
)
add_subdirectory
(
reduce_ops
)
add_subdirectory
(
reduce_ops
)
add_subdirectory
(
sequence_ops
)
add_subdirectory
(
sequence_ops
)
add_subdirectory
(
jit
)
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
add_subdirectory
(
distributed
)
add_subdirectory
(
distributed
)
...
@@ -64,7 +65,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
...
@@ -64,7 +65,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
dynload_warpctc
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
dynload_warpctc
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence_padding sequence_scale cos_sim_functor memory jit_kernel
_helper
concat_and_split cross_entropy softmax vol2col im2col sampler
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
depthwise_conv prelu
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
depthwise_conv prelu
)
...
...
paddle/fluid/operators/crf_decoding_op.h
浏览文件 @
693e5e65
...
@@ -16,7 +16,7 @@ limitations under the License. */
...
@@ -16,7 +16,7 @@ limitations under the License. */
#include <limits>
#include <limits>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/
math/jit_kernel
.h"
#include "paddle/fluid/operators/
jit/kernels
.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -82,10 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
...
@@ -82,10 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor
track
;
Tensor
track
;
int
*
track_value
=
int
*
track_value
=
track
.
mutable_data
<
int
>
(
emission_dims
,
platform
::
CPUPlace
());
track
.
mutable_data
<
int
>
(
emission_dims
,
platform
::
CPUPlace
());
const
auto
&
ker
=
math
::
jitkernel
::
KernelPool
::
Instance
()
auto
ker
=
jit
::
Get
<
jit
::
kCRFDecoding
,
jit
::
CRFDecodingTuples
<
T
>
,
.
template
Get
<
math
::
jitkernel
::
CRFDecodeKernel
<
T
>
>
(
platform
::
CPUPlace
>
(
tag_num
);
static_cast
<
int
>
(
tag_num
));
ker
(
static_cast
<
int
>
(
seq_len
),
x
,
w
,
alpha_value
,
track_value
,
tag_num
);
ker
->
Compute
(
static_cast
<
int
>
(
seq_len
),
x
,
w
,
alpha_value
,
track_value
);
T
max_score
=
-
std
::
numeric_limits
<
T
>::
max
();
T
max_score
=
-
std
::
numeric_limits
<
T
>::
max
();
int
max_i
=
0
;
int
max_i
=
0
;
for
(
size_t
i
=
0
;
i
<
tag_num
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
tag_num
;
++
i
)
{
...
...
paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc
浏览文件 @
693e5e65
...
@@ -18,7 +18,7 @@ limitations under the License. */
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/operators/
math/jit_kernel
.h"
#include "paddle/fluid/operators/
jit/kernels
.h"
#include "xbyak/xbyak.h"
#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"
#include "xbyak/xbyak_util.h"
...
@@ -108,10 +108,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
...
@@ -108,10 +108,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr
int
simd_width
=
16
;
constexpr
int
simd_width
=
16
;
int
C
=
c
/
simd_width
;
int
C
=
c
/
simd_width
;
const
auto
&
multiply
=
auto
multiply
=
jit
::
Get
<
jit
::
kNCHW16CMulNC
,
jit
::
NCHW16CMulNCTuples
<
T
>
,
math
::
jitkernel
::
KernelPool
::
Instance
()
platform
::
CPUPlace
>
(
0
);
.
template
Get
<
math
::
jitkernel
::
EltwiseMulnChw16cNCKernel
<
T
>
>
(
n
);
#pragma omp parallel for collapse(2)
#pragma omp parallel for collapse(2)
for
(
int
ni
=
0
;
ni
<
n
;
ni
++
)
{
for
(
int
ni
=
0
;
ni
<
n
;
ni
++
)
{
for
(
int
ci
=
0
;
ci
<
C
;
ci
++
)
{
for
(
int
ci
=
0
;
ci
<
C
;
ci
++
)
{
...
@@ -122,7 +120,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
...
@@ -122,7 +120,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto
ptr_z
=
auto
ptr_z
=
z_data
+
ni
*
C
*
h
*
w
*
simd_width
+
ci
*
h
*
w
*
simd_width
;
z_data
+
ni
*
C
*
h
*
w
*
simd_width
+
ci
*
h
*
w
*
simd_width
;
multiply
->
Compute
(
ptr_x
,
ptr_y
,
ptr_z
,
h
,
w
);
multiply
(
ptr_x
,
ptr_y
,
ptr_z
,
h
,
w
);
}
}
}
}
}
}
...
...
paddle/fluid/operators/fused/fusion_gru_op.cc
浏览文件 @
693e5e65
...
@@ -15,9 +15,9 @@ limitations under the License. */
...
@@ -15,9 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_gru_op.h"
#include "paddle/fluid/operators/fused/fusion_gru_op.h"
#include <cstring> // for memcpy
#include <cstring> // for memcpy
#include <string>
#include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -182,27 +182,29 @@ class FusionGRUKernel : public framework::OpKernel<T> {
...
@@ -182,27 +182,29 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int total_T = x_dims[0]; \
const int total_T = x_dims[0]; \
const int D3 = wh_dims[1]
const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \
#define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const int D2 = D * 2; \
const math::jitkernel::gru_attr_t attr( \
const jit::gru_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
ctx.Attr<std::string>("activation")); \
jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
math::jitkernel::gru_t one_step; \
jit::gru_t one_step; \
const auto& ker = \
auto ComputeH1 = \
math::jitkernel::KernelPool::Instance() \
jit::Get<jit::kGRUH1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
.template Get<math::jitkernel::GRUKernel<T>, \
auto ComputeHtPart1 = \
const math::jitkernel::gru_attr_t&>(attr); \
jit::Get<jit::kGRUHtPart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
const T* x_data = x->data<T>(); \
auto ComputeHtPart2 = \
const T* wx_data = wx->data<T>(); \
jit::Get<jit::kGRUHtPart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
const T* wh_data = wh->data<T>(); \
const T* x_data = x->data<T>(); \
auto place = ctx.GetPlace(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place)
T* xx_data = xx->mutable_data<T>(place)
void
SeqCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
SeqCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
...
@@ -241,7 +243,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
...
@@ -241,7 +243,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
}
else
{
}
else
{
one_step
.
gates
=
xx_data
;
one_step
.
gates
=
xx_data
;
one_step
.
ht
=
hidden_out_data
;
one_step
.
ht
=
hidden_out_data
;
ker
->
ComputeH1
(
&
one_step
,
&
attr
);
ComputeH1
(
&
one_step
,
&
attr
);
prev_hidden_data
=
hidden_out_data
;
prev_hidden_data
=
hidden_out_data
;
tstart
=
1
;
tstart
=
1
;
move_step
();
move_step
();
...
@@ -254,12 +256,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
...
@@ -254,12 +256,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step
.
gates
=
xx_data
;
one_step
.
gates
=
xx_data
;
one_step
.
ht_1
=
prev_hidden_data
;
one_step
.
ht_1
=
prev_hidden_data
;
one_step
.
ht
=
hidden_out_data
;
one_step
.
ht
=
hidden_out_data
;
ker
->
ComputeHtPart1
(
&
one_step
,
&
attr
);
ComputeHtPart1
(
&
one_step
,
&
attr
);
// gemm rt * Ws
// gemm rt * Ws
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D
,
D
,
static_cast
<
T
>
(
1
),
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D
,
D
,
static_cast
<
T
>
(
1
),
hidden_out_data
,
D
,
wh_state_data
,
D
,
static_cast
<
T
>
(
1
),
hidden_out_data
,
D
,
wh_state_data
,
D
,
static_cast
<
T
>
(
1
),
xx_data
+
D2
,
D3
);
xx_data
+
D2
,
D3
);
ker
->
ComputeHtPart2
(
&
one_step
,
&
attr
);
ComputeHtPart2
(
&
one_step
,
&
attr
);
// save prev
// save prev
prev_hidden_data
=
hidden_out_data
;
prev_hidden_data
=
hidden_out_data
;
move_step
();
move_step
();
...
@@ -323,7 +325,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
...
@@ -323,7 +325,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
one_step
.
gates
=
cur_in_data
;
one_step
.
gates
=
cur_in_data
;
one_step
.
ht
=
cur_out_data
;
one_step
.
ht
=
cur_out_data
;
ker
->
ComputeH1
(
&
one_step
,
&
attr
);
ComputeH1
(
&
one_step
,
&
attr
);
// add offset
// add offset
cur_in_data
+=
D3
;
cur_in_data
+=
D3
;
cur_out_data
+=
D
;
cur_out_data
+=
D
;
...
@@ -351,7 +353,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
...
@@ -351,7 +353,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step
.
gates
=
cur_batched_data
;
one_step
.
gates
=
cur_batched_data
;
one_step
.
ht_1
=
cur_prev_hidden_data
;
one_step
.
ht_1
=
cur_prev_hidden_data
;
one_step
.
ht
=
cur_out_data
;
one_step
.
ht
=
cur_out_data
;
ker
->
ComputeHtPart1
(
&
one_step
,
&
attr
);
ComputeHtPart1
(
&
one_step
,
&
attr
);
cur_batched_data
+=
D3
;
cur_batched_data
+=
D3
;
cur_prev_hidden_data
+=
D
;
cur_prev_hidden_data
+=
D
;
...
@@ -369,7 +371,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
...
@@ -369,7 +371,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step
.
gates
=
cur_batched_data
;
one_step
.
gates
=
cur_batched_data
;
one_step
.
ht_1
=
cur_prev_hidden_data
;
one_step
.
ht_1
=
cur_prev_hidden_data
;
one_step
.
ht
=
cur_out_data
;
one_step
.
ht
=
cur_out_data
;
ker
->
ComputeHtPart2
(
&
one_step
,
&
attr
);
ComputeHtPart2
(
&
one_step
,
&
attr
);
cur_batched_data
+=
D3
;
cur_batched_data
+=
D3
;
cur_prev_hidden_data
+=
D
;
cur_prev_hidden_data
+=
D
;
cur_out_data
+=
D
;
cur_out_data
+=
D
;
...
...
paddle/fluid/operators/fused/fusion_lstm_op.cc
浏览文件 @
693e5e65
...
@@ -14,9 +14,9 @@ limitations under the License. */
...
@@ -14,9 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_lstm_op.h"
#include "paddle/fluid/operators/fused/fusion_lstm_op.h"
#include <string>
#include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -235,31 +235,32 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -235,31 +235,32 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \
const int D = wh_dims[0]; \
const int D4 = wh_dims[1]
const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \
#define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/
\
/* diagonal weight*/
\
const T* wp_data = bias->data<T>() + D4; \
const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/
\
/* for peephole only*/
\
T* checked_cell_data = nullptr; \
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/
\
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/
\
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
} \
} \
const math::jitkernel::lstm_attr_t attr( \
const jit::lstm_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \
D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
ctx.Attr<std::string>("candidate_activation"), \
jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
ctx.Attr<std::string>("cell_activation"), use_peepholes); \
jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
math::jitkernel::lstm_t one_step; \
use_peepholes); \
one_step.wp = wp_data; \
jit::lstm_t one_step; \
one_step.checked = checked_cell_data; \
one_step.wp = wp_data; \
const auto& ker = \
one_step.checked = checked_cell_data; \
math::jitkernel::KernelPool::Instance() \
auto ComputeC1H1 = \
.template Get<math::jitkernel::LSTMKernel<T>, \
jit::Get<jit::kLSTMC1H1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \
const math::jitkernel::lstm_attr_t&>(attr)
auto ComputeCtHt = \
jit::Get<jit::kLSTMCtHt, jit::LSTMTuples<T>, platform::CPUPlace>(attr)
// Wh GEMM
// Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \
#define GEMM_WH_ADDON(bs, prev, out) \
...
@@ -305,7 +306,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -305,7 +306,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step
.
gates
=
xx_data
;
one_step
.
gates
=
xx_data
;
one_step
.
ct
=
c_out_data
;
one_step
.
ct
=
c_out_data
;
one_step
.
ht
=
h_out_data
;
one_step
.
ht
=
h_out_data
;
ker
->
ComputeC1H1
(
&
one_step
,
&
attr
);
ComputeC1H1
(
&
one_step
,
&
attr
);
tstart
=
1
;
tstart
=
1
;
// move one step
// move one step
prev_h_data
=
h_out_data
;
prev_h_data
=
h_out_data
;
...
@@ -321,7 +322,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -321,7 +322,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step
.
ct_1
=
prev_c_data
;
one_step
.
ct_1
=
prev_c_data
;
one_step
.
ct
=
c_out_data
;
one_step
.
ct
=
c_out_data
;
one_step
.
ht
=
h_out_data
;
one_step
.
ht
=
h_out_data
;
ker
->
ComputeCtHt
(
&
one_step
,
&
attr
);
ComputeCtHt
(
&
one_step
,
&
attr
);
// move one step
// move one step
prev_h_data
=
h_out_data
;
prev_h_data
=
h_out_data
;
prev_c_data
=
c_out_data
;
prev_c_data
=
c_out_data
;
...
@@ -401,7 +402,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -401,7 +402,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step
.
gates
=
cur_in_data
;
one_step
.
gates
=
cur_in_data
;
one_step
.
ct
=
cur_c_out_data
;
one_step
.
ct
=
cur_c_out_data
;
one_step
.
ht
=
cur_h_out_data
;
one_step
.
ht
=
cur_h_out_data
;
ker
->
ComputeC1H1
(
&
one_step
,
&
attr
);
ComputeC1H1
(
&
one_step
,
&
attr
);
cur_in_data
+=
D4
;
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_c_out_data
+=
D
;
...
@@ -431,7 +432,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -431,7 +432,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step
.
ct_1
=
cur_prev_c_data
;
one_step
.
ct_1
=
cur_prev_c_data
;
one_step
.
ct
=
cur_c_out_data
;
one_step
.
ct
=
cur_c_out_data
;
one_step
.
ht
=
cur_h_out_data
;
one_step
.
ht
=
cur_h_out_data
;
ker
->
ComputeCtHt
(
&
one_step
,
&
attr
);
ComputeCtHt
(
&
one_step
,
&
attr
);
// move one batch
// move one batch
cur_in_data
+=
D4
;
cur_in_data
+=
D4
;
...
...
paddle/fluid/operators/jit/CMakeLists.txt
0 → 100644
浏览文件 @
693e5e65
set
(
jit_file
${
PADDLE_BINARY_DIR
}
/paddle/fluid/operators/jit/kernels.h
)
file
(
WRITE
${
jit_file
}
"// Generated by the paddle/fluid/operators/jit/CMakeLists.txt. DO NOT EDIT!
\n\n
"
)
file
(
APPEND
${
jit_file
}
"
\#
pragma once
\n
"
)
file
(
APPEND
${
jit_file
}
"
\#
include
\"
paddle/fluid/operators/jit/helper.h
\"\n
"
)
file
(
APPEND
${
jit_file
}
"
\#
include
\"
paddle/fluid/operators/jit/registry.h
\"\n\n
"
)
set
(
JIT_KERNEL_DEPS cpu_info cblas gflags enforce place
)
file
(
GLOB jit_kernel_cc_srcs RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"*.cc"
)
list
(
REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc
)
cc_library
(
jit_kernel_base SRCS
${
jit_kernel_cc_srcs
}
DEPS
${
JIT_KERNEL_DEPS
}
)
# refer must go first
add_subdirectory
(
refer
)
add_subdirectory
(
more
)
if
(
WITH_XBYAK
)
add_subdirectory
(
gen
)
endif
()
cc_library
(
jit_kernel_helper SRCS
${
jit_kernel_cc_srcs
}
DEPS
${
JIT_KERNEL_DEPS
}
)
cc_test
(
jit_kernel_test SRCS test.cc DEPS jit_kernel_helper
)
if
(
NOT WIN32
)
cc_binary
(
jit_kernel_benchmark SRCS benchmark.cc DEPS jit_kernel_helper
)
endif
()
paddle/fluid/operators/jit/README.md
0 → 100644
浏览文件 @
693e5e65
# JIT Kernel
结合函数模板和JIT生成需要的kernel函数。
这里的kernel是比Operator中kernel更小级别的算子单元,更侧重的是在不同硬件上的性能。可以有多重第三方库的实现,每种实现有自己的
`UseMe`
函数负责什么条件下可以被调用。
这里实现的函数可以非常细粒度的函数方法,比如Vector MUL, 也可以是一个复杂的逻辑比如LSTM等。复杂的逻辑也可以由自己的底层函数拼接而成。
目前仅支持CPU上的高性能计算。
## 目录结构
```
txt
PaddlePaddle/Paddle/paddle/fluid/
├── ...
├── operator/
│ ├── .../
└── jit/
├── ...
├── gen/
│ └── ...
|── more/
│ ├── ...
│ ├── mkl/
│ │ └── ...
│ ├── mkldnn/
│ │ └── ...
│ ├── mix/
│ │ └── ...
│ ├── intrinsic/
│ │ └── ...
│ └── openblas/
│ └── ...
└── refer/
└── ...
```
基本类的定义都放在根目录下,根目录下包括gen,more和refer三个目录。每个目录下都是一种或者多种实现,每种kernel算子都需要有reference的实现,用作单元测试的基准,其他的实现都是可选的。
-
gen: 代表使用jit生成的code,需要依赖xbyak库。该实现最关心的就是性能。
-
refer: 代表reference的实现,每种kernel算子都需要有在CPU上的reference的实现,他主要关心的算法逻辑的正确性。
-
more: 下面可以放入跟多实现,可以包括mkl,mkldnn,intrinsic,openblas等,也可以是自身已有的kernel组合。
## 动态获取
提供一个
`jit::Get`
方法,根据kernel类别获取,每种实现都有自己的使用范围,根据范围动态和当前条件选择需要的kernel函数。
## 测试
-
逻辑测试
所有实现都要与refer的code对比,需要满足精度要求, 包括float和double的数据类型
-
性能测试
所有实现的性能对比,并且与最终的
`jit::Get`
方法对比,该方法拿到的性能需要在各种条件下都是最好的。
# 如何添加新的算子
-
在
`KernelType`
中添加
`your_key`
.
-
实现Reference 的逻辑,这个是必须是在CPU上的实现,并且不能依赖任何第三方库。实现后在
`refer/CmakeLists.txt`
中添加
`USE_JITKERNEL_REFER(your_key)`
来使用该kernel.
-
(optional) 实现更多的算法在
`more`
目录下,可以依赖mkl,intrinsic或者mkldnn等第三方库。
-
(optional) 实现基于Xbyak的生成code,在
`gen`
目下。 jitcode需要实现自己的
`JitCodeCreator`
,并注册在与refer相同的
`KernelType`
上。
-
必要时可以添加新的
`KernelTuples`
,可以参考
`XYZNTuples`
,新加的Attr类型需要特例化
`JitCodeKey`
方法。
-
在
`test.cc`
中添加unit test,至少需要测试
`float`
和
`double`
两种数据类型,如有必要需要支持额外的数据类型,比如
`int8`
的相关函数。
-
在
`benchmark.cc`
中添加相应的性能对比,同一种kernel需要对比所有实现,并且确保
`jit::Get`
得到的实现一直是速度最快的。
# 优点
-
统一的Get方法,接口简单。
-
同一套逻辑可以有多套实现,可以依赖多套第三方库,互不影响。
-
目录结构清晰,不会在某个文件中有多个宏定义,导致的可读性差问题。
-
优化方便,可以直接针对某种属性针对性优化,并不影响其他属性下的性能。
-
可以支持多种平台,包括Linux,Mac 和 Windows,至少可以保证每种平台都可以正常work。后期也可以针对不同平台有针对的优化。框架层面可以使用统一接口,不必关心底层实现。
paddle/fluid/operators/jit/benchmark.cc
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include <iostream>
#include <random>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
DEFINE_int32
(
burning
,
10
,
"Burning times."
);
DEFINE_int32
(
repeat
,
3000
,
"Repeat times."
);
DEFINE_int32
(
max_size
,
1000
,
"The Max size would be tested."
);
template
<
typename
T
>
void
RandomVec
(
const
int
n
,
T
*
a
,
const
T
lower
=
static_cast
<
T
>
(
-
20.
f
),
const
T
upper
=
static_cast
<
T
>
(
20.
f
),
unsigned
int
seed
=
100
)
{
std
::
mt19937
rng
(
seed
);
std
::
uniform_real_distribution
<
double
>
uniform_dist
(
0
,
1
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
a
[
i
]
=
static_cast
<
T
>
(
uniform_dist
(
rng
)
*
(
upper
-
lower
)
+
lower
);
}
}
std
::
vector
<
int
>
TestSizes
()
{
std
::
vector
<
int
>
s
;
for
(
int
i
=
1
;
i
<=
FLAGS_max_size
;
++
i
)
{
s
.
push_back
(
i
);
}
return
s
;
}
template
<
typename
KernelTuples
,
typename
...
Args
>
struct
BenchFunc
{
// return this function avg time
double
operator
()(
const
typename
KernelTuples
::
func_type
tgt
,
Args
...
args
)
{
for
(
int
i
=
0
;
i
<
FLAGS_burning
;
++
i
)
{
tgt
(
args
...);
}
auto
start
=
paddle
::
platform
::
PosixInNsec
()
/
1e-3
;
for
(
int
i
=
0
;
i
<
FLAGS_repeat
;
++
i
)
{
tgt
(
args
...);
}
auto
end
=
paddle
::
platform
::
PosixInNsec
()
/
1e-3
;
return
static_cast
<
double
>
(
end
-
start
)
/
FLAGS_repeat
;
}
};
namespace
jit
=
paddle
::
operators
::
jit
;
template
<
jit
::
KernelType
KT
,
typename
KernelTuples
,
typename
PlaceType
,
typename
...
Args
>
void
BenchAllImpls
(
const
typename
KernelTuples
::
attr_type
&
attr
,
Args
...
args
)
{
BenchFunc
<
KernelTuples
,
Args
...
>
benchmark
;
std
::
vector
<
std
::
pair
<
std
::
string
,
double
>>
infos
;
// test refer
auto
refer
=
jit
::
GetRefer
<
KT
,
KernelTuples
>
();
if
(
!
refer
)
{
LOG
(
FATAL
)
<<
"Refer can not be empty!"
;
}
infos
.
push_back
(
std
::
make_pair
(
"Refer"
,
benchmark
(
refer
,
args
...)));
// test jitcode
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
KernelTuples
,
PlaceType
>
(
attr
);
if
(
jitcode
)
{
infos
.
push_back
(
std
::
make_pair
(
"JitCode"
,
benchmark
(
jitcode
,
args
...)));
}
// test all impls in more
jit
::
KernelKey
kkey
(
KT
,
PlaceType
());
auto
&
pool
=
jit
::
KernelPool
().
Instance
().
AllKernels
();
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
jit
::
KernelMore
<
KernelTuples
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
auto
more
=
i
->
GetFunc
();
infos
.
push_back
(
std
::
make_pair
(
i
->
ImplType
(),
benchmark
(
more
,
args
...)));
}
}
}
// Test result from Get function
auto
tgt
=
jit
::
Get
<
KT
,
KernelTuples
,
PlaceType
>
(
attr
);
if
(
!
tgt
)
{
LOG
(
FATAL
)
<<
"Target can not be empty!"
;
}
infos
.
push_back
(
std
::
make_pair
(
"Target"
,
benchmark
(
tgt
,
args
...)));
// print
std
::
ostringstream
loginfos
;
loginfos
<<
"Kernel Type "
<<
jit
::
to_string
(
KT
)
<<
": "
<<
attr
<<
": "
;
for
(
auto
pair
:
infos
)
{
loginfos
<<
pair
.
first
<<
" takes "
<<
pair
.
second
<<
" us; "
;
}
LOG
(
INFO
)
<<
loginfos
.
str
();
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
BenchXYZNKernel
()
{
for
(
int
d
:
TestSizes
())
{
std
::
vector
<
T
>
x
(
d
),
y
(
d
),
z
(
d
);
RandomVec
<
T
>
(
d
,
x
.
data
());
RandomVec
<
T
>
(
d
,
y
.
data
());
BenchAllImpls
<
KT
,
jit
::
XYZNTuples
<
T
>
,
PlaceType
>
(
d
,
x
.
data
(),
y
.
data
(),
z
.
data
(),
d
);
}
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
BenchAXYNKernel
()
{
for
(
int
d
:
TestSizes
())
{
const
T
a
=
static_cast
<
T
>
(
3
);
std
::
vector
<
T
>
x
(
d
),
y
(
d
);
RandomVec
<
T
>
(
d
,
x
.
data
());
BenchAllImpls
<
KT
,
jit
::
AXYNTuples
<
T
>
,
PlaceType
>
(
d
,
&
a
,
x
.
data
(),
y
.
data
(),
d
);
}
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
BenchXYNKernel
()
{
for
(
int
d
:
TestSizes
())
{
std
::
vector
<
T
>
x
(
d
),
y
(
d
);
RandomVec
<
T
>
(
d
,
x
.
data
());
BenchAllImpls
<
KT
,
jit
::
XYNTuples
<
T
>
,
PlaceType
>
(
d
,
x
.
data
(),
y
.
data
(),
d
);
}
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
BenchLSTMKernel
()
{
for
(
bool
use_peephole
:
{
true
,
false
})
{
for
(
int
d
:
TestSizes
())
{
const
jit
::
lstm_attr_t
attr
(
d
,
jit
::
kVSigmoid
,
jit
::
kVTanh
,
jit
::
kVTanh
,
use_peephole
);
std
::
vector
<
T
>
x
(
4
*
d
),
ct_1
(
d
),
ct
(
d
),
ht
(
d
),
wp
(
3
*
d
),
checked
(
2
*
d
);
RandomVec
<
T
>
(
4
*
d
,
x
.
data
(),
-
2.
f
,
2.
f
);
RandomVec
<
T
>
(
3
*
d
,
wp
.
data
(),
-
2.
f
,
2.
f
);
RandomVec
<
T
>
(
d
,
ct_1
.
data
(),
-
2.
f
,
2.
f
);
const
T
*
ct_1_data
=
ct_1
.
data
();
const
T
*
wp_data
=
wp
.
data
();
T
*
x_data
=
x
.
data
();
T
*
checked_data
=
checked
.
data
();
T
*
ct_data
=
ct
.
data
();
T
*
ht_data
=
ht
.
data
();
jit
::
lstm_t
step
;
step
.
gates
=
x_data
;
step
.
ct_1
=
ct_1_data
;
step
.
ct
=
ct_data
;
step
.
ht
=
ht_data
;
if
(
use_peephole
)
{
step
.
wp
=
wp_data
;
step
.
checked
=
checked_data
;
}
BenchAllImpls
<
KT
,
jit
::
LSTMTuples
<
T
>
,
PlaceType
>
(
attr
,
&
step
,
&
attr
);
}
}
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
BenchGRUKernel
()
{
for
(
int
d
:
TestSizes
())
{
const
jit
::
gru_attr_t
attr
(
d
,
jit
::
kVSigmoid
,
jit
::
kVTanh
);
std
::
vector
<
T
>
x
(
3
*
d
),
ht_1
(
d
),
ht
(
d
);
RandomVec
<
T
>
(
3
*
d
,
x
.
data
(),
-
2.
f
,
2.
f
);
RandomVec
<
T
>
(
d
,
ht_1
.
data
(),
-
2.
f
,
2.
f
);
const
T
*
ht_1_data
=
ht_1
.
data
();
T
*
x_data
=
x
.
data
();
T
*
ht_data
=
ht
.
data
();
jit
::
gru_t
step
;
step
.
gates
=
x_data
;
step
.
ht_1
=
ht_1_data
;
step
.
ht
=
ht_data
;
BenchAllImpls
<
KT
,
jit
::
GRUTuples
<
T
>
,
PlaceType
>
(
attr
,
&
step
,
&
attr
);
}
}
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
// Options:
// --burning: the burning time before count
// --repeat: the repeat times
// --max_size: the max size would be tested
int
main
(
int
argc
,
char
*
argv
[])
{
gflags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
google
::
InitGoogleLogging
(
argv
[
0
]);
LOG
(
INFO
)
<<
"Burning "
<<
FLAGS_burning
<<
" times, Repeat "
<<
FLAGS_repeat
<<
" times."
;
using
T
=
float
;
using
PlaceType
=
paddle
::
platform
::
CPUPlace
;
// xyzn
BenchXYZNKernel
<
jit
::
kVMul
,
T
,
PlaceType
>
();
BenchXYZNKernel
<
jit
::
kVAdd
,
T
,
PlaceType
>
();
BenchXYZNKernel
<
jit
::
kVAddRelu
,
T
,
PlaceType
>
();
BenchXYZNKernel
<
jit
::
kVSub
,
T
,
PlaceType
>
();
// axyn
BenchAXYNKernel
<
jit
::
kVScal
,
T
,
PlaceType
>
();
BenchAXYNKernel
<
jit
::
kVAddBias
,
T
,
PlaceType
>
();
// xyn
BenchXYNKernel
<
jit
::
kVRelu
,
T
,
PlaceType
>
();
BenchXYNKernel
<
jit
::
kVIdentity
,
T
,
PlaceType
>
();
BenchXYNKernel
<
jit
::
kVExp
,
T
,
PlaceType
>
();
BenchXYNKernel
<
jit
::
kVSigmoid
,
T
,
PlaceType
>
();
BenchXYNKernel
<
jit
::
kVTanh
,
T
,
PlaceType
>
();
// lstm and peephole
BenchLSTMKernel
<
jit
::
kLSTMCtHt
,
T
,
PlaceType
>
();
BenchLSTMKernel
<
jit
::
kLSTMC1H1
,
T
,
PlaceType
>
();
// gru functions
BenchGRUKernel
<
jit
::
kGRUH1
,
T
,
PlaceType
>
();
BenchGRUKernel
<
jit
::
kGRUHtPart1
,
T
,
PlaceType
>
();
BenchGRUKernel
<
jit
::
kGRUHtPart2
,
T
,
PlaceType
>
();
}
paddle/fluid/operators/jit/gen/CMakeLists.txt
0 → 100644
浏览文件 @
693e5e65
file
(
GLOB jitcode_cc_srcs RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"*.cc"
)
cc_library
(
jit_kernel_jitcode SRCS
${
jitcode_cc_srcs
}
DEPS jit_kernel_base xbyak
)
set
(
JIT_KERNEL_DEPS
${
JIT_KERNEL_DEPS
}
xbyak jit_kernel_jitcode PARENT_SCOPE
)
function
(
USE_JITKERNEL_GEN TARGET
)
file
(
APPEND
${
jit_file
}
"USE_JITKERNEL_GEN(
${
TARGET
}
);
\n
"
)
endfunction
()
# use gen jitcode kernel by name
USE_JITKERNEL_GEN
(
kVMul
)
USE_JITKERNEL_GEN
(
kVAdd
)
#USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me
USE_JITKERNEL_GEN
(
kVAddRelu
)
USE_JITKERNEL_GEN
(
kVScal
)
USE_JITKERNEL_GEN
(
kVAddBias
)
USE_JITKERNEL_GEN
(
kVRelu
)
USE_JITKERNEL_GEN
(
kVIdentity
)
USE_JITKERNEL_GEN
(
kVExp
)
USE_JITKERNEL_GEN
(
kVSigmoid
)
USE_JITKERNEL_GEN
(
kVTanh
)
USE_JITKERNEL_GEN
(
kLSTMCtHt
)
USE_JITKERNEL_GEN
(
kLSTMC1H1
)
USE_JITKERNEL_GEN
(
kGRUH1
)
USE_JITKERNEL_GEN
(
kGRUHtPart1
)
USE_JITKERNEL_GEN
(
kGRUHtPart2
)
USE_JITKERNEL_GEN
(
kNCHW16CMulNC
)
paddle/fluid/operators/jit/gen/act.cc
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
const
float
ALIGN32_BEG
exp_float_consts
[]
ALIGN32_END
=
{
REPEAT_8TIMES
(
1.
f
),
REPEAT_8TIMES
(
2.
f
),
REPEAT_8TIMES
(
0.5
f
),
REPEAT_8TIMES
(
EXP_HIG
),
REPEAT_8TIMES
(
EXP_LOW
),
REPEAT_8TIMES
(
CEPHES_LOG2EF
),
REPEAT_8TIMES
(
CEPHES_EXP_C1
),
REPEAT_8TIMES
(
CEPHES_EXP_C2
),
REPEAT_8TIMES
(
CEPHES_EXP_P0
),
REPEAT_8TIMES
(
CEPHES_EXP_P1
),
REPEAT_8TIMES
(
CEPHES_EXP_P2
),
REPEAT_8TIMES
(
CEPHES_EXP_P3
),
REPEAT_8TIMES
(
CEPHES_EXP_P4
),
REPEAT_8TIMES
(
CEPHES_EXP_P5
),
REPEAT_8TIMES
(
EXP_MAX_INPUT
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MAX
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MIN
)};
const
int
ALIGN32_BEG
exp_int_0x7f
[]
ALIGN32_END
=
{
REPEAT_8TIMES
(
0x7f
)};
int
ALIGN32_BEG
g_tmp_mem
[
16
]
ALIGN32_END
=
{
0
};
void
VActJitCode
::
genCode
()
{
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
act
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
type_
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
int
rest
=
num_
%
YMM_FLOAT_BLOCK
;
while
(
rest
>
0
)
{
int
block
=
XMM_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
block
=
4
;
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
}
else
if
(
rest
>=
2
)
{
block
=
2
;
vmovq
(
xmm_src
,
ptr
[
param1
+
offset
]);
}
else
{
block
=
1
;
vmovss
(
xmm_src
,
ptr
[
param1
+
offset
]);
}
act
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
type_
);
if
(
rest
>=
4
)
{
vmovups
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
else
if
(
rest
>=
2
)
{
vmovq
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
else
{
vmovss
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
offset
+=
sizeof
(
float
)
*
block
;
rest
-=
block
;
}
ret
();
}
#define DECLARE_ACT_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override; \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_ACT_CREATOR
(
VRelu
);
DECLARE_ACT_CREATOR
(
VIdentity
);
DECLARE_ACT_CREATOR
(
VExp
);
DECLARE_ACT_CREATOR
(
VSigmoid
);
DECLARE_ACT_CREATOR
(
VTanh
);
// TODO(TJ): tuning use me
size_t
VReluCreator
::
CodeSize
(
const
int
&
d
)
const
{
return
96
/* init size */
+
(
d
/
YMM_FLOAT_BLOCK
+
3
)
*
4
/* instructions */
*
8
/* average bytes for each instruction */
;
}
size_t
VIdentityCreator
::
CodeSize
(
const
int
&
d
)
const
{
return
96
+
(
d
/
YMM_FLOAT_BLOCK
+
3
)
*
4
*
8
;
}
size_t
VExpCreator
::
CodeSize
(
const
int
&
d
)
const
{
return
96
+
(
d
/
YMM_FLOAT_BLOCK
+
3
)
*
70
*
8
;
}
size_t
VSigmoidCreator
::
CodeSize
(
const
int
&
d
)
const
{
return
96
+
(
d
/
YMM_FLOAT_BLOCK
+
3
)
*
82
*
8
;
}
size_t
VTanhCreator
::
CodeSize
(
const
int
&
d
)
const
{
return
96
+
(
d
/
YMM_FLOAT_BLOCK
+
3
)
*
84
*
8
;
}
#undef DECLARE_ACT_CREATOR
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
namespace
gen
=
paddle
::
operators
::
jit
::
gen
;
REGISTER_JITKERNEL_GEN
(
kVRelu
,
gen
::
VReluCreator
);
REGISTER_JITKERNEL_GEN
(
kVIdentity
,
gen
::
VIdentityCreator
);
REGISTER_JITKERNEL_GEN
(
kVExp
,
gen
::
VExpCreator
);
REGISTER_JITKERNEL_GEN
(
kVSigmoid
,
gen
::
VSigmoidCreator
);
REGISTER_JITKERNEL_GEN
(
kVTanh
,
gen
::
VTanhCreator
);
paddle/fluid/operators/
math/jit_code
.h
→
paddle/fluid/operators/
jit/gen/act
.h
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
Licensed under the Apache License, Version 2.0 (the "License");
*
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
*
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
*
You may obtain a copy of the License at
*
http://www.apache.org/licenses/LICENSE-2.0
*
http://www.apache.org/licenses/LICENSE-2.0
*
Unless required by applicable law or agreed to in writing, software
*
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
*
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
*
See the License for the specific language governing permissions and
limitations under the License. */
*
limitations under the License. */
#pragma once
#pragma once
#include <string>
#include <string>
#include "paddle/fluid/operators/math/jit_gen.h"
#include "glog/logging.h"
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
jit
{
namespace
jitkernel
{
namespace
gen
{
namespace
gen
{
using
reg64_t
=
const
Xbyak
::
Reg64
;
using
reg32_t
=
const
Xbyak
::
Reg32
;
using
xmm_t
=
const
Xbyak
::
Xmm
;
using
ymm_t
=
const
Xbyak
::
Ymm
;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
Label
=
Xbyak
::
Label
;
typedef
enum
{
mul
=
0
,
add
,
sub
,
relu
,
exp
,
sigmoid
,
tanh
,
identity
}
operand_type
;
extern
const
float
exp_float_consts
[];
extern
const
float
exp_float_consts
[];
extern
const
int
exp_int_0x7f
[];
extern
const
int
exp_int_0x7f
[];
extern
int
g_tmp_mem
[];
extern
int
g_tmp_mem
[];
...
@@ -79,94 +59,15 @@ extern int g_tmp_mem[];
...
@@ -79,94 +59,15 @@ extern int g_tmp_mem[];
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class
VActFunc
:
public
JitCode
{
class
VXXJitCode
:
public
JitCode
{
public:
const
char
*
name
()
const
override
{
std
::
string
base
=
"VXXJitCode"
;
if
(
scalar_index_
==
1
)
{
base
+=
"_Scalar"
;
}
else
{
base
+=
"_Vec"
;
}
if
(
type_
==
operand_type
::
mul
)
{
base
+=
"_Mul"
;
}
else
if
(
type_
==
operand_type
::
add
)
{
base
+=
"_Add"
;
}
if
(
scalar_index_
==
2
)
{
base
+=
"_Scalar"
;
}
else
{
base
+=
"_Vec"
;
}
base
+=
(
with_relu_
?
"_Relu"
:
""
);
return
base
.
c_str
();
}
explicit
VXXJitCode
(
int
d
,
operand_type
type
,
int
scalar_index
,
bool
with_relu
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
),
scalar_index_
(
scalar_index
),
with_relu_
(
with_relu
)
{}
static
bool
init
(
int
d
,
int
scalar_index
=
0
);
void
generate
()
override
;
private:
int
num_
;
operand_type
type_
;
int
scalar_index_
;
bool
with_relu_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
reg64_t
param3
{
abi_param3
};
xmm_t
xmm_src1
=
xmm_t
(
0
);
xmm_t
xmm_src2
=
xmm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
2
);
xmm_t
xmm_zero
=
xmm_t
(
3
);
ymm_t
ymm_src1
=
ymm_t
(
0
);
ymm_t
ymm_src2
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
3
);
};
class
VActJitCode
:
public
JitCode
{
public:
public:
const
char
*
name
()
const
override
{
explicit
VActFunc
(
size_t
code_size
,
void
*
code_ptr
)
std
::
string
base
=
"VActJitCode"
;
:
JitCode
(
code_size
,
code_ptr
)
{}
switch
(
type_
)
{
virtual
const
char
*
name
()
const
=
0
;
case
operand_type
::
relu
:
virtual
void
genCode
()
=
0
;
base
+=
"_Relu"
;
break
;
case
operand_type
::
exp
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
sigmoid
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
tanh
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
identity
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
return
base
.
c_str
();
}
explicit
VActJitCode
(
int
d
,
operand_type
type
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
)
{}
static
bool
init
(
int
d
,
operand_type
type
);
void
generate
()
override
;
protected:
protected:
// compute
relu
with ymm, xmm
// compute
RELU
with ymm, xmm
template
<
typename
JMM
>
template
<
typename
JMM
>
void
relu_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
zero_idx
=
15
)
{
// NOLINT
void
relu_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
zero_idx
=
15
)
{
// NOLINT
JMM
zero
=
JMM
(
zero_idx
);
JMM
zero
=
JMM
(
zero_idx
);
...
@@ -174,7 +75,7 @@ class VActJitCode : public JitCode {
...
@@ -174,7 +75,7 @@ class VActJitCode : public JitCode {
vmaxps
(
dst
,
src
,
zero
);
vmaxps
(
dst
,
src
,
zero
);
}
}
// compute
exp
with ymm, xmm
// compute
EXP
with ymm, xmm
template
<
typename
JMM
>
template
<
typename
JMM
>
void
exp_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
int
fx_idx
=
12
,
// NOLINT
void
exp_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
int
fx_idx
=
12
,
// NOLINT
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
tmp_idx
=
15
)
{
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
tmp_idx
=
15
)
{
...
@@ -258,7 +159,7 @@ class VActJitCode : public JitCode {
...
@@ -258,7 +159,7 @@ class VActJitCode : public JitCode {
pop
(
reg_ptr_global
);
pop
(
reg_ptr_global
);
}
}
// compute
sigmoid
with ymm, xmm
// compute
SIGMOID
with ymm, xmm
template
<
typename
JMM
>
template
<
typename
JMM
>
void
sigmoid_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
// NOLINT
void
sigmoid_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
// NOLINT
int
fx_idx
=
12
,
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
fx_idx
=
12
,
int
fy_idx
=
13
,
int
mask_idx
=
14
,
...
@@ -283,7 +184,7 @@ class VActJitCode : public JitCode {
...
@@ -283,7 +184,7 @@ class VActJitCode : public JitCode {
pop
(
reg_ptr_global
);
pop
(
reg_ptr_global
);
}
}
// compute
tanh
with ymm, xmm
// compute
TANH
with ymm, xmm
template
<
typename
JMM
>
template
<
typename
JMM
>
void
tanh_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
// NOLINT
void
tanh_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
// NOLINT
int
fx_idx
=
12
,
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
fx_idx
=
12
,
int
fy_idx
=
13
,
int
mask_idx
=
14
,
...
@@ -310,223 +211,109 @@ class VActJitCode : public JitCode {
...
@@ -310,223 +211,109 @@ class VActJitCode : public JitCode {
pop
(
reg_ptr_global
);
pop
(
reg_ptr_global
);
}
}
// compute IDENTITY with ymm, xmm
template
<
typename
JMM
>
void
identity_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
zero_idx
)
{
// NOLINT
JMM
zero
=
JMM
(
zero_idx
);
vxorps
(
zero
,
zero
,
zero
);
vaddps
(
dst
,
src
,
zero
);
// TODO(TJ): use below
// dst.setIdx(src.getIdx());
}
template
<
typename
JMM
>
template
<
typename
JMM
>
void
act
(
JMM
&
dst
,
JMM
&
src
,
operand_type
type
)
{
// NOLINT
void
act
(
JMM
&
dst
,
JMM
&
src
,
operand_type
type
)
{
// NOLINT
// use 11~15
// use 11~15
switch
(
type
)
{
switch
(
type
)
{
case
operand_type
::
relu
:
case
operand_type
::
RELU
:
relu_jmm
<
JMM
>
(
dst
,
src
,
15
);
relu_jmm
<
JMM
>
(
dst
,
src
,
15
);
break
;
break
;
case
operand_type
::
exp
:
case
operand_type
::
EXP
:
exp_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
exp_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
break
;
break
;
case
operand_type
::
sigmoid
:
case
operand_type
::
SIGMOID
:
sigmoid_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
sigmoid_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
break
;
break
;
case
operand_type
::
tanh
:
case
operand_type
::
TANH
:
tanh_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
tanh_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
15
);
break
;
break
;
case
operand_type
::
identity
:
case
operand_type
::
IDENTITY
:
identity_jmm
<
JMM
>
(
dst
,
src
,
15
);
break
;
break
;
default:
default:
// throw error
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type
;
break
;
break
;
}
}
}
}
protected:
int
num_
;
operand_type
type_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
xmm_t
xmm_src
=
xmm_t
(
0
);
ymm_t
ymm_src
=
ymm_t
(
0
);
xmm_t
xmm_dst
=
xmm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
};
class
LSTMJitCode
:
public
VActJitCode
{
class
VActJitCode
:
public
VActFunc
{
public:
public:
const
char
*
name
()
const
override
{
explicit
VActJitCode
(
int
d
,
operand_type
type
,
size_t
code_size
,
std
::
string
base
=
"LSTMJitCode"
;
void
*
code_ptr
=
nullptr
)
if
(
use_peephole_
)
{
:
VActFunc
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
)
{
base
+=
"_Peephole"
;
if
(
!
(
type_
==
operand_type
::
RELU
||
type_
==
operand_type
::
EXP
||
}
type_
==
operand_type
::
SIGMOID
||
type_
==
operand_type
::
TANH
||
if
(
compute_c1h1_
)
{
type_
==
operand_type
::
IDENTITY
)
)
{
base
+=
"_C1H1"
;
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type_
;
}
}
auto
AddTypeStr
=
[
&
](
operand_type
type
)
{
this
->
genCode
();
switch
(
type
)
{
case
operand_type
::
relu
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
exp
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
sigmoid
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
tanh
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
identity
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
};
AddTypeStr
(
act_gate_
);
AddTypeStr
(
act_cand_
);
AddTypeStr
(
act_cell_
);
return
base
.
c_str
();
}
explicit
LSTMJitCode
(
bool
compute_c1h1
,
const
lstm_attr_t
&
attr
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
VActJitCode
(
attr
.
d
,
operand_type
::
sigmoid
/* this is bugy*/
,
code_size
,
code_ptr
),
compute_c1h1_
(
compute_c1h1
)
{
auto
typeExchange
=
[](
const
std
::
string
&
type
)
->
gen
::
operand_type
{
if
(
type
==
"sigmoid"
)
{
return
operand_type
::
sigmoid
;
}
else
if
(
type
==
"relu"
)
{
return
operand_type
::
relu
;
}
else
if
(
type
==
"tanh"
)
{
return
operand_type
::
tanh
;
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
operand_type
::
identity
;
}
// else throw error
return
operand_type
::
identity
;
};
num_
=
attr
.
d
;
use_peephole_
=
attr
.
use_peephole
;
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
act_cell_
=
typeExchange
(
attr
.
act_cell
);
}
}
static
bool
init
(
int
d
);
void
generate
()
override
;
protected:
int
num_
;
bool
compute_c1h1_
;
bool
use_peephole_
;
operand_type
act_gate_
;
operand_type
act_cand_
;
operand_type
act_cell_
;
reg64_t
param1
{
abi_param1
};
};
class
GRUJitCode
:
public
VActJitCode
{
public:
const
char
*
name
()
const
override
{
const
char
*
name
()
const
override
{
std
::
string
base
=
"GRUJitCode"
;
std
::
string
base
=
"VActJitCode"
;
if
(
id_
==
0
)
{
switch
(
type_
)
{
base
+=
"_H1"
;
case
operand_type
::
RELU
:
}
else
if
(
id_
==
1
)
{
base
+=
"_Relu"
;
base
+=
"_HtPart1"
;
break
;
}
else
if
(
id_
==
2
)
{
case
operand_type
::
EXP
:
base
+=
"_HtPart2"
;
base
+=
"_Exp"
;
break
;
case
operand_type
::
SIGMOID
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
TANH
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
IDENTITY
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
}
auto
AddTypeStr
=
[
&
](
operand_type
type
)
{
switch
(
type
)
{
case
operand_type
::
relu
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
exp
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
sigmoid
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
tanh
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
identity
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
};
AddTypeStr
(
act_gate_
);
AddTypeStr
(
act_cand_
);
return
base
.
c_str
();
return
base
.
c_str
();
}
}
void
genCode
()
override
;
explicit
GRUJitCode
(
int
id
,
const
gru_attr_t
&
attr
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
VActJitCode
(
attr
.
d
,
operand_type
::
sigmoid
/* this is bugy*/
,
code_size
,
code_ptr
),
id_
(
id
)
{
auto
typeExchange
=
[](
const
std
::
string
&
type
)
->
gen
::
operand_type
{
if
(
type
==
"sigmoid"
)
{
return
operand_type
::
sigmoid
;
}
else
if
(
type
==
"relu"
)
{
return
operand_type
::
relu
;
}
else
if
(
type
==
"tanh"
)
{
return
operand_type
::
tanh
;
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
operand_type
::
identity
;
}
// else throw error
return
operand_type
::
identity
;
};
num_
=
attr
.
d
;
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
}
static
bool
init
(
int
d
);
void
generate
()
override
;
protected:
protected:
int
id_
;
int
num_
;
int
num_
;
operand_type
act_gate_
;
operand_type
type_
;
operand_type
act_cand_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param1
{
abi_param1
};
};
reg64_t
param2
{
abi_param2
};
#ifdef PADDLE_WITH_MKLDNN
xmm_t
xmm_src
=
xmm_t
(
0
);
struct
EltwiseMulnChw16cNC
:
public
Xbyak
::
CodeGenerator
{
ymm_t
ymm_src
=
ymm_t
(
0
);
explicit
EltwiseMulnChw16cNC
(
size_t
code_size
=
256
*
1024
)
:
Xbyak
::
CodeGenerator
(
code_size
)
{
// RDI is ptr x_input
// RSI is ptr y_input
// RDX is ptr output
// RCX is height
// r8 is width
push
(
rbx
);
xmm_t
xmm_dst
=
xmm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
xor_
(
rax
,
rax
);
#define DECLARE_ACT_JITCODE(name, op_type) \
xor_
(
r10
,
r10
);
class name##JitCode : public VActJitCode { \
vmovups
(
zmm3
,
ptr
[
rsi
]);
public: \
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
: VActJitCode(d, op_type, code_size, code_ptr) {} \
};
L
(
"h_loop"
);
DECLARE_ACT_JITCODE
(
VRelu
,
operand_type
::
RELU
);
xor_
(
rbx
,
rbx
);
DECLARE_ACT_JITCODE
(
VIdentity
,
operand_type
::
IDENTITY
);
L
(
"w_loop"
);
DECLARE_ACT_JITCODE
(
VExp
,
operand_type
::
EXP
);
vmovups
(
zmm2
,
ptr
[
rdi
+
rax
]);
DECLARE_ACT_JITCODE
(
VSigmoid
,
operand_type
::
SIGMOID
);
vmulps
(
zmm1
,
zmm2
,
zmm3
);
DECLARE_ACT_JITCODE
(
VTanh
,
operand_type
::
TANH
);
vmovups
(
ptr
[
rdx
+
rax
],
zmm1
);
add
(
rax
,
64
);
inc
(
rbx
);
cmp
(
r8
,
rbx
);
jnz
(
"w_loop"
);
inc
(
r10
);
cmp
(
r10
,
rcx
);
jnz
(
"h_loop"
);
pop
(
rbx
);
#undef DECLARE_ACT_JITCODE
ret
();
}
};
#endif
}
// namespace gen
}
// namespace gen
}
// namespace jitkernel
}
// namespace jit
}
// namespace math
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/jit/gen/blas.cc
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/blas.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
void
VXXJitCode
::
genCode
()
{
// do not need push stack, and do not need save avx512reg if do not use avx512
int
offset
=
0
;
if
(
with_relu_
)
{
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
}
if
(
scalar_index_
==
1
)
{
vbroadcastss
(
ymm_src1
,
ptr
[
param1
]);
}
else
if
(
scalar_index_
==
2
)
{
vbroadcastss
(
ymm_src2
,
ptr
[
param2
]);
}
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
if
(
scalar_index_
!=
1
)
{
vmovups
(
ymm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
ymm_src2
,
ptr
[
param2
+
offset
]);
}
if
(
type_
==
operand_type
::
MUL
)
{
vmulps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
}
else
if
(
type_
==
operand_type
::
ADD
)
{
vaddps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
}
if
(
with_relu_
)
{
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_dst
);
}
vmovups
(
ptr
[
param3
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
int
rest
=
num_
%
YMM_FLOAT_BLOCK
;
while
(
rest
>
0
)
{
int
block
=
XMM_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
block
=
4
;
if
(
scalar_index_
!=
1
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
}
else
if
(
rest
>=
2
)
{
block
=
2
;
if
(
scalar_index_
!=
1
)
{
vmovq
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovq
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
}
else
{
block
=
1
;
if
(
scalar_index_
!=
1
)
{
vmovss
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovss
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
}
switch
(
type_
)
{
case
operand_type
::
MUL
:
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
break
;
case
operand_type
::
ADD
:
vaddps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
break
;
default:
break
;
}
if
(
with_relu_
)
{
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_dst
);
}
if
(
rest
>=
4
)
{
vmovups
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
else
if
(
rest
>=
2
)
{
vmovq
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
else
{
vmovss
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
offset
+=
sizeof
(
float
)
*
block
;
rest
-=
block
;
}
ret
();
}
void
NCHW16CMulNCJitCode
::
genCode
()
{
// RDI is ptr x_input
// RSI is ptr y_input
// RDX is ptr output
// RCX is height
// r8 is width
push
(
rbx
);
xor_
(
rax
,
rax
);
xor_
(
r10
,
r10
);
vmovups
(
zmm3
,
ptr
[
rsi
]);
L
(
"h_loop"
);
xor_
(
rbx
,
rbx
);
L
(
"w_loop"
);
vmovups
(
zmm2
,
ptr
[
rdi
+
rax
]);
vmulps
(
zmm1
,
zmm2
,
zmm3
);
vmovups
(
ptr
[
rdx
+
rax
],
zmm1
);
add
(
rax
,
64
);
inc
(
rbx
);
cmp
(
r8
,
rbx
);
jnz
(
"w_loop"
);
inc
(
r10
);
cmp
(
r10
,
rcx
);
jnz
(
"h_loop"
);
pop
(
rbx
);
ret
();
}
class
NCHW16CMulNCCreator
:
public
JitCodeCreator
<
int
>
{
public:
bool
UseMe
(
const
int
&
attr
)
const
override
{
return
platform
::
MayIUse
(
platform
::
avx512f
);
}
size_t
CodeSize
(
const
int
&
d
)
const
override
{
return
256
*
1024
;
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
int
&
attr
)
const
override
{
return
make_unique
<
NCHW16CMulNCJitCode
>
(
attr
,
CodeSize
(
attr
));
}
};
#define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override { \
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_BLAS_CREATOR
(
VMul
);
DECLARE_BLAS_CREATOR
(
VAdd
);
DECLARE_BLAS_CREATOR
(
VSub
);
DECLARE_BLAS_CREATOR
(
VAddRelu
);
DECLARE_BLAS_CREATOR
(
VScal
);
DECLARE_BLAS_CREATOR
(
VAddBias
);
#undef DECLARE_BLAS_CREATOR
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
namespace
gen
=
paddle
::
operators
::
jit
::
gen
;
REGISTER_JITKERNEL_GEN
(
kVMul
,
gen
::
VMulCreator
);
REGISTER_JITKERNEL_GEN
(
kVAdd
,
gen
::
VAddCreator
);
// TODO(TJ): enable sub
// REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN
(
kVAddRelu
,
gen
::
VAddReluCreator
);
REGISTER_JITKERNEL_GEN
(
kVScal
,
gen
::
VScalCreator
);
REGISTER_JITKERNEL_GEN
(
kVAddBias
,
gen
::
VAddBiasCreator
);
REGISTER_JITKERNEL_GEN
(
kNCHW16CMulNC
,
gen
::
NCHW16CMulNCCreator
);
paddle/fluid/operators/jit/gen/blas.h
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class
VXXJitCode
:
public
JitCode
{
public:
explicit
VXXJitCode
(
int
d
,
operand_type
type
,
int
scalar_index
,
bool
with_relu
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
),
type_
(
type
),
scalar_index_
(
scalar_index
),
with_relu_
(
with_relu
)
{
if
(
!
(
type_
==
operand_type
::
MUL
||
type_
==
operand_type
::
ADD
))
{
LOG
(
FATAL
)
<<
"Do not support this operand type: "
<<
type_
;
}
this
->
genCode
();
}
virtual
const
char
*
name
()
const
{
std
::
string
base
=
"VXXJitCode"
;
if
(
scalar_index_
==
1
)
{
base
+=
"_Scalar"
;
}
else
{
base
+=
"_Vec"
;
}
if
(
type_
==
operand_type
::
MUL
)
{
base
+=
"_Mul"
;
}
else
if
(
type_
==
operand_type
::
ADD
)
{
base
+=
"_Add"
;
}
if
(
scalar_index_
==
2
)
{
base
+=
"_Scalar"
;
}
else
{
base
+=
"_Vec"
;
}
base
+=
(
with_relu_
?
"_Relu"
:
""
);
return
base
.
c_str
();
}
void
genCode
()
override
;
private:
int
num_
;
operand_type
type_
;
int
scalar_index_
;
bool
with_relu_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
reg64_t
param3
{
abi_param3
};
xmm_t
xmm_src1
=
xmm_t
(
0
);
xmm_t
xmm_src2
=
xmm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
2
);
xmm_t
xmm_zero
=
xmm_t
(
3
);
ymm_t
ymm_src1
=
ymm_t
(
0
);
ymm_t
ymm_src2
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
3
);
};
#define DECLARE_BLAS_JITCODE(name, op_type, scalar_idx, with_relu) \
class name##JitCode : public VXXJitCode { \
public: \
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
: VXXJitCode(d, op_type, scalar_idx, with_relu, code_size, code_ptr) { \
} \
};
DECLARE_BLAS_JITCODE
(
VMul
,
operand_type
::
MUL
,
0
,
false
);
DECLARE_BLAS_JITCODE
(
VAdd
,
operand_type
::
ADD
,
0
,
false
);
DECLARE_BLAS_JITCODE
(
VSub
,
operand_type
::
SUB
,
0
,
false
);
DECLARE_BLAS_JITCODE
(
VAddRelu
,
operand_type
::
ADD
,
0
,
true
);
DECLARE_BLAS_JITCODE
(
VScal
,
operand_type
::
MUL
,
1
,
false
);
DECLARE_BLAS_JITCODE
(
VAddBias
,
operand_type
::
ADD
,
1
,
false
);
#undef DECLARE_BLAS_JITCODE
// nChw16c = nChw16c .* NC
class
NCHW16CMulNCJitCode
:
public
JitCode
{
public:
DECLARE_JIT_CODE
(
NCHW16CMulNCJitCode
);
explicit
NCHW16CMulNCJitCode
(
int
d
/*unused*/
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
)
{
this
->
genCode
();
}
void
genCode
()
override
;
};
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/gen/gru.cc
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/gru.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
void
GRUJitCode
::
genCode
()
{
reg64_t
reg_ptr_gates
=
rax
;
reg64_t
reg_ptr_ht_1
=
r9
;
reg64_t
reg_ptr_ht
=
r10
;
mov
(
reg_ptr_gates
,
ptr
[
param1
+
offsetof
(
gru_t
,
gates
)]);
mov
(
reg_ptr_ht_1
,
ptr
[
param1
+
offsetof
(
gru_t
,
ht_1
)]);
mov
(
reg_ptr_ht
,
ptr
[
param1
+
offsetof
(
gru_t
,
ht
)]);
ymm_t
ymm_one
=
ymm_t
(
0
);
if
(
id_
==
2
)
{
reg64_t
reg_ptr_tmp
=
r11
;
mov
(
reg_ptr_tmp
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_one
,
ptr
[
reg_ptr_tmp
+
OFFSET_EXP_ONE
]);
}
int
offset
=
0
;
int
d
=
num_
*
sizeof
(
float
);
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
ymm_t
ymm_u
=
ymm_t
(
1
);
ymm_t
ymm_r
=
ymm_t
(
2
);
ymm_t
ymm_s
=
ymm_t
(
3
);
ymm_t
ymm_ht_1
=
ymm_t
(
4
);
// W: {W_update, W_reset; W_state}
if
(
id_
==
0
||
id_
==
2
)
{
vmovups
(
ymm_u
,
ptr
[
reg_ptr_gates
+
offset
]);
vmovups
(
ymm_s
,
ptr
[
reg_ptr_gates
+
offset
+
2
*
d
]);
}
if
(
id_
==
1
)
{
vmovups
(
ymm_r
,
ptr
[
reg_ptr_gates
+
offset
+
d
]);
}
if
(
id_
==
1
||
id_
==
2
)
{
vmovups
(
ymm_ht_1
,
ptr
[
reg_ptr_ht_1
+
offset
]);
}
if
(
id_
==
0
)
{
// ht = act_gate(u) * act_cand(s)
act
<
ymm_t
>
(
ymm_u
,
ymm_u
,
act_gate_
);
act
<
ymm_t
>
(
ymm_s
,
ymm_s
,
act_cand_
);
vmulps
(
ymm_s
,
ymm_s
,
ymm_u
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_s
);
}
else
if
(
id_
==
1
)
{
// ht = act_gate(r) * ht_1
act
<
ymm_t
>
(
ymm_r
,
ymm_r
,
act_gate_
);
vmulps
(
ymm_r
,
ymm_r
,
ymm_ht_1
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_r
);
}
else
if
(
id_
==
2
)
{
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
ymm_t
ymm_one_inner
=
ymm_t
(
ymm_one
.
getIdx
());
act
<
ymm_t
>
(
ymm_u
,
ymm_u
,
act_gate_
);
act
<
ymm_t
>
(
ymm_s
,
ymm_s
,
act_cand_
);
vmulps
(
ymm_s
,
ymm_s
,
ymm_u
);
vsubps
(
ymm_u
,
ymm_one_inner
,
ymm_u
);
vmulps
(
ymm_u
,
ymm_ht_1
,
ymm_u
);
vaddps
(
ymm_u
,
ymm_s
,
ymm_u
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_u
);
}
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
ret
();
}
#define DECLARE_GRU_CREATOR(name) \
class name##Creator : public JitCodeCreator<gru_attr_t> { \
public: \
/* TODO(TJ): enable more */
\
bool UseMe(const gru_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const gru_attr_t& attr) const override { \
return 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode( \
const gru_attr_t& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_GRU_CREATOR
(
GRUH1
);
DECLARE_GRU_CREATOR
(
GRUHtPart1
);
DECLARE_GRU_CREATOR
(
GRUHtPart2
);
#undef DECLARE_GRU_CREATOR
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
namespace
gen
=
paddle
::
operators
::
jit
::
gen
;
REGISTER_JITKERNEL_GEN
(
kGRUH1
,
gen
::
GRUH1Creator
);
REGISTER_JITKERNEL_GEN
(
kGRUHtPart1
,
gen
::
GRUHtPart1Creator
);
REGISTER_JITKERNEL_GEN
(
kGRUHtPart2
,
gen
::
GRUHtPart2Creator
);
paddle/fluid/operators/jit/gen/gru.h
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
class
GRUJitCode
:
public
VActFunc
{
public:
explicit
GRUJitCode
(
int
id
,
const
gru_attr_t
&
attr
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
:
VActFunc
(
code_size
,
code_ptr
),
id_
(
id
),
num_
(
attr
.
d
)
{
auto
typeExchange
=
[](
KernelType
type
)
->
gen
::
operand_type
{
if
(
type
==
KernelType
::
kVSigmoid
)
{
return
operand_type
::
SIGMOID
;
}
else
if
(
type
==
KernelType
::
kVRelu
)
{
return
operand_type
::
RELU
;
}
else
if
(
type
==
KernelType
::
kVTanh
)
{
return
operand_type
::
TANH
;
}
else
if
(
type
==
KernelType
::
kVIdentity
)
{
return
operand_type
::
IDENTITY
;
}
else
{
LOG
(
FATAL
)
<<
"Do not support this jit::KernelType: "
<<
type
;
}
return
operand_type
::
IDENTITY
;
};
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
this
->
genCode
();
}
const
char
*
name
()
const
override
{
std
::
string
base
=
"GRUJitCode"
;
if
(
id_
==
0
)
{
base
+=
"_H1"
;
}
else
if
(
id_
==
1
)
{
base
+=
"_HtPart1"
;
}
else
if
(
id_
==
2
)
{
base
+=
"_HtPart2"
;
}
auto
AddTypeStr
=
[
&
](
operand_type
type
)
{
switch
(
type
)
{
case
operand_type
::
RELU
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
EXP
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
SIGMOID
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
TANH
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
IDENTITY
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
};
AddTypeStr
(
act_gate_
);
AddTypeStr
(
act_cand_
);
return
base
.
c_str
();
}
void
genCode
()
override
;
protected:
int
id_
;
int
num_
;
operand_type
act_gate_
;
operand_type
act_cand_
;
reg64_t
param1
{
abi_param1
};
};
#define DECLARE_GRU_JITCODE(name, id) \
class name##JitCode : public GRUJitCode { \
public: \
explicit name##JitCode(const gru_attr_t& attr, size_t code_size, \
void* code_ptr = nullptr) \
: GRUJitCode(id, attr, code_size, code_ptr) {} \
};
DECLARE_GRU_JITCODE
(
GRUH1
,
0
);
DECLARE_GRU_JITCODE
(
GRUHtPart1
,
1
);
DECLARE_GRU_JITCODE
(
GRUHtPart2
,
2
);
#undef DECLARE_GRU_JITCODE
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/gen/jitcode.h
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/platform/cpu_info.h"
#define XBYAK_USE_MMAP_ALLOCATOR
#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
// Application Binary Interface
constexpr
Xbyak
::
Operand
::
Code
abi_param1
(
Xbyak
::
Operand
::
RDI
),
abi_param2
(
Xbyak
::
Operand
::
RSI
),
abi_param3
(
Xbyak
::
Operand
::
RDX
),
abi_param4
(
Xbyak
::
Operand
::
RCX
);
constexpr
Xbyak
::
Operand
::
Code
g_abi_regs
[]
=
{
Xbyak
::
Operand
::
RBX
,
Xbyak
::
Operand
::
RBP
,
Xbyak
::
Operand
::
R12
,
Xbyak
::
Operand
::
R13
,
Xbyak
::
Operand
::
R14
,
Xbyak
::
Operand
::
R15
};
constexpr
int
num_g_abi_regs
=
sizeof
(
g_abi_regs
)
/
sizeof
(
g_abi_regs
[
0
]);
using
reg64_t
=
const
Xbyak
::
Reg64
;
using
reg32_t
=
const
Xbyak
::
Reg32
;
using
xmm_t
=
const
Xbyak
::
Xmm
;
using
ymm_t
=
const
Xbyak
::
Ymm
;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
Label
=
Xbyak
::
Label
;
typedef
enum
{
MUL
=
0
,
ADD
,
SUB
,
RELU
,
EXP
,
SIGMOID
,
TANH
,
IDENTITY
}
operand_type
;
#define DECLARE_JIT_CODE(codename) \
const char* name() const override { return #codename; }
class
JitCode
:
public
GenBase
,
public
Xbyak
::
CodeGenerator
{
public:
explicit
JitCode
(
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
:
Xbyak
::
CodeGenerator
(
(
code_size
%
4096
!=
0
?
(
code_size
/
4096
+
1
)
*
4096
:
code_size
),
code_ptr
)
{}
virtual
const
char
*
name
()
const
=
0
;
virtual
void
genCode
()
=
0
;
size_t
getSize
()
const
override
{
return
CodeGenerator
::
getSize
();
}
const
unsigned
char
*
getCodeInternal
()
override
{
const
Xbyak
::
uint8
*
code
=
CodeGenerator
::
getCode
();
return
code
;
}
protected:
Xbyak
::
Reg64
param1
{
abi_param1
};
const
int
EVEX_max_8b_offt
=
0x200
;
const
Xbyak
::
Reg64
reg_EVEX_max_8b_offt
=
rbp
;
virtual
void
preCode
()
{
for
(
int
i
=
0
;
i
<
num_g_abi_regs
;
++
i
)
{
push
(
Xbyak
::
Reg64
(
g_abi_regs
[
i
]));
}
if
(
platform
::
MayIUse
(
platform
::
avx512f
))
{
mov
(
reg_EVEX_max_8b_offt
,
2
*
EVEX_max_8b_offt
);
}
}
virtual
void
postCode
()
{
for
(
int
i
=
0
;
i
<
num_g_abi_regs
;
++
i
)
{
pop
(
Xbyak
::
Reg64
(
g_abi_regs
[
num_g_abi_regs
-
1
-
i
]));
}
ret
();
}
void
L
(
const
char
*
label
)
{
Xbyak
::
CodeGenerator
::
L
(
label
);
}
void
L
(
const
Xbyak
::
Label
&
label
)
{
Xbyak
::
CodeGenerator
::
L
(
label
);
}
// Enhanced vector extension
Xbyak
::
Address
EVEX_compress_addr
(
Xbyak
::
Reg64
base
,
int
offt
,
bool
bcast
=
false
)
{
int
scale
=
0
;
// Learn from https://github.com/intel/mkl-dnn
if
(
EVEX_max_8b_offt
<=
offt
&&
offt
<
3
*
EVEX_max_8b_offt
)
{
offt
=
offt
-
2
*
EVEX_max_8b_offt
;
scale
=
1
;
}
else
if
(
3
*
EVEX_max_8b_offt
<=
offt
&&
offt
<
5
*
EVEX_max_8b_offt
)
{
offt
=
offt
-
4
*
EVEX_max_8b_offt
;
scale
=
2
;
}
auto
re
=
Xbyak
::
RegExp
()
+
base
+
offt
;
if
(
scale
)
{
re
=
re
+
reg_EVEX_max_8b_offt
*
scale
;
}
if
(
bcast
)
{
return
zword_b
[
re
];
}
else
{
return
zword
[
re
];
}
}
};
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/gen/lstm.cc
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/lstm.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
void
LSTMJitCode
::
genCode
()
{
if
(
use_peephole_
)
{
preCode
();
}
reg64_t
reg_ptr_gates
=
rax
;
reg64_t
reg_ptr_ct_1
=
r9
;
reg64_t
reg_ptr_ct
=
r10
;
reg64_t
reg_ptr_ht
=
r11
;
reg64_t
reg_ptr_wp
=
r12
;
mov
(
reg_ptr_gates
,
ptr
[
param1
+
offsetof
(
lstm_t
,
gates
)]);
mov
(
reg_ptr_ct_1
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ct_1
)]);
mov
(
reg_ptr_ct
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ct
)]);
mov
(
reg_ptr_ht
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ht
)]);
if
(
use_peephole_
)
{
mov
(
reg_ptr_wp
,
ptr
[
param1
+
offsetof
(
lstm_t
,
wp
)]);
}
int
offset
=
0
;
int
d
=
num_
*
sizeof
(
float
);
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
/* gates: W_ch, W_ih, W_fh, W_oh */
ymm_t
ymm_c
=
ymm_t
(
0
);
ymm_t
ymm_i
=
ymm_t
(
1
);
ymm_t
ymm_f
=
ymm_t
(
2
);
ymm_t
ymm_o
=
ymm_t
(
3
);
ymm_t
ymm_ct_1
=
ymm_t
(
4
);
ymm_t
ymm_wp0
=
ymm_t
(
5
);
ymm_t
ymm_wp1
=
ymm_t
(
6
);
ymm_t
ymm_wp2
=
ymm_t
(
7
);
vmovups
(
ymm_c
,
ptr
[
reg_ptr_gates
+
offset
]);
vmovups
(
ymm_i
,
ptr
[
reg_ptr_gates
+
offset
+
d
]);
vmovups
(
ymm_f
,
ptr
[
reg_ptr_gates
+
offset
+
2
*
d
]);
vmovups
(
ymm_o
,
ptr
[
reg_ptr_gates
+
offset
+
3
*
d
]);
if
(
!
compute_c1h1_
)
{
vmovups
(
ymm_ct_1
,
ptr
[
reg_ptr_ct_1
+
offset
]);
}
if
(
use_peephole_
)
{
vmovups
(
ymm_wp0
,
ptr
[
reg_ptr_wp
+
offset
]);
vmovups
(
ymm_wp1
,
ptr
[
reg_ptr_wp
+
offset
+
d
]);
vmovups
(
ymm_wp2
,
ptr
[
reg_ptr_wp
+
offset
+
2
*
d
]);
}
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
// act_cand(c)
act
<
ymm_t
>
(
ymm_c
,
ymm_c
,
act_cand_
);
// act_gate(i) or act_gate(ct_1 * wp0 + i)
if
(
!
compute_c1h1_
&&
use_peephole_
)
{
vmulps
(
ymm_wp0
,
ymm_ct_1
,
ymm_wp0
);
vaddps
(
ymm_i
,
ymm_i
,
ymm_wp0
);
}
act
<
ymm_t
>
(
ymm_i
,
ymm_i
,
act_gate_
);
vmulps
(
ymm_c
,
ymm_c
,
ymm_i
);
if
(
!
compute_c1h1_
)
{
// act_gate(f) or act_gate(ct_1 * wp1 + f)
if
(
use_peephole_
)
{
vmulps
(
ymm_wp1
,
ymm_ct_1
,
ymm_wp1
);
vaddps
(
ymm_f
,
ymm_f
,
ymm_wp1
);
}
act
<
ymm_t
>
(
ymm_f
,
ymm_f
,
act_gate_
);
// ct
vmulps
(
ymm_f
,
ymm_f
,
ymm_ct_1
);
vaddps
(
ymm_f
,
ymm_f
,
ymm_c
);
}
/* H_t = act_cell(C_t) * act_gate(o) */
// act_cell(C_t)
ymm_t
ymm_ct
=
compute_c1h1_
?
ymm_c
:
ymm_f
;
ymm_t
ymm_tmp
=
ymm_i
;
act
<
ymm_t
>
(
ymm_tmp
,
ymm_ct
,
act_cell_
);
// act_gate(o) or act_gate(ct * wp2 + o)
if
(
use_peephole_
)
{
vmulps
(
ymm_wp2
,
ymm_ct
,
ymm_wp2
);
vaddps
(
ymm_o
,
ymm_o
,
ymm_wp2
);
}
act
<
ymm_t
>
(
ymm_o
,
ymm_o
,
act_gate_
);
// ht
vmulps
(
ymm_o
,
ymm_o
,
ymm_tmp
);
// save ct and ht
vmovups
(
ptr
[
reg_ptr_ct
+
offset
],
ymm_ct
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_o
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
if
(
use_peephole_
)
{
postCode
();
}
else
{
ret
();
}
}
#define DECLARE_LSTM_CREATOR(name) \
class name##Creator : public JitCodeCreator<lstm_attr_t> { \
public: \
/* TODO(TJ): enable more */
\
bool UseMe(const lstm_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const lstm_attr_t& attr) const override { \
return 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode( \
const lstm_attr_t& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_LSTM_CREATOR
(
LSTMCtHt
);
DECLARE_LSTM_CREATOR
(
LSTMC1H1
);
#undef DECLARE_LSTM_CREATOR
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
namespace
gen
=
paddle
::
operators
::
jit
::
gen
;
REGISTER_JITKERNEL_GEN
(
kLSTMCtHt
,
gen
::
LSTMCtHtCreator
);
REGISTER_JITKERNEL_GEN
(
kLSTMC1H1
,
gen
::
LSTMC1H1Creator
);
paddle/fluid/operators/jit/gen/lstm.h
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
class
LSTMJitCode
:
public
VActFunc
{
public:
explicit
LSTMJitCode
(
bool
compute_c1h1
,
const
lstm_attr_t
&
attr
,
size_t
code_size
,
void
*
code_ptr
=
nullptr
)
:
VActFunc
(
code_size
,
code_ptr
),
num_
(
attr
.
d
),
compute_c1h1_
(
compute_c1h1
),
use_peephole_
(
attr
.
use_peephole
)
{
auto
typeExchange
=
[](
KernelType
type
)
->
gen
::
operand_type
{
if
(
type
==
KernelType
::
kVSigmoid
)
{
return
operand_type
::
SIGMOID
;
}
else
if
(
type
==
KernelType
::
kVRelu
)
{
return
operand_type
::
RELU
;
}
else
if
(
type
==
KernelType
::
kVTanh
)
{
return
operand_type
::
TANH
;
}
else
if
(
type
==
KernelType
::
kVIdentity
)
{
return
operand_type
::
IDENTITY
;
}
else
{
LOG
(
FATAL
)
<<
"Do not support this jit::KernelType: "
<<
type
;
}
return
operand_type
::
IDENTITY
;
};
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
act_cell_
=
typeExchange
(
attr
.
act_cell
);
this
->
genCode
();
}
const
char
*
name
()
const
override
{
std
::
string
base
=
"LSTMJitCode"
;
if
(
use_peephole_
)
{
base
+=
"_Peephole"
;
}
if
(
compute_c1h1_
)
{
base
+=
"_C1H1"
;
}
auto
AddTypeStr
=
[
&
](
operand_type
type
)
{
switch
(
type
)
{
case
operand_type
::
RELU
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
EXP
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
SIGMOID
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
TANH
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
IDENTITY
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
};
AddTypeStr
(
act_gate_
);
AddTypeStr
(
act_cand_
);
AddTypeStr
(
act_cell_
);
return
base
.
c_str
();
}
void
genCode
()
override
;
protected:
int
num_
;
bool
compute_c1h1_
;
bool
use_peephole_
;
operand_type
act_gate_
;
operand_type
act_cand_
;
operand_type
act_cell_
;
reg64_t
param1
{
abi_param1
};
};
#define DECLARE_LSTM_JITCODE(name, compute_c1h1) \
class name##JitCode : public LSTMJitCode { \
public: \
explicit name##JitCode(const lstm_attr_t& attr, size_t code_size, \
void* code_ptr = nullptr) \
: LSTMJitCode(compute_c1h1, attr, code_size, code_ptr) {} \
};
DECLARE_LSTM_JITCODE
(
LSTMCtHt
,
false
);
DECLARE_LSTM_JITCODE
(
LSTMC1H1
,
true
);
#undef DECLARE_LSTM_JITCODE
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/gen_base.cc
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen_base.h"
#include <fstream>
#include <iostream>
#include <sstream>
DEFINE_bool
(
dump_jitcode
,
false
,
"Whether to dump the jitcode to file"
);
namespace
paddle
{
namespace
operators
{
namespace
jit
{
// refer do not need useme, it would be the last one.
void
GenBase
::
dumpCode
(
const
unsigned
char
*
code
)
const
{
if
(
code
)
{
static
int
counter
=
0
;
std
::
ostringstream
filename
;
filename
<<
"paddle_jitcode_"
<<
name
()
<<
"."
<<
counter
<<
".bin"
;
counter
++
;
std
::
ofstream
fout
(
filename
.
str
(),
std
::
ios
::
out
);
if
(
fout
.
is_open
())
{
fout
.
write
(
reinterpret_cast
<
const
char
*>
(
code
),
this
->
getSize
());
fout
.
close
();
}
}
}
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/gen_base.h
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <memory> // for unique_ptr
#include "paddle/fluid/operators/jit/kernel_base.h"
DECLARE_bool
(
dump_jitcode
);
namespace
paddle
{
namespace
operators
{
namespace
jit
{
class
GenBase
:
public
Kernel
{
public:
virtual
~
GenBase
()
=
default
;
virtual
const
char
*
name
()
const
=
0
;
virtual
size_t
getSize
()
const
=
0
;
virtual
const
unsigned
char
*
getCodeInternal
()
=
0
;
template
<
typename
Func
>
Func
getCode
()
{
const
unsigned
char
*
code
=
this
->
getCodeInternal
();
if
(
FLAGS_dump_jitcode
)
{
this
->
dumpCode
(
code
);
}
return
reinterpret_cast
<
Func
>
(
const_cast
<
unsigned
char
*>
(
code
));
}
protected:
void
dumpCode
(
const
unsigned
char
*
code
)
const
;
};
// Creator is used to creat the jitcode and save in pool.
// Every JitCode should have one creator.
class
GenCreator
{
public:
virtual
~
GenCreator
()
=
default
;
};
template
<
typename
Attr
>
class
JitCodeCreator
:
public
GenCreator
{
public:
virtual
~
JitCodeCreator
()
=
default
;
// condition when this jit code can be used.
virtual
bool
UseMe
(
const
Attr
&
attr
)
const
=
0
;
// estimate this code size
virtual
size_t
CodeSize
(
const
Attr
&
attr
)
const
=
0
;
// create this code
virtual
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
Attr
&
attr
)
const
=
0
;
};
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/helper.cc
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/helper.h"
#include <algorithm> // tolower
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
#define ONE_CASE(key) \
case key: \
return #key
const
char
*
to_string
(
KernelType
kt
)
{
switch
(
kt
)
{
ONE_CASE
(
kVMul
);
ONE_CASE
(
kVAdd
);
ONE_CASE
(
kVAddRelu
);
ONE_CASE
(
kVSub
);
ONE_CASE
(
kVScal
);
ONE_CASE
(
kVAddBias
);
ONE_CASE
(
kVRelu
);
ONE_CASE
(
kVIdentity
);
ONE_CASE
(
kVExp
);
ONE_CASE
(
kVSigmoid
);
ONE_CASE
(
kVTanh
);
ONE_CASE
(
kLSTMCtHt
);
ONE_CASE
(
kLSTMC1H1
);
ONE_CASE
(
kGRUH1
);
ONE_CASE
(
kGRUHtPart1
);
ONE_CASE
(
kGRUHtPart2
);
ONE_CASE
(
kCRFDecoding
);
ONE_CASE
(
kLayerNorm
);
ONE_CASE
(
kNCHW16CMulNC
);
default:
PADDLE_THROW
(
"Not support type: %d, or forget to add it."
,
kt
);
return
"NOT JITKernel"
;
}
return
nullptr
;
}
#undef ONE_CASE
KernelType
to_kerneltype
(
const
std
::
string
&
act
)
{
std
::
string
lower
=
act
;
std
::
transform
(
lower
.
begin
(),
lower
.
end
(),
lower
.
begin
(),
::
tolower
);
if
(
lower
==
"relu"
||
lower
==
"vrelu"
)
{
return
kVRelu
;
}
else
if
(
lower
==
"identity"
||
lower
==
"videntity"
||
lower
==
""
)
{
return
kVIdentity
;
}
else
if
(
lower
==
"exp"
||
lower
==
"vexp"
)
{
return
kVExp
;
}
else
if
(
lower
==
"sigmoid"
||
lower
==
"vsigmoid"
)
{
return
kVSigmoid
;
}
else
if
(
lower
==
"tanh"
||
lower
==
"vtanh"
)
{
return
kVTanh
;
}
PADDLE_THROW
(
"Not support type: %s, or forget to add this case"
,
act
);
return
kNone
;
}
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/helper.h
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
template
<
KernelType
KT
,
typename
KernelTuples
,
typename
PlaceType
>
inline
typename
std
::
enable_if
<
std
::
is_same
<
typename
KernelTuples
::
data_type
,
float
>::
value
&&
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
,
typename
KernelTuples
::
func_type
>::
type
GetJitCode
(
const
typename
KernelTuples
::
attr_type
&
attr
)
{
using
Func
=
typename
KernelTuples
::
func_type
;
using
Attr
=
typename
KernelTuples
::
attr_type
;
size_t
key
=
JitCodeKey
<
Attr
>
(
attr
);
auto
&
codes
=
JitCodePool
<
KT
>
().
Instance
();
if
(
codes
.
Has
(
key
))
{
return
codes
.
AllKernels
().
at
(
key
)
->
template
getCode
<
Func
>();
}
// creator is not related with attr, so can use KernelKey as key
KernelKey
kkey
(
KT
,
PlaceType
());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto
&
creator_map
=
JitCodeCreatorPool
().
Instance
().
AllCreators
();
auto
iter
=
creator_map
.
find
(
kkey
);
if
(
iter
!=
creator_map
.
end
())
{
auto
&
creators
=
iter
->
second
;
for
(
auto
&
cur
:
creators
)
{
auto
i
=
dynamic_cast
<
const
JitCodeCreator
<
Attr
>*>
(
cur
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
auto
p
=
i
->
CreateJitCode
(
attr
);
if
(
p
)
{
auto
f
=
p
->
template
getCode
<
Func
>();
codes
.
Insert
(
key
,
std
::
move
(
p
));
return
f
;
}
}
}
}
return
nullptr
;
}
template
<
KernelType
KT
,
typename
KernelTuples
,
typename
PlaceType
>
inline
typename
std
::
enable_if
<
!
std
::
is_same
<
typename
KernelTuples
::
data_type
,
float
>::
value
||
!
std
::
is_same
<
PlaceType
,
platform
::
CPUPlace
>::
value
,
typename
KernelTuples
::
func_type
>::
type
GetJitCode
(
const
typename
KernelTuples
::
attr_type
&
attr
)
{
return
nullptr
;
}
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
template
<
KernelType
KT
,
typename
KernelTuples
>
inline
typename
KernelTuples
::
func_type
GetRefer
()
{
auto
&
ref_pool
=
ReferKernelPool
().
Instance
().
AllKernels
();
KernelKey
kkey
(
KT
,
platform
::
CPUPlace
());
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
PADDLE_ENFORCE
(
ref_iter
!=
ref_pool
.
end
(),
"Every Kernel should have reference function."
);
auto
&
ref_impls
=
ref_iter
->
second
;
for
(
auto
&
impl
:
ref_impls
)
{
auto
i
=
dynamic_cast
<
const
ReferKernel
<
KernelTuples
>*>
(
impl
.
get
());
if
(
i
)
{
return
i
->
GetFunc
();
}
}
return
nullptr
;
}
template
<
KernelType
KT
,
typename
KernelTuples
,
typename
PlaceType
=
platform
::
CPUPlace
>
typename
KernelTuples
::
func_type
Get
(
const
typename
KernelTuples
::
attr_type
&
attr
)
{
auto
jitfunc
=
GetJitCode
<
KT
,
KernelTuples
,
PlaceType
>
(
attr
);
if
(
jitfunc
)
{
return
jitfunc
;
}
// pool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey
kkey
(
KT
,
PlaceType
());
auto
&
pool
=
KernelPool
().
Instance
().
AllKernels
();
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
KernelMore
<
KernelTuples
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
return
i
->
GetFunc
();
}
}
}
// The last implementation should be reference function on CPUPlace.
return
GetRefer
<
KT
,
KernelTuples
>
();
}
const
char
*
to_string
(
KernelType
kt
);
KernelType
to_kerneltype
(
const
std
::
string
&
act
);
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
lstm_attr_t
&
attr
)
{
os
<<
"dim_size["
<<
attr
.
d
<<
"],act_gate["
<<
to_string
(
attr
.
act_gate
)
<<
"],act_cand["
<<
to_string
(
attr
.
act_cand
)
<<
"],act_cell["
<<
to_string
(
attr
.
act_cell
)
<<
"],use_peephole["
<<
(
attr
.
use_peephole
?
"True"
:
"False"
)
<<
"]"
;
return
os
;
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
gru_attr_t
&
attr
)
{
os
<<
"dim_size["
<<
attr
.
d
<<
"],act_gate["
<<
to_string
(
attr
.
act_gate
)
<<
"],act_cand["
<<
to_string
(
attr
.
act_cand
)
<<
"]"
;
return
os
;
}
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/kernel_base.h
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jit/macro.h"
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
typedef
enum
{
kNone
=
0
,
kVMul
=
1
,
kVAdd
=
2
,
kVAddRelu
,
kVSub
,
kVScal
,
kVAddBias
,
kVRelu
,
kVIdentity
,
kVExp
,
kVSigmoid
,
kVTanh
,
kLSTMCtHt
,
kLSTMC1H1
,
kGRUH1
,
kGRUHtPart1
,
kGRUHtPart2
,
kCRFDecoding
,
kLayerNorm
,
kNCHW16CMulNC
,
}
KernelType
;
template
<
typename
T
>
struct
XYZNTuples
{
typedef
T
data_type
;
typedef
int
attr_type
;
typedef
void
(
*
func_type
)(
const
T
*
,
const
T
*
,
T
*
,
int
);
};
template
<
typename
T
>
struct
AXYNTuples
:
public
XYZNTuples
<
T
>
{};
template
<
typename
T
>
struct
XYNTuples
{
typedef
T
data_type
;
typedef
int
attr_type
;
typedef
void
(
*
func_type
)(
const
T
*
,
T
*
,
int
);
};
typedef
struct
{
void
*
gates
;
// gates: x_ch, x_ih, x_fh, x_oh
const
void
*
ct_1
;
void
*
ct
;
void
*
ht
;
/* weight_peephole and checked data are only used in peephole*/
const
void
*
wp
{
nullptr
};
// W_ic, W_fc, W_oc
void
*
checked
{
nullptr
};
// size: 2 * d
}
lstm_t
;
typedef
struct
{
void
*
gates
;
// gates: {x_update, x_reset; x_state}
const
void
*
ht_1
;
void
*
ht
;
}
gru_t
;
struct
rnn_attr_s
{
int
d
;
KernelType
act_gate
,
act_cand
;
rnn_attr_s
()
=
default
;
explicit
rnn_attr_s
(
int
_d
,
KernelType
_act_gate
,
KernelType
_act_cand
)
:
d
(
_d
),
act_gate
(
_act_gate
),
act_cand
(
_act_cand
)
{}
};
struct
lstm_attr_s
:
public
rnn_attr_s
{
bool
use_peephole
;
KernelType
act_cell
;
lstm_attr_s
()
=
default
;
explicit
lstm_attr_s
(
int
_d
,
KernelType
_act_gate
,
KernelType
_act_cand
,
KernelType
_act_cell
,
bool
_use_peephole
=
false
)
:
rnn_attr_s
(
_d
,
_act_gate
,
_act_cand
),
use_peephole
(
_use_peephole
),
act_cell
(
_act_cell
)
{}
};
typedef
struct
rnn_attr_s
gru_attr_t
;
typedef
struct
lstm_attr_s
lstm_attr_t
;
template
<
typename
T
>
struct
LSTMTuples
{
typedef
T
data_type
;
typedef
lstm_attr_t
attr_type
;
typedef
void
(
*
func_type
)(
lstm_t
*
,
const
lstm_attr_t
*
);
};
template
<
typename
T
>
struct
GRUTuples
{
typedef
T
data_type
;
typedef
gru_attr_t
attr_type
;
typedef
void
(
*
func_type
)(
gru_t
*
,
const
gru_attr_t
*
);
};
template
<
typename
T
>
struct
CRFDecodingTuples
{
typedef
T
data_type
;
typedef
int
attr_type
;
typedef
void
(
*
func_type
)(
const
int
,
const
T
*
,
const
T
*
,
T
*
,
int
*
,
int
);
};
template
<
typename
T
>
struct
LayerNormTuples
{
typedef
T
data_type
;
typedef
int
attr_type
;
typedef
void
(
*
func_type
)(
T
*
,
T
*
,
T
*
,
T
*
,
const
T
*
,
const
T
*
,
int
,
const
float
,
int
);
};
// nChw16c = nChw16c .* NC
template
<
typename
T
>
struct
NCHW16CMulNCTuples
{
typedef
T
data_type
;
typedef
int
attr_type
;
typedef
void
(
*
func_type
)(
const
T
*
,
const
T
*
,
T
*
,
int
,
int
);
};
// Just for adding to kernel pool without template
class
Kernel
{
public:
Kernel
()
=
default
;
virtual
~
Kernel
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
};
template
<
typename
KernelTuples
>
class
KernelMore
:
public
Kernel
{
public:
using
T
=
typename
KernelTuples
::
data_type
;
using
Func
=
typename
KernelTuples
::
func_type
;
using
Attr
=
typename
KernelTuples
::
attr_type
;
virtual
Func
GetFunc
()
const
{
return
func
;
}
virtual
bool
UseMe
(
const
Attr
&
attr
)
const
=
0
;
virtual
const
char
*
ImplType
()
const
=
0
;
protected:
Func
func
{
nullptr
};
};
template
<
typename
KernelTuples
>
class
ReferKernel
:
public
KernelMore
<
KernelTuples
>
{
public:
// Refer code can always be used
bool
UseMe
(
const
typename
KernelTuples
::
attr_type
&
attr
)
const
override
{
return
true
;
}
const
char
*
ImplType
()
const
override
{
return
"Refer"
;
}
};
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/kernel_key.cc
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
template
<
>
size_t
JitCodeKey
<
int
>
(
const
int
&
d
)
{
return
d
;
}
constexpr
int
act_type_shift
=
3
;
// suppot 2^3 act types
template
<
>
size_t
JitCodeKey
<
lstm_attr_t
>
(
const
lstm_attr_t
&
attr
)
{
size_t
key
=
attr
.
d
;
int
gate_key
=
static_cast
<
int
>
(
attr
.
act_gate
)
<<
1
;
int
cand_key
=
static_cast
<
int
>
(
attr
.
act_cand
)
<<
(
1
+
act_type_shift
);
int
cell_key
=
static_cast
<
int
>
(
attr
.
act_cell
)
<<
(
1
+
act_type_shift
*
2
);
return
(
key
<<
(
1
+
act_type_shift
*
3
))
+
gate_key
+
cand_key
+
cell_key
+
attr
.
use_peephole
;
}
template
<
>
size_t
JitCodeKey
<
gru_attr_t
>
(
const
gru_attr_t
&
attr
)
{
size_t
key
=
attr
.
d
;
return
(
key
<<
(
act_type_shift
*
2
))
+
static_cast
<
int
>
(
attr
.
act_gate
)
+
(
static_cast
<
int
>
(
attr
.
act_cand
)
<<
act_type_shift
);
}
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/kernel_key.h
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
struct
KernelKey
{
struct
Hash
{
size_t
operator
()(
const
KernelKey
&
key
)
const
{
int
place
=
key
.
place_
.
which
();
// less than 2^8
int
type
=
static_cast
<
int
>
(
key
.
type_
)
<<
8
;
// less than 2^(32-8)
std
::
hash
<
int
>
hasher
;
return
hasher
(
place
+
type
);
}
};
KernelType
type_
;
platform
::
Place
place_
;
KernelKey
(
KernelType
type
,
platform
::
Place
place
)
:
type_
(
type
),
place_
(
place
)
{}
size_t
hash_key
()
const
{
return
Hash
()(
*
this
);
}
bool
operator
==
(
const
KernelKey
&
o
)
const
{
return
platform
::
places_are_same_class
(
place_
,
o
.
place_
)
&&
type_
==
o
.
type_
;
}
bool
operator
!=
(
const
KernelKey
&
o
)
const
{
return
!
(
*
this
==
o
);
}
};
// Every JitCode should have a method to get the key from attribution
template
<
typename
Attr
>
size_t
JitCodeKey
(
const
Attr
&
attr
);
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/kernel_pool.cc
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include <memory> // for shared_ptr
#include <string>
#include <unordered_map>
namespace
paddle
{
namespace
operators
{
namespace
jit
{
JitCodeCreatorPool
&
JitCodeCreatorPool
::
Instance
()
{
static
JitCodeCreatorPool
g_creator_pool
;
return
g_creator_pool
;
}
KernelPool
&
KernelPool
::
Instance
()
{
static
KernelPool
g_kernel_pool
;
return
g_kernel_pool
;
}
ReferKernelPool
&
ReferKernelPool
::
Instance
()
{
static
ReferKernelPool
g_refer_kernel_pool
;
return
g_refer_kernel_pool
;
}
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/kernel_pool.h
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <memory> // for unique_ptr
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
template
<
KernelType
KT
>
class
JitCodePool
{
typedef
std
::
unique_ptr
<
GenBase
>
GenBasePtr
;
typedef
std
::
unordered_map
<
size_t
,
GenBasePtr
>
JitCodeMap
;
public:
JitCodePool
()
=
default
;
static
JitCodePool
&
Instance
()
{
static
thread_local
JitCodePool
<
KT
>
g_jit_codes
;
return
g_jit_codes
;
}
const
JitCodeMap
&
AllKernels
()
{
return
codes_
;
}
bool
Has
(
size_t
key
)
const
{
return
codes_
.
find
(
key
)
!=
codes_
.
end
();
}
void
Insert
(
size_t
key
,
GenBasePtr
value
)
{
codes_
.
emplace
(
key
,
std
::
move
(
value
));
}
private:
JitCodeMap
codes_
;
DISABLE_COPY_AND_ASSIGN
(
JitCodePool
);
};
class
JitCodeCreatorPool
{
typedef
std
::
unique_ptr
<
const
GenCreator
>
GenCreatorPtr
;
typedef
std
::
unordered_map
<
KernelKey
,
std
::
vector
<
GenCreatorPtr
>
,
KernelKey
::
Hash
>
GenCreatorPtrMap
;
public:
JitCodeCreatorPool
()
=
default
;
static
JitCodeCreatorPool
&
Instance
();
GenCreatorPtrMap
&
AllCreators
()
{
return
creators_
;
}
void
Insert
(
const
KernelKey
&
key
,
GenCreatorPtr
value
)
{
if
(
creators_
.
find
(
key
)
==
creators_
.
end
())
{
creators_
.
emplace
(
key
,
std
::
vector
<
GenCreatorPtr
>
());
}
creators_
.
at
(
key
).
emplace_back
(
std
::
move
(
value
));
}
private:
GenCreatorPtrMap
creators_
;
DISABLE_COPY_AND_ASSIGN
(
JitCodeCreatorPool
);
};
typedef
std
::
unique_ptr
<
const
Kernel
>
KernelPtr
;
typedef
std
::
unordered_map
<
KernelKey
,
std
::
vector
<
KernelPtr
>
,
KernelKey
::
Hash
>
KernelMap
;
class
KernelPool
{
public:
static
KernelPool
&
Instance
();
KernelPool
()
=
default
;
KernelMap
&
AllKernels
()
{
return
pool_
;
}
void
Insert
(
const
KernelKey
&
key
,
KernelPtr
value
)
{
if
(
pool_
.
find
(
key
)
==
pool_
.
end
())
{
pool_
.
emplace
(
key
,
std
::
vector
<
KernelPtr
>
());
}
pool_
.
at
(
key
).
emplace_back
(
std
::
move
(
value
));
}
private:
KernelMap
pool_
;
DISABLE_COPY_AND_ASSIGN
(
KernelPool
);
};
// Every kernel should have refer code and it should be used in unit tests,
// so refer kernels should have it's independent kernel pool
class
ReferKernelPool
{
public:
static
ReferKernelPool
&
Instance
();
ReferKernelPool
()
=
default
;
KernelMap
&
AllKernels
()
{
return
pool_
;
}
void
Insert
(
const
KernelKey
&
key
,
KernelPtr
value
)
{
if
(
pool_
.
find
(
key
)
==
pool_
.
end
())
{
pool_
.
emplace
(
key
,
std
::
vector
<
KernelPtr
>
());
}
pool_
.
at
(
key
).
emplace_back
(
std
::
move
(
value
));
}
private:
KernelMap
pool_
;
DISABLE_COPY_AND_ASSIGN
(
ReferKernelPool
);
};
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/macro.h
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
namespace
paddle
{
namespace
operators
{
namespace
jit
{
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/more/CMakeLists.txt
0 → 100644
浏览文件 @
693e5e65
function
(
USE_JITKERNEL_MORE TARGET TYPE
)
file
(
APPEND
${
jit_file
}
"USE_JITKERNEL_MORE(
${
TARGET
}
${
TYPE
}
);
\n
"
)
endfunction
()
if
(
WITH_MKLML
)
add_subdirectory
(
mkl
)
endif
()
if
(
WITH_AVX
)
add_subdirectory
(
intrinsic
)
endif
()
# mix should be last
add_subdirectory
(
mix
)
set
(
JIT_KERNEL_DEPS
${
JIT_KERNEL_DEPS
}
PARENT_SCOPE
)
paddle/fluid/operators/jit/more/intrinsic/CMakeLists.txt
0 → 100644
浏览文件 @
693e5e65
file
(
GLOB jit_kernel_cc_intrinsic RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"*.cc"
)
cc_library
(
jit_kernel_intrinsic SRCS
${
jit_kernel_cc_intrinsic
}
DEPS jit_kernel_base
)
set
(
JIT_KERNEL_DEPS
${
JIT_KERNEL_DEPS
}
jit_kernel_intrinsic PARENT_SCOPE
)
# use mkl kernels by name and type
USE_JITKERNEL_MORE
(
kCRFDecoding, intrinsic
)
USE_JITKERNEL_MORE
(
kLayerNorm, intrinsic
)
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.cc
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h"
#include <limits>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
more
{
namespace
intrinsic
{
// Note: intrinsic code is not runtime build.
// For example, if you build code on AVX, and run on AVX512 it can only use AVX
void
CRFDecoding
(
const
int
seq_len
,
const
float
*
x
,
const
float
*
w
,
float
*
alpha
,
int
*
track
,
int
tag_num
)
{
#ifdef __AVX512F__
const
int
step_size
=
ZMM_FLOAT_BLOCK
;
#else
const
int
step_size
=
YMM_FLOAT_BLOCK
;
#endif
const
int
end
=
tag_num
/
step_size
;
const
int
rest
=
tag_num
%
step_size
;
/* Setup the alpha initial value.*/
int
i_offset
=
0
;
int
last_offset
=
rest
-
step_size
;
for
(
int
i
=
0
;
i
<=
end
;
++
i
)
{
#ifdef __AVX512F__
// Declare the variable for the content of weights, input and alpha values.
__m512
w_content
,
x_content
,
alpha_content
;
// Load the relevant data into the variables from un-aligned address.
w_content
=
_mm512_loadu_ps
(
w
+
i_offset
);
x_content
=
_mm512_loadu_ps
(
x
+
i_offset
);
alpha_content
=
_mm512_add_ps
(
w_content
,
x_content
);
// Save the alpha value.
_mm512_storeu_ps
(
alpha_value
+
i_offset
,
alpha_content
);
#else
// AVX or AVX2
// weights, input and alpha values.
__m256
w_content
,
x_content
,
alpha_content
;
// Load the relevant data into the variables from un-aligned address.
w_content
=
_mm256_loadu_ps
(
w
+
i_offset
);
x_content
=
_mm256_loadu_ps
(
x
+
i_offset
);
alpha_content
=
_mm256_add_ps
(
w_content
,
x_content
);
_mm256_storeu_ps
(
alpha
+
i_offset
,
alpha_content
);
#endif
i_offset
+=
step_size
;
if
(
i
==
end
-
1
)
{
if
(
rest
>
0
)
{
i_offset
+=
last_offset
;
}
else
{
break
;
}
}
}
// Use the column-major strategy to get the location of maximum score.
int
seq_offset
=
0
;
constexpr
int
state_trans_base_idx
=
2
;
for
(
int
k
=
1
;
k
<
seq_len
;
++
k
)
{
int
j_offset
=
0
;
for
(
int
j
=
0
;
j
<=
end
;
++
j
)
{
/* Initialize the variables of maximum score and location.*/
#ifdef __AVX512F__
__m512
max_score
=
_mm512_set1_ps
(
-
std
::
numeric_limits
<
float
>::
max
());
__m512i
max_j
=
_mm512_setzero_si512
();
#else
__m256
max_score
=
_mm256_set1_ps
(
-
std
::
numeric_limits
<
float
>::
max
());
__m256i
max_j
=
_mm256_set1_epi32
(
0
);
#endif
/* Calculate the offset of transition_weights.*/
int
trans_offset
=
state_trans_base_idx
*
tag_num
+
j_offset
;
for
(
int
i
=
0
;
i
<
tag_num
;
++
i
)
{
/* Initalize the content of alpha variable with related offset.*/
#ifdef __AVX512F__
__m512
alpha_content
=
_mm512_set1_ps
(
*
(
alpha
+
seq_offset
+
i
));
/* Obtain the content of weights from un-aligned address.*/
__m512
w_content
=
_mm512_loadu_ps
(
w
+
trans_offset
);
__m512
score_v
=
_mm512_add_ps
(
alpha_content
,
w_content
);
__mmask16
mask
=
_mm512_cmp_ps_mask
(
score_v
,
max_score
,
_CMP_GT_OS
);
/* AVX512 instructions.*/
max_j
=
_mm512_mask_set1_epi32
(
max_j
,
mask
,
i
);
/* Update the max_score value.*/
max_score
=
_mm512_max_ps
(
max_score
,
score_v
);
#else
__m256
alpha_content
=
_mm256_broadcast_ss
(
alpha
+
seq_offset
+
i
);
/* Obtain the content of weights from un-aligned address.*/
__m256
w_content
=
_mm256_loadu_ps
(
w
+
trans_offset
);
__m256
score_v
=
_mm256_add_ps
(
alpha_content
,
w_content
);
__m256
mask
=
_mm256_cmp_ps
(
score_v
,
max_score
,
_CMP_GT_OS
);
/* According to the mask value, update the index of the max_score.*/
#ifdef __AVX2__
max_j
=
_mm256_or_si256
(
_mm256_andnot_si256
((
__m256i
)
mask
,
max_j
),
_mm256_and_si256
((
__m256i
)
mask
,
_mm256_set1_epi32
(
i
)));
#else
__m128i
lo_max_j
=
_mm256_extractf128_si256
(
max_j
,
0
);
__m128i
hi_max_j
=
_mm256_extractf128_si256
(
max_j
,
1
);
__m128i
lo_mask
=
_mm256_extractf128_si256
(
*
(
__m256i
*
)
&
mask
,
0
);
// NOLINT
__m128i
hi_mask
=
_mm256_extractf128_si256
(
*
(
__m256i
*
)
&
mask
,
1
);
// NOLINT
lo_max_j
=
_mm_andnot_si128
(
lo_mask
,
lo_max_j
);
hi_max_j
=
_mm_andnot_si128
(
hi_mask
,
hi_max_j
);
lo_mask
=
_mm_and_si128
(
lo_mask
,
_mm_set1_epi32
(
i
));
hi_mask
=
_mm_and_si128
(
hi_mask
,
_mm_set1_epi32
(
i
));
lo_max_j
=
_mm_or_si128
(
lo_mask
,
lo_max_j
);
hi_max_j
=
_mm_or_si128
(
hi_mask
,
hi_max_j
);
max_j
=
_mm256_insertf128_si256
(
max_j
,
lo_max_j
,
0
);
max_j
=
_mm256_insertf128_si256
(
max_j
,
hi_max_j
,
1
);
#endif
/* Update the max_score value.*/
max_score
=
_mm256_max_ps
(
max_score
,
score_v
);
#endif
trans_offset
+=
tag_num
;
}
/* Update the alpha and track values. */
#ifdef __AVX512F__
__m512
x_content
=
_mm512_loadu_ps
(
x
+
seq_offset
+
this
->
num_
+
j_offset
);
max_score
=
_mm512_add_ps
(
max_score
,
x_content
);
_mm512_storeu_ps
(
alpha
+
seq_offset
+
this
->
num_
+
j_offset
,
max_score
);
_mm512_storeu_si512
(
reinterpret_cast
<
__m512i
*>
(
track
+
seq_offset
+
this
->
num_
+
j_offset
),
max_j
);
#else
__m256
x_content
=
_mm256_loadu_ps
(
x
+
seq_offset
+
tag_num
+
j_offset
);
max_score
=
_mm256_add_ps
(
max_score
,
x_content
);
_mm256_storeu_ps
(
alpha
+
seq_offset
+
tag_num
+
j_offset
,
max_score
);
_mm256_storeu_si256
(
reinterpret_cast
<
__m256i
*>
(
track
+
seq_offset
+
tag_num
+
j_offset
),
max_j
);
#endif
/* Calculate the offset of next step*/
j_offset
+=
step_size
;
if
(
j
==
end
-
1
)
{
if
(
rest
>
0
)
{
j_offset
+=
last_offset
;
}
else
{
break
;
}
}
}
seq_offset
+=
tag_num
;
}
}
bool
CRFDecodingKernel
::
UseMe
(
const
int
&
d
)
const
{
#ifdef __AVX512F__
constexpr
int
block
=
ZMM_FLOAT_BLOCK
;
#else
constexpr
int
block
=
YMM_FLOAT_BLOCK
;
#endif
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
>=
block
;
}
}
// namespace intrinsic
}
// namespace more
}
// namespace jit
}
// namespace operators
}
// namespace paddle
namespace
intrinsic
=
paddle
::
operators
::
jit
::
more
::
intrinsic
;
REGISTER_JITKERNEL_MORE
(
kCRFDecoding
,
intrinsic
,
intrinsic
::
CRFDecodingKernel
);
paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
more
{
namespace
intrinsic
{
void
CRFDecoding
(
const
int
seq_len
,
const
float
*
x
,
const
float
*
w
,
float
*
alpha
,
int
*
track
,
int
tag_num
);
class
CRFDecodingKernel
:
public
KernelMore
<
CRFDecodingTuples
<
float
>>
{
public:
CRFDecodingKernel
()
{
this
->
func
=
CRFDecoding
;
}
bool
UseMe
(
const
typename
CRFDecodingTuples
<
float
>::
attr_type
&
)
const
override
;
const
char
*
ImplType
()
const
override
{
return
"Intrinsic"
;
}
};
}
// namespace intrinsic
}
// namespace more
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/intrinsic/layer_norm.h"
#include <limits>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
more
{
namespace
intrinsic
{
void
LayerNorm
(
float
*
x
,
float
*
out
,
float
*
mean
,
float
*
var
,
const
float
*
scale
,
const
float
*
bias
,
int
height
,
const
float
epsilon
,
int
right
)
{
__m256
sum
;
__m256
mean_vec
,
var_vec
;
__m128
hi
,
lo
;
__m256
tmp
;
size_t
offset
;
size_t
j
;
int
block
=
YMM_FLOAT_BLOCK
;
const
int
rest
=
right
%
block
;
const
int
end
=
right
-
rest
;
__m256
reverse_num_vec
=
_mm256_div_ps
(
_mm256_set1_ps
(
1.0
),
_mm256_set1_ps
(
right
));
__m256
epsilon_vec
=
_mm256_set1_ps
(
epsilon
);
int
rest_mask
=
((
-
1
)
&
(
~
((
~
0U
)
>>
(
sizeof
(
int
)
*
8
-
(
block
-
rest
)))))
&
0x0ff
;
__m256i
mask_vec
=
_mm256_set_epi32
(
rest_mask
&
0x80
?
0xffffffff
:
0
,
rest_mask
&
0x40
?
0xffffffff
:
0
,
rest_mask
&
0x20
?
0xffffffff
:
0
,
rest_mask
&
0x10
?
0xffffffff
:
0
,
rest_mask
&
0x8
?
0xffffffff
:
0
,
rest_mask
&
0x4
?
0xffffffff
:
0
,
rest_mask
&
0x2
?
0xffffffff
:
0
,
rest_mask
&
0x1
?
0xffffffff
:
0
);
for
(
int
i
=
0
;
i
<
height
;
++
i
)
{
offset
=
i
*
right
;
/* get mean */
sum
=
_mm256_setzero_ps
();
for
(
j
=
offset
;
j
<
end
+
offset
;
j
+=
block
)
{
sum
=
_mm256_add_ps
(
sum
,
_mm256_loadu_ps
((
const
float
*
)
x
+
j
));
}
if
(
rest
!=
0
)
{
j
=
offset
+
right
-
block
;
tmp
=
_mm256_loadu_ps
((
const
float
*
)
x
+
j
);
tmp
=
_mm256_blendv_ps
(
_mm256_setzero_ps
(),
tmp
,
*
(
__m256
*
)
&
mask_vec
);
// NOLINT
sum
=
_mm256_add_ps
(
sum
,
tmp
);
}
hi
=
_mm256_extractf128_ps
(
sum
,
1
);
lo
=
_mm256_extractf128_ps
(
sum
,
0
);
sum
=
_mm256_add_ps
(
sum
,
_mm256_insertf128_ps
(
_mm256_insertf128_ps
(
_mm256_setzero_ps
(),
hi
,
0
),
lo
,
1
));
sum
=
_mm256_hadd_ps
(
sum
,
sum
);
sum
=
_mm256_hadd_ps
(
sum
,
sum
);
mean_vec
=
_mm256_mul_ps
(
sum
,
reverse_num_vec
);
mean
[
i
]
=
*
reinterpret_cast
<
float
*>
(
&
mean_vec
);
/* get variance */
sum
=
_mm256_setzero_ps
();
for
(
j
=
offset
;
j
<
end
+
offset
;
j
+=
block
)
{
tmp
=
_mm256_sub_ps
(
_mm256_loadu_ps
((
const
float
*
)
x
+
j
),
mean_vec
);
tmp
=
_mm256_mul_ps
(
tmp
,
tmp
);
sum
=
_mm256_add_ps
(
sum
,
tmp
);
}
if
(
rest
!=
0
)
{
j
=
offset
+
right
-
block
;
tmp
=
_mm256_sub_ps
(
_mm256_loadu_ps
((
const
float
*
)
x
+
j
),
mean_vec
);
tmp
=
_mm256_mul_ps
(
tmp
,
tmp
);
tmp
=
_mm256_blendv_ps
(
_mm256_setzero_ps
(),
tmp
,
*
(
__m256
*
)
&
mask_vec
);
// NOLINT
sum
=
_mm256_add_ps
(
sum
,
tmp
);
}
hi
=
_mm256_extractf128_ps
(
sum
,
1
);
lo
=
_mm256_extractf128_ps
(
sum
,
0
);
sum
=
_mm256_add_ps
(
sum
,
_mm256_insertf128_ps
(
_mm256_insertf128_ps
(
_mm256_setzero_ps
(),
hi
,
0
),
lo
,
1
));
sum
=
_mm256_hadd_ps
(
sum
,
sum
);
sum
=
_mm256_hadd_ps
(
sum
,
sum
);
var_vec
=
_mm256_mul_ps
(
sum
,
reverse_num_vec
);
var
[
i
]
=
*
reinterpret_cast
<
float
*>
(
&
var_vec
);
/* get x_norm and calculate output*/
for
(
j
=
offset
;
j
<
end
+
offset
;
j
+=
block
)
{
tmp
=
_mm256_sub_ps
(
_mm256_loadu_ps
((
const
float
*
)
x
+
j
),
mean_vec
);
tmp
=
_mm256_div_ps
(
tmp
,
_mm256_sqrt_ps
(
_mm256_add_ps
(
var_vec
,
epsilon_vec
)));
_mm256_storeu_ps
(
reinterpret_cast
<
float
*>
(
out
)
+
j
,
tmp
);
}
if
(
rest
!=
0
)
{
j
=
offset
+
right
-
block
;
tmp
=
_mm256_sub_ps
(
_mm256_loadu_ps
((
const
float
*
)
x
+
j
),
mean_vec
);
tmp
=
_mm256_div_ps
(
tmp
,
_mm256_sqrt_ps
(
_mm256_add_ps
(
var_vec
,
epsilon_vec
)));
_mm256_storeu_ps
(
reinterpret_cast
<
float
*>
(
out
)
+
j
,
tmp
);
}
if
(
scale
)
{
if
(
rest
!=
0
)
{
j
=
offset
+
right
-
block
;
tmp
=
_mm256_loadu_ps
((
const
float
*
)
out
+
j
);
}
for
(
j
=
offset
;
j
<
end
+
offset
;
j
+=
block
)
{
_mm256_storeu_ps
(
reinterpret_cast
<
float
*>
(
out
)
+
j
,
_mm256_mul_ps
(
_mm256_loadu_ps
((
const
float
*
)
out
+
j
),
_mm256_loadu_ps
((
const
float
*
)
scale
+
j
-
offset
)));
}
if
(
rest
!=
0
)
{
j
=
offset
+
right
-
block
;
_mm256_storeu_ps
(
reinterpret_cast
<
float
*>
(
out
)
+
j
,
_mm256_mul_ps
(
tmp
,
_mm256_loadu_ps
((
const
float
*
)
scale
+
j
-
offset
)));
}
}
if
(
bias
)
{
if
(
rest
!=
0
)
{
j
=
offset
+
right
-
block
;
tmp
=
_mm256_loadu_ps
((
const
float
*
)
out
+
j
);
}
for
(
j
=
offset
;
j
<
end
+
offset
;
j
+=
block
)
{
_mm256_storeu_ps
(
reinterpret_cast
<
float
*>
(
out
)
+
j
,
_mm256_add_ps
(
_mm256_loadu_ps
((
const
float
*
)
out
+
j
),
_mm256_loadu_ps
((
const
float
*
)
bias
+
j
-
offset
)));
}
if
(
rest
!=
0
)
{
j
=
offset
+
right
-
block
;
_mm256_storeu_ps
(
reinterpret_cast
<
float
*>
(
out
)
+
j
,
_mm256_add_ps
(
tmp
,
_mm256_loadu_ps
((
const
float
*
)
bias
+
j
-
offset
)));
}
}
}
}
bool
LayerNormKernel
::
UseMe
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx
)
&&
d
>=
YMM_FLOAT_BLOCK
;
}
}
// namespace intrinsic
}
// namespace more
}
// namespace jit
}
// namespace operators
}
// namespace paddle
namespace
intrinsic
=
paddle
::
operators
::
jit
::
more
::
intrinsic
;
REGISTER_JITKERNEL_MORE
(
kLayerNorm
,
intrinsic
,
intrinsic
::
LayerNormKernel
);
paddle/fluid/operators/jit/more/intrinsic/layer_norm.h
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
more
{
namespace
intrinsic
{
void
LayerNorm
(
float
*
x
,
float
*
out
,
float
*
mean
,
float
*
var
,
const
float
*
scale
,
const
float
*
bias
,
int
height
,
const
float
epsilon
,
int
right
);
class
LayerNormKernel
:
public
KernelMore
<
LayerNormTuples
<
float
>>
{
public:
LayerNormKernel
()
{
this
->
func
=
LayerNorm
;
}
bool
UseMe
(
const
typename
LayerNormTuples
<
float
>::
attr_type
&
)
const
override
;
const
char
*
ImplType
()
const
override
{
return
"Intrinsic"
;
}
};
}
// namespace intrinsic
}
// namespace more
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/more/mix/CMakeLists.txt
0 → 100644
浏览文件 @
693e5e65
file
(
GLOB jit_kernel_mix_cc RELATIVE
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
"*.cc"
)
cc_library
(
jit_kernel_mix SRCS
${
jit_kernel_mix_cc
}
DEPS jit_kernel_base
)
set
(
JIT_KERNEL_DEPS
${
JIT_KERNEL_DEPS
}
jit_kernel_mix PARENT_SCOPE
)
USE_JITKERNEL_MORE
(
kVSigmoid, mix
)
USE_JITKERNEL_MORE
(
kVTanh, mix
)
USE_JITKERNEL_MORE
(
kLSTMCtHt, mix
)
USE_JITKERNEL_MORE
(
kLSTMC1H1, mix
)
USE_JITKERNEL_MORE
(
kGRUH1, mix
)
USE_JITKERNEL_MORE
(
kGRUHtPart1, mix
)
USE_JITKERNEL_MORE
(
kGRUHtPart2, mix
)
paddle/fluid/operators/jit/more/mix/mix.cc
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/mix/mix.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
more
{
namespace
mix
{
void
VSigmoid
(
const
T
*
x
,
T
*
y
,
int
n
)
{
const
float
min
=
SIGMOID_THRESHOLD_MIN
;
const
float
max
=
SIGMOID_THRESHOLD_MAX
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
static_cast
<
T
>
(
0
)
-
y
[
i
];
}
auto
compute
=
Get
<
KernelType
::
kVExp
,
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
n
);
compute
(
y
,
y
,
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
y
[
i
]);
}
}
void
VTanh
(
const
T
*
x
,
T
*
y
,
int
n
)
{
const
T
a
=
2
,
b
=
-
1
;
auto
compute_scal
=
Get
<
kVScal
,
AXYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
n
);
auto
compute_addbias
=
Get
<
kVAddBias
,
AXYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
n
);
auto
compute_sigmoid
=
Get
<
kVSigmoid
,
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
n
);
compute_scal
(
&
a
,
x
,
y
,
n
);
compute_sigmoid
(
y
,
y
,
n
);
compute_scal
(
&
a
,
y
,
y
,
n
);
compute_addbias
(
&
b
,
y
,
y
,
n
);
}
void
(
*
getActFunc
(
KernelType
type
,
int
d
))(
const
T
*
,
T
*
,
int
)
{
// NOLINT
if
(
type
==
kVSigmoid
)
{
return
Get
<
kVSigmoid
,
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
}
else
if
(
type
==
kVRelu
)
{
return
Get
<
kVRelu
,
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
}
else
if
(
type
==
kVTanh
)
{
return
Get
<
kVTanh
,
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
}
else
if
(
type
==
kVIdentity
)
{
return
Get
<
kVIdentity
,
XYNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
return
nullptr
;
}
void
LSTMCtHt
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
const
T
*
ct_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ct_1
);
T
*
ct
=
reinterpret_cast
<
T
*>
(
step
->
ct
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
const
T
*
wp
=
reinterpret_cast
<
const
T
*>
(
step
->
wp
);
T
*
checked
=
reinterpret_cast
<
T
*>
(
step
->
checked
);
const
int
d
=
attr
->
d
;
const
int
d2
=
d
*
2
;
const
int
d3
=
d
*
3
;
auto
vmul_d
=
Get
<
kVMul
,
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
auto
vadd_d
=
Get
<
kVAdd
,
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
auto
vadd_d2
=
Get
<
kVAdd
,
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d2
);
auto
act_gate_d
=
getActFunc
(
attr
->
act_gate
,
d
);
auto
act_gate_d2
=
getActFunc
(
attr
->
act_gate
,
d2
);
auto
act_gate_d3
=
getActFunc
(
attr
->
act_gate
,
d3
);
auto
act_cand_d
=
getActFunc
(
attr
->
act_cand
,
d
);
auto
act_cell_d
=
getActFunc
(
attr
->
act_cell
,
d
);
if
(
attr
->
use_peephole
)
{
vmul_d
(
wp
,
ct_1
,
checked
,
d
);
vmul_d
(
wp
+
d
,
ct_1
,
checked
+
d
,
d
);
vadd_d2
(
checked
,
gates
+
d
,
gates
+
d
,
d2
);
act_gate_d2
(
gates
+
d
,
gates
+
d
,
d2
);
}
else
{
act_gate_d3
(
gates
+
d
,
gates
+
d
,
d3
);
}
// C_t = C_t-1 * fgated + cand_gated * igated
act_cand_d
(
gates
,
gates
,
d
);
vmul_d
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
vmul_d
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
vadd_d
(
gates
+
d
,
gates
+
d2
,
ct
,
d
);
if
(
attr
->
use_peephole
)
{
// get ogated
vmul_d
(
wp
+
d2
,
ct
,
gates
+
d
,
d
);
vadd_d
(
gates
+
d
,
gates
+
d3
,
gates
+
d3
,
d
);
act_gate_d
(
gates
+
d3
,
gates
+
d3
,
d
);
}
// H_t = act_cell(C_t) * ogated
act_cell_d
(
ct
,
gates
+
d2
,
d
);
vmul_d
(
gates
+
d2
,
gates
+
d3
,
ht
,
d
);
}
void
LSTMC1H1
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
ct
=
reinterpret_cast
<
T
*>
(
step
->
ct
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
int
d
=
attr
->
d
;
int
d2
=
d
*
2
;
int
d3
=
d
*
3
;
auto
vmul_d
=
Get
<
kVMul
,
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
auto
vadd_d
=
Get
<
kVAdd
,
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
auto
act_gate_d
=
getActFunc
(
attr
->
act_gate
,
d
);
auto
act_cand_d
=
getActFunc
(
attr
->
act_cand
,
d
);
auto
act_cell_d
=
getActFunc
(
attr
->
act_cell
,
d
);
/* C_t = igated * cgated*/
act_gate_d
(
gates
+
d
,
gates
+
d
,
d
);
act_cand_d
(
gates
,
gates
,
d
);
vmul_d
(
gates
,
gates
+
d
,
ct
,
d
);
if
(
attr
->
use_peephole
)
{
// get outgated, put W_oc * C_t on igated
const
T
*
wp
=
reinterpret_cast
<
const
T
*>
(
step
->
wp
);
vmul_d
(
wp
+
d2
,
ct
,
gates
+
d
,
d
);
vadd_d
(
gates
+
d
,
gates
+
d3
,
gates
+
d3
,
d
);
}
/* H_t = act_cell(C_t) * ogated */
act_gate_d
(
gates
+
d3
,
gates
+
d3
,
d
);
act_cell_d
(
ct
,
gates
+
d2
,
d
);
vmul_d
(
gates
+
d2
,
gates
+
d3
,
ht
,
d
);
}
// compute h1 without h0
void
GRUH1
(
gru_t
*
step
,
const
gru_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
int
d
=
attr
->
d
;
int
d2
=
d
*
2
;
auto
act_gate
=
getActFunc
(
attr
->
act_gate
,
d
);
auto
act_cand
=
getActFunc
(
attr
->
act_cand
,
d
);
auto
vmul_d
=
Get
<
kVMul
,
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
d
);
act_gate
(
gates
,
gates
,
d
);
act_cand
(
gates
+
d2
,
gates
+
d2
,
d
);
vmul_d
(
gates
,
gates
+
d2
,
ht
,
d
);
}
// compute the first part of GRU: ht = act_gate(r) * ht_1
void
GRUHtPart1
(
gru_t
*
step
,
const
gru_attr_t
*
attr
)
{
// W: {W_update, W_reset; W_state}
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
const
T
*
ht_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ht_1
);
auto
act_gate
=
getActFunc
(
attr
->
act_gate
,
attr
->
d
);
auto
vmul_d
=
Get
<
kVMul
,
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
->
d
);
act_gate
(
gates
+
attr
->
d
,
gates
+
attr
->
d
,
attr
->
d
);
vmul_d
(
ht_1
,
gates
+
attr
->
d
,
ht
,
attr
->
d
);
}
// compute the second part of GRU:
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
void
GRUHtPart2
(
gru_t
*
step
,
const
gru_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
const
T
*
ht_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ht_1
);
int
d
=
attr
->
d
;
auto
act_gate
=
getActFunc
(
attr
->
act_gate
,
d
);
auto
act_cand
=
getActFunc
(
attr
->
act_cand
,
d
);
T
*
y
=
gates
+
d
*
2
;
act_gate
(
gates
,
gates
,
d
);
act_cand
(
y
,
y
,
d
);
// out = zt*ht~ + (1-zt)*ht_1
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
ht
[
i
]
=
gates
[
i
]
*
y
[
i
]
+
(
static_cast
<
T
>
(
1
)
-
gates
[
i
])
*
ht_1
[
i
];
}
}
// TODO(TJ): tuning me
bool
VSigmoidKernel
::
UseMe
(
const
int
&
d
)
const
{
return
true
;
}
bool
VTanhKernel
::
UseMe
(
const
int
&
d
)
const
{
return
true
;
}
bool
LSTMCtHtKernel
::
UseMe
(
const
lstm_attr_t
&
attr
)
const
{
return
true
;
}
bool
LSTMC1H1Kernel
::
UseMe
(
const
lstm_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUH1Kernel
::
UseMe
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUHtPart1Kernel
::
UseMe
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
bool
GRUHtPart2Kernel
::
UseMe
(
const
gru_attr_t
&
attr
)
const
{
return
true
;
}
}
// namespace mix
}
// namespace more
}
// namespace jit
}
// namespace operators
}
// namespace paddle
namespace
mix
=
paddle
::
operators
::
jit
::
more
::
mix
;
#define REGISTER_MORE_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel)
REGISTER_MORE_KERNEL
(
kVSigmoid
,
VSigmoid
);
REGISTER_MORE_KERNEL
(
kVTanh
,
VTanh
);
REGISTER_MORE_KERNEL
(
kLSTMCtHt
,
LSTMCtHt
);
REGISTER_MORE_KERNEL
(
kLSTMC1H1
,
LSTMC1H1
);
REGISTER_MORE_KERNEL
(
kGRUH1
,
GRUH1
);
REGISTER_MORE_KERNEL
(
kGRUHtPart1
,
GRUHtPart1
);
REGISTER_MORE_KERNEL
(
kGRUHtPart2
,
GRUHtPart2
);
#undef REGISTER_MORE_KERNEL
paddle/fluid/operators/jit/more/mix/mix.h
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
more
{
namespace
mix
{
using
T
=
float
;
void
VSigmoid
(
const
T
*
x
,
T
*
y
,
int
n
);
void
VTanh
(
const
T
*
x
,
T
*
y
,
int
n
);
void
LSTMCtHt
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
);
void
LSTMC1H1
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
);
void
GRUH1
(
gru_t
*
step
,
const
gru_attr_t
*
attr
);
void
GRUHtPart1
(
gru_t
*
step
,
const
gru_attr_t
*
attr
);
void
GRUHtPart2
(
gru_t
*
step
,
const
gru_attr_t
*
attr
);
#define DECLARE_MORE_KERNEL(name, tuples) \
class name##Kernel : public KernelMore<tuples<T>> { \
public: \
name##Kernel() { this->func = name; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \
}
// XYN
DECLARE_MORE_KERNEL
(
VSigmoid
,
XYNTuples
);
DECLARE_MORE_KERNEL
(
VTanh
,
XYNTuples
);
DECLARE_MORE_KERNEL
(
LSTMCtHt
,
LSTMTuples
);
DECLARE_MORE_KERNEL
(
LSTMC1H1
,
LSTMTuples
);
DECLARE_MORE_KERNEL
(
GRUH1
,
GRUTuples
);
DECLARE_MORE_KERNEL
(
GRUHtPart1
,
GRUTuples
);
DECLARE_MORE_KERNEL
(
GRUHtPart2
,
GRUTuples
);
#undef DECLARE_MORE_KERNEL
}
// namespace mix
}
// namespace more
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/more/mkl/CMakeLists.txt
0 → 100644
浏览文件 @
693e5e65
cc_library
(
jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml
)
set
(
JIT_KERNEL_DEPS
${
JIT_KERNEL_DEPS
}
dynload_mklml jit_kernel_mkl PARENT_SCOPE
)
# use mkl kernels by name and type
USE_JITKERNEL_MORE
(
kVMul, mkl
)
USE_JITKERNEL_MORE
(
kVAdd, mkl
)
USE_JITKERNEL_MORE
(
kVScal, mkl
)
USE_JITKERNEL_MORE
(
kVExp, mkl
)
USE_JITKERNEL_MORE
(
kVSigmoid, mkl
)
USE_JITKERNEL_MORE
(
kVTanh, mkl
)
paddle/fluid/operators/jit/more/mkl/mkl.cc
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/mkl/mkl.h"
#include "paddle/fluid/operators/jit/refer/refer.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/dynload/mklml.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
more
{
namespace
mkl
{
template
<
>
void
VMul
<
float
>
(
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
n
)
{
platform
::
dynload
::
vsMul
(
n
,
x
,
y
,
z
);
}
template
<
>
void
VMul
<
double
>
(
const
double
*
x
,
const
double
*
y
,
double
*
z
,
int
n
)
{
platform
::
dynload
::
vdMul
(
n
,
x
,
y
,
z
);
}
template
<
>
void
VAdd
<
float
>
(
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
n
)
{
platform
::
dynload
::
vsAdd
(
n
,
x
,
y
,
z
);
}
template
<
>
void
VAdd
<
double
>
(
const
double
*
x
,
const
double
*
y
,
double
*
z
,
int
n
)
{
platform
::
dynload
::
vdAdd
(
n
,
x
,
y
,
z
);
}
template
<
>
void
VScal
<
float
>
(
const
float
*
a
,
const
float
*
x
,
float
*
y
,
int
n
)
{
if
(
x
==
y
)
{
platform
::
dynload
::
cblas_sscal
(
n
,
*
a
,
y
,
1
);
}
else
{
refer
::
VScal
<
float
>
(
a
,
x
,
y
,
n
);
}
}
template
<
>
void
VScal
<
double
>
(
const
double
*
a
,
const
double
*
x
,
double
*
y
,
int
n
)
{
if
(
x
==
y
)
{
platform
::
dynload
::
cblas_dscal
(
n
,
*
a
,
y
,
1
);
}
else
{
refer
::
VScal
<
double
>
(
a
,
x
,
y
,
n
);
}
}
template
<
>
void
VExp
<
float
>
(
const
float
*
x
,
float
*
y
,
int
n
)
{
platform
::
dynload
::
vsExp
(
n
,
x
,
y
);
}
template
<
>
void
VExp
<
double
>
(
const
double
*
x
,
double
*
y
,
int
n
)
{
platform
::
dynload
::
vdExp
(
n
,
x
,
y
);
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template
<
>
bool
VMulKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
template
<
>
bool
VAddKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
template
<
>
bool
VScalKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
template
<
>
bool
VExpKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
d
>
7
;
}
template
<
>
bool
VSigmoidKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
d
>
7
;
}
template
<
>
bool
VTanhKernel
<
float
>::
UseMe
(
const
int
&
d
)
const
{
return
d
>
7
;
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \
bool func##Kernel<double>::UseMe(const int& d) const { \
return true; \
}
AWALYS_USE_ME_WITH_DOUBLE
(
VMul
);
AWALYS_USE_ME_WITH_DOUBLE
(
VAdd
);
AWALYS_USE_ME_WITH_DOUBLE
(
VScal
);
AWALYS_USE_ME_WITH_DOUBLE
(
VExp
);
AWALYS_USE_ME_WITH_DOUBLE
(
VSigmoid
);
AWALYS_USE_ME_WITH_DOUBLE
(
VTanh
);
#undef AWALYS_USE_ME_WITH_DOUBLE
}
// namespace mkl
}
// namespace more
}
// namespace jit
}
// namespace operators
}
// namespace paddle
namespace
mkl
=
paddle
::
operators
::
jit
::
more
::
mkl
;
#define REGISTER_MKL_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL
(
kVMul
,
VMul
);
REGISTER_MKL_KERNEL
(
kVAdd
,
VAdd
);
REGISTER_MKL_KERNEL
(
kVScal
,
VScal
);
REGISTER_MKL_KERNEL
(
kVExp
,
VExp
);
REGISTER_MKL_KERNEL
(
kVSigmoid
,
VSigmoid
);
REGISTER_MKL_KERNEL
(
kVTanh
,
VTanh
);
#undef REGISTER_MKL_KERNEL
paddle/fluid/operators/jit/more/mkl/mkl.h
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
more
{
namespace
mkl
{
template
<
typename
T
>
void
VMul
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
template
<
typename
T
>
void
VAdd
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
template
<
typename
T
>
void
VScal
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
);
template
<
typename
T
>
void
VExp
(
const
T
*
x
,
T
*
y
,
int
n
);
template
<
typename
T
>
void
VSigmoid
(
const
T
*
x
,
T
*
y
,
int
n
)
{
const
T
min
=
SIGMOID_THRESHOLD_MIN
;
const
T
max
=
SIGMOID_THRESHOLD_MAX
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
static_cast
<
T
>
(
0
)
-
y
[
i
];
}
VExp
(
y
,
y
,
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
y
[
i
]);
}
}
template
<
typename
T
>
void
VTanh
(
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
2
)
*
x
[
i
];
}
VSigmoid
(
y
,
y
,
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
2
)
*
y
[
i
]
-
static_cast
<
T
>
(
1
);
}
}
#define DECLARE_MKL_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public KernelMore<tuples<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \
}
// XYZN
DECLARE_MKL_KERNEL
(
VMul
,
XYZNTuples
);
DECLARE_MKL_KERNEL
(
VAdd
,
XYZNTuples
);
// AXYN
DECLARE_MKL_KERNEL
(
VScal
,
AXYNTuples
);
// XYN
DECLARE_MKL_KERNEL
(
VExp
,
XYNTuples
);
DECLARE_MKL_KERNEL
(
VSigmoid
,
XYNTuples
);
DECLARE_MKL_KERNEL
(
VTanh
,
XYNTuples
);
#undef DECLARE_MKL_KERNEL
}
// namespace mkl
}
// namespace more
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/refer/CMakeLists.txt
0 → 100644
浏览文件 @
693e5e65
cc_library
(
jit_kernel_refer SRCS refer.cc DEPS jit_kernel_base
)
set
(
JIT_KERNEL_DEPS
${
JIT_KERNEL_DEPS
}
jit_kernel_refer PARENT_SCOPE
)
function
(
USE_JITKERNEL_REFER TARGET
)
file
(
APPEND
${
jit_file
}
"USE_JITKERNEL_REFER(
${
TARGET
}
);
\n
"
)
endfunction
()
# use refer kernel by name
USE_JITKERNEL_REFER
(
kVMul
)
USE_JITKERNEL_REFER
(
kVAdd
)
USE_JITKERNEL_REFER
(
kVAddRelu
)
USE_JITKERNEL_REFER
(
kVSub
)
USE_JITKERNEL_REFER
(
kVScal
)
USE_JITKERNEL_REFER
(
kVAddBias
)
USE_JITKERNEL_REFER
(
kVRelu
)
USE_JITKERNEL_REFER
(
kVIdentity
)
USE_JITKERNEL_REFER
(
kVExp
)
USE_JITKERNEL_REFER
(
kVSigmoid
)
USE_JITKERNEL_REFER
(
kVTanh
)
USE_JITKERNEL_REFER
(
kLSTMCtHt
)
USE_JITKERNEL_REFER
(
kLSTMC1H1
)
USE_JITKERNEL_REFER
(
kGRUH1
)
USE_JITKERNEL_REFER
(
kGRUHtPart1
)
USE_JITKERNEL_REFER
(
kGRUHtPart2
)
USE_JITKERNEL_REFER
(
kCRFDecoding
)
USE_JITKERNEL_REFER
(
kLayerNorm
)
USE_JITKERNEL_REFER
(
kNCHW16CMulNC
)
paddle/fluid/operators/jit/refer/refer.cc
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/refer/refer.h"
#include "paddle/fluid/operators/jit/registry.h"
namespace
refer
=
paddle
::
operators
::
jit
::
refer
;
#define REGISTER_REFER_KERNEL(key, func) \
REGISTER_JITKERNEL_REFER(key, refer::func##Kernel<float>, \
refer::func##Kernel<double>)
REGISTER_REFER_KERNEL
(
kVMul
,
VMul
);
REGISTER_REFER_KERNEL
(
kVAdd
,
VAdd
);
REGISTER_REFER_KERNEL
(
kVAddRelu
,
VAddRelu
);
REGISTER_REFER_KERNEL
(
kVSub
,
VSub
);
REGISTER_REFER_KERNEL
(
kVScal
,
VScal
);
REGISTER_REFER_KERNEL
(
kVAddBias
,
VAddBias
);
REGISTER_REFER_KERNEL
(
kVRelu
,
VRelu
);
REGISTER_REFER_KERNEL
(
kVIdentity
,
VIdentity
);
REGISTER_REFER_KERNEL
(
kVExp
,
VExp
);
REGISTER_REFER_KERNEL
(
kVSigmoid
,
VSigmoid
);
REGISTER_REFER_KERNEL
(
kVTanh
,
VTanh
);
REGISTER_REFER_KERNEL
(
kLSTMCtHt
,
LSTMCtHt
);
REGISTER_REFER_KERNEL
(
kLSTMC1H1
,
LSTMC1H1
);
REGISTER_REFER_KERNEL
(
kGRUH1
,
GRUH1
);
REGISTER_REFER_KERNEL
(
kGRUHtPart1
,
GRUHtPart1
);
REGISTER_REFER_KERNEL
(
kGRUHtPart2
,
GRUHtPart2
);
REGISTER_REFER_KERNEL
(
kCRFDecoding
,
CRFDecoding
);
REGISTER_REFER_KERNEL
(
kLayerNorm
,
LayerNorm
);
REGISTER_REFER_KERNEL
(
kNCHW16CMulNC
,
NCHW16CMulNC
);
#undef REGISTER_REFER_KERNEL
paddle/fluid/operators/
math/jit_kernel_
refer.h
→
paddle/fluid/operators/
jit/refer/
refer.h
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
Licensed under the Apache License, Version 2.0 (the "License");
*
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
*
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
*
You may obtain a copy of the License at
*
http://www.apache.org/licenses/LICENSE-2.0
*
http://www.apache.org/licenses/LICENSE-2.0
*
Unless required by applicable law or agreed to in writing, software
*
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
*
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
*
See the License for the specific language governing permissions and
limitations under the License. */
*
limitations under the License. */
#pragma once
#pragma once
#include <cmath>
#include <cmath>
#include <string>
#include <limits>
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/operators/jit/helper.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
jit
{
namespace
jitkernel
{
namespace
refer
{
namespace
refer
{
/* Refer code only focus on correctness */
// Refer code only focus on correctness
template
<
typename
T
>
template
<
typename
T
>
void
VMul
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
)
{
void
VMul
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
...
@@ -47,6 +48,13 @@ void VAddRelu(const T* x, const T* y, T* z, int n) {
...
@@ -47,6 +48,13 @@ void VAddRelu(const T* x, const T* y, T* z, int n) {
}
}
}
}
template
<
typename
T
>
void
VSub
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
z
[
i
]
=
x
[
i
]
-
y
[
i
];
}
}
template
<
typename
T
>
template
<
typename
T
>
void
VScal
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
)
{
void
VScal
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
...
@@ -69,7 +77,11 @@ void VRelu(const T* x, T* y, int n) {
...
@@ -69,7 +77,11 @@ void VRelu(const T* x, T* y, int n) {
}
}
template
<
typename
T
>
template
<
typename
T
>
inline
void
VIdentity
(
const
T
*
x
,
T
*
y
,
int
n
)
{}
inline
void
VIdentity
(
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
x
[
i
];
}
}
template
<
typename
T
>
template
<
typename
T
>
void
VExp
(
const
T
*
x
,
T
*
y
,
int
n
)
{
void
VExp
(
const
T
*
x
,
T
*
y
,
int
n
)
{
...
@@ -102,20 +114,22 @@ void VTanh(const T* x, T* y, int n) {
...
@@ -102,20 +114,22 @@ void VTanh(const T* x, T* y, int n) {
}
}
template
<
typename
T
>
template
<
typename
T
>
void
(
*
getActFunc
(
const
std
::
string
&
type
))(
const
T
*
,
T
*
,
int
)
{
// NOLINT
void
(
*
getActFunc
(
KernelType
type
))(
const
T
*
,
T
*
,
int
)
{
// NOLINT
if
(
type
==
"sigmoid"
)
{
if
(
type
==
kVSigmoid
)
{
return
VSigmoid
<
T
>
;
return
VSigmoid
<
T
>
;
}
else
if
(
type
==
"relu"
)
{
}
else
if
(
type
==
kVRelu
)
{
return
VRelu
<
T
>
;
return
VRelu
<
T
>
;
}
else
if
(
type
==
"tanh"
)
{
}
else
if
(
type
==
kVTanh
)
{
return
VTanh
<
T
>
;
return
VTanh
<
T
>
;
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
}
else
if
(
type
==
kVIdentity
)
{
return
VIdentity
<
T
>
;
return
VIdentity
<
T
>
;
}
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
PADDLE_THROW
(
"Not support type: %s"
,
type
);
return
nullptr
;
return
nullptr
;
}
}
// TODO(TJ): add refer gemm and make LSTM kernels combine as same GRU kernels
// compute ct and ht
// compute ct and ht
template
<
typename
T
>
template
<
typename
T
>
void
LSTMCtHt
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
)
{
void
LSTMCtHt
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
)
{
...
@@ -231,8 +245,134 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
...
@@ -231,8 +245,134 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
}
}
}
}
template
<
typename
T
>
void
CRFDecoding
(
const
int
seq_len
,
const
T
*
x
,
const
T
*
w
,
T
*
alpha
,
int
*
track
,
int
right
)
{
constexpr
int
state_trans_base_idx
=
2
;
for
(
int
i
=
0
;
i
<
right
;
++
i
)
{
alpha
[
i
]
=
w
[
i
]
+
x
[
i
];
}
for
(
int
k
=
1
;
k
<
seq_len
;
++
k
)
{
for
(
int
i
=
0
;
i
<
right
;
++
i
)
{
T
max_score
=
-
std
::
numeric_limits
<
T
>::
max
();
int
max_j
=
0
;
for
(
int
j
=
0
;
j
<
right
;
++
j
)
{
T
score
=
alpha
[(
k
-
1
)
*
right
+
j
]
+
w
[(
j
+
state_trans_base_idx
)
*
right
+
i
];
if
(
score
>
max_score
)
{
max_score
=
score
;
max_j
=
j
;
}
}
alpha
[
k
*
right
+
i
]
=
max_score
+
x
[
k
*
right
+
i
];
track
[
k
*
right
+
i
]
=
max_j
;
}
}
}
template
<
typename
T
>
void
LayerNorm
(
T
*
x
,
T
*
out
,
T
*
mean
,
T
*
var
,
const
T
*
scale
,
const
T
*
bias
,
int
height
,
const
float
epsilon
,
int
right
)
{
// get mean
for
(
int
i
=
0
;
i
<
height
;
i
++
)
{
T
sum
=
0.0
;
int
offset
=
i
*
right
;
for
(
int
j
=
0
;
j
<
right
;
j
++
)
{
sum
+=
x
[
offset
+
j
];
}
mean
[
i
]
=
sum
/
right
;
}
// get variance
for
(
int
i
=
0
;
i
<
height
;
i
++
)
{
T
sum
=
0.0
;
int
offset
=
i
*
right
;
for
(
int
j
=
0
;
j
<
right
;
j
++
)
{
sum
+=
(
x
[
offset
+
j
]
-
mean
[
i
])
*
(
x
[
offset
+
j
]
-
mean
[
i
]);
}
var
[
i
]
=
sum
/
right
;
}
for
(
int
i
=
0
;
i
<
height
;
i
++
)
{
int
offset
=
i
*
right
;
T
sqrt_var
=
std
::
sqrt
(
var
[
i
]
+
(
T
)
epsilon
);
for
(
int
j
=
0
;
j
<
right
;
j
++
)
{
out
[
offset
+
j
]
=
(
x
[
offset
+
j
]
-
mean
[
i
])
/
sqrt_var
;
}
}
if
(
scale
)
{
for
(
int
i
=
0
;
i
<
height
;
i
++
)
{
int
offset
=
i
*
right
;
for
(
int
j
=
0
;
j
<
right
;
j
++
)
{
out
[
offset
+
j
]
*=
scale
[
j
];
}
}
}
if
(
bias
)
{
for
(
int
i
=
0
;
i
<
height
;
i
++
)
{
int
offset
=
i
*
right
;
for
(
int
j
=
0
;
j
<
right
;
j
++
)
{
out
[
offset
+
j
]
+=
bias
[
j
];
}
}
}
}
template
<
typename
T
>
void
NCHW16CMulNC
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
height
,
int
width
)
{
int
offset
=
0
;
for
(
int
h
=
0
;
h
<
height
;
++
h
)
{
for
(
int
w
=
0
;
w
<
width
;
++
w
)
{
for
(
int
i
=
0
;
i
<
16
;
++
i
)
{
z
[
i
+
offset
]
=
y
[
i
]
*
x
[
i
+
offset
];
}
offset
+=
ZMM_FLOAT_BLOCK
;
}
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
}
// const T* x, const T* y, T* z, int n
DECLARE_REFER_KERNEL
(
VMul
,
XYZNTuples
);
DECLARE_REFER_KERNEL
(
VAdd
,
XYZNTuples
);
DECLARE_REFER_KERNEL
(
VAddRelu
,
XYZNTuples
);
DECLARE_REFER_KERNEL
(
VSub
,
XYZNTuples
);
// const T* a, const T* x, T* y, int n
DECLARE_REFER_KERNEL
(
VScal
,
AXYNTuples
);
DECLARE_REFER_KERNEL
(
VAddBias
,
AXYNTuples
);
// const T* x, T* y, int n
DECLARE_REFER_KERNEL
(
VRelu
,
XYNTuples
);
DECLARE_REFER_KERNEL
(
VIdentity
,
XYNTuples
);
DECLARE_REFER_KERNEL
(
VExp
,
XYNTuples
);
DECLARE_REFER_KERNEL
(
VSigmoid
,
XYNTuples
);
DECLARE_REFER_KERNEL
(
VTanh
,
XYNTuples
);
// lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL
(
LSTMCtHt
,
LSTMTuples
);
DECLARE_REFER_KERNEL
(
LSTMC1H1
,
LSTMTuples
);
// gru_t*, const gru_attr_t*
DECLARE_REFER_KERNEL
(
GRUH1
,
GRUTuples
);
DECLARE_REFER_KERNEL
(
GRUHtPart1
,
GRUTuples
);
DECLARE_REFER_KERNEL
(
GRUHtPart2
,
GRUTuples
);
DECLARE_REFER_KERNEL
(
CRFDecoding
,
CRFDecodingTuples
);
DECLARE_REFER_KERNEL
(
LayerNorm
,
LayerNormTuples
);
DECLARE_REFER_KERNEL
(
NCHW16CMulNC
,
NCHW16CMulNCTuples
);
#undef DECLARE_REFER_KERNEL
}
// namespace refer
}
// namespace refer
}
// namespace jitkernel
}
// namespace jit
}
// namespace math
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/jit/registry.h
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <memory>
#include <tuple>
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h" // for UNUSED
namespace
paddle
{
namespace
operators
{
namespace
jit
{
// make_unique is supported since c++14
template
<
typename
T
,
typename
...
Args
>
inline
std
::
unique_ptr
<
T
>
make_unique
(
Args
&&
...
args
)
{
static_assert
(
!
std
::
is_array
<
T
>::
value
,
"T must not be array"
);
return
std
::
unique_ptr
<
T
>
(
new
T
(
std
::
forward
<
Args
>
(
args
)...));
}
template
<
typename
Pool
,
typename
PlaceType
,
bool
IsEnd
,
size_t
I
,
typename
...
KernelImpls
>
struct
JitKernelRegistrarFunctor
;
template
<
typename
Pool
,
typename
PlaceType
,
size_t
I
,
typename
...
KernelImpls
>
struct
JitKernelRegistrarFunctor
<
Pool
,
PlaceType
,
true
,
I
,
KernelImpls
...
>
{
void
operator
()(
KernelType
kt
)
const
{}
};
template
<
typename
Pool
,
typename
PlaceType
,
size_t
I
,
typename
...
KernelImpls
>
struct
JitKernelRegistrarFunctor
<
Pool
,
PlaceType
,
false
,
I
,
KernelImpls
...
>
{
using
KERNEL_IMPL_TYPE
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
KernelImpls
...
>>::
type
;
void
operator
()(
KernelType
kt
)
const
{
KernelKey
kkey
(
kt
,
PlaceType
());
Pool
().
Instance
().
Insert
(
kkey
,
std
::
move
(
make_unique
<
const
KERNEL_IMPL_TYPE
>
()));
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelImpls
...
>>::
value
;
JitKernelRegistrarFunctor
<
Pool
,
PlaceType
,
I
+
1
==
size
,
I
+
1
,
KernelImpls
...
>
func
;
func
(
kt
);
}
};
template
<
typename
Pool
,
typename
PlaceType
,
typename
...
KernelImpls
>
class
JitKernelRegistrar
{
public:
explicit
JitKernelRegistrar
(
KernelType
kt
)
{
JitKernelRegistrarFunctor
<
Pool
,
PlaceType
,
false
,
0
,
KernelImpls
...
>
func
;
func
(
kt
);
}
void
Touch
()
{}
};
#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
// Refer always on CPUPlace
#define REGISTER_JITKERNEL_REFER(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace, \
"REGISTER_KERNEL_REFER must be called in global namespace"); \
static ::paddle::operators::jit::JitKernelRegistrar< \
::paddle::operators::jit::ReferKernelPool, ::paddle::platform::CPUPlace, \
__VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \
::paddle::operators::jit::KernelType::kernel_type); \
int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \
return 0; \
}
// kernel_type: should be in paddle::operators::jit::KernelType
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \
"REGISTER_KERNEL_MORE must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \
UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jit::JitKernelRegistrar< \
::paddle::operators::jit::KernelPool, ::paddle::platform::place_type, \
__VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \
::paddle::operators::jit::KernelType::kernel_type); \
int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \
.Touch(); \
return 0; \
}
#define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__)
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
#define REGISTER_JITKERNEL_GEN(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_gen_##kernel_type##_CPUPlace_, \
"REGISTER_JITKERNEL_GEN must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_gen_##kernel_type##_has_refer_ UNUSED = \
TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jit::JitKernelRegistrar< \
::paddle::operators::jit::JitCodeCreatorPool, \
::paddle::platform::CPUPlace, __VA_ARGS__> \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_( \
::paddle::operators::jit::KernelType::kernel_type); \
int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() { \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch(); \
return 0; \
}
#define USE_JITKERNEL_GEN(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_gen_##kernel_type##_CPUPlace_, \
"USE_JITKERNEL_GEN must be called in global namespace"); \
extern int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_(); \
static int use_jitkernel_gen_##kernel_type##_CPUPlace_ UNUSED = \
TouchJitKernelReg_gen_##kernel_type##_CPUPlace_()
#define USE_JITKERNEL_REFER(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace_, \
"USE_JITKERNEL_REFER must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int use_jitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \
TouchJitKernelReg_##kernel_type##_refer_CPUPlace_()
#define USE_KERNEL_MORE(kernel_type, impl_type, place_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type##_, \
"USE_JITKERNEL_MORE must be called in global namespace"); \
extern int \
TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \
static int use_jitkernel_##kernel_type##_##impl_type##_##place_type##_ \
UNUSED = \
TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_()
#define USE_JITKERNEL_MORE(kernel_type, impl_type) \
USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace)
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/test.cc
0 → 100644
浏览文件 @
693e5e65
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include <random>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/place.h"
template
<
typename
T
>
void
RandomVec
(
const
int
n
,
T
*
a
,
const
T
lower
=
static_cast
<
T
>
(
-
20.
f
),
const
T
upper
=
static_cast
<
T
>
(
20.
f
))
{
static
unsigned
int
seed
=
100
;
std
::
mt19937
rng
(
seed
++
);
std
::
uniform_real_distribution
<
double
>
uniform_dist
(
0
,
1
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
a
[
i
]
=
static_cast
<
T
>
(
uniform_dist
(
rng
)
*
(
upper
-
lower
)
+
lower
);
}
}
template
<
typename
T
>
void
ExpectEQ
(
const
T
*
target
,
const
T
*
refer
,
int
n
)
{
if
(
std
::
is_floating_point
<
T
>::
value
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
EXPECT_NEAR
(
target
[
i
],
refer
[
i
],
1e-5
);
}
}
else
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
EXPECT_EQ
(
target
[
i
],
refer
[
i
]);
}
}
}
std
::
vector
<
int
>
TestSizes
()
{
std
::
vector
<
int
>
s
;
for
(
int
i
=
1
;
i
<
32
;
++
i
)
{
s
.
push_back
(
i
);
}
// test some large size
s
.
push_back
(
100
);
s
.
push_back
(
1000
);
s
.
push_back
(
2000
);
return
s
;
}
namespace
jit
=
paddle
::
operators
::
jit
;
template
<
typename
KernelTuples
,
typename
...
Args
>
struct
TestFuncWithRefer
{
void
operator
()(
const
typename
KernelTuples
::
func_type
tgt
,
Args
...
args
)
{}
};
template
<
typename
T
>
struct
TestFuncWithRefer
<
jit
::
XYZNTuples
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>>
{
void
operator
()(
const
typename
jit
::
XYZNTuples
<
T
>::
func_type
tgt
,
const
std
::
vector
<
T
>&
x
,
const
std
::
vector
<
T
>&
y
,
const
std
::
vector
<
T
>&
zref
)
{
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_EQ
(
zref
.
size
(),
x
.
size
());
EXPECT_EQ
(
zref
.
size
(),
y
.
size
());
const
T
*
x_data
=
x
.
data
();
const
T
*
y_data
=
y
.
data
();
const
T
*
zref_data
=
zref
.
data
();
const
int
d
=
zref
.
size
();
std
::
vector
<
T
>
ztgt
(
d
);
T
*
ztgt_data
=
ztgt
.
data
();
// test normal
tgt
(
x_data
,
y_data
,
ztgt_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
// test inplace x
std
::
copy
(
x
.
begin
(),
x
.
end
(),
ztgt
.
begin
());
tgt
(
ztgt_data
,
y_data
,
ztgt_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
// test inplace y
std
::
copy
(
y
.
begin
(),
y
.
end
(),
ztgt
.
begin
());
tgt
(
x_data
,
ztgt_data
,
ztgt_data
,
d
);
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
d
);
}
};
template
<
typename
T
>
struct
TestFuncWithRefer
<
jit
::
AXYNTuples
<
T
>
,
T
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>>
{
void
operator
()(
const
typename
jit
::
AXYNTuples
<
T
>::
func_type
tgt
,
const
T
a
,
const
std
::
vector
<
T
>&
x
,
const
std
::
vector
<
T
>&
yref
)
{
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_EQ
(
yref
.
size
(),
x
.
size
());
const
T
*
x_data
=
x
.
data
();
const
T
*
yref_data
=
yref
.
data
();
const
int
d
=
yref
.
size
();
std
::
vector
<
T
>
ytgt
(
d
);
T
*
ytgt_data
=
ytgt
.
data
();
// test normal
tgt
(
&
a
,
x_data
,
ytgt_data
,
d
);
ExpectEQ
<
T
>
(
ytgt_data
,
yref_data
,
d
);
// test inplace x
std
::
copy
(
x
.
begin
(),
x
.
end
(),
ytgt
.
begin
());
tgt
(
&
a
,
ytgt_data
,
ytgt_data
,
d
);
ExpectEQ
<
T
>
(
ytgt_data
,
yref_data
,
d
);
}
};
template
<
typename
T
>
struct
TestFuncWithRefer
<
jit
::
XYNTuples
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>>
{
void
operator
()(
const
typename
jit
::
XYNTuples
<
T
>::
func_type
tgt
,
const
std
::
vector
<
T
>&
x
,
const
std
::
vector
<
T
>&
yref
)
{
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_EQ
(
yref
.
size
(),
x
.
size
());
const
T
*
x_data
=
x
.
data
();
const
T
*
yref_data
=
yref
.
data
();
const
int
d
=
yref
.
size
();
std
::
vector
<
T
>
ytgt
(
d
);
T
*
ytgt_data
=
ytgt
.
data
();
// test normal
tgt
(
x_data
,
ytgt_data
,
d
);
ExpectEQ
<
T
>
(
ytgt_data
,
yref_data
,
d
);
// test inplace x
std
::
copy
(
x
.
begin
(),
x
.
end
(),
ytgt
.
begin
());
tgt
(
ytgt_data
,
ytgt_data
,
d
);
ExpectEQ
<
T
>
(
ytgt_data
,
yref_data
,
d
);
}
};
template
<
typename
T
>
struct
TestFuncWithRefer
<
jit
::
LSTMTuples
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>>
{
void
operator
()(
const
typename
jit
::
LSTMTuples
<
T
>::
func_type
tgt
,
const
std
::
vector
<
T
>&
xsrc
,
const
std
::
vector
<
T
>&
wp
,
const
std
::
vector
<
T
>&
ct_1
,
const
std
::
vector
<
T
>&
ct_ref
,
const
std
::
vector
<
T
>&
ht_ref
,
const
typename
jit
::
LSTMTuples
<
T
>::
attr_type
&
attr
)
{
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_EQ
(
ct_ref
.
size
(),
ht_ref
.
size
());
EXPECT_EQ
(
ct_1
.
size
(),
ht_ref
.
size
());
EXPECT_EQ
(
xsrc
.
size
(),
4
*
ht_ref
.
size
());
EXPECT_EQ
(
wp
.
size
(),
3
*
ht_ref
.
size
());
// x could be changed after compute, so copy to save src
int
d
=
ht_ref
.
size
();
std
::
vector
<
T
>
x
(
xsrc
.
size
()),
ct
(
ct_ref
.
size
()),
ht
(
ht_ref
.
size
());
std
::
vector
<
T
>
checked
(
2
*
d
);
std
::
copy
(
xsrc
.
begin
(),
xsrc
.
end
(),
x
.
begin
());
const
T
*
ct_1_data
=
ct_1
.
data
();
const
T
*
wp_data
=
wp
.
data
();
const
T
*
ct_ref_data
=
ct_ref
.
data
();
const
T
*
ht_ref_data
=
ht_ref
.
data
();
T
*
x_data
=
x
.
data
();
T
*
ct_data
=
ct
.
data
();
T
*
ht_data
=
ht
.
data
();
T
*
checked_data
=
checked
.
data
();
paddle
::
operators
::
jit
::
lstm_t
step
;
step
.
gates
=
x_data
;
step
.
ct_1
=
ct_1_data
;
step
.
ct
=
ct_data
;
step
.
ht
=
ht_data
;
if
(
attr
.
use_peephole
)
{
step
.
wp
=
wp_data
;
step
.
checked
=
checked_data
;
}
tgt
(
&
step
,
&
attr
);
ExpectEQ
<
T
>
(
ct_data
,
ct_ref_data
,
d
);
ExpectEQ
<
T
>
(
ht_data
,
ht_ref_data
,
d
);
}
};
template
<
typename
T
>
struct
TestFuncWithRefer
<
jit
::
GRUTuples
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>>
{
void
operator
()(
const
typename
jit
::
GRUTuples
<
T
>::
func_type
tgt
,
const
std
::
vector
<
T
>&
xsrc
,
const
std
::
vector
<
T
>&
ht_1
,
const
std
::
vector
<
T
>&
ht_ref
,
const
typename
jit
::
GRUTuples
<
T
>::
attr_type
&
attr
)
{
EXPECT_TRUE
(
tgt
!=
nullptr
);
EXPECT_EQ
(
ht_1
.
size
(),
ht_ref
.
size
());
EXPECT_EQ
(
xsrc
.
size
(),
3
*
ht_ref
.
size
());
// x could be changed after compute, so copy to save src
int
d
=
ht_ref
.
size
();
std
::
vector
<
T
>
x
(
xsrc
.
size
()),
ht
(
ht_ref
.
size
());
std
::
copy
(
xsrc
.
begin
(),
xsrc
.
end
(),
x
.
begin
());
const
T
*
ht_1_data
=
ht_1
.
data
();
const
T
*
ht_ref_data
=
ht_ref
.
data
();
T
*
x_data
=
x
.
data
();
T
*
ht_data
=
ht
.
data
();
paddle
::
operators
::
jit
::
gru_t
step
;
step
.
gates
=
x_data
;
step
.
ht_1
=
ht_1_data
;
step
.
ht
=
ht_data
;
tgt
(
&
step
,
&
attr
);
ExpectEQ
<
T
>
(
ht_data
,
ht_ref_data
,
d
);
}
};
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
KernelTuples
,
typename
PlaceType
,
typename
...
Args
>
void
TestAllImpls
(
const
typename
KernelTuples
::
attr_type
&
attr
,
Args
...
args
)
{
TestFuncWithRefer
<
KernelTuples
,
Args
...
>
test
;
// test jitcode
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
KernelTuples
,
PlaceType
>
(
attr
);
if
(
jitcode
)
{
VLOG
(
10
)
<<
"Test Jitcode Kernel "
;
test
(
jitcode
,
args
...);
}
// test all impls in more
jit
::
KernelKey
kkey
(
KT
,
PlaceType
());
auto
&
pool
=
jit
::
KernelPool
().
Instance
().
AllKernels
();
auto
iter
=
pool
.
find
(
kkey
);
if
(
iter
!=
pool
.
end
())
{
auto
&
impls
=
iter
->
second
;
for
(
auto
&
impl
:
impls
)
{
auto
i
=
dynamic_cast
<
const
jit
::
KernelMore
<
KernelTuples
>*>
(
impl
.
get
());
if
(
i
&&
i
->
UseMe
(
attr
))
{
auto
more
=
i
->
GetFunc
();
VLOG
(
10
)
<<
"Test More Kernel : "
<<
i
->
ImplType
();
test
(
more
,
args
...);
}
}
}
// test result from Get function
// VLOG(10) << "Test Get function ";
auto
tgt
=
jit
::
Get
<
KT
,
KernelTuples
,
PlaceType
>
(
attr
);
test
(
tgt
,
args
...);
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestXYZNKernel
()
{
namespace
jit
=
paddle
::
operators
::
jit
;
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
XYZNTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
d
),
y
(
d
),
zref
(
d
);
RandomVec
<
T
>
(
d
,
x
.
data
());
RandomVec
<
T
>
(
d
,
y
.
data
());
std
::
vector
<
T
>
xinp
(
d
),
yinp
(
d
);
// inplace test
std
::
copy
(
x
.
begin
(),
x
.
end
(),
xinp
.
begin
());
std
::
copy
(
y
.
begin
(),
y
.
end
(),
yinp
.
begin
());
const
T
*
x_data
=
x
.
data
();
const
T
*
y_data
=
y
.
data
();
T
*
zref_data
=
zref
.
data
();
T
*
xinp_data
=
xinp
.
data
();
T
*
yinp_data
=
yinp
.
data
();
// test refer code inplace
ref
(
x_data
,
y_data
,
zref_data
,
d
);
ref
(
x_data
,
yinp_data
,
yinp_data
,
d
);
ref
(
xinp_data
,
y_data
,
xinp_data
,
d
);
ExpectEQ
<
T
>
(
xinp_data
,
zref_data
,
d
);
ExpectEQ
<
T
>
(
yinp_data
,
zref_data
,
d
);
TestAllImpls
<
KT
,
jit
::
XYZNTuples
<
T
>
,
PlaceType
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>>
(
d
,
x
,
y
,
zref
);
}
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestAXYNKernel
()
{
namespace
jit
=
paddle
::
operators
::
jit
;
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
AXYNTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
const
T
a
=
static_cast
<
T
>
(
3
);
std
::
vector
<
T
>
x
(
d
),
yref
(
d
);
std
::
vector
<
T
>
xinp
(
d
);
// inplace test
RandomVec
<
T
>
(
d
,
x
.
data
());
std
::
copy
(
x
.
begin
(),
x
.
end
(),
xinp
.
begin
());
const
T
*
x_data
=
x
.
data
();
T
*
yref_data
=
yref
.
data
();
T
*
xinp_data
=
xinp
.
data
();
// test refer code inplace
ref
(
&
a
,
x_data
,
yref_data
,
d
);
ref
(
&
a
,
xinp_data
,
xinp_data
,
d
);
ExpectEQ
<
T
>
(
xinp_data
,
yref_data
,
d
);
TestAllImpls
<
KT
,
jit
::
AXYNTuples
<
T
>
,
PlaceType
,
T
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>>
(
d
,
a
,
x
,
yref
);
}
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestXYNKernel
()
{
namespace
jit
=
paddle
::
operators
::
jit
;
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
for
(
int
d
:
TestSizes
())
{
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
XYNTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
d
),
yref
(
d
);
std
::
vector
<
T
>
xinp
(
d
);
// inplace test
RandomVec
<
T
>
(
d
,
x
.
data
(),
-
2.
f
,
2.
f
);
std
::
copy
(
x
.
begin
(),
x
.
end
(),
xinp
.
begin
());
const
T
*
x_data
=
x
.
data
();
T
*
yref_data
=
yref
.
data
();
T
*
xinp_data
=
xinp
.
data
();
// test refer code inplace
ref
(
x_data
,
yref_data
,
d
);
ref
(
xinp_data
,
xinp_data
,
d
);
ExpectEQ
<
T
>
(
xinp_data
,
yref_data
,
d
);
TestAllImpls
<
KT
,
jit
::
XYNTuples
<
T
>
,
PlaceType
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>>
(
d
,
x
,
yref
);
}
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestLSTMKernel
()
{
namespace
jit
=
paddle
::
operators
::
jit
;
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
std
::
vector
<
std
::
string
>
all_acts
=
{
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
};
for
(
int
d
:
TestSizes
())
{
for
(
bool
use_peephole
:
{
true
,
false
})
{
for
(
auto
&
act_gate
:
all_acts
)
{
for
(
auto
&
act_cand
:
all_acts
)
{
for
(
auto
&
act_cell
:
all_acts
)
{
const
jit
::
lstm_attr_t
attr
(
d
,
jit
::
to_kerneltype
(
act_gate
),
jit
::
to_kerneltype
(
act_cand
),
jit
::
to_kerneltype
(
act_cell
),
use_peephole
);
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
LSTMTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
xsrc
(
4
*
d
),
wp
(
3
*
d
),
ct_1
(
d
);
std
::
vector
<
T
>
ct_ref
(
d
),
ht_ref
(
d
),
checked
(
2
*
d
);
RandomVec
<
T
>
(
4
*
d
,
xsrc
.
data
(),
-
2.
f
,
2.
f
);
RandomVec
<
T
>
(
3
*
d
,
wp
.
data
(),
-
2.
f
,
2.
f
);
RandomVec
<
T
>
(
d
,
ct_1
.
data
(),
-
2.
f
,
2.
f
);
// x could be changed after compute, so copy to save src
std
::
vector
<
T
>
x
(
xsrc
.
size
());
std
::
copy
(
xsrc
.
begin
(),
xsrc
.
end
(),
x
.
begin
());
const
T
*
ct_1_data
=
ct_1
.
data
();
const
T
*
wp_data
=
wp
.
data
();
T
*
x_data
=
x
.
data
();
T
*
checked_data
=
checked
.
data
();
T
*
ct_ref_data
=
ct_ref
.
data
();
T
*
ht_ref_data
=
ht_ref
.
data
();
jit
::
lstm_t
step
;
step
.
gates
=
x_data
;
step
.
ct_1
=
ct_1_data
;
step
.
ct
=
ct_ref_data
;
step
.
ht
=
ht_ref_data
;
if
(
use_peephole
)
{
step
.
wp
=
wp_data
;
step
.
checked
=
checked_data
;
}
ref
(
&
step
,
&
attr
);
VLOG
(
10
)
<<
attr
;
TestAllImpls
<
KT
,
jit
::
LSTMTuples
<
T
>
,
PlaceType
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>>
(
attr
,
xsrc
,
wp
,
ct_1
,
ct_ref
,
ht_ref
,
attr
);
}
}
}
}
}
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestGRUKernel
()
{
namespace
jit
=
paddle
::
operators
::
jit
;
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
std
::
vector
<
std
::
string
>
all_acts
=
{
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
};
for
(
int
d
:
TestSizes
())
{
for
(
auto
&
act_gate
:
all_acts
)
{
for
(
auto
&
act_cand
:
all_acts
)
{
const
jit
::
gru_attr_t
attr
(
d
,
jit
::
to_kerneltype
(
act_gate
),
jit
::
to_kerneltype
(
act_cand
));
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
GRUTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
xsrc
(
3
*
d
),
ht_1
(
d
),
ht_ref
(
d
);
RandomVec
<
T
>
(
3
*
d
,
xsrc
.
data
(),
-
2.
f
,
2.
f
);
RandomVec
<
T
>
(
d
,
ht_1
.
data
(),
-
2.
f
,
2.
f
);
// x could be changed after compute, so copy to save src
std
::
vector
<
T
>
x
(
xsrc
.
size
());
std
::
copy
(
xsrc
.
begin
(),
xsrc
.
end
(),
x
.
begin
());
const
T
*
ht_1_data
=
ht_1
.
data
();
T
*
x_data
=
x
.
data
();
T
*
ht_ref_data
=
ht_ref
.
data
();
jit
::
gru_t
step
;
step
.
gates
=
x_data
;
step
.
ht_1
=
ht_1_data
;
step
.
ht
=
ht_ref_data
;
ref
(
&
step
,
&
attr
);
VLOG
(
10
)
<<
attr
;
TestAllImpls
<
KT
,
jit
::
GRUTuples
<
T
>
,
PlaceType
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>
,
std
::
vector
<
T
>>
(
attr
,
xsrc
,
ht_1
,
ht_ref
,
attr
);
}
}
}
}
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestNCHW16CMulNCKernel
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
const
int
n
=
3
,
c
=
16
*
4
,
h
=
10
,
w
=
10
;
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
NCHW16CMulNCTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
int
sz
=
n
*
c
*
h
*
w
;
std
::
vector
<
T
>
x
(
sz
),
y
(
n
*
c
),
zref
(
sz
);
std
::
vector
<
T
>
ztgt
(
sz
),
zjit
(
sz
);
RandomVec
<
T
>
(
sz
,
x
.
data
(),
-
2.
f
,
2.
f
);
RandomVec
<
T
>
(
n
*
c
,
y
.
data
(),
-
2.
f
,
2.
f
);
const
T
*
x_data
=
x
.
data
();
const
T
*
y_data
=
y
.
data
();
T
*
zref_data
=
zref
.
data
();
T
*
ztgt_data
=
ztgt
.
data
();
T
*
zjit_data
=
zjit
.
data
();
constexpr
int
simd_width
=
ZMM_FLOAT_BLOCK
;
int
C
=
c
/
simd_width
;
auto
tgt
=
jit
::
Get
<
KT
,
jit
::
NCHW16CMulNCTuples
<
T
>
,
PlaceType
>
(
0
);
auto
jitcode
=
jit
::
GetJitCode
<
KT
,
jit
::
NCHW16CMulNCTuples
<
T
>
,
PlaceType
>
(
0
);
EXPECT_TRUE
(
tgt
!=
nullptr
);
if
(
std
::
is_same
<
T
,
float
>::
value
&&
paddle
::
platform
::
MayIUse
(
paddle
::
platform
::
avx512f
))
{
EXPECT_TRUE
(
jitcode
!=
nullptr
);
}
for
(
int
ni
=
0
;
ni
<
n
;
ni
++
)
{
for
(
int
ci
=
0
;
ci
<
C
;
ci
++
)
{
auto
ptr_x
=
x_data
+
ni
*
C
*
h
*
w
*
simd_width
+
ci
*
h
*
w
*
simd_width
;
auto
ptr_y
=
y_data
+
ni
*
C
*
simd_width
+
ci
*
simd_width
;
auto
ptr_zref
=
zref_data
+
ni
*
C
*
h
*
w
*
simd_width
+
ci
*
h
*
w
*
simd_width
;
auto
ptr_ztgt
=
ztgt_data
+
ni
*
C
*
h
*
w
*
simd_width
+
ci
*
h
*
w
*
simd_width
;
ref
(
ptr_x
,
ptr_y
,
ptr_zref
,
h
,
w
);
tgt
(
ptr_x
,
ptr_y
,
ptr_ztgt
,
h
,
w
);
if
(
jitcode
)
{
auto
ptr_zjit
=
zjit_data
+
ni
*
C
*
h
*
w
*
simd_width
+
ci
*
h
*
w
*
simd_width
;
jitcode
(
ptr_x
,
ptr_y
,
ptr_zjit
,
h
,
w
);
}
}
}
ExpectEQ
<
T
>
(
ztgt_data
,
zref_data
,
sz
);
if
(
jitcode
)
{
ExpectEQ
<
T
>
(
zjit_data
,
zref_data
,
sz
);
}
}
// XYZNTuple
TEST
(
JITKernel
,
kVMul
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestXYZNKernel
<
jit
::
kVMul
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestXYZNKernel
<
jit
::
kVMul
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVAdd
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestXYZNKernel
<
jit
::
kVAdd
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestXYZNKernel
<
jit
::
kVAdd
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVAddRelu
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestXYZNKernel
<
jit
::
kVAddRelu
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestXYZNKernel
<
jit
::
kVAddRelu
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVSub
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestXYZNKernel
<
jit
::
kVSub
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestXYZNKernel
<
jit
::
kVSub
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
// AXYNTuples
TEST
(
JITKernel
,
kVScal
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestAXYNKernel
<
jit
::
kVScal
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestAXYNKernel
<
jit
::
kVScal
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVAddBias
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestAXYNKernel
<
jit
::
kVAddBias
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestAXYNKernel
<
jit
::
kVAddBias
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
// XYNTuples
TEST
(
JITKernel
,
kVRelu
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestXYNKernel
<
jit
::
kVRelu
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVRelu
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVIdentity
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestXYNKernel
<
jit
::
kVIdentity
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVIdentity
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVExp
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestXYNKernel
<
jit
::
kVExp
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVExp
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVSigmoid
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestXYNKernel
<
jit
::
kVSigmoid
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVSigmoid
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kVTanh
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestXYNKernel
<
jit
::
kVTanh
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestXYNKernel
<
jit
::
kVTanh
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
// LSTM
TEST
(
JITKernel
,
kLSTMCtHt
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestLSTMKernel
<
jit
::
kLSTMCtHt
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestLSTMKernel
<
jit
::
kLSTMCtHt
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kLSTMC1H1
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestLSTMKernel
<
jit
::
kLSTMC1H1
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestLSTMKernel
<
jit
::
kLSTMC1H1
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
// GRU
TEST
(
JITKernel
,
kGRUH1
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestGRUKernel
<
jit
::
kGRUH1
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestGRUKernel
<
jit
::
kGRUH1
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kGRUHtPart1
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestGRUKernel
<
jit
::
kGRUHtPart1
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestGRUKernel
<
jit
::
kGRUHtPart1
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kGRUHtPart2
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestGRUKernel
<
jit
::
kGRUHtPart2
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestGRUKernel
<
jit
::
kGRUHtPart2
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
JITKernel
,
kNCHW16CMulNC
)
{
namespace
jit
=
paddle
::
operators
::
jit
;
TestNCHW16CMulNCKernel
<
jit
::
kNCHW16CMulNC
,
float
,
paddle
::
platform
::
CPUPlace
>
();
TestNCHW16CMulNCKernel
<
jit
::
kNCHW16CMulNC
,
double
,
paddle
::
platform
::
CPUPlace
>
();
}
// TODO(yihua/TJ): add crf decoding and layer norm unit tests
TEST
(
JITKernel
,
pool
)
{
// TODO(TJ): add some test
}
paddle/fluid/operators/layer_norm_op.h
浏览文件 @
693e5e65
...
@@ -19,7 +19,7 @@ limitations under the License. */
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__)
!defined(__OSX__)
#include "paddle/fluid/operators/
math/jit_kernel
.h"
#include "paddle/fluid/operators/
jit/kernels
.h"
#endif
#endif
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
...
@@ -229,12 +229,12 @@ class LayerNormKernel : public framework::OpKernel<T> {
...
@@ -229,12 +229,12 @@ class LayerNormKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ
(
scale
->
numel
(),
right
);
PADDLE_ENFORCE_EQ
(
scale
->
numel
(),
right
);
PADDLE_ENFORCE_EQ
(
bias
->
numel
(),
right
);
PADDLE_ENFORCE_EQ
(
bias
->
numel
(),
right
);
const
auto
&
ker
=
math
::
jitkernel
::
KernelPool
::
Instance
()
auto
ker
=
.
template
Get
<
math
::
jitkernel
::
LayerNormKernel
<
T
>
>
(
jit
::
Get
<
jit
::
kLayerNorm
,
jit
::
LayerNormTuples
<
T
>
,
platform
::
CPUPlace
>
(
static_cast
<
int
>
(
right
)
);
right
);
ker
->
Compute
(
x
.
data
<
T
>
(),
out
.
data
<
T
>
(),
mean
->
data
<
T
>
(),
var
->
data
<
T
>
(),
ker
(
x
.
data
<
T
>
(),
out
.
data
<
T
>
(),
mean
->
data
<
T
>
(),
var
->
data
<
T
>
(),
scale
->
data
<
T
>
(),
bias
->
data
<
T
>
(),
static_cast
<
int
>
(
left
),
scale
->
data
<
T
>
(),
bias
->
data
<
T
>
(),
static_cast
<
int
>
(
left
),
static_cast
<
const
float
>
(
epsilon
)
);
static_cast
<
const
float
>
(
epsilon
),
right
);
#endif
#endif
}
}
};
};
...
...
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
693e5e65
...
@@ -73,12 +73,3 @@ if(WITH_GPU)
...
@@ -73,12 +73,3 @@ if(WITH_GPU)
endif
()
endif
()
cc_test
(
concat_test SRCS concat_test.cc DEPS concat_and_split
)
cc_test
(
concat_test SRCS concat_test.cc DEPS concat_and_split
)
cc_test
(
cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info
)
cc_test
(
cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info
)
set
(
JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc jit_kernel_layer_norm.cc
)
set
(
JIT_KERNEL_DEPS cpu_info cblas gflags enforce
)
if
(
WITH_XBYAK
)
list
(
APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc
)
list
(
APPEND JIT_KERNEL_DEPS xbyak
)
endif
()
cc_library
(
jit_kernel SRCS
${
JIT_KERNEL_SRCS
}
DEPS
${
JIT_KERNEL_DEPS
}
)
cc_test
(
jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel
)
paddle/fluid/operators/math/fc_compute.h
浏览文件 @
693e5e65
...
@@ -14,8 +14,8 @@ limitations under the License. */
...
@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -30,22 +30,21 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
...
@@ -30,22 +30,21 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
return
;
return
;
}
}
if
(
relu
)
{
if
(
relu
)
{
const
auto
&
vaddrelu
=
jitkernel
::
KernelPool
::
Instance
()
auto
compute
=
.
template
Get
<
jitkernel
::
VAddReluKernel
<
T
>
>
(
N
);
jit
::
Get
<
jit
::
kVAddRelu
,
jit
::
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
N
);
for
(
int
i
=
0
;
i
<
M
;
i
++
)
{
for
(
int
i
=
0
;
i
<
M
;
i
++
)
{
T
*
dst
=
Y
+
i
*
N
;
T
*
dst
=
Y
+
i
*
N
;
vaddrelu
->
C
ompute
(
B
,
dst
,
dst
,
N
);
c
ompute
(
B
,
dst
,
dst
,
N
);
}
}
}
else
{
}
else
{
const
auto
&
vadd
=
jitkernel
::
KernelPool
::
Instance
()
auto
compute
=
.
template
Get
<
jitkernel
::
VAddKernel
<
T
>
>
(
N
);
jit
::
Get
<
jit
::
kVAdd
,
jit
::
XYZNTuples
<
T
>
,
platform
::
CPUPlace
>
(
N
);
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#pragma omp parallel for
#endif
#endif
for
(
int
i
=
0
;
i
<
M
;
i
++
)
{
for
(
int
i
=
0
;
i
<
M
;
i
++
)
{
T
*
dst
=
Y
+
i
*
N
;
T
*
dst
=
Y
+
i
*
N
;
vadd
->
C
ompute
(
B
,
dst
,
dst
,
N
);
c
ompute
(
B
,
dst
,
dst
,
N
);
}
}
}
}
}
}
...
...
paddle/fluid/operators/math/jit_code.cc
已删除
100644 → 0
浏览文件 @
95cbe07c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_code.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
gen
{
using
namespace
platform
;
// NOLINT
bool
VXXJitCode
::
init
(
int
d
,
int
scalar_index
)
{
// It's not necessary to use avx512 since it would slow down the frequency
// and this kernel is not compute bound.
return
MayIUse
(
avx
)
&&
scalar_index
>=
0
&&
scalar_index
<=
2
;
}
void
VXXJitCode
::
generate
()
{
// do not need push stack, and do not need save avx512reg if do not use avx512
int
offset
=
0
;
if
(
with_relu_
)
{
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
}
if
(
scalar_index_
==
1
)
{
vbroadcastss
(
ymm_src1
,
ptr
[
param1
]);
}
else
if
(
scalar_index_
==
2
)
{
vbroadcastss
(
ymm_src2
,
ptr
[
param2
]);
}
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
if
(
scalar_index_
!=
1
)
{
vmovups
(
ymm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
ymm_src2
,
ptr
[
param2
+
offset
]);
}
if
(
type_
==
operand_type
::
mul
)
{
vmulps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
}
else
if
(
type_
==
operand_type
::
add
)
{
vaddps
(
ymm_dst
,
ymm_src1
,
ymm_src2
);
}
if
(
with_relu_
)
{
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_dst
);
}
vmovups
(
ptr
[
param3
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
int
rest
=
num_
%
YMM_FLOAT_BLOCK
;
while
(
rest
>
0
)
{
int
block
=
XMM_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
block
=
4
;
if
(
scalar_index_
!=
1
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
}
else
if
(
rest
>=
2
)
{
block
=
2
;
if
(
scalar_index_
!=
1
)
{
vmovq
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovq
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
}
else
{
block
=
1
;
if
(
scalar_index_
!=
1
)
{
vmovss
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmovss
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
}
switch
(
type_
)
{
case
operand_type
::
mul
:
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
break
;
case
operand_type
::
add
:
vaddps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
break
;
default:
break
;
}
if
(
with_relu_
)
{
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_dst
);
}
if
(
rest
>=
4
)
{
vmovups
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
else
if
(
rest
>=
2
)
{
vmovq
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
else
{
vmovss
(
ptr
[
param3
+
offset
],
xmm_dst
);
}
offset
+=
sizeof
(
float
)
*
block
;
rest
-=
block
;
}
ret
();
}
const
float
ALIGN32_BEG
exp_float_consts
[]
ALIGN32_END
=
{
REPEAT_8TIMES
(
1.
f
),
REPEAT_8TIMES
(
2.
f
),
REPEAT_8TIMES
(
0.5
f
),
REPEAT_8TIMES
(
EXP_HIG
),
REPEAT_8TIMES
(
EXP_LOW
),
REPEAT_8TIMES
(
CEPHES_LOG2EF
),
REPEAT_8TIMES
(
CEPHES_EXP_C1
),
REPEAT_8TIMES
(
CEPHES_EXP_C2
),
REPEAT_8TIMES
(
CEPHES_EXP_P0
),
REPEAT_8TIMES
(
CEPHES_EXP_P1
),
REPEAT_8TIMES
(
CEPHES_EXP_P2
),
REPEAT_8TIMES
(
CEPHES_EXP_P3
),
REPEAT_8TIMES
(
CEPHES_EXP_P4
),
REPEAT_8TIMES
(
CEPHES_EXP_P5
),
REPEAT_8TIMES
(
EXP_MAX_INPUT
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MAX
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MIN
)};
const
int
ALIGN32_BEG
exp_int_0x7f
[]
ALIGN32_END
=
{
REPEAT_8TIMES
(
0x7f
)};
int
ALIGN32_BEG
g_tmp_mem
[
16
]
ALIGN32_END
=
{
0
};
bool
VActJitCode
::
init
(
int
d
,
operand_type
type
)
{
// TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
return
MayIUse
(
avx
);
}
void
VActJitCode
::
generate
()
{
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
act
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
type_
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
int
rest
=
num_
%
YMM_FLOAT_BLOCK
;
while
(
rest
>
0
)
{
int
block
=
XMM_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
block
=
4
;
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
}
else
if
(
rest
>=
2
)
{
block
=
2
;
vmovq
(
xmm_src
,
ptr
[
param1
+
offset
]);
}
else
{
block
=
1
;
vmovss
(
xmm_src
,
ptr
[
param1
+
offset
]);
}
act
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
type_
);
if
(
rest
>=
4
)
{
vmovups
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
else
if
(
rest
>=
2
)
{
vmovq
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
else
{
vmovss
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
offset
+=
sizeof
(
float
)
*
block
;
rest
-=
block
;
}
ret
();
}
bool
LSTMJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
%
8
==
0
;
}
void
LSTMJitCode
::
generate
()
{
if
(
use_peephole_
)
{
preCode
();
}
reg64_t
reg_ptr_gates
=
rax
;
reg64_t
reg_ptr_ct_1
=
r9
;
reg64_t
reg_ptr_ct
=
r10
;
reg64_t
reg_ptr_ht
=
r11
;
reg64_t
reg_ptr_wp
=
r12
;
mov
(
reg_ptr_gates
,
ptr
[
param1
+
offsetof
(
lstm_t
,
gates
)]);
mov
(
reg_ptr_ct_1
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ct_1
)]);
mov
(
reg_ptr_ct
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ct
)]);
mov
(
reg_ptr_ht
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ht
)]);
if
(
use_peephole_
)
{
mov
(
reg_ptr_wp
,
ptr
[
param1
+
offsetof
(
lstm_t
,
wp
)]);
}
int
offset
=
0
;
int
d
=
num_
*
sizeof
(
float
);
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
/* gates: W_ch, W_ih, W_fh, W_oh */
ymm_t
ymm_c
=
ymm_t
(
0
);
ymm_t
ymm_i
=
ymm_t
(
1
);
ymm_t
ymm_f
=
ymm_t
(
2
);
ymm_t
ymm_o
=
ymm_t
(
3
);
ymm_t
ymm_ct_1
=
ymm_t
(
4
);
ymm_t
ymm_wp0
=
ymm_t
(
5
);
ymm_t
ymm_wp1
=
ymm_t
(
6
);
ymm_t
ymm_wp2
=
ymm_t
(
7
);
vmovups
(
ymm_c
,
ptr
[
reg_ptr_gates
+
offset
]);
vmovups
(
ymm_i
,
ptr
[
reg_ptr_gates
+
offset
+
d
]);
vmovups
(
ymm_f
,
ptr
[
reg_ptr_gates
+
offset
+
2
*
d
]);
vmovups
(
ymm_o
,
ptr
[
reg_ptr_gates
+
offset
+
3
*
d
]);
if
(
!
compute_c1h1_
)
{
vmovups
(
ymm_ct_1
,
ptr
[
reg_ptr_ct_1
+
offset
]);
}
if
(
use_peephole_
)
{
vmovups
(
ymm_wp0
,
ptr
[
reg_ptr_wp
+
offset
]);
vmovups
(
ymm_wp1
,
ptr
[
reg_ptr_wp
+
offset
+
d
]);
vmovups
(
ymm_wp2
,
ptr
[
reg_ptr_wp
+
offset
+
2
*
d
]);
}
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
// act_cand(c)
act
<
ymm_t
>
(
ymm_c
,
ymm_c
,
act_cand_
);
// act_gate(i) or act_gate(ct_1 * wp0 + i)
if
(
!
compute_c1h1_
&&
use_peephole_
)
{
vmulps
(
ymm_wp0
,
ymm_ct_1
,
ymm_wp0
);
vaddps
(
ymm_i
,
ymm_i
,
ymm_wp0
);
}
act
<
ymm_t
>
(
ymm_i
,
ymm_i
,
act_gate_
);
vmulps
(
ymm_c
,
ymm_c
,
ymm_i
);
if
(
!
compute_c1h1_
)
{
// act_gate(f) or act_gate(ct_1 * wp1 + f)
if
(
use_peephole_
)
{
vmulps
(
ymm_wp1
,
ymm_ct_1
,
ymm_wp1
);
vaddps
(
ymm_f
,
ymm_f
,
ymm_wp1
);
}
act
<
ymm_t
>
(
ymm_f
,
ymm_f
,
act_gate_
);
// ct
vmulps
(
ymm_f
,
ymm_f
,
ymm_ct_1
);
vaddps
(
ymm_f
,
ymm_f
,
ymm_c
);
}
/* H_t = act_cell(C_t) * act_gate(o) */
// act_cell(C_t)
ymm_t
ymm_ct
=
compute_c1h1_
?
ymm_c
:
ymm_f
;
ymm_t
ymm_tmp
=
ymm_i
;
act
<
ymm_t
>
(
ymm_tmp
,
ymm_ct
,
act_cell_
);
// act_gate(o) or act_gate(ct * wp2 + o)
if
(
use_peephole_
)
{
vmulps
(
ymm_wp2
,
ymm_ct
,
ymm_wp2
);
vaddps
(
ymm_o
,
ymm_o
,
ymm_wp2
);
}
act
<
ymm_t
>
(
ymm_o
,
ymm_o
,
act_gate_
);
// ht
vmulps
(
ymm_o
,
ymm_o
,
ymm_tmp
);
// save ct and ht
vmovups
(
ptr
[
reg_ptr_ct
+
offset
],
ymm_ct
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_o
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
if
(
use_peephole_
)
{
postCode
();
}
else
{
ret
();
}
}
bool
GRUJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
%
8
==
0
;
}
void
GRUJitCode
::
generate
()
{
reg64_t
reg_ptr_gates
=
rax
;
reg64_t
reg_ptr_ht_1
=
r9
;
reg64_t
reg_ptr_ht
=
r10
;
mov
(
reg_ptr_gates
,
ptr
[
param1
+
offsetof
(
gru_t
,
gates
)]);
mov
(
reg_ptr_ht_1
,
ptr
[
param1
+
offsetof
(
gru_t
,
ht_1
)]);
mov
(
reg_ptr_ht
,
ptr
[
param1
+
offsetof
(
gru_t
,
ht
)]);
ymm_t
ymm_one
=
ymm_t
(
0
);
if
(
id_
==
2
)
{
reg64_t
reg_ptr_tmp
=
r11
;
mov
(
reg_ptr_tmp
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_one
,
ptr
[
reg_ptr_tmp
+
OFFSET_EXP_ONE
]);
}
int
offset
=
0
;
int
d
=
num_
*
sizeof
(
float
);
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
ymm_t
ymm_u
=
ymm_t
(
1
);
ymm_t
ymm_r
=
ymm_t
(
2
);
ymm_t
ymm_s
=
ymm_t
(
3
);
ymm_t
ymm_ht_1
=
ymm_t
(
4
);
// W: {W_update, W_reset; W_state}
if
(
id_
==
0
||
id_
==
2
)
{
vmovups
(
ymm_u
,
ptr
[
reg_ptr_gates
+
offset
]);
vmovups
(
ymm_s
,
ptr
[
reg_ptr_gates
+
offset
+
2
*
d
]);
}
if
(
id_
==
1
)
{
vmovups
(
ymm_r
,
ptr
[
reg_ptr_gates
+
offset
+
d
]);
}
if
(
id_
==
1
||
id_
==
2
)
{
vmovups
(
ymm_ht_1
,
ptr
[
reg_ptr_ht_1
+
offset
]);
}
if
(
id_
==
0
)
{
// ht = act_gate(u) * act_cand(s)
act
<
ymm_t
>
(
ymm_u
,
ymm_u
,
act_gate_
);
act
<
ymm_t
>
(
ymm_s
,
ymm_s
,
act_cand_
);
vmulps
(
ymm_s
,
ymm_s
,
ymm_u
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_s
);
}
else
if
(
id_
==
1
)
{
// ht = act_gate(r) * ht_1
act
<
ymm_t
>
(
ymm_r
,
ymm_r
,
act_gate_
);
vmulps
(
ymm_r
,
ymm_r
,
ymm_ht_1
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_r
);
}
else
if
(
id_
==
2
)
{
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
ymm_t
ymm_one_inner
=
ymm_t
(
ymm_one
.
getIdx
());
act
<
ymm_t
>
(
ymm_u
,
ymm_u
,
act_gate_
);
act
<
ymm_t
>
(
ymm_s
,
ymm_s
,
act_cand_
);
vmulps
(
ymm_s
,
ymm_s
,
ymm_u
);
vsubps
(
ymm_u
,
ymm_one_inner
,
ymm_u
);
vmulps
(
ymm_u
,
ymm_ht_1
,
ymm_u
);
vaddps
(
ymm_u
,
ymm_s
,
ymm_u
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_u
);
}
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
ret
();
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_gen.cc
已删除
100644 → 0
浏览文件 @
95cbe07c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_gen.h"
#include <fstream>
#include <iostream>
#include <sstream>
#include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool
(
dump_jitcode
,
false
,
"Whether to dump the jitcode to file"
);
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
gen
{
constexpr
Xbyak
::
Operand
::
Code
g_abi_regs
[]
=
{
Xbyak
::
Operand
::
RBX
,
Xbyak
::
Operand
::
RBP
,
Xbyak
::
Operand
::
R12
,
Xbyak
::
Operand
::
R13
,
Xbyak
::
Operand
::
R14
,
Xbyak
::
Operand
::
R15
};
constexpr
int
num_g_abi_regs
=
sizeof
(
g_abi_regs
)
/
sizeof
(
g_abi_regs
[
0
]);
void
JitCode
::
preCode
()
{
for
(
int
i
=
0
;
i
<
num_g_abi_regs
;
++
i
)
{
push
(
Xbyak
::
Reg64
(
g_abi_regs
[
i
]));
}
if
(
platform
::
MayIUse
(
platform
::
avx512f
))
{
mov
(
reg_EVEX_max_8b_offt
,
2
*
EVEX_max_8b_offt
);
}
}
void
JitCode
::
postCode
()
{
for
(
int
i
=
0
;
i
<
num_g_abi_regs
;
++
i
)
{
pop
(
Xbyak
::
Reg64
(
g_abi_regs
[
num_g_abi_regs
-
1
-
i
]));
}
ret
();
}
void
JitCode
::
dumpCode
(
const
Xbyak
::
uint8
*
code
)
const
{
if
(
code
)
{
static
int
counter
=
0
;
std
::
ostringstream
filename
;
filename
<<
"paddle_jitcode_"
<<
name
()
<<
"."
<<
counter
<<
".bin"
;
counter
++
;
std
::
ofstream
fout
(
filename
.
str
(),
std
::
ios
::
out
);
if
(
fout
.
is_open
())
{
fout
.
write
(
reinterpret_cast
<
const
char
*>
(
code
),
getSize
());
fout
.
close
();
}
}
}
Xbyak
::
Address
JitCode
::
EVEX_compress_addr
(
Xbyak
::
Reg64
base
,
int
offt
,
bool
bcast
)
{
int
scale
=
0
;
if
(
EVEX_max_8b_offt
<=
offt
&&
offt
<
3
*
EVEX_max_8b_offt
)
{
offt
=
offt
-
2
*
EVEX_max_8b_offt
;
scale
=
1
;
}
else
if
(
3
*
EVEX_max_8b_offt
<=
offt
&&
offt
<
5
*
EVEX_max_8b_offt
)
{
offt
=
offt
-
4
*
EVEX_max_8b_offt
;
scale
=
2
;
}
auto
re
=
Xbyak
::
RegExp
()
+
base
+
offt
;
if
(
scale
)
{
re
=
re
+
reg_EVEX_max_8b_offt
*
scale
;
}
if
(
bcast
)
{
return
zword_b
[
re
];
}
else
{
return
zword
[
re
];
}
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_gen.h
已删除
100644 → 0
浏览文件 @
95cbe07c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <type_traits>
#include "paddle/fluid/platform/macros.h"
#define XBYAK_USE_MMAP_ALLOCATOR
#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"
DECLARE_bool
(
dump_jitcode
);
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
namespace
gen
{
#define DECLARE_JIT_CODE(codename) \
const char *name() const override { return #codename; }
// Application Binary Interface
constexpr
Xbyak
::
Operand
::
Code
abi_param1
(
Xbyak
::
Operand
::
RDI
),
abi_param2
(
Xbyak
::
Operand
::
RSI
),
abi_param3
(
Xbyak
::
Operand
::
RDX
),
abi_param4
(
Xbyak
::
Operand
::
RCX
),
abi_not_param1
(
Xbyak
::
Operand
::
RCX
);
class
JitCode
:
public
Xbyak
::
CodeGenerator
{
public:
explicit
JitCode
(
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
Xbyak
::
CodeGenerator
(
code_size
,
code_ptr
)
{}
virtual
~
JitCode
()
{}
virtual
const
char
*
name
()
const
=
0
;
virtual
void
generate
()
=
0
;
template
<
typename
FUNC
>
const
FUNC
getCode
()
{
this
->
generate
();
const
Xbyak
::
uint8
*
code
=
CodeGenerator
::
getCode
();
if
(
FLAGS_dump_jitcode
)
{
this
->
dumpCode
(
code
);
}
return
reinterpret_cast
<
const
FUNC
>
(
code
);
}
DISABLE_COPY_AND_ASSIGN
(
JitCode
);
protected:
Xbyak
::
Reg64
param1
{
abi_param1
};
const
int
EVEX_max_8b_offt
=
0x200
;
const
Xbyak
::
Reg64
reg_EVEX_max_8b_offt
=
rbp
;
void
preCode
();
void
postCode
();
void
dumpCode
(
const
Xbyak
::
uint8
*
code
)
const
;
void
L
(
const
char
*
label
)
{
Xbyak
::
CodeGenerator
::
L
(
label
);
}
void
L
(
const
Xbyak
::
Label
&
label
)
{
Xbyak
::
CodeGenerator
::
L
(
label
);
}
// Enhanced vector extension
Xbyak
::
Address
EVEX_compress_addr
(
Xbyak
::
Reg64
base
,
int
offt
,
bool
bcast
=
false
);
};
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel.cc
已删除
100644 → 0
浏览文件 @
95cbe07c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <iostream>
#include <string>
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
KernelPool
&
KernelPool
::
Instance
()
{
static
thread_local
KernelPool
g_jit_kernels
;
return
g_jit_kernels
;
}
std
::
shared_ptr
<
const
Kernel
>
KernelPool
::
Get
(
const
std
::
string
&
key
)
const
{
if
(
kers_
.
find
(
key
)
==
kers_
.
end
())
{
return
nullptr
;
}
return
kers_
.
at
(
key
);
}
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel.h
已删除
100644 → 0
浏览文件 @
95cbe07c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <memory> // for shared_ptr
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/macros.h"
// Note: Only support on CPU yet.
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
// TODO(TJ): remove me
typedef
enum
{
kLT8
,
kEQ8
,
kGT8LT16
,
kEQ16
,
kGT16
}
jit_block
;
class
Kernel
{
public:
Kernel
()
=
default
;
virtual
~
Kernel
()
=
default
;
// TODO(TJ): below members should be deprecated.
int
num_
{
0
};
int
end_
{
0
};
int
rest_
{
0
};
DISABLE_COPY_AND_ASSIGN
(
Kernel
);
};
class
KernelPool
{
public:
static
KernelPool
&
Instance
();
template
<
typename
Ker
,
typename
...
ARGS
>
std
::
shared_ptr
<
const
Ker
>
Get
(
ARGS
...
args
);
std
::
shared_ptr
<
const
Kernel
>
Get
(
const
std
::
string
&
key
)
const
;
private:
KernelPool
()
=
default
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
const
Kernel
>>
kers_
;
DISABLE_COPY_AND_ASSIGN
(
KernelPool
);
};
template
<
typename
T
>
class
VMulKernel
:
public
Kernel
{
public:
void
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
);
};
template
<
typename
T
>
class
VAddKernel
:
public
Kernel
{
public:
void
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
);
};
template
<
typename
T
>
class
VAddReluKernel
:
public
Kernel
{
public:
void
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
);
};
template
<
typename
T
>
class
VScalKernel
:
public
Kernel
{
public:
// y = a.*x
void
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
);
};
template
<
typename
T
>
class
VAddBiasKernel
:
public
Kernel
{
public:
// y = a.+x
void
(
*
Compute
)(
const
T
*
,
const
T
*
,
T
*
,
int
);
};
#ifdef PADDLE_WITH_MKLDNN
template
<
typename
T
>
class
EltwiseMulnChw16cNCKernel
:
public
Kernel
{
public:
// nChw16c = nChw16c .* NC
void
(
*
Compute
)(
const
float
*
,
const
float
*
,
float
*
,
int
,
int
);
};
#endif
template
<
typename
T
>
class
VActKernel
:
public
Kernel
{
public:
void
(
*
Compute
)(
const
T
*
,
T
*
,
int
);
};
template
<
typename
T
>
class
VReluKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VIdentityKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VExpKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VSigmoidKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VTanhKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
LSTMKernel
:
public
Kernel
{
public:
// compute c1 and h1 without c0 or h0
void
(
*
ComputeC1H1
)(
lstm_t
*
,
const
lstm_attr_t
*
);
void
(
*
ComputeCtHt
)(
lstm_t
*
,
const
lstm_attr_t
*
);
};
template
<
typename
T
>
class
GRUKernel
:
public
Kernel
{
public:
// compute h1 without h0
void
(
*
ComputeH1
)(
gru_t
*
,
const
gru_attr_t
*
);
void
(
*
ComputeHtPart1
)(
gru_t
*
,
const
gru_attr_t
*
);
void
(
*
ComputeHtPart2
)(
gru_t
*
,
const
gru_attr_t
*
);
};
template
<
typename
T
>
class
CRFDecodeKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
const
int
seq_len
,
const
T
*
x
,
const
T
*
w
,
T
*
alpha
,
int
*
track
)
const
=
0
;
};
template
<
typename
T
>
class
LayerNormKernel
:
public
Kernel
{
public:
virtual
void
Compute
(
T
*
x
,
T
*
out
,
T
*
mean
,
T
*
var
,
const
T
*
scale
,
const
T
*
bias
,
int
height
,
const
float
epsilon
)
const
=
0
;
};
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_blas.cc
已删除
100644 → 0
浏览文件 @
95cbe07c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
#ifdef PADDLE_WITH_MKLML
template
<
typename
T
>
void
VMulMKL
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
template
<
>
void
VMulMKL
<
float
>
(
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
n
)
{
platform
::
dynload
::
vsMul
(
n
,
x
,
y
,
z
);
}
template
<
>
void
VMulMKL
<
double
>
(
const
double
*
x
,
const
double
*
y
,
double
*
z
,
int
n
)
{
platform
::
dynload
::
vdMul
(
n
,
x
,
y
,
z
);
}
template
<
typename
T
>
void
VAddMKL
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int
n
);
template
<
>
void
VAddMKL
<
float
>
(
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
n
)
{
platform
::
dynload
::
vsAdd
(
n
,
x
,
y
,
z
);
}
template
<
>
void
VAddMKL
<
double
>
(
const
double
*
x
,
const
double
*
y
,
double
*
z
,
int
n
)
{
platform
::
dynload
::
vdAdd
(
n
,
x
,
y
,
z
);
}
template
<
typename
T
>
void
VScalMKL
(
const
T
*
a
,
const
T
*
x
,
T
*
y
,
int
n
);
template
<
>
void
VScalMKL
<
float
>
(
const
float
*
a
,
const
float
*
x
,
float
*
y
,
int
n
)
{
if
(
x
==
y
)
{
platform
::
dynload
::
cblas_sscal
(
n
,
*
a
,
y
,
1
);
}
else
{
refer
::
VScal
<
float
>
(
a
,
x
,
y
,
n
);
}
}
template
<
>
void
VScalMKL
<
double
>
(
const
double
*
a
,
const
double
*
x
,
double
*
y
,
int
n
)
{
if
(
x
==
y
)
{
platform
::
dynload
::
cblas_dscal
(
n
,
*
a
,
y
,
1
);
}
else
{
refer
::
VScal
<
double
>
(
a
,
x
,
y
,
n
);
}
}
#endif
/* VMUL JitKernel */
template
<
typename
T
>
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VMulKernelImpl
(
int
d
)
:
VMulKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
// roughly estimate the size of code
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
mul
,
0
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
#ifdef PADDLE_WITH_MKLML
if
(
useMKL
(
d
))
{
this
->
Compute
=
VMulMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
refer
::
VMul
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VXXJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VMulKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VXXJitCode
::
init
(
d
);
}
#endif
#ifdef PADDLE_WITH_MKLML
template
<
>
bool
VMulKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
platform
::
MayIUse
(
platform
::
avx512f
)
&&
d
>
512
;
}
template
<
>
bool
VMulKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
/* VAdd JitKernel */
template
<
typename
T
>
class
VAddKernelImpl
:
public
VAddKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VAddKernelImpl
(
int
d
)
:
VAddKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
#ifdef PADDLE_WITH_MKLML
if
(
useMKL
(
d
))
{
this
->
Compute
=
VAddMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
refer
::
VAdd
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VXXJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VAddKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VXXJitCode
::
init
(
d
);
}
#endif
#ifdef PADDLE_WITH_MKLML
template
<
>
bool
VAddKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
template
<
>
bool
VAddKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
#ifdef PADDLE_WITH_MKLDNN
/* EltwiseMul for nChw16c & NC inputs JitKernel */
template
<
typename
T
>
class
EltwiseMulnChw16cNCKernelImpl
:
public
math
::
jitkernel
::
EltwiseMulnChw16cNCKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
EltwiseMulnChw16cNCKernelImpl
(
int
d
)
:
EltwiseMulnChw16cNCKernel
<
T
>
()
{
using
mul_func_t
=
void
(
*
)(
const
float
*
,
const
float
*
,
float
*
,
int
,
int
);
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
// roughly estimate the size of code
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
4
*
8
;
sz
=
sz
>
4096
?
sz
:
4096
;
jitcode_
.
reset
(
new
gen
::
EltwiseMulnChw16cNC
(
sz
));
this
->
Compute
=
(
mul_func_t
)
jitcode_
->
getCode
();
return
;
}
#endif
PADDLE_THROW
(
"This kernel shouldn't be used in Non-Xbyak, Non-MKL-DNN "
"environemnt"
);
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
EltwiseMulnChw16cNC
>
jitcode_
{
nullptr
};
};
template
<
>
bool
EltwiseMulnChw16cNCKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
true
;
}
#endif
#endif
/* VAddRelu JitKernel */
template
<
typename
T
>
class
VAddReluKernelImpl
:
public
VAddReluKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VAddReluKernelImpl
(
int
d
)
:
VAddReluKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
true
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
this
->
Compute
=
refer
::
VAddRelu
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VXXJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VAddReluKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VXXJitCode
::
init
(
d
);
}
#endif
/* VScal JitKernel */
template
<
typename
T
>
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VScalKernelImpl
(
int
d
)
:
VScalKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
mul
,
1
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
#ifdef PADDLE_WITH_MKLML
if
(
useMKL
(
d
))
{
this
->
Compute
=
VScalMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
refer
::
VScal
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VXXJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VScalKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VXXJitCode
::
init
(
d
,
1
);
}
#endif
#ifdef PADDLE_WITH_MKLML
template
<
>
bool
VScalKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
template
<
>
bool
VScalKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
/* VAddBias JitKernel */
template
<
typename
T
>
class
VAddBiasKernelImpl
:
public
VAddBiasKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VAddBiasKernelImpl
(
int
d
)
:
VAddBiasKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
1
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
this
->
Compute
=
refer
::
VAddBias
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VXXJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VAddBiasKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VXXJitCode
::
init
(
d
,
1
);
}
#endif
/* VRelu JitKernel */
template
<
typename
T
>
class
VReluKernelImpl
:
public
VReluKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VReluKernelImpl
(
int
d
)
:
VReluKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
/* init size */
+
d
/
YMM_FLOAT_BLOCK
*
4
/* instructions */
*
8
/* average bytes for each instruction */
;
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
relu
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
this
->
Compute
=
refer
::
VRelu
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VActJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VReluKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VActJitCode
::
init
(
d
,
gen
::
operand_type
::
relu
);
}
#endif
/* An empty JitKernel */
template
<
typename
T
>
class
VIdentityKernelImpl
:
public
VIdentityKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VIdentityKernelImpl
(
int
d
)
:
VIdentityKernel
<
T
>
()
{
this
->
Compute
=
refer
::
VIdentity
<
T
>
;
}
};
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL
(
vaddrelu
,
VAddReluKernel
);
REGISTER_JITKERNEL
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL
(
vaddbias
,
VAddBiasKernel
);
REGISTER_JITKERNEL
(
vrelu
,
VReluKernel
);
REGISTER_JITKERNEL
(
videntity
,
VIdentityKernel
);
#ifdef PADDLE_WITH_MKLDNN
REGISTER_JITKERNEL
(
eltwise_mul_nchw16c
,
EltwiseMulnChw16cNCKernel
);
#endif
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
已删除
100644 → 0
浏览文件 @
95cbe07c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <limits>
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
/* CRF Decode JitKernel */
template
<
typename
T
,
platform
::
cpu_isa_t
isa
,
jit_block
>
class
CRFDecodeKernelImpl
:
public
CRFDecodeKernel
<
T
>
{
public:
explicit
CRFDecodeKernelImpl
(
int
tag_num
)
:
CRFDecodeKernel
<
T
>
()
{
this
->
num_
=
tag_num
;
}
void
Compute
(
const
int
seq_len
,
const
T
*
x
,
const
T
*
w
,
T
*
alpha
,
int
*
track
)
const
override
{
constexpr
int
state_trans_base_idx
=
2
;
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
alpha
[
i
]
=
w
[
i
]
+
x
[
i
];
}
for
(
int
k
=
1
;
k
<
seq_len
;
++
k
)
{
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
T
max_score
=
-
std
::
numeric_limits
<
T
>::
max
();
int
max_j
=
0
;
for
(
int
j
=
0
;
j
<
this
->
num_
;
++
j
)
{
T
score
=
alpha
[(
k
-
1
)
*
this
->
num_
+
j
]
+
w
[(
j
+
state_trans_base_idx
)
*
this
->
num_
+
i
];
if
(
score
>
max_score
)
{
max_score
=
score
;
max_j
=
j
;
}
}
alpha
[
k
*
this
->
num_
+
i
]
=
max_score
+
x
[
k
*
this
->
num_
+
i
];
track
[
k
*
this
->
num_
+
i
]
=
max_j
;
}
}
}
};
#define INIT_ALPHA(step_size) \
/* Setup the alpha initial value.*/
\
int i_offset = 0; \
int last_offset = this->rest_ - step_size; \
for (int i = 0; i <= this->end_; ++i) { \
/* weights, input and alpha values. */
\
__m256 w_content, x_content, alpha_content; \
/* Load the relevant data into the variables from un-aligned address.*/
\
w_content = _mm256_loadu_ps(w + i_offset); \
x_content = _mm256_loadu_ps(x + i_offset); \
alpha_content = _mm256_add_ps(w_content, x_content); \
_mm256_storeu_ps(alpha + i_offset, alpha_content); \
i_offset += step_size; \
if (i == this->end_ - 1) { \
if (this->rest_ > 0) { \
i_offset += last_offset; \
} else { \
break; \
} \
} \
}
#define UPDATE_ALPHA(step_size) \
/* Update the alpha and track values. */
\
__m256 x_content = _mm256_loadu_ps(x + seq_offset + this->num_ + j_offset); \
max_score = _mm256_add_ps(max_score, x_content); \
_mm256_storeu_ps(alpha + seq_offset + this->num_ + j_offset, max_score); \
_mm256_storeu_si256( \
reinterpret_cast<__m256i*>(track + seq_offset + this->num_ + j_offset), \
max_j); \
/* Calculate the offset of next step*/
\
j_offset += step_size; \
if (j == this->end_ - 1) { \
if (this->rest_ > 0) { \
j_offset += last_offset; \
} else { \
break; \
} \
}
#define INTRIAVX_FLOAT(block) \
template <> \
CRFDecodeKernelImpl<float, platform::avx, block>::CRFDecodeKernelImpl( \
int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ / YMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, platform::avx, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(YMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
for (int k = 1; k < seq_len; ++k) { \
int j_offset = 0; \
for (int j = 0; j <= this->end_; ++j) { \
/* Initialize the variables of maximum score and location.*/
\
__m256 max_score = _mm256_set1_ps(-std::numeric_limits<float>::max()); \
__m256i max_j = _mm256_set1_epi32(0); \
/* Calculate the offset of transition_weights.*/
\
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
for (int i = 0; i < this->num_; ++i) { \
/* Initalize the content of alpha variable with related offset.*/
\
__m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i); \
/* Obtain the content of weights from un-aligned address.*/
\
__m256 w_content = _mm256_loadu_ps(w + trans_offset); \
__m256 score_v = _mm256_add_ps(alpha_content, w_content); \
__m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS); \
/* According to the mask value, update the index of the max_score.*/
\
/* AVX instructions.*/
\
__m128i lo_max_j = _mm256_extractf128_si256(max_j, 0); \
__m128i hi_max_j = _mm256_extractf128_si256(max_j, 1); \
__m128i lo_mask = _mm256_extractf128_si256(*(__m256i*)&mask, 0); \
__m128i hi_mask = _mm256_extractf128_si256(*(__m256i*)&mask, 1); \
lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j); \
hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j); \
lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i)); \
hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i)); \
lo_max_j = _mm_or_si128(lo_mask, lo_max_j); \
hi_max_j = _mm_or_si128(hi_mask, hi_max_j); \
max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0); \
max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1); \
/* AVX done*/
\
/* Update the max_score value.*/
\
max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \
} \
UPDATE_ALPHA(YMM_FLOAT_BLOCK) \
} \
seq_offset += this->num_; \
} \
}
#define INTRIAVX2_FLOAT(isa, block) \
template <> \
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ / YMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, isa, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(YMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
for (int k = 1; k < seq_len; ++k) { \
int j_offset = 0; \
for (int j = 0; j <= this->end_; ++j) { \
/* Initialize the variables of maximum score and location.*/
\
__m256 max_score = _mm256_set1_ps(-std::numeric_limits<float>::max()); \
__m256i max_j = _mm256_set1_epi32(0); \
/* Calculate the offset of transition_weights.*/
\
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
for (int i = 0; i < this->num_; ++i) { \
/* Initalize the content of alpha variable with related offset.*/
\
__m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i); \
/* Obtain the content of weights from un-aligned address.*/
\
__m256 w_content = _mm256_loadu_ps(w + trans_offset); \
__m256 score_v = _mm256_add_ps(alpha_content, w_content); \
__m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS); \
/* According to the mask value, update the index of the max_score.*/
\
/* AVX2 instructions.*/
\
max_j = _mm256_or_si256( \
_mm256_andnot_si256((__m256i)mask, max_j), \
_mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i))); \
/* Update the max_score value.*/
\
max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \
} \
UPDATE_ALPHA(YMM_FLOAT_BLOCK) \
} \
seq_offset += this->num_; \
} \
}
#define INTRIAVX512_FLOAT(block) \
template <> \
CRFDecodeKernelImpl<float, platform::avx512f, block>::CRFDecodeKernelImpl( \
int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ / ZMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % ZMM_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, platform::avx512f, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(ZMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
for (int k = 1; k < seq_len; ++k) { \
int j_offset = 0; \
for (int j = 0; j <= this->end_; ++j) { \
/* Initialize the variables of maximum score and location.*/
\
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<float>::max()); \
__m512i max_j = _mm512_setzero_si512(); \
/* Calculate the offset of transition_weights.*/
\
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
for (int i = 0; i < this->num_; ++i) { \
/* Initalize the content of alpha variable with related offset.*/
\
__m512 alpha_content = _mm512_set1_ps(*(alpha + seq_offset + i)); \
/* Obtain the content of weights from un-aligned address.*/
\
__m512 w_content = _mm512_loadu_ps(w + trans_offset); \
__m512 score_v = _mm512_add_ps(alpha_content, w_content); \
__mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS); \
/* AVX512 instructions.*/
\
max_j = _mm512_mask_set1_epi32(max_j, mask, i); \
/* Update the max_score value.*/
\
max_score = _mm512_max_ps(max_score, score_v); \
trans_offset += this->num_; \
} \
/* Update the alpha and track values.*/
\
__m512 x_content = \
_mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); \
max_score = _mm512_add_ps(max_score, x_content); \
_mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, \
max_score); \
_mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + \
this->num_ + j_offset), \
max_j); \
/* Calculate the offset of next step*/
\
j_offset += ZMM_FLOAT_BLOCK; \
if (j == this->end_ - 1) { \
if (this->rest_ > 0) { \
j_offset += last_offset; \
} else { \
break; \
} \
} \
} \
seq_offset += this->num_; \
} \
}
#ifdef __AVX__
INTRIAVX_FLOAT
(
kEQ8
);
INTRIAVX_FLOAT
(
kGT8LT16
);
INTRIAVX_FLOAT
(
kEQ16
);
INTRIAVX_FLOAT
(
kGT16
);
#endif
#ifdef __AVX2__
INTRIAVX2_FLOAT
(
platform
::
avx2
,
kEQ8
);
INTRIAVX2_FLOAT
(
platform
::
avx2
,
kGT8LT16
);
INTRIAVX2_FLOAT
(
platform
::
avx2
,
kEQ16
);
INTRIAVX2_FLOAT
(
platform
::
avx2
,
kGT16
);
#endif
#ifdef __AVX512F__
INTRIAVX2_FLOAT
(
platform
::
avx512f
,
kEQ8
);
INTRIAVX2_FLOAT
(
platform
::
avx512f
,
kGT8LT16
);
INTRIAVX512_FLOAT
(
kEQ16
);
INTRIAVX512_FLOAT
(
kGT16
);
#endif
#undef INTRIAVX512_FLOAT
#undef INTRIAVX2_FLOAT
#undef INTRIAVX_FLOAT
#undef INIT_ALPHA
#undef UPDATE_ALPHA
REGISTER_JITKERNEL_DEPRECATED
(
crf_decode
,
CRFDecodeKernel
);
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_exp.cc
已删除
100644 → 0
浏览文件 @
95cbe07c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
#ifdef PADDLE_WITH_MKLML
// try to use MKL to speedup
template
<
typename
T
>
void
VExpMKL
(
const
T
*
x
,
T
*
y
,
int
n
);
template
<
>
void
VExpMKL
<
float
>
(
const
float
*
x
,
float
*
y
,
int
n
)
{
platform
::
dynload
::
vsExp
(
n
,
x
,
y
);
}
template
<
>
void
VExpMKL
<
double
>
(
const
double
*
x
,
double
*
y
,
int
n
)
{
platform
::
dynload
::
vdExp
(
n
,
x
,
y
);
}
template
<
typename
T
>
void
VSigmoidMKL
(
const
T
*
x
,
T
*
y
,
int
n
)
{
const
T
min
=
SIGMOID_THRESHOLD_MIN
;
const
T
max
=
SIGMOID_THRESHOLD_MAX
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
static_cast
<
T
>
(
0
)
-
y
[
i
];
}
VExpMKL
(
y
,
y
,
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
y
[
i
]);
}
}
template
<
typename
T
>
void
VTanhMKL
(
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
2
)
*
x
[
i
];
}
VSigmoidMKL
(
y
,
y
,
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
2
)
*
y
[
i
]
-
static_cast
<
T
>
(
1
);
}
}
#endif
/* VExp JitKernel */
template
<
typename
T
>
class
VExpKernelImpl
:
public
VExpKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VExpKernelImpl
(
int
d
)
:
VExpKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
70
*
8
;
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
exp
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
#ifdef PADDLE_WITH_MKLML
if
(
useMKL
(
d
))
{
this
->
Compute
=
VExpMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
refer
::
VExp
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VActJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VExpKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VActJitCode
::
init
(
d
,
gen
::
operand_type
::
exp
);
}
#endif
#ifdef PADDLE_WITH_MKLML
template
<
>
bool
VExpKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
template
<
>
bool
VExpKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
/* VSigmoid JitKernel */
template
<
typename
T
>
class
VSigmoidKernelImpl
:
public
VSigmoidKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VSigmoidKernelImpl
(
int
d
)
:
VSigmoidKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
82
*
8
;
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
sigmoid
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
#ifdef PADDLE_WITH_MKLML
// strictly it's a better impl with MKL, then is refer
if
(
useMKL
(
d
))
{
this
->
Compute
=
VSigmoidMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
refer
::
VSigmoid
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VActJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VSigmoidKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VActJitCode
::
init
(
d
,
gen
::
operand_type
::
sigmoid
);
}
#endif
#ifdef PADDLE_WITH_MKLML
template
<
>
bool
VSigmoidKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
template
<
>
bool
VSigmoidKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
/* VTanh JitKernel */
template
<
typename
T
>
class
VTanhKernelImpl
:
public
VTanhKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VTanhKernelImpl
(
int
d
)
:
VTanhKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
84
*
8
;
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
tanh
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
#ifdef PADDLE_WITH_MKLML
// strictly it's a better impl with MKL, then is refer
if
(
useMKL
(
d
))
{
this
->
Compute
=
VTanhMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
refer
::
VTanh
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VActJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VTanhKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VActJitCode
::
init
(
d
,
gen
::
operand_type
::
tanh
);
}
#endif
#ifdef PADDLE_WITH_MKLML
template
<
>
bool
VTanhKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
template
<
>
bool
VTanhKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
REGISTER_JITKERNEL
(
vexp
,
VExpKernel
);
REGISTER_JITKERNEL
(
vsigmoid
,
VSigmoidKernel
);
REGISTER_JITKERNEL
(
vtanh
,
VTanhKernel
);
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_impl.h
已删除
100644 → 0
浏览文件 @
95cbe07c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <type_traits>
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
typedef
struct
{
void
*
gates
;
// gates: W_ch, W_ih, W_fh, W_oh
const
void
*
ct_1
;
void
*
ct
;
void
*
ht
;
/* weight_peephole and checked data are only used in peephole*/
const
void
*
wp
{
nullptr
};
void
*
checked
{
nullptr
};
}
lstm_t
;
typedef
struct
{
void
*
gates
;
// gates: {W_update, W_reset; W_state}
const
void
*
ht_1
;
void
*
ht
;
}
gru_t
;
struct
rnn_attr_s
{
int
d
;
std
::
string
act_gate
,
act_cand
;
rnn_attr_s
()
=
default
;
rnn_attr_s
(
int
_d
,
const
std
::
string
&
_act_gate
,
const
std
::
string
&
_act_cand
)
:
d
(
_d
),
act_gate
(
_act_gate
),
act_cand
(
_act_cand
)
{}
};
struct
lstm_attr_s
:
public
rnn_attr_s
{
bool
use_peephole
;
std
::
string
act_cell
;
lstm_attr_s
()
=
default
;
lstm_attr_s
(
int
_d
,
const
std
::
string
&
_act_gate
,
const
std
::
string
&
_act_cand
,
const
std
::
string
&
_act_cell
,
bool
_use_peephole
=
false
)
:
rnn_attr_s
(
_d
,
_act_gate
,
_act_cand
),
use_peephole
(
_use_peephole
),
act_cell
(
_act_cell
)
{}
};
typedef
struct
rnn_attr_s
gru_attr_t
;
typedef
struct
lstm_attr_s
lstm_attr_t
;
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_layer_norm.cc
已删除
100644 → 0
浏览文件 @
95cbe07c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <math.h>
#include <limits>
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
/* Layer Norm JitKernel */
template
<
typename
T
,
platform
::
cpu_isa_t
isa
,
jit_block
>
class
LayerNormKernelImpl
:
public
LayerNormKernel
<
T
>
{
public:
explicit
LayerNormKernelImpl
(
int
right
)
:
LayerNormKernel
<
T
>
()
{
this
->
num_
=
right
;
}
void
Compute
(
T
*
x
,
T
*
out
,
T
*
mean
,
T
*
var
,
const
T
*
scale
,
const
T
*
bias
,
int
height
,
const
float
epsilon
)
const
override
{
// get mean
for
(
int
i
=
0
;
i
<
height
;
i
++
)
{
T
sum
=
0.0
;
int
offset
=
i
*
this
->
num_
;
for
(
int
j
=
0
;
j
<
this
->
num_
;
j
++
)
{
sum
+=
x
[
offset
+
j
];
}
mean
[
i
]
=
sum
/
this
->
num_
;
}
// get variance
for
(
int
i
=
0
;
i
<
height
;
i
++
)
{
T
sum
=
0.0
;
int
offset
=
i
*
this
->
num_
;
for
(
int
j
=
0
;
j
<
this
->
num_
;
j
++
)
{
sum
+=
(
x
[
offset
+
j
]
-
mean
[
i
])
*
(
x
[
offset
+
j
]
-
mean
[
i
]);
}
var
[
i
]
=
sum
/
this
->
num_
;
}
for
(
int
i
=
0
;
i
<
height
;
i
++
)
{
int
offset
=
i
*
this
->
num_
;
T
sqrt_var
=
sqrt
(
var
[
i
]
+
(
T
)
epsilon
);
for
(
int
j
=
0
;
j
<
this
->
num_
;
j
++
)
{
out
[
offset
+
j
]
=
(
x
[
offset
+
j
]
-
mean
[
i
])
/
sqrt_var
;
}
}
if
(
scale
)
{
for
(
int
i
=
0
;
i
<
height
;
i
++
)
{
int
offset
=
i
*
this
->
num_
;
for
(
int
j
=
0
;
j
<
this
->
num_
;
j
++
)
{
out
[
offset
+
j
]
*=
scale
[
j
];
}
}
}
if
(
bias
)
{
for
(
int
i
=
0
;
i
<
height
;
i
++
)
{
int
offset
=
i
*
this
->
num_
;
for
(
int
j
=
0
;
j
<
this
->
num_
;
j
++
)
{
out
[
offset
+
j
]
+=
bias
[
j
];
}
}
}
}
};
#define INTRIAVX_FLOAT(isa, jit_block) \
template <> \
LayerNormKernelImpl<float, isa, jit_block>::LayerNormKernelImpl(int right) \
: LayerNormKernel<float>() { \
this->num_ = right; \
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
this->end_ = this->num_ - this->rest_; \
} \
template <> \
void LayerNormKernelImpl<float, isa, jit_block>::Compute( \
float* x, float* out, float* mean, float* var, const float* scale, \
const float* bias, int height, const float epsilon) const { \
__m256 sum; \
__m256 mean_vec, var_vec; \
__m128 hi, lo; \
__m256 tmp; \
size_t offset; \
size_t j; \
size_t block = YMM_FLOAT_BLOCK; \
__m256 reverse_num_vec = \
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \
__m256 epsilon_vec = _mm256_set1_ps(epsilon); \
int rest_mask = \
((-1) & (~((~0U) >> (sizeof(int) * 8 - (YMM_FLOAT_BLOCK - rest_))))) & \
0x0ff; \
__m256i mask_vec = _mm256_set_epi32( \
rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0, \
rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0, \
rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0, \
rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0); \
\
for (int i = 0; i < height; ++i) { \
offset = i * this->num_; \
\
/* get mean */
\
sum = _mm256_setzero_ps(); \
for (j = offset; j < end_ + offset; j += block) { \
sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j)); \
} \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
tmp = _mm256_loadu_ps((const float*)x + j); \
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, *(__m256*)&mask_vec); \
sum = _mm256_add_ps(sum, tmp); \
} \
hi = _mm256_extractf128_ps(sum, 1); \
lo = _mm256_extractf128_ps(sum, 0); \
sum = _mm256_add_ps( \
sum, _mm256_insertf128_ps( \
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \
sum = _mm256_hadd_ps(sum, sum); \
sum = _mm256_hadd_ps(sum, sum); \
mean_vec = _mm256_mul_ps(sum, reverse_num_vec); \
mean[i] = *reinterpret_cast<float*>(&mean_vec); \
\
/* get variance */
\
sum = _mm256_setzero_ps(); \
for (j = offset; j < end_ + offset; j += block) { \
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
tmp = _mm256_mul_ps(tmp, tmp); \
sum = _mm256_add_ps(sum, tmp); \
} \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
tmp = _mm256_mul_ps(tmp, tmp); \
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, *(__m256*)&mask_vec); \
sum = _mm256_add_ps(sum, tmp); \
} \
hi = _mm256_extractf128_ps(sum, 1); \
lo = _mm256_extractf128_ps(sum, 0); \
sum = _mm256_add_ps( \
sum, _mm256_insertf128_ps( \
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \
sum = _mm256_hadd_ps(sum, sum); \
sum = _mm256_hadd_ps(sum, sum); \
var_vec = _mm256_mul_ps(sum, reverse_num_vec); \
var[i] = *reinterpret_cast<float*>(&var_vec); \
\
/* get x_norm and calculate output*/
\
for (j = offset; j < end_ + offset; j += block) { \
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
tmp = _mm256_div_ps( \
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp); \
} \
if (rest_ != 0) { \
j = offset + num_ - block; \
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
tmp = _mm256_div_ps( \
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp); \
} \
\
if (scale) { \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
tmp = _mm256_loadu_ps((const float*)out + j); \
} \
for (j = offset; j < end_ + offset; j += block) { \
_mm256_storeu_ps( \
reinterpret_cast<float*>(out) + j, \
_mm256_mul_ps( \
_mm256_loadu_ps((const float*)out + j), \
_mm256_loadu_ps((const float*)scale + j - offset))); \
} \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
_mm256_storeu_ps( \
reinterpret_cast<float*>(out) + j, \
_mm256_mul_ps( \
tmp, _mm256_loadu_ps((const float*)scale + j - offset))); \
} \
} \
\
if (bias) { \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
tmp = _mm256_loadu_ps((const float*)out + j); \
} \
for (j = offset; j < end_ + offset; j += block) { \
_mm256_storeu_ps( \
reinterpret_cast<float*>(out) + j, \
_mm256_add_ps( \
_mm256_loadu_ps((const float*)out + j), \
_mm256_loadu_ps((const float*)bias + j - offset))); \
} \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
_mm256_storeu_ps( \
reinterpret_cast<float*>(out) + j, \
_mm256_add_ps( \
tmp, _mm256_loadu_ps((const float*)bias + j - offset))); \
} \
} \
} \
}
#ifdef __AVX__
INTRIAVX_FLOAT
(
platform
::
avx
,
kEQ8
);
INTRIAVX_FLOAT
(
platform
::
avx
,
kGT8LT16
);
INTRIAVX_FLOAT
(
platform
::
avx
,
kEQ16
);
INTRIAVX_FLOAT
(
platform
::
avx
,
kGT16
);
INTRIAVX_FLOAT
(
platform
::
avx2
,
kEQ8
);
INTRIAVX_FLOAT
(
platform
::
avx2
,
kGT8LT16
);
INTRIAVX_FLOAT
(
platform
::
avx2
,
kEQ16
);
INTRIAVX_FLOAT
(
platform
::
avx2
,
kGT16
);
INTRIAVX_FLOAT
(
platform
::
avx512f
,
kEQ8
);
INTRIAVX_FLOAT
(
platform
::
avx512f
,
kGT8LT16
);
INTRIAVX_FLOAT
(
platform
::
avx512f
,
kEQ16
);
INTRIAVX_FLOAT
(
platform
::
avx512f
,
kGT16
);
#endif
#undef INTRIAVX_FLOAT
REGISTER_JITKERNEL_DEPRECATED
(
layer_norm
,
LayerNormKernel
);
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_macro.h
已删除
100644 → 0
浏览文件 @
95cbe07c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
#define JITKERNEL_DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(int d) { \
std::string key(#ker_key "f"); \
if (useJIT(d)) { \
/* only jit code need record d*/
\
return key + "jit" + std::to_string(d); \
} else if (useMKL(d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
} \
template <> \
std::string ker_class##Impl<double>::name(int d) { \
std::string key(#ker_key "d"); \
/* jit code do not support double yet*/
\
if (useMKL(d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
}
#define JITKERNEL_DECLARE(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, int>(int d)
#define JITKERNEL_FIND_KEY(ker_class, ker_dtype) \
std::string key = ker_class##Impl<ker_dtype>::name(d)
#define JITKERNEL_IMPL(ker_class, ker_dtype) \
p = std::dynamic_pointer_cast<ker_class<ker_dtype>>( \
std::make_shared<ker_class##Impl<ker_dtype>>(d))
#define REGISTER_JITKERNEL_WITH_DTYPE(ker_class, ker_dtype, marco_declare, \
macro_find_key, macro_impl) \
marco_declare(ker_class, ker_dtype) { \
macro_find_key(ker_class, ker_dtype); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
macro_impl(ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
}
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \
marco_declare, macro_find_key, macro_impl) \
marco_define_name(ker_key, ker_class); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, marco_declare, \
macro_find_key, macro_impl); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, marco_declare, \
macro_find_key, macro_impl)
#define REGISTER_JITKERNEL(ker_key, ker_class) \
REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \
JITKERNEL_DECLARE, JITKERNEL_FIND_KEY, \
JITKERNEL_IMPL)
// TODO(TJ): below defines are deprecated, would be remove recently
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d < YMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kLT8); \
} else if (d == YMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ8); \
} else if (d > YMM_FLOAT_BLOCK && d < ZMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kGT8LT16); \
} else if (d == ZMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ16); \
} else { \
macro_(ker, dtype, isa, kGT16); \
}
#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \
if (platform::MayIUse(platform::avx512f)) { \
SEARCH_BLOCK(macro_, ker, dtype, platform::avx512f); \
} else if (platform::MayIUse(platform::avx2)) { \
SEARCH_BLOCK(macro_, ker, dtype, platform::avx2); \
} else if (platform::MayIUse(platform::avx)) { \
SEARCH_BLOCK(macro_, ker, dtype, platform::avx); \
} else { \
SEARCH_BLOCK(macro_, ker, dtype, platform::isa_any); \
}
#define JITKERNEL_KEY(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d)
#define JITKERNEL_NEW_IMPL_DEPRECATED(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(d))
#define JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, ker_dtype, \
dtype_key, marco_declare, macro_key, \
macro_impl) \
marco_declare(ker_class, ker_dtype) { \
std::string key = macro_key(ker_key, dtype_key); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
SEARCH_ISA_BLOCK(macro_impl, ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
}
#define REGISTER_JITKERNEL_DEPRECATED(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, \
JITKERNEL_DECLARE, JITKERNEL_KEY, \
JITKERNEL_NEW_IMPL_DEPRECATED); \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \
JITKERNEL_DECLARE, JITKERNEL_KEY, \
JITKERNEL_NEW_IMPL_DEPRECATED)
#define REGISTER_JITKERNEL_ARGS_DEPRECATED(ker_key, ker_class, marco_declare, \
macro_key, macro_impl) \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, marco_declare, \
macro_key, macro_impl); \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \
marco_declare, macro_key, macro_impl)
#define FOR_EACH_ISA(macro_, block) \
macro_(platform::avx512f, block); \
macro_(platform::avx2, block); \
macro_(platform::avx, block); \
macro_(platform::isa_any, block)
#define FOR_EACH_BLOCK(macro_, isa) \
macro_(isa, kLT8); \
macro_(isa, kEQ8); \
macro_(isa, kGT8LT16); \
macro_(isa, kEQ16); \
macro_(isa, kGT16)
#define FOR_EACH_ISA_BLOCK(macro_) \
FOR_EACH_BLOCK(macro_, platform::avx512f); \
FOR_EACH_BLOCK(macro_, platform::avx2); \
FOR_EACH_BLOCK(macro_, platform::avx); \
FOR_EACH_BLOCK(macro_, platform::isa_any)
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_rnn.cc
已删除
100644 → 0
浏览文件 @
95cbe07c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
/* LSTM JitKernel */
template
<
typename
T
>
class
LSTMKernelImpl
:
public
LSTMKernel
<
T
>
{
public:
static
inline
std
::
string
name
(
const
lstm_attr_t
&
attr
)
{
PADDLE_THROW
(
"DType should be either float or double"
);
}
static
inline
bool
useJIT
(
int
d
)
{
return
false
;
}
static
inline
bool
useMKL
(
int
d
)
{
return
false
;
}
explicit
LSTMKernelImpl
(
const
lstm_attr_t
&
attr
)
:
LSTMKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
attr
.
d
))
{
size_t
sz
=
96
+
attr
.
d
/
YMM_FLOAT_BLOCK
*
90
*
4
*
8
;
jitcode0_
.
reset
(
new
gen
::
LSTMJitCode
(
false
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeCtHt
=
jitcode0_
->
getCode
<
void
(
*
)(
lstm_t
*
,
const
lstm_attr_t
*
)
>
();
jitcode1_
.
reset
(
new
gen
::
LSTMJitCode
(
true
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeC1H1
=
jitcode1_
->
getCode
<
void
(
*
)(
lstm_t
*
,
const
lstm_attr_t
*
)
>
();
return
;
}
#endif
this
->
ComputeCtHt
=
refer
::
LSTMCtHt
<
T
>
;
this
->
ComputeC1H1
=
refer
::
LSTMC1H1
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
LSTMJitCode
>
jitcode0_
{
nullptr
},
jitcode1_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
LSTMKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
LSTMJitCode
::
init
(
d
);
}
#endif
/* Peephole JitKernel */
template
<
typename
T
>
class
PeepholeKernelImpl
:
public
LSTMKernel
<
T
>
{
public:
static
inline
std
::
string
name
(
const
lstm_attr_t
&
attr
)
{
PADDLE_THROW
(
"DType should be either float or double"
);
}
static
inline
bool
useJIT
(
int
d
)
{
return
false
;
}
static
inline
bool
useMKL
(
int
d
)
{
return
false
;
}
explicit
PeepholeKernelImpl
(
const
lstm_attr_t
&
attr
)
:
LSTMKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
attr
.
d
))
{
size_t
sz
=
96
+
attr
.
d
/
YMM_FLOAT_BLOCK
*
96
*
4
*
8
;
jitcode0_
.
reset
(
new
gen
::
LSTMJitCode
(
false
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeCtHt
=
jitcode0_
->
getCode
<
void
(
*
)(
lstm_t
*
,
const
lstm_attr_t
*
)
>
();
jitcode1_
.
reset
(
new
gen
::
LSTMJitCode
(
true
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeC1H1
=
jitcode1_
->
getCode
<
void
(
*
)(
lstm_t
*
,
const
lstm_attr_t
*
)
>
();
return
;
}
#endif
this
->
ComputeCtHt
=
refer
::
LSTMCtHt
<
T
>
;
this
->
ComputeC1H1
=
refer
::
LSTMC1H1
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
LSTMJitCode
>
jitcode0_
{
nullptr
},
jitcode1_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
PeepholeKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
LSTMJitCode
::
init
(
d
);
}
#endif
#define JITKERNEL_DEFINE_NAME_LSTM(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(const lstm_attr_t& attr) { \
std::string key(#ker_key "f"); \
key += (attr.act_gate + attr.act_cand + attr.act_cell + \
(attr.use_peephole ? "p" : "n")); \
if (useJIT(attr.d)) { \
/* only jit code need record d*/
\
return key + "jit" + std::to_string(attr.d); \
} else if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
} \
template <> \
std::string ker_class##Impl<double>::name(const lstm_attr_t& attr) { \
std::string key(#ker_key "d"); \
/* jit code do not support double yet*/
\
if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
}
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const LSTMKernel<ker_dtype>> \
KernelPool::Get<LSTMKernel<ker_dtype>, const lstm_attr_t&>( \
const lstm_attr_t& attr)
#define JITKERNEL_FIND_KEY_LSTM(ker_class, ker_dtype) \
std::string key = ker_class##Impl<ker_dtype>::name(attr)
#define JITKERNEL_LSTM_IMPL(ker, dtype) \
if (attr.use_peephole) { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<PeepholeKernelImpl<dtype>>(attr)); \
} else { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype>>(attr)); \
}
REGISTER_JITKERNEL_ARGS
(
lstm
,
LSTMKernel
,
JITKERNEL_DEFINE_NAME_LSTM
,
JITKERNEL_DECLARE_LSTM
,
JITKERNEL_FIND_KEY_LSTM
,
JITKERNEL_LSTM_IMPL
);
#undef JITKERNEL_LSTM_IMPL
#undef JITKERNEL_FIND_KEY_LSTM
#undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_DEFINE_NAME_LSTM
/* GRU JitKernel */
template
<
typename
T
>
class
GRUKernelImpl
:
public
GRUKernel
<
T
>
{
public:
static
inline
std
::
string
name
(
const
gru_attr_t
&
attr
)
{
PADDLE_THROW
(
"DType should be either float or double"
);
}
static
inline
bool
useJIT
(
int
d
)
{
return
false
;
}
static
inline
bool
useMKL
(
int
d
)
{
return
false
;
}
explicit
GRUKernelImpl
(
const
gru_attr_t
&
attr
)
:
GRUKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
attr
.
d
))
{
size_t
sz
=
96
+
attr
.
d
/
YMM_FLOAT_BLOCK
*
96
*
2
*
8
;
jitcode0_
.
reset
(
new
gen
::
GRUJitCode
(
0
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeH1
=
jitcode0_
->
getCode
<
void
(
*
)(
gru_t
*
,
const
gru_attr_t
*
)
>
();
jitcode1_
.
reset
(
new
gen
::
GRUJitCode
(
1
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeHtPart1
=
jitcode1_
->
getCode
<
void
(
*
)(
gru_t
*
,
const
gru_attr_t
*
)
>
();
jitcode2_
.
reset
(
new
gen
::
GRUJitCode
(
2
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeHtPart2
=
jitcode2_
->
getCode
<
void
(
*
)(
gru_t
*
,
const
gru_attr_t
*
)
>
();
return
;
}
#endif
this
->
ComputeH1
=
refer
::
GRUH1
<
T
>
;
this
->
ComputeHtPart1
=
refer
::
GRUHtPart1
<
T
>
;
this
->
ComputeHtPart2
=
refer
::
GRUHtPart2
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
GRUJitCode
>
jitcode0_
{
nullptr
},
jitcode1_
{
nullptr
},
jitcode2_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
GRUKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
GRUJitCode
::
init
(
d
);
}
#endif
#define JITKERNEL_DEFINE_NAME_GRU(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(const gru_attr_t& attr) { \
std::string key(#ker_key "f"); \
key += (attr.act_gate + attr.act_cand); \
if (useJIT(attr.d)) { \
/* only jit code need record d*/
\
return key + "jit" + std::to_string(attr.d); \
} else if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
} \
template <> \
std::string ker_class##Impl<double>::name(const gru_attr_t& attr) { \
std::string key(#ker_key "d"); \
/* jit code do not support double yet*/
\
if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
}
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, const gru_attr_t&>( \
const gru_attr_t& attr)
#define JITKERNEL_FIND_KEY_GRU(ker_class, ker_dtype) \
std::string key = ker_class##Impl<ker_dtype>::name(attr)
#define JITKERNEL_GRU_IMPL(ker, dtype) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype>>(attr));
REGISTER_JITKERNEL_ARGS
(
gru
,
GRUKernel
,
JITKERNEL_DEFINE_NAME_GRU
,
JITKERNEL_DECLARE_GRU
,
JITKERNEL_FIND_KEY_GRU
,
JITKERNEL_GRU_IMPL
);
#undef JITKERNEL_GRU_IMPL
#undef JITKERNEL_FIND_KEY_GRU
#undef JITKERNEL_DECLARE_GRU
#undef JITKERNEL_DEFINE_NAME_GRU
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/jit_kernel_test.cc
已删除
100644 → 0
浏览文件 @
95cbe07c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <cmath> // for exp
#include <cstring> // for memcpy
#include <random>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef __AVX__
#include <immintrin.h>
#endif
constexpr
int
repeat
=
20000
;
// TODO(TJ): benchmark and test should be seperated,
// benchmark should verify more sizes
inline
double
GetCurrentUS
()
{
struct
timeval
time
;
gettimeofday
(
&
time
,
NULL
);
return
1e+6
*
time
.
tv_sec
+
time
.
tv_usec
;
}
template
<
typename
T
>
void
RandomVec
(
const
int
n
,
T
*
a
,
const
T
lower
=
static_cast
<
T
>
(
-
20.
f
),
const
T
upper
=
static_cast
<
T
>
(
20.
f
))
{
static
unsigned
int
seed
=
100
;
std
::
mt19937
rng
(
seed
++
);
std
::
uniform_real_distribution
<
double
>
uniform_dist
(
0
,
1
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
a
[
i
]
=
static_cast
<
T
>
(
uniform_dist
(
rng
)
*
(
upper
-
lower
)
+
lower
);
}
}
#if defined __AVX__ || defined __AVX2__
void
vrelu_intri8
(
const
int
n
,
const
float
*
x
,
float
*
y
)
{
__m256
tmp
=
_mm256_loadu_ps
(
x
);
tmp
=
_mm256_max_ps
(
tmp
,
_mm256_setzero_ps
());
_mm256_storeu_ps
(
y
,
tmp
);
}
#endif
TEST
(
JitKernel
,
vrelu
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
namespace
refer
=
paddle
::
operators
::
math
::
jitkernel
::
refer
;
for
(
int
d
:
{
3
,
7
,
8
,
15
,
16
,
30
,
256
,
512
})
{
std
::
vector
<
float
>
x
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
(),
-
10.
f
,
1.
f
);
const
auto
&
ker
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VReluKernel
<
float
>
>
(
d
);
const
float
*
x_data
=
x
.
data
();
float
*
ztgt_data
=
ztgt
.
data
();
float
*
zref_data
=
zref
.
data
();
auto
trefs
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
refer
::
VRelu
<
float
>
(
x_data
,
zref_data
,
d
);
}
auto
trefe
=
GetCurrentUS
();
#if defined __AVX__ || defined __AVX2__
if
(
d
==
8
)
{
auto
si0
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vrelu_intri8
(
d
,
x_data
,
zref_data
);
}
auto
si1
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size 8 intr takes: "
<<
(
si1
-
si0
)
/
repeat
<<
" us"
;
}
#endif
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size "
<<
d
<<
": refer takes: "
<<
(
trefe
-
trefs
)
/
repeat
<<
" us, tgt takes: "
<<
(
ttgte
-
ttgts
)
/
repeat
<<
" us"
;
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
EXPECT_NEAR
(
ztgt_data
[
i
],
zref_data
[
i
],
1e-3
);
}
}
}
TEST
(
JitKernel
,
vaddbias
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
namespace
refer
=
paddle
::
operators
::
math
::
jitkernel
::
refer
;
for
(
int
d
:
{
7
,
8
,
15
,
16
,
30
,
64
,
100
,
128
,
256
})
{
std
::
vector
<
float
>
x
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
(),
-
2.
f
,
2.
f
);
const
auto
&
ker
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VAddBiasKernel
<
float
>
>
(
d
);
const
float
a
=
2.
f
;
const
float
*
x_data
=
x
.
data
();
float
*
ztgt_data
=
ztgt
.
data
();
float
*
zref_data
=
zref
.
data
();
auto
trefs
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
refer
::
VAddBias
<
float
>
(
&
a
,
x_data
,
zref_data
,
d
);
}
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
&
a
,
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size "
<<
d
<<
": refer takes: "
<<
(
trefe
-
trefs
)
/
repeat
<<
" us, tgt takes: "
<<
(
ttgte
-
ttgts
)
/
repeat
<<
" us"
;
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
EXPECT_NEAR
(
ztgt_data
[
i
],
zref_data
[
i
],
1e-3
);
}
}
}
#ifdef PADDLE_WITH_MKLML
void
vexp_mkl
(
const
int
n
,
const
float
*
x
,
float
*
y
)
{
paddle
::
platform
::
dynload
::
vsExp
(
n
,
x
,
y
);
}
#endif
TEST
(
JitKernel
,
vexp
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
namespace
refer
=
paddle
::
operators
::
math
::
jitkernel
::
refer
;
for
(
int
d
:
{
1
,
3
,
4
,
6
,
7
,
8
,
12
,
15
,
16
,
20
,
30
,
128
,
256
})
{
std
::
vector
<
float
>
x
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
(),
-
2.
f
,
2.
f
);
const
auto
&
ker
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VExpKernel
<
float
>
>
(
d
);
const
float
*
x_data
=
x
.
data
();
float
*
ztgt_data
=
ztgt
.
data
();
float
*
zref_data
=
zref
.
data
();
auto
trefs
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
refer
::
VExp
<
float
>
(
x_data
,
zref_data
,
d
);
}
auto
trefe
=
GetCurrentUS
();
#ifdef PADDLE_WITH_MKLML
auto
tmkls
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vexp_mkl
(
d
,
x_data
,
zref_data
);
}
auto
tmkle
=
GetCurrentUS
();
#endif
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
// ker->Compute(x_data, ztgt_data);
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size "
<<
d
<<
": refer takes: "
<<
(
trefe
-
trefs
)
/
repeat
#ifdef PADDLE_WITH_MKLML
<<
" us, mkl takes: "
<<
(
tmkle
-
tmkls
)
/
repeat
<<
" us, "
#else
<<
" us, "
#endif
<<
"tgt takes: "
<<
(
ttgte
-
ttgts
)
/
repeat
<<
" us"
;
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
EXPECT_NEAR
(
ztgt_data
[
i
],
zref_data
[
i
],
1e-3
);
}
}
}
void
vsigmoid_better
(
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VExpKernel
<
float
>>&
vexp
,
const
int
n
,
const
float
*
x
,
float
*
y
)
{
const
float
min
=
SIGMOID_THRESHOLD_MIN
;
const
float
max
=
SIGMOID_THRESHOLD_MAX
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
0.
f
-
y
[
i
];
}
vexp
->
Compute
(
y
,
y
,
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
1.
f
/
(
1.
f
+
y
[
i
]);
}
}
TEST
(
JitKernel
,
vsigmoid
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
namespace
refer
=
paddle
::
operators
::
math
::
jitkernel
::
refer
;
for
(
int
d
:
{
1
,
3
,
4
,
6
,
7
,
8
,
15
,
16
,
30
,
32
,
64
,
100
,
128
,
256
})
{
std
::
vector
<
float
>
x
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
(),
-
2.
f
,
2.
f
);
const
auto
&
ker
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VSigmoidKernel
<
float
>
>
(
d
);
const
auto
&
vexp
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VExpKernel
<
float
>
>
(
d
);
const
float
*
x_data
=
x
.
data
();
float
*
ztgt_data
=
ztgt
.
data
();
float
*
zref_data
=
zref
.
data
();
auto
tmkls
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vsigmoid_better
(
vexp
,
d
,
x_data
,
zref_data
);
}
auto
tmkle
=
GetCurrentUS
();
auto
trefs
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
refer
::
VSigmoid
<
float
>
(
x_data
,
zref_data
,
d
);
}
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size "
<<
d
<<
": refer takes: "
<<
(
trefe
-
trefs
)
/
repeat
<<
" us, better(jit exp) takes: "
<<
(
tmkle
-
tmkls
)
/
repeat
<<
" us, tgt takes: "
<<
(
ttgte
-
ttgts
)
/
repeat
<<
" us"
;
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
EXPECT_NEAR
(
ztgt_data
[
i
],
zref_data
[
i
],
1e-3
);
}
}
}
void
vtanh_better
(
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VScalKernel
<
float
>>&
vscal
,
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VSigmoidKernel
<
float
>>&
vsigmoid
,
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VAddBiasKernel
<
float
>>&
vaddbias
,
const
int
n
,
const
float
*
x
,
float
*
y
)
{
const
float
a
=
2.
f
,
b
=
-
1.
f
;
vscal
->
Compute
(
&
a
,
x
,
y
,
n
);
vsigmoid
->
Compute
(
y
,
y
,
n
);
vscal
->
Compute
(
&
a
,
y
,
y
,
n
);
vaddbias
->
Compute
(
&
b
,
y
,
y
,
n
);
}
TEST
(
JitKernel
,
vtanh
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
namespace
refer
=
paddle
::
operators
::
math
::
jitkernel
::
refer
;
for
(
int
d
:
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
15
,
16
,
30
,
32
,
64
,
100
,
128
,
256
})
{
std
::
vector
<
float
>
x
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
(),
-
2.
f
,
2.
f
);
const
auto
&
ker
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VTanhKernel
<
float
>
>
(
d
);
const
auto
&
vscal
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VScalKernel
<
float
>
>
(
d
);
const
auto
&
vsigmoid
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VSigmoidKernel
<
float
>
>
(
d
);
const
auto
&
vaddbias
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VAddBiasKernel
<
float
>
>
(
d
);
const
float
*
x_data
=
x
.
data
();
float
*
ztgt_data
=
ztgt
.
data
();
float
*
zref_data
=
zref
.
data
();
auto
tmkls
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vtanh_better
(
vscal
,
vsigmoid
,
vaddbias
,
d
,
x_data
,
zref_data
);
}
auto
tmkle
=
GetCurrentUS
();
auto
trefs
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
refer
::
VTanh
<
float
>
(
x_data
,
zref_data
,
d
);
}
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size "
<<
d
<<
": refer takes: "
<<
(
trefe
-
trefs
)
/
repeat
<<
" us, better(jit exp) takes: "
<<
(
tmkle
-
tmkls
)
/
repeat
<<
" us, tgt takes: "
<<
(
ttgte
-
ttgts
)
/
repeat
<<
" us"
;
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
EXPECT_NEAR
(
ztgt_data
[
i
],
zref_data
[
i
],
1e-3
);
}
}
}
void
lstm_ctht_better
(
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VSigmoidKernel
<
float
>>&
vsigmoid_3d
,
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VTanhKernel
<
float
>>&
vtanh_d
,
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VMulKernel
<
float
>>&
vmul_d
,
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VAddKernel
<
float
>>&
vadd_d
,
const
int
d
,
float
*
gates
,
const
float
*
ct_1
,
float
*
ct
,
float
*
ht
)
{
int
d2
=
d
*
2
;
vsigmoid_3d
->
Compute
(
gates
+
d
,
gates
+
d
,
3
*
d
);
vtanh_d
->
Compute
(
gates
,
gates
,
d
);
vmul_d
->
Compute
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
vmul_d
->
Compute
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
vadd_d
->
Compute
(
gates
+
d
,
gates
+
d2
,
ct
,
d
);
/* H_t = act_cell(C_t) * ogated */
vtanh_d
->
Compute
(
ct
,
gates
+
d2
,
d
);
vmul_d
->
Compute
(
gates
+
d2
,
gates
+
d
*
3
,
ht
,
d
);
}
TEST
(
JitKernel
,
lstm
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
namespace
refer
=
paddle
::
operators
::
math
::
jitkernel
::
refer
;
for
(
int
d
:
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
15
,
16
,
30
,
32
,
64
,
100
})
{
int
d4
=
d
*
4
;
int
d3
=
d
*
3
;
std
::
vector
<
float
>
x
(
d4
),
xref
(
d4
);
std
::
vector
<
float
>
ct_1
(
d
),
ct_tgt
(
d
),
ht_tgt
(
d
);
std
::
vector
<
float
>
ct_ref
(
d
),
ht_ref
(
d
);
RandomVec
<
float
>
(
d4
,
x
.
data
(),
-
2.
f
,
2.
f
);
RandomVec
<
float
>
(
d
,
ct_1
.
data
(),
-
2.
f
,
2.
f
);
memcpy
(
xref
.
data
(),
x
.
data
(),
sizeof
(
float
)
*
d4
);
std
::
string
act_gate
=
"sigmoid"
,
act_cand
=
"tanh"
,
act_cell
=
"tanh"
;
const
jit
::
lstm_attr_t
attr
(
d
,
act_gate
,
act_cand
,
act_cell
,
false
);
const
auto
&
ker
=
jit
::
KernelPool
::
Instance
()
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
const
jit
::
lstm_attr_t
&>
(
attr
);
// below kernels are used to compute refer
const
auto
&
vsigmoid_3d
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VSigmoidKernel
<
float
>
>
(
d3
);
const
auto
&
vtanh_d
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VTanhKernel
<
float
>
>
(
d
);
const
auto
&
vmul_d
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VMulKernel
<
float
>
>
(
d
);
const
auto
&
vadd_d
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VAddKernel
<
float
>
>
(
d
);
float
*
x_data
=
x
.
data
();
float
*
xref_data
=
xref
.
data
();
const
float
*
ct_1_data
=
ct_1
.
data
();
float
*
ct_tgt_data
=
ct_tgt
.
data
();
float
*
ht_tgt_data
=
ht_tgt
.
data
();
float
*
ct_ref_data
=
ct_ref
.
data
();
float
*
ht_ref_data
=
ht_ref
.
data
();
// compute once to check correctness
jit
::
lstm_t
step
;
step
.
gates
=
xref_data
;
step
.
ct_1
=
ct_1_data
;
step
.
ct
=
ct_ref_data
;
step
.
ht
=
ht_ref_data
;
refer
::
LSTMCtHt
<
float
>
(
&
step
,
&
attr
);
step
.
gates
=
x_data
;
step
.
ct
=
ct_tgt_data
;
step
.
ht
=
ht_tgt_data
;
ker
->
ComputeCtHt
(
&
step
,
&
attr
);
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
EXPECT_NEAR
(
ct_tgt_data
[
i
],
ct_ref_data
[
i
],
1e-3
);
EXPECT_NEAR
(
ht_tgt_data
[
i
],
ht_ref_data
[
i
],
1e-3
);
}
auto
tmkls
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
lstm_ctht_better
(
vsigmoid_3d
,
vtanh_d
,
vmul_d
,
vadd_d
,
d
,
xref_data
,
ct_1_data
,
ct_ref_data
,
ht_ref_data
);
}
auto
tmkle
=
GetCurrentUS
();
auto
trefs
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
refer
::
LSTMCtHt
<
float
>
(
&
step
,
&
attr
);
}
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
ComputeCtHt
(
&
step
,
&
attr
);
}
auto
ttgte
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size "
<<
d
<<
": refer takes: "
<<
(
trefe
-
trefs
)
/
repeat
<<
" us, better(jit) takes: "
<<
(
tmkle
-
tmkls
)
/
repeat
<<
" us, tgt takes: "
<<
(
ttgte
-
ttgts
)
/
repeat
<<
" us"
;
}
}
#if defined __AVX__ || defined __AVX2__
void
vscal_intri8
(
const
int
n
,
const
float
a
,
const
float
*
x
,
float
*
y
)
{
__m256
tmp
;
__m256
scalar
=
_mm256_set1_ps
(
a
);
tmp
=
_mm256_loadu_ps
(
x
);
tmp
=
_mm256_mul_ps
(
tmp
,
scalar
);
_mm256_storeu_ps
(
y
,
tmp
);
}
void
vscal_inp_intri8
(
const
int
n
,
const
float
a
,
float
*
x
)
{
__m256
tmp
;
__m256
scalar
=
_mm256_set1_ps
(
a
);
tmp
=
_mm256_loadu_ps
(
x
);
tmp
=
_mm256_mul_ps
(
tmp
,
scalar
);
_mm256_storeu_ps
(
x
,
tmp
);
}
#endif
#ifdef PADDLE_WITH_MKLML
void
vscal_inp_mkl
(
const
int
n
,
const
float
a
,
float
*
x
)
{
paddle
::
platform
::
dynload
::
cblas_sscal
(
n
,
a
,
x
,
1
);
}
#endif
TEST
(
JitKernel
,
vscal
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
namespace
refer
=
paddle
::
operators
::
math
::
jitkernel
::
refer
;
for
(
int
d
:
{
7
,
8
,
15
,
16
,
30
,
256
,
512
})
{
std
::
vector
<
float
>
x
(
d
),
y
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
());
std
::
memcpy
(
y
.
data
(),
x
.
data
(),
sizeof
(
float
)
*
d
);
float
a
=
2.
f
;
const
auto
&
ker
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VScalKernel
<
float
>
>
(
d
);
const
float
*
x_data
=
x
.
data
();
float
*
y_data
=
y
.
data
();
float
*
ztgt_data
=
ztgt
.
data
();
float
*
zref_data
=
zref
.
data
();
auto
trefs
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
refer
::
VScal
<
float
>
(
&
a
,
x_data
,
zref_data
,
d
);
}
auto
trefe
=
GetCurrentUS
();
auto
trefs1
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
refer
::
VScal
<
float
>
(
&
a
,
y_data
,
y_data
,
d
);
}
auto
trefe1
=
GetCurrentUS
();
#ifdef PADDLE_WITH_MKLML
auto
tmkls
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vscal_inp_mkl
(
d
,
a
,
y_data
);
}
auto
tmkle
=
GetCurrentUS
();
#endif
#if defined __AVX__ || defined __AVX2__
if
(
d
==
8
)
{
auto
si0
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vscal_intri8
(
d
,
a
,
x_data
,
zref_data
);
}
auto
si1
=
GetCurrentUS
();
auto
si2
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vscal_inp_intri8
(
d
,
a
,
y_data
);
}
auto
si3
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size 8 intr takes: "
<<
(
si1
-
si0
)
/
repeat
<<
" us, inplace: "
<<
(
si3
-
si2
)
/
repeat
<<
" us"
;
}
#endif
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
&
a
,
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
auto
ttgts1
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
&
a
,
y_data
,
y_data
,
d
);
}
auto
ttgte1
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size "
<<
d
<<
": refer takes: "
<<
(
trefe
-
trefs
)
/
repeat
<<
" us, inplace takes: "
<<
(
trefe1
-
trefs1
)
/
repeat
#ifdef PADDLE_WITH_MKLML
<<
" us, mkl inplace takes: "
<<
(
tmkle
-
tmkls
)
/
repeat
<<
" us, "
#else
<<
" us, "
#endif
<<
"tgt takes: "
<<
(
ttgte
-
ttgts
)
/
repeat
<<
"us, tgt inplace takes: "
<<
(
ttgte1
-
ttgts1
)
/
repeat
<<
" us"
;
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
EXPECT_NEAR
(
ztgt_data
[
i
],
zref_data
[
i
],
1e-3
);
}
}
}
#if defined __AVX__ || defined __AVX2__
void
vmul_intri8
(
const
int
n
,
const
float
*
x
,
const
float
*
y
,
float
*
z
)
{
__m256
tmpx
,
tmpy
;
tmpx
=
_mm256_loadu_ps
(
x
);
tmpy
=
_mm256_loadu_ps
(
y
);
tmpx
=
_mm256_mul_ps
(
tmpx
,
tmpy
);
_mm256_storeu_ps
(
z
,
tmpx
);
}
#endif
#ifdef PADDLE_WITH_MKLML
void
vmul_mkl
(
const
int
n
,
const
float
*
x
,
const
float
*
y
,
float
*
z
)
{
paddle
::
platform
::
dynload
::
vsMul
(
n
,
x
,
y
,
z
);
}
#endif
TEST
(
JitKernel
,
vmul
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
namespace
refer
=
paddle
::
operators
::
math
::
jitkernel
::
refer
;
for
(
int
d
:
{
7
,
8
,
15
,
16
,
20
,
30
,
256
,
512
,
1000
,
1024
})
{
std
::
vector
<
float
>
x
(
d
),
y
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
());
RandomVec
<
float
>
(
d
,
y
.
data
());
const
auto
&
ker
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VMulKernel
<
float
>
>
(
d
);
const
float
*
x_data
=
x
.
data
();
const
float
*
y_data
=
y
.
data
();
float
*
ztgt_data
=
ztgt
.
data
();
float
*
zref_data
=
zref
.
data
();
auto
trefs
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
refer
::
VMul
<
float
>
(
x_data
,
y_data
,
zref_data
,
d
);
}
auto
trefe
=
GetCurrentUS
();
#ifdef PADDLE_WITH_MKLML
auto
tmkls
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vmul_mkl
(
d
,
x_data
,
y_data
,
zref_data
);
}
auto
tmkle
=
GetCurrentUS
();
#endif
#if defined __AVX__ || defined __AVX2__
if
(
d
==
8
)
{
auto
si0
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vmul_intri8
(
d
,
x_data
,
y_data
,
zref_data
);
}
auto
si1
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size 8 intr takes: "
<<
(
si1
-
si0
)
/
repeat
;
}
#endif
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
x_data
,
y_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size "
<<
d
<<
": refer takes: "
<<
(
trefe
-
trefs
)
/
repeat
#ifdef PADDLE_WITH_MKLML
<<
" us, mkl takes: "
<<
(
tmkle
-
tmkls
)
/
repeat
<<
" us, "
#else
<<
" us, "
#endif
<<
"tgt takes: "
<<
(
ttgte
-
ttgts
)
/
repeat
<<
" us"
;
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
EXPECT_NEAR
(
ztgt_data
[
i
],
zref_data
[
i
],
1e-3
);
}
}
}
#if defined __AVX__ || defined __AVX2__
void
vadd_intri8
(
const
int
n
,
const
float
*
x
,
const
float
*
y
,
float
*
z
)
{
__m256
tmpx
,
tmpy
;
tmpx
=
_mm256_loadu_ps
(
x
);
tmpy
=
_mm256_loadu_ps
(
y
);
tmpx
=
_mm256_add_ps
(
tmpx
,
tmpy
);
_mm256_storeu_ps
(
z
,
tmpx
);
}
#endif
#ifdef PADDLE_WITH_MKLML
void
vadd_mkl
(
const
int
n
,
const
float
*
x
,
const
float
*
y
,
float
*
z
)
{
paddle
::
platform
::
dynload
::
vsAdd
(
n
,
x
,
y
,
z
);
}
#endif
TEST
(
JitKernel
,
vadd
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
namespace
refer
=
paddle
::
operators
::
math
::
jitkernel
::
refer
;
for
(
int
d
:
{
7
,
8
,
15
,
16
,
30
,
256
,
512
})
{
std
::
vector
<
float
>
x
(
d
),
y
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
());
RandomVec
<
float
>
(
d
,
y
.
data
());
const
auto
&
ker
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VAddKernel
<
float
>
>
(
d
);
const
float
*
x_data
=
x
.
data
();
const
float
*
y_data
=
y
.
data
();
float
*
ztgt_data
=
ztgt
.
data
();
float
*
zref_data
=
zref
.
data
();
auto
trefs
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
refer
::
VAdd
<
float
>
(
x_data
,
y_data
,
zref_data
,
d
);
}
auto
trefe
=
GetCurrentUS
();
#ifdef PADDLE_WITH_MKLML
auto
tmkls
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vadd_mkl
(
d
,
x_data
,
y_data
,
zref_data
);
}
auto
tmkle
=
GetCurrentUS
();
#endif
#if defined __AVX__ || defined __AVX2__
if
(
d
==
8
)
{
auto
si0
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vadd_intri8
(
d
,
x_data
,
y_data
,
zref_data
);
}
auto
si1
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size 8 intr takes: "
<<
(
si1
-
si0
)
/
repeat
;
}
#endif
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
x_data
,
y_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size "
<<
d
<<
": refer takes: "
<<
(
trefe
-
trefs
)
/
repeat
#ifdef PADDLE_WITH_MKLML
<<
" us, mkl takes: "
<<
(
tmkle
-
tmkls
)
/
repeat
<<
" us, "
#else
<<
" us, "
#endif
<<
"tgt takes: "
<<
(
ttgte
-
ttgts
)
/
repeat
<<
" us"
;
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
EXPECT_NEAR
(
ztgt_data
[
i
],
zref_data
[
i
],
1e-3
);
}
}
}
void
vaddrelu_better
(
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VAddKernel
<
float
>>&
vadd
,
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VReluKernel
<
float
>>&
vrelu
,
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
d
)
{
vadd
->
Compute
(
x
,
y
,
z
,
d
);
vrelu
->
Compute
(
z
,
z
,
d
);
}
TEST
(
JitKernel
,
vaddrelu
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
namespace
refer
=
paddle
::
operators
::
math
::
jitkernel
::
refer
;
for
(
int
d
:
{
7
,
8
,
15
,
16
,
30
,
256
,
512
})
{
std
::
vector
<
float
>
x
(
d
),
y
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
());
RandomVec
<
float
>
(
d
,
y
.
data
());
const
auto
&
ker
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VAddReluKernel
<
float
>
>
(
d
);
const
auto
&
vadd
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VAddKernel
<
float
>
>
(
d
);
const
auto
&
vrelu
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VReluKernel
<
float
>
>
(
d
);
const
float
*
x_data
=
x
.
data
();
const
float
*
y_data
=
y
.
data
();
float
*
ztgt_data
=
ztgt
.
data
();
float
*
zref_data
=
zref
.
data
();
auto
trefs
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
refer
::
VAddRelu
<
float
>
(
x_data
,
y_data
,
zref_data
,
d
);
}
auto
trefe
=
GetCurrentUS
();
auto
tmkls
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
vaddrelu_better
(
vadd
,
vrelu
,
x_data
,
y_data
,
zref_data
,
d
);
}
auto
tmkle
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
(
x_data
,
y_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
VLOG
(
3
)
<<
"Vec size "
<<
d
<<
": refer takes: "
<<
(
trefe
-
trefs
)
/
repeat
<<
" us, better takes: "
<<
(
tmkle
-
tmkls
)
/
repeat
<<
" us, "
<<
"tgt takes: "
<<
(
ttgte
-
ttgts
)
/
repeat
<<
" us"
;
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
EXPECT_NEAR
(
ztgt_data
[
i
],
zref_data
[
i
],
1e-3
);
}
}
}
TEST
(
JitKernel
,
pool
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
const
int
frame_size
=
4
;
std
::
string
act_gate
=
"sigmoid"
,
act_cand
=
"tanh"
,
act_cell
=
"tanh"
;
jit
::
lstm_attr_t
attr
(
frame_size
,
act_gate
,
act_cand
,
act_cell
,
false
);
// empty call it to avoid unknown flag 'use_pinned_memory' on Mac
paddle
::
platform
::
MayIUse
(
paddle
::
platform
::
avx
);
const
auto
&
plstm1
=
jit
::
KernelPool
::
Instance
()
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
const
jit
::
lstm_attr_t
&>
(
attr
);
const
auto
&
plstm2
=
jit
::
KernelPool
::
Instance
()
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
const
jit
::
lstm_attr_t
&>
(
attr
);
EXPECT_EQ
(
plstm1
,
plstm2
);
const
auto
&
peephole
=
jit
::
KernelPool
::
Instance
()
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
const
jit
::
lstm_attr_t
&>
(
jit
::
lstm_attr_t
(
frame_size
,
act_gate
,
act_cand
,
act_cell
,
true
));
EXPECT_TRUE
(
plstm1
!=
peephole
);
const
auto
&
pvmul_f
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VMulKernel
<
float
>
>
(
4
);
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
plstm2
)
!=
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_f
));
const
auto
&
pvmul_d
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VMulKernel
<
double
>
>
(
4
);
EXPECT_TRUE
(
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_f
)
!=
std
::
dynamic_pointer_cast
<
const
jit
::
Kernel
>
(
pvmul_d
));
const
auto
&
pvmul_from_key
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulfjit4"
);
#if defined(__APPLE__) || defined(__OSX__) || defined(_WIN32)
EXPECT_EQ
(
pvmul_from_key
,
nullptr
);
#else
EXPECT_EQ
(
pvmul_from_key
,
pvmul_f
);
#endif
const
auto
&
pvmul_from_key2
=
jit
::
KernelPool
::
Instance
().
Get
(
"vmulfjit"
);
EXPECT_TRUE
(
pvmul_from_key2
==
nullptr
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录