Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
6a7f83d4
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6a7f83d4
编写于
11月 23, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enable gru jitcode and refine act and lstm jitcode
test=develop
上级
686eaf20
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
149 addition
and
136 deletion
+149
-136
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+103
-80
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+38
-52
paddle/fluid/operators/math/jit_kernel_refer.h
paddle/fluid/operators/math/jit_kernel_refer.h
+3
-1
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+3
-3
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+2
-0
未找到文件。
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
6a7f83d4
...
@@ -140,32 +140,10 @@ bool VActJitCode::init(int d, operand_type type) {
...
@@ -140,32 +140,10 @@ bool VActJitCode::init(int d, operand_type type) {
}
}
void
VActJitCode
::
generate
()
{
void
VActJitCode
::
generate
()
{
xmm_t
xmm_zero
=
xmm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
2
);
if
(
type_
==
operand_type
::
relu
)
{
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
}
int
offset
=
0
;
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
switch
(
type_
)
{
act
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
type_
);
case
operand_type
::
relu
:
relu_jmm
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
ymm_zero
);
break
;
case
operand_type
::
exp
:
exp_jmm
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
sigmoid
:
sigmoid_jmm
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
tanh
:
tanh_jmm
<
ymm_t
>
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
identity
:
break
;
default:
break
;
}
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
}
...
@@ -182,22 +160,7 @@ void VActJitCode::generate() {
...
@@ -182,22 +160,7 @@ void VActJitCode::generate() {
block
=
1
;
block
=
1
;
vmovss
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmovss
(
xmm_src
,
ptr
[
param1
+
offset
]);
}
}
switch
(
type_
)
{
act
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
type_
);
case
operand_type
::
relu
:
relu_jmm
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
xmm_zero
);
break
;
case
operand_type
::
exp
:
exp_jmm
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
sigmoid
:
sigmoid_jmm
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
tanh
:
tanh_jmm
<
xmm_t
>
(
xmm_dst
,
xmm_src
,
2
,
3
,
4
,
5
);
break
;
default:
break
;
}
if
(
rest
>=
4
)
{
if
(
rest
>=
4
)
{
vmovups
(
ptr
[
param2
+
offset
],
xmm_dst
);
vmovups
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
else
if
(
rest
>=
2
)
{
}
else
if
(
rest
>=
2
)
{
...
@@ -233,52 +196,64 @@ void LSTMJitCode::generate() {
...
@@ -233,52 +196,64 @@ void LSTMJitCode::generate() {
int
offset
=
0
;
int
offset
=
0
;
int
d
=
num_
*
sizeof
(
float
);
int
d
=
num_
*
sizeof
(
float
);
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
/* C_t = C_t-1 * fgated + cand_gated * igated*/
/* gates: W_ch, W_ih, W_fh, W_oh */
// c
ymm_t
ymm_c
=
ymm_t
(
0
);
vmovups
(
ymm_src
,
ptr
[
reg_ptr_gates
+
offset
]);
ymm_t
ymm_i
=
ymm_t
(
1
);
act
<
ymm_t
>
(
ymm_c
,
ymm_src
,
act_cand_
);
ymm_t
ymm_f
=
ymm_t
(
2
);
// i
ymm_t
ymm_o
=
ymm_t
(
3
);
vmovups
(
ymm_src
,
ptr
[
reg_ptr_gates
+
offset
+
d
]);
ymm_t
ymm_ct_1
=
ymm_t
(
4
);
if
(
!
compute_c1h1_
&&
use_peephole_
)
{
ymm_t
ymm_wp0
=
ymm_t
(
5
);
ymm_t
ymm_wp
=
ymm_t
(
2
);
ymm_t
ymm_wp1
=
ymm_t
(
6
);
ymm_t
ymm_ct_1
=
ymm_t
(
3
);
ymm_t
ymm_wp2
=
ymm_t
(
7
);
vmovups
(
ymm_wp
,
ptr
[
reg_ptr_wp
+
offset
]);
vmovups
(
ymm_c
,
ptr
[
reg_ptr_gates
+
offset
]);
vmovups
(
ymm_i
,
ptr
[
reg_ptr_gates
+
offset
+
d
]);
vmovups
(
ymm_f
,
ptr
[
reg_ptr_gates
+
offset
+
2
*
d
]);
vmovups
(
ymm_o
,
ptr
[
reg_ptr_gates
+
offset
+
3
*
d
]);
if
(
!
compute_c1h1_
)
{
vmovups
(
ymm_ct_1
,
ptr
[
reg_ptr_ct_1
+
offset
]);
vmovups
(
ymm_ct_1
,
ptr
[
reg_ptr_ct_1
+
offset
]);
vmulps
(
ymm_wp
,
ymm_ct_1
,
ymm_wp
);
vaddps
(
ymm_src
,
ymm_src
,
ymm_wp
);
}
}
act
<
ymm_t
>
(
ymm_i
,
ymm_src
,
act_gate_
);
if
(
use_peephole_
)
{
vmovups
(
ymm_wp0
,
ptr
[
reg_ptr_wp
+
offset
]);
vmovups
(
ymm_wp1
,
ptr
[
reg_ptr_wp
+
offset
+
d
]);
vmovups
(
ymm_wp2
,
ptr
[
reg_ptr_wp
+
offset
+
2
*
d
]);
}
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
// act_cand(c)
act
<
ymm_t
>
(
ymm_c
,
ymm_c
,
act_cand_
);
// act_gate(i) or act_gate(ct_1 * wp0 + i)
if
(
!
compute_c1h1_
&&
use_peephole_
)
{
vmulps
(
ymm_wp0
,
ymm_ct_1
,
ymm_wp0
);
vaddps
(
ymm_i
,
ymm_i
,
ymm_wp0
);
}
act
<
ymm_t
>
(
ymm_i
,
ymm_i
,
act_gate_
);
vmulps
(
ymm_c
,
ymm_c
,
ymm_i
);
vmulps
(
ymm_c
,
ymm_c
,
ymm_i
);
if
(
!
compute_c1h1_
)
{
if
(
!
compute_c1h1_
)
{
// f
// act_gate(f) or act_gate(ct_1 * wp1 + f)
vmovups
(
ymm_src
,
ptr
[
reg_ptr_gates
+
offset
+
2
*
d
]);
vmovups
(
ymm_i
,
ptr
[
reg_ptr_ct_1
+
offset
]);
if
(
use_peephole_
)
{
if
(
use_peephole_
)
{
ymm_t
ymm_wp
=
ymm_t
(
3
);
vmulps
(
ymm_wp1
,
ymm_ct_1
,
ymm_wp1
);
vmovups
(
ymm_wp
,
ptr
[
reg_ptr_wp
+
offset
+
d
]);
vaddps
(
ymm_f
,
ymm_f
,
ymm_wp1
);
vmulps
(
ymm_wp
,
ymm_i
,
ymm_wp
);
vaddps
(
ymm_src
,
ymm_src
,
ymm_wp
);
}
}
act
<
ymm_t
>
(
ymm_f
,
ymm_src
,
act_gate_
);
act
<
ymm_t
>
(
ymm_f
,
ymm_f
,
act_gate_
);
vmulps
(
ymm_f
,
ymm_f
,
ymm_i
);
// ct
vmulps
(
ymm_f
,
ymm_f
,
ymm_ct_1
);
vaddps
(
ymm_f
,
ymm_f
,
ymm_c
);
vaddps
(
ymm_f
,
ymm_f
,
ymm_c
);
}
}
/* H_t = act_cell(C_t) * ogated */
/* H_t = act_cell(C_t) * act_gate(o) */
// act_cell(C_t)
ymm_t
ymm_ct
=
compute_c1h1_
?
ymm_c
:
ymm_f
;
ymm_t
ymm_ct
=
compute_c1h1_
?
ymm_c
:
ymm_f
;
ymm_t
ymm_o
=
compute_c1h1_
?
ymm_f
:
ymm_c
;
ymm_t
ymm_tmp
=
ymm_i
;
ymm_t
ymm_tmp
=
ymm_i
;
vmovups
(
ptr
[
reg_ptr_ct
+
offset
],
ymm_ct
);
// save ct
act
<
ymm_t
>
(
ymm_tmp
,
ymm_ct
,
act_cell_
);
vmovups
(
ymm_src
,
ptr
[
reg_ptr_gates
+
offset
+
3
*
d
]);
// act_gate(o) or act_gate(ct * wp2 + o)
if
(
use_peephole_
)
{
if
(
use_peephole_
)
{
ymm_t
ymm_wp
=
ymm_t
(
2
);
vmulps
(
ymm_wp2
,
ymm_ct
,
ymm_wp2
);
vmovups
(
ymm_wp
,
ptr
[
reg_ptr_wp
+
offset
+
d
*
2
]);
vaddps
(
ymm_o
,
ymm_o
,
ymm_wp2
);
vmulps
(
ymm_wp
,
ymm_ct
,
ymm_wp
);
vaddps
(
ymm_src
,
ymm_src
,
ymm_wp
);
}
}
act
<
ymm_t
>
(
ymm_tmp
,
ymm_ct
,
act_cell_
);
act
<
ymm_t
>
(
ymm_o
,
ymm_o
,
act_gate_
);
act
<
ymm_t
>
(
ymm_o
,
ymm_src
,
act_gate_
);
// ht
vmulps
(
ymm_o
,
ymm_tmp
,
ymm_o
);
vmulps
(
ymm_o
,
ymm_o
,
ymm_tmp
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_o
);
// save ht
// save ct and ht
vmovups
(
ptr
[
reg_ptr_ct
+
offset
],
ymm_ct
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_o
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
}
...
@@ -293,13 +268,61 @@ bool GRUJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
...
@@ -293,13 +268,61 @@ bool GRUJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
void
GRUJitCode
::
generate
()
{
void
GRUJitCode
::
generate
()
{
reg64_t
reg_ptr_gates
=
rax
;
reg64_t
reg_ptr_gates
=
rax
;
reg64_t
reg_ptr_ct_1
=
r9
;
reg64_t
reg_ptr_ht_1
=
r9
;
reg64_t
reg_ptr_ct
=
r10
;
reg64_t
reg_ptr_ht
=
r10
;
reg64_t
reg_ptr_ht
=
r11
;
mov
(
reg_ptr_gates
,
ptr
[
param1
+
offsetof
(
gru_t
,
gates
)]);
mov
(
reg_ptr_gates
,
ptr
[
param1
+
offsetof
(
lstm_t
,
gates
)]);
mov
(
reg_ptr_ht_1
,
ptr
[
param1
+
offsetof
(
gru_t
,
ht_1
)]);
mov
(
reg_ptr_ct_1
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ct_1
)]);
mov
(
reg_ptr_ht
,
ptr
[
param1
+
offsetof
(
gru_t
,
ht
)]);
mov
(
reg_ptr_ct
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ct
)]);
ymm_t
ymm_one
=
ymm_t
(
0
);
mov
(
reg_ptr_ht
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ht
)]);
if
(
id_
==
2
)
{
reg64_t
reg_ptr_tmp
=
r11
;
mov
(
reg_ptr_tmp
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_one
,
ptr
[
reg_ptr_tmp
+
OFFSET_EXP_ONE
]);
}
int
offset
=
0
;
int
d
=
num_
*
sizeof
(
float
);
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
ymm_t
ymm_u
=
ymm_t
(
1
);
ymm_t
ymm_r
=
ymm_t
(
2
);
ymm_t
ymm_s
=
ymm_t
(
3
);
ymm_t
ymm_ht_1
=
ymm_t
(
4
);
// W: {W_update, W_reset; W_state}
if
(
id_
==
0
||
id_
==
2
)
{
vmovups
(
ymm_u
,
ptr
[
reg_ptr_gates
+
offset
]);
vmovups
(
ymm_s
,
ptr
[
reg_ptr_gates
+
offset
+
2
*
d
]);
}
if
(
id_
==
1
)
{
vmovups
(
ymm_r
,
ptr
[
reg_ptr_gates
+
offset
+
d
]);
}
if
(
id_
==
1
||
id_
==
2
)
{
vmovups
(
ymm_ht_1
,
ptr
[
reg_ptr_ht_1
+
offset
]);
}
if
(
id_
==
0
)
{
// ht = act_gate(u) * act_cand(s)
act
<
ymm_t
>
(
ymm_u
,
ymm_u
,
act_gate_
);
act
<
ymm_t
>
(
ymm_s
,
ymm_s
,
act_cand_
);
vmulps
(
ymm_s
,
ymm_s
,
ymm_u
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_s
);
}
else
if
(
id_
==
1
)
{
// ht = act_gate(r) * ht_1
act
<
ymm_t
>
(
ymm_r
,
ymm_r
,
act_gate_
);
vmulps
(
ymm_r
,
ymm_r
,
ymm_ht_1
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_r
);
}
else
if
(
id_
==
2
)
{
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
ymm_t
ymm_one_inner
=
ymm_t
(
ymm_one
.
getIdx
());
act
<
ymm_t
>
(
ymm_u
,
ymm_u
,
act_gate_
);
act
<
ymm_t
>
(
ymm_s
,
ymm_s
,
act_cand_
);
vmulps
(
ymm_s
,
ymm_s
,
ymm_u
);
vsubps
(
ymm_u
,
ymm_one_inner
,
ymm_u
);
vmulps
(
ymm_u
,
ymm_ht_1
,
ymm_u
);
vaddps
(
ymm_u
,
ymm_s
,
ymm_u
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_u
);
}
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
ret
();
ret
();
}
}
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
6a7f83d4
...
@@ -169,31 +169,34 @@ class VActJitCode : public JitCode {
...
@@ -169,31 +169,34 @@ class VActJitCode : public JitCode {
protected:
protected:
// compute relu with ymm, xmm
// compute relu with ymm, xmm
template
<
typename
JMM
>
template
<
typename
JMM
>
void
relu_jmm
(
JMM
&
dst
,
JMM
&
src
,
JMM
&
zero
)
{
// NOLINT
void
relu_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
zero_idx
=
15
)
{
// NOLINT
JMM
zero
=
JMM
(
zero_idx
);
vxorps
(
zero
,
zero
,
zero
);
vmaxps
(
dst
,
src
,
zero
);
vmaxps
(
dst
,
src
,
zero
);
}
}
// compute exp with ymm, xmm
// compute exp with ymm, xmm
template
<
typename
JMM
>
template
<
typename
JMM
>
void
exp_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
// NOLINT
void
exp_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
int
fx_idx
=
12
,
// NOLINT
int
mask_idx
=
4
,
int
tmp_idx
=
5
)
{
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
tmp_idx
=
15
)
{
using
namespace
platform
::
jit
;
// NOLINT
using
namespace
platform
::
jit
;
// NOLINT
assert
(
src
.
getIdx
()
!=
dst
.
getIdx
());
// TODO(TJ): use enfore
// check all idx can not equal
// check all idx can not equal
JMM
jmm_src
=
JMM
(
src_idx
);
JMM
jmm_fx
=
JMM
(
fx_idx
);
JMM
jmm_fx
=
JMM
(
fx_idx
);
JMM
jmm_fy
=
JMM
(
fy_idx
);
JMM
jmm_fy
=
JMM
(
fy_idx
);
JMM
jmm_mask
=
JMM
(
mask_idx
);
JMM
jmm_mask
=
JMM
(
mask_idx
);
JMM
jmm_tmp
=
JMM
(
tmp_idx
);
JMM
jmm_tmp
=
JMM
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
push
(
reg_ptr_global
);
vmovaps
(
jmm_src
,
src
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
vminps
(
src
,
src
,
jmm_tmp
);
vminps
(
jmm_src
,
jmm_
src
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOW
]);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOW
]);
vmaxps
(
src
,
src
,
jmm_tmp
);
vmaxps
(
jmm_src
,
jmm_
src
,
jmm_tmp
);
// express exp(x) as exp(g + n*log(2))
// express exp(x) as exp(g + n*log(2))
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOG2EF
]);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOG2EF
]);
vmulps
(
jmm_fx
,
src
,
jmm_tmp
);
vmulps
(
jmm_fx
,
jmm_
src
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_0P5
]);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_0P5
]);
vaddps
(
jmm_fx
,
jmm_fx
,
jmm_tmp
);
vaddps
(
jmm_fx
,
jmm_fx
,
jmm_tmp
);
vroundps
(
jmm_fy
,
jmm_fx
,
0x01
);
vroundps
(
jmm_fy
,
jmm_fx
,
0x01
);
...
@@ -207,21 +210,21 @@ class VActJitCode : public JitCode {
...
@@ -207,21 +210,21 @@ class VActJitCode : public JitCode {
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C2
]);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C2
]);
JMM
ymm_z
=
JMM
(
jmm_mask
.
getIdx
());
JMM
ymm_z
=
JMM
(
jmm_mask
.
getIdx
());
vmulps
(
ymm_z
,
jmm_fx
,
jmm_tmp
);
vmulps
(
ymm_z
,
jmm_fx
,
jmm_tmp
);
vsubps
(
src
,
src
,
jmm_fy
);
vsubps
(
jmm_src
,
jmm_
src
,
jmm_fy
);
vsubps
(
src
,
src
,
ymm_z
);
vsubps
(
jmm_src
,
jmm_
src
,
ymm_z
);
vmulps
(
ymm_z
,
src
,
src
);
vmulps
(
ymm_z
,
jmm_src
,
jmm_
src
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P0
]);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P0
]);
vmulps
(
dst
,
src
,
jmm_tmp
);
vmulps
(
dst
,
jmm_
src
,
jmm_tmp
);
for
(
size_t
i
=
OFFSET_EXP_P1
;
i
<
OFFSET_EXP_P5
;
for
(
size_t
i
=
OFFSET_EXP_P1
;
i
<
OFFSET_EXP_P5
;
i
+=
(
YMM_FLOAT_BLOCK
*
sizeof
(
float
)))
{
i
+=
(
YMM_FLOAT_BLOCK
*
sizeof
(
float
)))
{
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
i
]);
// P1~P4
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
i
]);
// P1~P4
vaddps
(
dst
,
dst
,
jmm_tmp
);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vmulps
(
dst
,
dst
,
src
);
vmulps
(
dst
,
dst
,
jmm_
src
);
}
}
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P5
]);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P5
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vmulps
(
dst
,
dst
,
ymm_z
);
vmulps
(
dst
,
dst
,
ymm_z
);
vaddps
(
dst
,
dst
,
src
);
vaddps
(
dst
,
dst
,
jmm_
src
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
]);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vaddps
(
dst
,
dst
,
jmm_tmp
);
// build 2^n
// build 2^n
...
@@ -258,20 +261,23 @@ class VActJitCode : public JitCode {
...
@@ -258,20 +261,23 @@ class VActJitCode : public JitCode {
// compute sigmoid with ymm, xmm
// compute sigmoid with ymm, xmm
template
<
typename
JMM
>
template
<
typename
JMM
>
void
sigmoid_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
fx_idx
=
2
,
// NOLINT
void
sigmoid_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
// NOLINT
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
)
{
int
fx_idx
=
12
,
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
tmp_idx
=
15
)
{
// y = 1 / (1 + e^-x)
// y = 1 / (1 + e^-x)
JMM
jmm_tmp
=
JMM
(
tmp_idx
);
JMM
jmm_tmp
=
JMM
(
tmp_idx
);
JMM
jmm_src
=
JMM
(
src_idx
);
reg64_t
reg_ptr_global
=
rax
;
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
push
(
reg_ptr_global
);
vmovaps
(
jmm_src
,
src
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MAX
]);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MAX
]);
vminps
(
src
,
src
,
jmm_tmp
);
vminps
(
jmm_src
,
jmm_
src
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MIN
]);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MIN
]);
vmaxps
(
src
,
src
,
jmm_tmp
);
vmaxps
(
jmm_src
,
jmm_
src
,
jmm_tmp
);
vxorps
(
jmm_tmp
,
jmm_tmp
,
jmm_tmp
);
vxorps
(
jmm_tmp
,
jmm_tmp
,
jmm_tmp
);
vsubps
(
src
,
jmm_tmp
,
src
);
vsubps
(
jmm_src
,
jmm_tmp
,
jmm_
src
);
exp_jmm
<
JMM
>
(
dst
,
src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
exp_jmm
<
JMM
>
(
dst
,
jmm_src
,
src_idx
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vdivps
(
dst
,
jmm_tmp
,
dst
);
vdivps
(
dst
,
jmm_tmp
,
dst
);
...
@@ -280,19 +286,22 @@ class VActJitCode : public JitCode {
...
@@ -280,19 +286,22 @@ class VActJitCode : public JitCode {
// compute tanh with ymm, xmm
// compute tanh with ymm, xmm
template
<
typename
JMM
>
template
<
typename
JMM
>
void
tanh_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
// NOLINT
void
tanh_jmm
(
JMM
&
dst
,
JMM
&
src
,
int
src_idx
=
11
,
// NOLINT
int
mask_idx
=
4
,
int
tmp_idx
=
5
)
{
int
fx_idx
=
12
,
int
fy_idx
=
13
,
int
mask_idx
=
14
,
int
tmp_idx
=
15
)
{
// y = 2 / (1 + e^(-2x)) - 1
// y = 2 / (1 + e^(-2x)) - 1
JMM
jmm_src
=
JMM
(
src_idx
);
JMM
jmm_tmp
=
JMM
(
tmp_idx
);
JMM
jmm_tmp
=
JMM
(
tmp_idx
);
JMM
jmm_zero
=
JMM
(
mask_idx
);
JMM
jmm_zero
=
JMM
(
mask_idx
);
reg64_t
reg_ptr_global
=
rax
;
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
push
(
reg_ptr_global
);
vmovaps
(
jmm_src
,
src
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vxorps
(
jmm_zero
,
jmm_zero
,
jmm_zero
);
vxorps
(
jmm_zero
,
jmm_zero
,
jmm_zero
);
vsubps
(
jmm_tmp
,
jmm_zero
,
jmm_tmp
);
vsubps
(
jmm_tmp
,
jmm_zero
,
jmm_tmp
);
vmulps
(
src
,
src
,
jmm_tmp
);
vmulps
(
jmm_src
,
jmm_
src
,
jmm_tmp
);
exp_jmm
<
JMM
>
(
dst
,
src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
exp_jmm
<
JMM
>
(
dst
,
jmm_src
,
src_idx
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vaddps
(
dst
,
dst
,
jmm_tmp
);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vmovaps
(
jmm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
...
@@ -304,23 +313,19 @@ class VActJitCode : public JitCode {
...
@@ -304,23 +313,19 @@ class VActJitCode : public JitCode {
template
<
typename
JMM
>
template
<
typename
JMM
>
void
act
(
JMM
&
dst
,
JMM
&
src
,
operand_type
type
)
{
// NOLINT
void
act
(
JMM
&
dst
,
JMM
&
src
,
operand_type
type
)
{
// NOLINT
// use 15
// use 11~15
JMM
zero
=
JMM
(
15
);
if
(
type_
==
operand_type
::
relu
)
{
vxorps
(
zero
,
zero
,
zero
);
}
switch
(
type
)
{
switch
(
type
)
{
case
operand_type
::
relu
:
case
operand_type
::
relu
:
relu_jmm
<
JMM
>
(
dst
,
src
,
zero
);
relu_jmm
<
JMM
>
(
dst
,
src
,
15
);
break
;
break
;
case
operand_type
::
exp
:
case
operand_type
::
exp
:
exp_jmm
<
JMM
>
(
dst
,
src
,
2
,
3
,
4
,
5
);
exp_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
1
5
);
break
;
break
;
case
operand_type
::
sigmoid
:
case
operand_type
::
sigmoid
:
sigmoid_jmm
<
JMM
>
(
dst
,
src
,
2
,
3
,
4
,
5
);
sigmoid_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
1
5
);
break
;
break
;
case
operand_type
::
tanh
:
case
operand_type
::
tanh
:
tanh_jmm
<
JMM
>
(
dst
,
src
,
2
,
3
,
4
,
5
);
tanh_jmm
<
JMM
>
(
dst
,
src
,
11
,
12
,
13
,
14
,
1
5
);
break
;
break
;
case
operand_type
::
identity
:
case
operand_type
::
identity
:
break
;
break
;
...
@@ -414,15 +419,6 @@ class LSTMJitCode : public VActJitCode {
...
@@ -414,15 +419,6 @@ class LSTMJitCode : public VActJitCode {
operand_type
act_cand_
;
operand_type
act_cand_
;
operand_type
act_cell_
;
operand_type
act_cell_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param1
{
abi_param1
};
xmm_t
xmm_src
=
xmm_t
(
0
);
xmm_t
xmm_c
=
xmm_t
(
1
);
xmm_t
xmm_i
=
xmm_t
(
6
);
xmm_t
xmm_f
=
xmm_t
(
7
);
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_c
=
ymm_t
(
1
);
// 2~5 for act
ymm_t
ymm_i
=
ymm_t
(
6
);
ymm_t
ymm_f
=
ymm_t
(
7
);
};
};
class
GRUJitCode
:
public
VActJitCode
{
class
GRUJitCode
:
public
VActJitCode
{
...
@@ -492,16 +488,6 @@ class GRUJitCode : public VActJitCode {
...
@@ -492,16 +488,6 @@ class GRUJitCode : public VActJitCode {
operand_type
act_gate_
;
operand_type
act_gate_
;
operand_type
act_cand_
;
operand_type
act_cand_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param1
{
abi_param1
};
xmm_t
xmm_src
=
xmm_t
(
0
);
xmm_t
xmm_c
=
xmm_t
(
1
);
xmm_t
xmm_i
=
xmm_t
(
6
);
xmm_t
xmm_f
=
xmm_t
(
7
);
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_c
=
ymm_t
(
1
);
ymm_t
ymm_i
=
ymm_t
(
6
);
ymm_t
ymm_f
=
ymm_t
(
7
);
};
};
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
...
...
paddle/fluid/operators/math/jit_kernel_refer.h
浏览文件 @
6a7f83d4
...
@@ -206,7 +206,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
...
@@ -206,7 +206,7 @@ void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
const
T
*
ht_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ht_1
);
const
T
*
ht_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ht_1
);
auto
act_gate
=
getActFunc
<
T
>
(
attr
->
act_gate
);
auto
act_gate
=
getActFunc
<
T
>
(
attr
->
act_gate
);
act_gate
(
gates
,
gates
,
attr
->
d
*
2
);
act_gate
(
gates
+
attr
->
d
,
gates
+
attr
->
d
,
attr
->
d
);
VMul
(
ht_1
,
gates
+
attr
->
d
,
ht
,
attr
->
d
);
VMul
(
ht_1
,
gates
+
attr
->
d
,
ht
,
attr
->
d
);
}
}
...
@@ -215,9 +215,11 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
...
@@ -215,9 +215,11 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
const
T
*
ht_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ht_1
);
const
T
*
ht_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ht_1
);
auto
act_gate
=
getActFunc
<
T
>
(
attr
->
act_gate
);
auto
act_cand
=
getActFunc
<
T
>
(
attr
->
act_cand
);
auto
act_cand
=
getActFunc
<
T
>
(
attr
->
act_cand
);
int
d
=
attr
->
d
;
int
d
=
attr
->
d
;
T
*
y
=
gates
+
d
*
2
;
T
*
y
=
gates
+
d
*
2
;
act_gate
(
gates
,
gates
,
d
);
act_cand
(
y
,
y
,
d
);
act_cand
(
y
,
y
,
d
);
// out = zt*ht~ + (1-zt)*ht_1
// out = zt*ht~ + (1-zt)*ht_1
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
...
...
paddle/fluid/operators/math/jit_kernel_rnn.cc
浏览文件 @
6a7f83d4
...
@@ -177,7 +177,7 @@ class GRUKernelImpl : public GRUKernel<T> {
...
@@ -177,7 +177,7 @@ class GRUKernelImpl : public GRUKernel<T> {
explicit
GRUKernelImpl
(
const
gru_attr_t
&
attr
)
:
GRUKernel
<
T
>
()
{
explicit
GRUKernelImpl
(
const
gru_attr_t
&
attr
)
:
GRUKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
attr
.
d
))
{
if
(
useJIT
(
attr
.
d
))
{
size_t
sz
=
96
+
attr
.
d
/
YMM_FLOAT_BLOCK
*
84
*
8
;
// should change
size_t
sz
=
96
+
attr
.
d
/
YMM_FLOAT_BLOCK
*
96
*
2
*
8
;
jitcode0_
.
reset
(
new
gen
::
GRUJitCode
(
0
,
attr
,
sz
>
4096
?
sz
:
4096
));
jitcode0_
.
reset
(
new
gen
::
GRUJitCode
(
0
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeH1
=
this
->
ComputeH1
=
jitcode0_
->
getCode
<
void
(
*
)(
gru_t
*
,
const
gru_attr_t
*
)
>
();
jitcode0_
->
getCode
<
void
(
*
)(
gru_t
*
,
const
gru_attr_t
*
)
>
();
...
@@ -188,7 +188,7 @@ class GRUKernelImpl : public GRUKernel<T> {
...
@@ -188,7 +188,7 @@ class GRUKernelImpl : public GRUKernel<T> {
jitcode2_
.
reset
(
new
gen
::
GRUJitCode
(
2
,
attr
,
sz
>
4096
?
sz
:
4096
));
jitcode2_
.
reset
(
new
gen
::
GRUJitCode
(
2
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeHtPart2
=
this
->
ComputeHtPart2
=
jitcode
1
_
->
getCode
<
void
(
*
)(
gru_t
*
,
const
gru_attr_t
*
)
>
();
jitcode
2
_
->
getCode
<
void
(
*
)(
gru_t
*
,
const
gru_attr_t
*
)
>
();
return
;
return
;
}
}
#endif
#endif
...
@@ -207,7 +207,7 @@ class GRUKernelImpl : public GRUKernel<T> {
...
@@ -207,7 +207,7 @@ class GRUKernelImpl : public GRUKernel<T> {
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
template
<
>
template
<
>
bool
GRUKernelImpl
<
float
>::
useJIT
(
int
d
)
{
bool
GRUKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
false
;
// jitcode not ready yet
return
gen
::
GRUJitCode
::
init
(
d
);
}
}
#endif
#endif
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
6a7f83d4
...
@@ -714,6 +714,8 @@ TEST(JitKernel, pool) {
...
@@ -714,6 +714,8 @@ TEST(JitKernel, pool) {
std
::
string
act_gate
=
"sigmoid"
,
act_cand
=
"tanh"
,
act_cell
=
"tanh"
;
std
::
string
act_gate
=
"sigmoid"
,
act_cand
=
"tanh"
,
act_cell
=
"tanh"
;
jit
::
lstm_attr_t
attr
(
frame_size
,
act_gate
,
act_cand
,
act_cell
,
false
);
jit
::
lstm_attr_t
attr
(
frame_size
,
act_gate
,
act_cand
,
act_cell
,
false
);
// empty call it to avoid unknown flag 'use_pinned_memory' on Mac
paddle
::
platform
::
jit
::
MayIUse
(
paddle
::
platform
::
jit
::
avx
);
const
auto
&
plstm1
=
const
auto
&
plstm1
=
jit
::
KernelPool
::
Instance
()
jit
::
KernelPool
::
Instance
()
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
const
jit
::
lstm_attr_t
&>
(
attr
);
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
const
jit
::
lstm_attr_t
&>
(
attr
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录