Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
7f17e561
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7f17e561
编写于
11月 16, 2018
作者:
T
tensor-tang
提交者:
GitHub
11月 16, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14423 from tensor-tang/fea/jit/act
jitcode act relu, exp, sigmoid, tanh
上级
28bd5b7b
1f00723f
变更
10
展开全部
隐藏空白更改
内联
并排
Showing
10 changed file
with
626 addition
and
503 deletion
+626
-503
paddle/fluid/operators/math/cpu_vec.h
paddle/fluid/operators/math/cpu_vec.h
+9
-9
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+230
-10
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+60
-12
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+9
-25
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+31
-43
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
+12
-12
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+224
-350
paddle/fluid/operators/math/jit_kernel_macro.h
paddle/fluid/operators/math/jit_kernel_macro.h
+19
-11
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+19
-19
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+13
-12
未找到文件。
paddle/fluid/operators/math/cpu_vec.h
浏览文件 @
7f17e561
...
...
@@ -33,11 +33,11 @@ namespace math {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define
AVX
_FLOAT_BLOCK 8
#define
YMM
_FLOAT_BLOCK 8
#define AVX_DOUBLE_BLOCK 4
#define
AVX2
_FLOAT_BLOCK 8
#define
YMM
_FLOAT_BLOCK 8
#define AVX2_DOUBLE_BLOCK 4
#define
AVX512
_FLOAT_BLOCK 16
#define
ZMM
_FLOAT_BLOCK 16
#define AVX512_DOUBLE_BLOCK 8
template
<
typename
T
>
...
...
@@ -88,7 +88,7 @@ template <>
inline
void
vec_scal
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
a
,
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
vec_scal
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
a
,
x
,
y
);
return
;
...
...
@@ -142,7 +142,7 @@ template <>
inline
void
vec_bias_sub
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
a
,
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
vec_bias_sub
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
a
,
x
,
y
);
return
;
...
...
@@ -200,7 +200,7 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
const
float
*
y
,
const
float
*
z
,
float
*
out
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
vec_cross
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
x
,
y
,
z
,
out
);
return
;
...
...
@@ -257,7 +257,7 @@ template <>
inline
void
vec_add_bias
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
a
,
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
vec_add_bias
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
a
,
x
,
y
);
return
;
...
...
@@ -326,7 +326,7 @@ template <>
inline
void
vec_sigmoid
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
vec_sigmoid
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
x
,
y
);
return
;
...
...
@@ -415,7 +415,7 @@ template <>
inline
void
vec_relu
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
*
4
)
{
vec_relu
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
x
,
y
);
return
;
...
...
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
7f17e561
...
...
@@ -41,7 +41,7 @@ void VXXJitCode::generate() {
}
else
if
(
scalar_index_
==
2
)
{
vbroadcastss
(
ymm_src2
,
ptr
[
param2
]);
}
for
(
int
i
=
0
;
i
<
num_
/
AVX
_FLOAT_BLOCK
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_
/
YMM
_FLOAT_BLOCK
;
++
i
)
{
if
(
scalar_index_
!=
1
)
{
vmovups
(
ymm_src1
,
ptr
[
param1
+
offset
]);
}
...
...
@@ -57,9 +57,9 @@ void VXXJitCode::generate() {
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_dst
);
}
vmovups
(
ptr
[
param3
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX
_FLOAT_BLOCK
;
offset
+=
sizeof
(
float
)
*
YMM
_FLOAT_BLOCK
;
}
int
rest
=
num_
%
AVX
_FLOAT_BLOCK
;
int
rest
=
num_
%
YMM
_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
if
(
scalar_index_
!=
1
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
...
...
@@ -118,18 +118,237 @@ void VXXJitCode::generate() {
ret
();
}
bool
ReluJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
);
}
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
#define CEPHES_LOG2EF 1.44269504088896341
#define CEPHES_EXP_C1 0.693359375
#define CEPHES_EXP_C2 -2.12194440e-4
#define CEPHES_EXP_P0 1.9875691500E-4
#define CEPHES_EXP_P1 1.3981999507E-3
#define CEPHES_EXP_P2 8.3334519073E-3
#define CEPHES_EXP_P3 4.1665795894E-2
#define CEPHES_EXP_P4 1.6666665459E-1
#define CEPHES_EXP_P5 5.0000001201E-1
void
ReluJitCode
::
generate
()
{
int
offset
=
0
;
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
static
const
float
exp_float_consts
[]
ALIGN32
=
{
REPEAT_8TIMES
(
1.
f
),
REPEAT_8TIMES
(
2.
f
),
REPEAT_8TIMES
(
0.5
f
),
REPEAT_8TIMES
(
EXP_HIG
),
REPEAT_8TIMES
(
EXP_LOW
),
REPEAT_8TIMES
(
CEPHES_LOG2EF
),
REPEAT_8TIMES
(
CEPHES_EXP_C1
),
REPEAT_8TIMES
(
CEPHES_EXP_C2
),
REPEAT_8TIMES
(
CEPHES_EXP_P0
),
REPEAT_8TIMES
(
CEPHES_EXP_P1
),
REPEAT_8TIMES
(
CEPHES_EXP_P2
),
REPEAT_8TIMES
(
CEPHES_EXP_P3
),
REPEAT_8TIMES
(
CEPHES_EXP_P4
),
REPEAT_8TIMES
(
CEPHES_EXP_P5
),
REPEAT_8TIMES
(
EXP_MAX_INPUT
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MAX
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MIN
)};
static
const
int
exp_int_0x7f
[]
ALIGN32
=
{
REPEAT_8TIMES
(
0x7f
)};
static
int
g_tmp_mem
[
16
]
ALIGN32
=
{
0
};
bool
VActJitCode
::
init
(
int
d
,
operand_type
type
)
{
bool
ok
=
MayIUse
(
avx
);
if
(
type
==
operand_type
::
relu
)
{
return
ok
;
}
else
if
(
type
==
operand_type
::
exp
)
{
// exp is slower than mkl when d >= 256
return
ok
&&
d
%
8
==
0
&&
d
<
256
;
}
else
{
// TODO(TJ): support more
return
ok
&&
d
%
8
==
0
;
}
}
void
VActJitCode
::
relu_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
ymm_t
&
ymm_zero
)
{
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_src
);
}
void
VActJitCode
::
exp_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
assert
(
ymm_src
.
getIdx
()
!=
ymm_dst
.
getIdx
());
// TODO(TJ): use enfore
// check all idx can not equal
ymm_t
ymm_fx
=
ymm_t
(
fx_idx
);
ymm_t
ymm_fy
=
ymm_t
(
fy_idx
);
ymm_t
ymm_mask
=
ymm_t
(
mask_idx
);
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
vminps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOW
]);
vmaxps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
// express exp(x) as exp(g + n*log(2))
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOG2EF
]);
vmulps
(
ymm_fx
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_0P5
]);
vaddps
(
ymm_fx
,
ymm_fx
,
ymm_tmp
);
vroundps
(
ymm_fy
,
ymm_fx
,
0x01
);
// if greater, substract 1
vcmpgtps
(
ymm_mask
,
ymm_fy
,
ymm_fx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vandps
(
ymm_mask
,
ymm_mask
,
ymm_tmp
);
vsubps
(
ymm_fx
,
ymm_fy
,
ymm_mask
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C1
]);
vmulps
(
ymm_fy
,
ymm_fx
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C2
]);
ymm_t
ymm_z
=
ymm_t
(
ymm_mask
.
getIdx
());
vmulps
(
ymm_z
,
ymm_fx
,
ymm_tmp
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_fy
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_z
);
vmulps
(
ymm_z
,
ymm_src
,
ymm_src
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P0
]);
vmulps
(
ymm_dst
,
ymm_src
,
ymm_tmp
);
for
(
size_t
i
=
OFFSET_EXP_P1
;
i
<
OFFSET_EXP_P5
;
i
+=
(
YMM_FLOAT_BLOCK
*
sizeof
(
float
)))
{
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
i
]);
// P1~P4
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
}
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P5
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_z
);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
// build 2^n
ymm_t
ymm_int
=
ymm_fx
;
vcvttps2dq
(
ymm_int
,
ymm_fx
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_int_0x7f
));
vmovdqa
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
if
(
MayIUse
(
avx2
))
{
vpaddd
(
ymm_int
,
ymm_int
,
ymm_tmp
);
vpslld
(
ymm_int
,
ymm_int
,
23
);
}
else
if
(
MayIUse
(
avx
))
{
xmm_t
xtmp1
=
xmm_t
(
ymm_int
.
getIdx
());
xmm_t
xtmp2
=
xmm_t
(
ymm_tmp
.
getIdx
());
reg64_t
reg_ptr_tmp
=
reg_ptr_global
;
mov
(
reg_ptr_tmp
,
reinterpret_cast
<
size_t
>
(
g_tmp_mem
));
vmovdqa
(
ptr
[
reg_ptr_tmp
],
ymm_int
);
vmovdqa
(
ptr
[
reg_ptr_tmp
+
YMM_FLOAT_BLOCK
*
sizeof
(
float
)],
ymm_tmp
);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_tmp
],
xtmp1
);
// next 128bits
vmovdqa
(
xtmp1
,
ptr
[
reg_ptr_tmp
+
4
/*xmm float block*/
*
sizeof
(
float
)]);
vmovdqa
(
xtmp2
,
ptr
[
reg_ptr_tmp
+
(
YMM_FLOAT_BLOCK
+
4
/*xmm float block*/
)
*
sizeof
(
float
)]);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_tmp
+
4
/*xmm float block*/
*
sizeof
(
float
)],
xtmp1
);
// load out
vmovdqa
(
ymm_int
,
ptr
[
reg_ptr_tmp
]);
}
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_int
);
pop
(
reg_ptr_global
);
}
void
VActJitCode
::
sigmoid_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
// y = 1 / (1 + e^-x)
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MAX
]);
vminps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MIN
]);
vmaxps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vxorps
(
ymm_tmp
,
ymm_tmp
,
ymm_tmp
);
vsubps
(
ymm_src
,
ymm_tmp
,
ymm_src
);
exp_ymm
(
ymm_dst
,
ymm_src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vdivps
(
ymm_dst
,
ymm_tmp
,
ymm_dst
);
pop
(
reg_ptr_global
);
}
void
VActJitCode
::
tanh_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
// y = 2 / (1 + e^(-2x)) - 1
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
ymm_t
ymm_zero
=
ymm_t
(
mask_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
for
(
int
i
=
0
;
i
<
num_
/
AVX_FLOAT_BLOCK
;
++
i
)
{
vsubps
(
ymm_tmp
,
ymm_zero
,
ymm_tmp
);
vmulps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
exp_ymm
(
ymm_dst
,
ymm_src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vdivps
(
ymm_dst
,
ymm_tmp
,
ymm_dst
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vsubps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
pop
(
reg_ptr_global
);
}
void
VActJitCode
::
generate
()
{
xmm_t
xmm_zero
=
xmm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
2
);
if
(
type_
==
operand_type
::
relu
)
{
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
}
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_src
);
switch
(
type_
)
{
case
operand_type
::
relu
:
relu_ymm
(
ymm_dst
,
ymm_src
,
ymm_zero
);
break
;
case
operand_type
::
exp
:
exp_ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
sigmoid
:
sigmoid_ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
tanh
:
tanh_ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
identity
:
break
;
default:
break
;
}
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX_FLOAT_BLOCK
;
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
if
(
type_
!=
operand_type
::
relu
)
{
// TODO(TJ): remove me
ret
();
return
;
}
int
rest
=
num_
%
AVX
_FLOAT_BLOCK
;
int
rest
=
num_
%
YMM
_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
...
...
@@ -151,6 +370,7 @@ void ReluJitCode::generate() {
}
ret
();
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
7f17e561
...
...
@@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
Label
=
Xbyak
::
Label
;
typedef
enum
{
mul
=
0
,
add
}
operand_type
;
typedef
enum
{
mul
=
0
,
add
,
sub
,
relu
,
exp
,
sigmoid
,
tanh
,
identity
}
operand_type
;
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class
VXXJitCode
:
public
JitCode
{
...
...
@@ -85,26 +94,65 @@ class VXXJitCode : public JitCode {
ymm_t
ymm_zero
=
ymm_t
(
3
);
};
class
Relu
JitCode
:
public
JitCode
{
class
VAct
JitCode
:
public
JitCode
{
public:
DECLARE_JIT_CODE
(
ReluJitCode
);
explicit
ReluJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
const
char
*
name
()
const
override
{
std
::
string
base
=
"VActJitCode"
;
switch
(
type_
)
{
case
operand_type
::
relu
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
exp
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
sigmoid
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
tanh
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
identity
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
return
base
.
c_str
();
}
explicit
VActJitCode
(
int
d
,
operand_type
type
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
,
type_
(
type
)
{}
static
bool
init
(
int
d
,
operand_type
type
);
void
generate
()
override
;
private:
protected:
// compute relu with ymm
void
relu_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
zero
);
// compute exp with ymm
void
exp_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
// compute sigmoid with ymm
void
sigmoid_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
// compute tanh with ymm
void
tanh_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
protected:
int
num_
;
operand_type
type_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
xmm_t
xmm_zero
=
xmm_t
(
0
);
xmm_t
xmm_src
=
xmm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
1
);
xmm_t
xmm_src
=
xmm_t
(
0
);
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_zero
=
ymm_t
(
0
);
ymm_t
ymm_src
=
ymm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
7f17e561
...
...
@@ -29,9 +29,9 @@ namespace jitkernel {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define
AVX_FLOAT_BLOCK 8
#define
AVX2
_FLOAT_BLOCK 8
#define
AVX512
_FLOAT_BLOCK 16
#define
XMM_FLOAT_BLOCK 4
#define
YMM
_FLOAT_BLOCK 8
#define
ZMM
_FLOAT_BLOCK 16
typedef
enum
{
kLT8
,
kEQ8
,
kGT8LT16
,
kEQ16
,
kGT16
}
jit_block
;
...
...
@@ -97,39 +97,23 @@ class VAddBiasKernel : public Kernel {
template
<
typename
T
>
class
VActKernel
:
public
Kernel
{
public:
v
irtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
v
oid
(
*
Compute
)(
const
T
*
,
T
*
,
int
)
;
};
template
<
typename
T
>
class
VReluKernel
:
public
VActKernel
<
T
>
{
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
void
(
*
Compute
)(
const
T
*
,
T
*
,
int
);
};
class
VReluKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VIdentityKernel
:
public
VActKernel
<
T
>
{
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
class
VIdentityKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VExpKernel
:
public
VActKernel
<
T
>
{
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
class
VExpKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VSigmoidKernel
:
public
VActKernel
<
T
>
{
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
class
VSigmoidKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VTanhKernel
:
public
VActKernel
<
T
>
{
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
class
VTanhKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
LSTMKernel
:
public
Kernel
{
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
7f17e561
...
...
@@ -25,10 +25,6 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace
paddle
{
namespace
operators
{
namespace
math
{
...
...
@@ -128,23 +124,16 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
#endif
#define DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
/* VMUL JitKernel */
template
<
typename
T
>
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VMulKernelImpl
(
int
d
)
:
VMulKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
// roughly estimate the size of code
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
mul
,
0
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
...
...
@@ -191,11 +180,11 @@ bool VMulKernelImpl<double>::useMKL(int d) {
template
<
typename
T
>
class
VAddKernelImpl
:
public
VAddKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VAddKernelImpl
(
int
d
)
:
VAddKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
...
...
@@ -241,11 +230,11 @@ bool VAddKernelImpl<double>::useMKL(int d) {
template
<
typename
T
>
class
VAddReluKernelImpl
:
public
VAddReluKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VAddReluKernelImpl
(
int
d
)
:
VAddReluKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
true
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
...
...
@@ -273,11 +262,11 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
template
<
typename
T
>
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VScalKernelImpl
(
int
d
)
:
VScalKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
mul
,
1
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
...
...
@@ -322,11 +311,11 @@ bool VScalKernelImpl<double>::useMKL(int d) {
template
<
typename
T
>
class
VAddBiasKernelImpl
:
public
VAddBiasKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VAddBiasKernelImpl
(
int
d
)
:
VAddBiasKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
1
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
...
...
@@ -355,15 +344,15 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
template
<
typename
T
>
class
VReluKernelImpl
:
public
VReluKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VReluKernelImpl
(
int
d
)
:
VReluKernel
<
T
>
()
{
this
->
num_
=
d
;
// TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
/*init*/
+
d
/
AVX_FLOAT_BLOCK
*
4
/* instructions*/
*
8
/*everage byte for each instruction*/
;
jitcode_
.
reset
(
new
gen
::
ReluJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
size_t
sz
=
96
/* init size */
+
d
/
YMM_FLOAT_BLOCK
*
4
/* instructions */
*
8
/* average bytes for each instruction */
;
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
relu
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
...
...
@@ -371,24 +360,32 @@ class VReluKernelImpl : public VReluKernel<T> {
this
->
Compute
=
VReluRefer
<
T
>
;
}
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
override
{
VReluRefer
(
x
,
y
,
this
->
num_
);
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
Relu
JitCode
>
jitcode_
{
nullptr
};
std
::
unique_ptr
<
gen
::
VAct
JitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VReluKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
ReluJitCode
::
init
(
d
);
return
gen
::
VActJitCode
::
init
(
d
,
gen
::
operand_type
::
relu
);
}
#endif
#undef DECLARE_STATIC_FUNC
template
<
typename
T
>
inline
void
VIdentityRefer
(
const
T
*
x
,
T
*
y
,
int
n
)
{}
/* An empty JitKernel */
template
<
typename
T
>
class
VIdentityKernelImpl
:
public
VIdentityKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VIdentityKernelImpl
(
int
d
)
:
VIdentityKernel
<
T
>
()
{
this
->
Compute
=
VIdentityRefer
<
T
>
;
}
};
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
...
...
@@ -396,16 +393,7 @@ REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL
(
vaddbias
,
VAddBiasKernel
);
REGISTER_JITKERNEL
(
vrelu
,
VReluKernel
);
/* An empty JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VIdentityKernelImpl
:
public
VIdentityKernel
<
T
>
{
public:
explicit
VIdentityKernelImpl
(
int
d
)
:
VIdentityKernel
<
T
>
()
{
this
->
num_
=
d
;
}
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
override
{}
};
REGISTER_JITKERNEL_DEPRECATED
(
videntity
,
VIdentityKernel
);
REGISTER_JITKERNEL
(
videntity
,
VIdentityKernel
);
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
浏览文件 @
7f17e561
...
...
@@ -105,14 +105,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ /
AVX
_FLOAT_BLOCK; \
this->rest_ = this->num_ %
AVX
_FLOAT_BLOCK; \
this->end_ = this->num_ /
YMM
_FLOAT_BLOCK; \
this->rest_ = this->num_ %
YMM
_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, jit::avx, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(
AVX
_FLOAT_BLOCK) \
INIT_ALPHA(
YMM
_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
...
...
@@ -150,7 +150,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \
} \
UPDATE_ALPHA(
AVX
_FLOAT_BLOCK) \
UPDATE_ALPHA(
YMM
_FLOAT_BLOCK) \
} \
seq_offset += this->num_; \
} \
...
...
@@ -161,14 +161,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ /
AVX2_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
AVX2_FLOAT_BLOCK;
\
this->end_ = this->num_ /
YMM_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
YMM_FLOAT_BLOCK;
\
} \
template <> \
void CRFDecodeKernelImpl<float, isa, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(
AVX2_FLOAT_BLOCK)
\
INIT_ALPHA(
YMM_FLOAT_BLOCK)
\
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
...
...
@@ -196,7 +196,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \
} \
UPDATE_ALPHA(
AVX2_FLOAT_BLOCK)
\
UPDATE_ALPHA(
YMM_FLOAT_BLOCK)
\
} \
seq_offset += this->num_; \
} \
...
...
@@ -208,14 +208,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ /
AVX512_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
AVX512_FLOAT_BLOCK;
\
this->end_ = this->num_ /
ZMM_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
ZMM_FLOAT_BLOCK;
\
} \
template <> \
void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(
AVX512_FLOAT_BLOCK)
\
INIT_ALPHA(
ZMM_FLOAT_BLOCK)
\
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
...
...
@@ -250,7 +250,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
this->num_ + j_offset), \
max_j); \
/* Calculate the offset of next step*/
\
j_offset +=
AVX512_FLOAT_BLOCK;
\
j_offset +=
ZMM_FLOAT_BLOCK;
\
if (j == this->end_ - 1) { \
if (this->rest_ > 0) { \
j_offset += last_offset; \
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
7f17e561
此差异已折叠。
点击以展开。
paddle/fluid/operators/math/jit_kernel_macro.h
浏览文件 @
7f17e561
...
...
@@ -15,12 +15,20 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
#define JITKERNEL_DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(int d) { \
...
...
@@ -86,17 +94,17 @@ namespace jitkernel {
namespace
jit
=
platform
::
jit
;
// TODO(TJ): below defines are deprecated, would be remove recently
#define SEARCH_BLOCK(macro_, ker, dtype, isa)
\
if (d <
AVX_FLOAT_BLOCK) {
\
macro_(ker, dtype, isa, kLT8);
\
} else if (d ==
AVX_FLOAT_BLOCK) {
\
macro_(ker, dtype, isa, kEQ8);
\
} else if (d >
AVX_FLOAT_BLOCK && d < AVX512
_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kGT8LT16);
\
} else if (d ==
AVX512
_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ16);
\
} else {
\
macro_(ker, dtype, isa, kGT16);
\
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d <
YMM_FLOAT_BLOCK) {
\
macro_(ker, dtype, isa, kLT8); \
} else if (d ==
YMM_FLOAT_BLOCK) {
\
macro_(ker, dtype, isa, kEQ8); \
} else if (d >
YMM_FLOAT_BLOCK && d < ZMM
_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kGT8LT16); \
} else if (d ==
ZMM
_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ16); \
} else { \
macro_(ker, dtype, isa, kGT16); \
}
#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \
...
...
paddle/fluid/operators/math/jit_kernel_rnn.cc
浏览文件 @
7f17e561
...
...
@@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> {
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
T
*
checked
)
const
override
{
// gates: W_ch, W_ih, W_fh, W_oh
act_gate_d3_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d
_
);
act_gate_d3_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d3
_
);
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2
_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d
_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d_
);
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
Deprecated
(
gates
+
d3_
,
gates
+
d3
_
);
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2
_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
,
d
_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d
_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
...
...
@@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
vmul_d_
->
Compute
(
wp_data
,
ct_1
,
checked
,
d_
);
vmul_d_
->
Compute
(
wp_data
+
d_
,
ct_1
,
checked
+
d_
,
d_
);
vadd_d2_
->
Compute
(
checked
,
gates
+
d_
,
gates
+
d_
,
d2_
);
act_gate_d2_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d
_
);
act_gate_d2_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d2
_
);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
/* get ogated*/
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
act_gate_d_
->
Compute
Deprecated
(
gates
+
d3_
,
gates
+
d3
_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
,
d
_
);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2
_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d
_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d_
);
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
/* get outgated, put W_oc * C_t on igated */
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
Deprecated
(
gates
+
d3_
,
gates
+
d3
_
);
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2
_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
,
d
_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d
_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
...
...
@@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> {
}
void
ComputeH1
(
T
*
gates
,
T
*
ht
)
const
override
{
act_gate_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_state_d_
->
Compute
Deprecated
(
gates
+
d2_
,
gates
+
d2
_
);
act_gate_d_
->
Compute
(
gates
,
gates
,
d_
);
act_state_d_
->
Compute
(
gates
+
d2_
,
gates
+
d2_
,
d
_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d2_
,
ht
,
d_
);
}
void
ComputeHtPart1
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
// W: {W_update, W_reset; W_state}
act_gate_d2_
->
Compute
Deprecated
(
gates
,
gates
);
act_gate_d2_
->
Compute
(
gates
,
gates
,
d2_
);
vmul_d_
->
Compute
(
ht_1
,
gates
+
d_
,
ht
,
d_
);
}
void
ComputeHtPart2
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
T
*
y
=
gates
+
d2_
;
act_state_d_
->
Compute
Deprecated
(
y
,
y
);
act_state_d_
->
Compute
(
y
,
y
,
d_
);
// out = zt*ht~ + (1-zt)*ht_1
for
(
int
i
=
0
;
i
<
d_
;
++
i
)
{
ht
[
i
]
=
gates
[
i
]
*
y
[
i
]
+
(
static_cast
<
T
>
(
1
)
-
gates
[
i
])
*
ht_1
[
i
];
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
7f17e561
...
...
@@ -181,7 +181,8 @@ TEST(JitKernel, vexp) {
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
ComputeDeprecated
(
x_data
,
ztgt_data
);
// ker->Compute(x_data, ztgt_data);
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
...
...
@@ -222,7 +223,7 @@ void vsigmoid_better(
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
0.
f
-
y
[
i
];
}
vexp
->
Compute
Deprecated
(
y
,
y
);
vexp
->
Compute
(
y
,
y
,
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
1.
f
/
(
1.
f
+
y
[
i
]);
}
...
...
@@ -253,7 +254,7 @@ TEST(JitKernel, vsigmoid) {
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
Deprecated
(
x_data
,
ztgt_data
);
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
...
...
@@ -287,7 +288,7 @@ void vtanh_better(
const
int
n
,
const
float
*
x
,
float
*
y
)
{
const
float
a
=
2.
f
,
b
=
-
1.
f
;
vscal
->
Compute
(
&
a
,
x
,
y
,
n
);
vsigmoid
->
Compute
Deprecated
(
y
,
y
);
vsigmoid
->
Compute
(
y
,
y
,
n
);
vscal
->
Compute
(
&
a
,
y
,
y
,
n
);
vaddbias
->
Compute
(
&
b
,
y
,
y
,
n
);
}
...
...
@@ -321,7 +322,7 @@ TEST(JitKernel, vtanh) {
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
Deprecated
(
x_data
,
ztgt_data
);
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
...
...
@@ -344,8 +345,8 @@ void lstm_ctht_ref(
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VExpKernel
<
float
>>&
vexp_1
,
const
int
d
,
float
*
gates
,
const
float
*
ct_1
,
float
*
ct
,
float
*
ht
)
{
vsigmoid_3d
->
Compute
Deprecated
(
gates
+
d
,
gates
+
d
);
vtanh_d
->
Compute
Deprecated
(
gates
,
gates
);
vsigmoid_3d
->
Compute
(
gates
+
d
,
gates
+
d
,
3
*
d
);
vtanh_d
->
Compute
(
gates
,
gates
,
d
);
const
float
*
i
=
gates
+
d
,
*
f
=
gates
+
d
*
2
,
*
o
=
gates
+
d
*
3
;
const
float
min
=
SIGMOID_THRESHOLD_MIN
;
const
float
max
=
SIGMOID_THRESHOLD_MAX
;
...
...
@@ -355,7 +356,7 @@ void lstm_ctht_ref(
// H_t = act_cell(C_t) * ogated
float
tmp
=
ct
[
k
]
*
2
;
tmp
=
0.
f
-
((
tmp
<
min
)
?
min
:
((
tmp
>
max
)
?
max
:
tmp
));
vexp_1
->
Compute
Deprecated
(
&
tmp
,
&
tmp
);
vexp_1
->
Compute
(
&
tmp
,
&
tmp
,
1
);
tmp
=
2.
f
/
(
1.
f
+
tmp
)
-
1.
f
;
ht
[
k
]
=
tmp
*
o
[
k
];
}
...
...
@@ -373,13 +374,13 @@ void lstm_ctht_better(
const
paddle
::
operators
::
math
::
jitkernel
::
VAddKernel
<
float
>>&
vadd_d
,
const
int
d
,
float
*
gates
,
const
float
*
ct_1
,
float
*
ct
,
float
*
ht
)
{
int
d2
=
d
*
2
;
vsigmoid_3d
->
Compute
Deprecated
(
gates
+
d
,
gates
+
d
);
vtanh_d
->
Compute
Deprecated
(
gates
,
gates
);
vsigmoid_3d
->
Compute
(
gates
+
d
,
gates
+
d
,
3
*
d
);
vtanh_d
->
Compute
(
gates
,
gates
,
d
);
vmul_d
->
Compute
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
vmul_d
->
Compute
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
vadd_d
->
Compute
(
gates
+
d
,
gates
+
d2
,
ct
,
d
);
/* H_t = act_cell(C_t) * ogated */
vtanh_d
->
Compute
Deprecated
(
ct
,
gates
+
d2
);
vtanh_d
->
Compute
(
ct
,
gates
+
d2
,
d
);
vmul_d
->
Compute
(
gates
+
d2
,
gates
+
d
*
3
,
ht
,
d
);
}
...
...
@@ -736,7 +737,7 @@ void vaddrelu_better(
const
paddle
::
operators
::
math
::
jitkernel
::
VReluKernel
<
float
>>&
vrelu
,
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
d
)
{
vadd
->
Compute
(
x
,
y
,
z
,
d
);
vrelu
->
Compute
Deprecated
(
z
,
z
);
vrelu
->
Compute
(
z
,
z
,
d
);
}
TEST
(
JitKernel
,
vaddrelu
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录