Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
e3b61cf5
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e3b61cf5
编写于
11月 22, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
init gru jitcode and fix lstm jitcode
test=develop
上级
0f254465
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
170 addition
and
42 deletion
+170
-42
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+28
-8
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+109
-31
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+33
-3
未找到文件。
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
e3b61cf5
...
...
@@ -214,6 +214,9 @@ void VActJitCode::generate() {
bool
LSTMJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
%
8
==
0
;
}
void
LSTMJitCode
::
generate
()
{
if
(
use_peephole_
)
{
preCode
();
}
reg64_t
reg_ptr_gates
=
rax
;
reg64_t
reg_ptr_ct_1
=
r9
;
reg64_t
reg_ptr_ct
=
r10
;
...
...
@@ -224,18 +227,19 @@ void LSTMJitCode::generate() {
mov
(
reg_ptr_ht
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ht
)]);
int
offset
=
0
;
int
d
=
num_
*
sizeof
(
float
);
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
/* C_t = C_t-1 * fgated + cand_gated * igated*/
// c
vmovups
(
ymm_src
,
ptr
[
reg_ptr_gates
+
offset
]);
act
<
ymm_t
>
(
ymm_c
,
ymm_src
,
act_cand_
);
// i
vmovups
(
ymm_src
,
ptr
[
reg_ptr_gates
+
offset
+
num_
]);
vmovups
(
ymm_src
,
ptr
[
reg_ptr_gates
+
offset
+
d
]);
act
<
ymm_t
>
(
ymm_i
,
ymm_src
,
act_gate_
);
vmulps
(
ymm_c
,
ymm_c
,
ymm_i
);
if
(
!
compute_c1h1_
)
{
// f
vmovups
(
ymm_src
,
ptr
[
reg_ptr_gates
+
offset
+
2
*
num_
]);
vmovups
(
ymm_src
,
ptr
[
reg_ptr_gates
+
offset
+
2
*
d
]);
act
<
ymm_t
>
(
ymm_f
,
ymm_src
,
act_gate_
);
vmovups
(
ymm_i
,
ptr
[
reg_ptr_ct_1
+
offset
]);
vmulps
(
ymm_f
,
ymm_f
,
ymm_i
);
...
...
@@ -245,20 +249,36 @@ void LSTMJitCode::generate() {
ymm_t
ymm_ct
=
compute_c1h1_
?
ymm_c
:
ymm_f
;
ymm_t
ymm_o
=
compute_c1h1_
?
ymm_f
:
ymm_c
;
ymm_t
ymm_tmp
=
ymm_i
;
vmovups
(
ptr
[
reg_ptr_ct
+
offset
],
ymm_ct
);
// save ct
act
<
ymm_t
>
(
ymm_tmp
,
ymm_ct
,
act_cell_
);
vmovups
(
ymm_src
,
ptr
[
reg_ptr_gates
+
offset
+
3
*
num_
]);
vmovups
(
ymm_src
,
ptr
[
reg_ptr_gates
+
offset
+
3
*
d
]);
act
<
ymm_t
>
(
ymm_o
,
ymm_src
,
act_gate_
);
vmulps
(
ymm_o
,
ymm_tmp
,
ymm_o
);
// save ct and ht
vmovups
(
ptr
[
reg_ptr_ct
+
offset
],
ymm_ct
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_o
);
vmovups
(
ptr
[
reg_ptr_ht
+
offset
],
ymm_o
);
// save ht
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
ret
();
if
(
use_peephole_
)
{
postCode
();
}
else
{
ret
();
}
}
bool
GRUJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
%
8
==
0
;
}
void
GRUJitCode
::
generate
()
{
reg64_t
reg_ptr_gates
=
rax
;
reg64_t
reg_ptr_ct_1
=
r9
;
reg64_t
reg_ptr_ct
=
r10
;
reg64_t
reg_ptr_ht
=
r11
;
mov
(
reg_ptr_gates
,
ptr
[
param1
+
offsetof
(
lstm_t
,
gates
)]);
mov
(
reg_ptr_ct_1
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ct_1
)]);
mov
(
reg_ptr_ct
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ct
)]);
mov
(
reg_ptr_ht
,
ptr
[
param1
+
offsetof
(
lstm_t
,
ht
)]);
ret
();
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
e3b61cf5
...
...
@@ -302,6 +302,34 @@ class VActJitCode : public JitCode {
pop
(
reg_ptr_global
);
}
template
<
typename
JMM
>
void
act
(
JMM
&
dst
,
JMM
&
src
,
operand_type
type
)
{
// NOLINT
// use 15
JMM
zero
=
JMM
(
15
);
if
(
type_
==
operand_type
::
relu
)
{
vxorps
(
zero
,
zero
,
zero
);
}
switch
(
type
)
{
case
operand_type
::
relu
:
relu_jmm
<
JMM
>
(
dst
,
src
,
zero
);
break
;
case
operand_type
::
exp
:
exp_jmm
<
JMM
>
(
dst
,
src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
sigmoid
:
sigmoid_jmm
<
JMM
>
(
dst
,
src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
tanh
:
tanh_jmm
<
JMM
>
(
dst
,
src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
identity
:
break
;
default:
// throw error
break
;
}
}
protected:
int
num_
;
operand_type
type_
;
...
...
@@ -386,44 +414,94 @@ class LSTMJitCode : public VActJitCode {
operand_type
act_cand_
;
operand_type
act_cell_
;
reg64_t
param1
{
abi_param1
};
xmm_t
xmm_src
=
xmm_t
(
0
);
xmm_t
xmm_c
=
xmm_t
(
1
);
xmm_t
xmm_i
=
xmm_t
(
2
);
xmm_t
xmm_f
=
xmm_t
(
3
);
xmm_t
xmm_i
=
xmm_t
(
6
);
xmm_t
xmm_f
=
xmm_t
(
7
);
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_c
=
ymm_t
(
1
);
ymm_t
ymm_i
=
ymm_t
(
2
);
ymm_t
ymm_f
=
ymm_t
(
3
);
ymm_t
ymm_c
=
ymm_t
(
1
);
// 2~5 for act
ymm_t
ymm_i
=
ymm_t
(
6
);
ymm_t
ymm_f
=
ymm_t
(
7
);
};
template
<
typename
JMM
>
void
act
(
JMM
&
dst
,
JMM
&
src
,
operand_type
type
)
{
// NOLINT
// use 15
JMM
zero
=
JMM
(
15
);
if
(
type_
==
operand_type
::
relu
)
{
vxorps
(
zero
,
zero
,
zero
);
}
switch
(
type
)
{
case
operand_type
::
relu
:
relu_jmm
<
JMM
>
(
dst
,
src
,
zero
);
break
;
case
operand_type
::
exp
:
exp_jmm
<
JMM
>
(
dst
,
src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
sigmoid
:
sigmoid_jmm
<
JMM
>
(
dst
,
src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
tanh
:
tanh_jmm
<
JMM
>
(
dst
,
src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
identity
:
break
;
default:
// throw error
break
;
class
GRUJitCode
:
public
VActJitCode
{
public:
const
char
*
name
()
const
override
{
std
::
string
base
=
"GRUJitCode"
;
if
(
id_
==
0
)
{
base
+=
"_H1"
;
}
else
if
(
id_
==
1
)
{
base
+=
"_HtPart1"
;
}
else
if
(
id_
==
2
)
{
base
+=
"_HtPart2"
;
}
auto
AddTypeStr
=
[
&
](
operand_type
type
)
{
switch
(
type
)
{
case
operand_type
::
relu
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
exp
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
sigmoid
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
tanh
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
identity
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
};
AddTypeStr
(
act_gate_
);
AddTypeStr
(
act_cand_
);
return
base
.
c_str
();
}
explicit
GRUJitCode
(
int
id
,
const
gru_attr_t
&
attr
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
VActJitCode
(
attr
.
d
,
operand_type
::
sigmoid
/* this is bugy*/
,
code_size
,
code_ptr
),
id_
(
id
)
{
auto
typeExchange
=
[](
const
std
::
string
&
type
)
->
gen
::
operand_type
{
if
(
type
==
"sigmoid"
)
{
return
operand_type
::
sigmoid
;
}
else
if
(
type
==
"relu"
)
{
return
operand_type
::
relu
;
}
else
if
(
type
==
"tanh"
)
{
return
operand_type
::
tanh
;
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
operand_type
::
identity
;
}
// else throw error
return
operand_type
::
identity
;
};
num_
=
attr
.
d
;
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
}
static
bool
init
(
int
d
);
void
generate
()
override
;
protected:
int
id_
;
int
num_
;
operand_type
act_gate_
;
operand_type
act_cand_
;
reg64_t
param1
{
abi_param1
};
xmm_t
xmm_src
=
xmm_t
(
0
);
xmm_t
xmm_c
=
xmm_t
(
1
);
xmm_t
xmm_i
=
xmm_t
(
6
);
xmm_t
xmm_f
=
xmm_t
(
7
);
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_c
=
ymm_t
(
1
);
ymm_t
ymm_i
=
ymm_t
(
6
);
ymm_t
ymm_f
=
ymm_t
(
7
);
};
#ifdef PADDLE_WITH_MKLDNN
...
...
paddle/fluid/operators/math/jit_kernel_rnn.cc
浏览文件 @
e3b61cf5
...
...
@@ -40,7 +40,7 @@ class LSTMKernelImpl : public LSTMKernel<T> {
explicit
LSTMKernelImpl
(
const
lstm_attr_t
&
attr
)
:
LSTMKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
attr
.
d
))
{
size_t
sz
=
96
+
attr
.
d
/
YMM_FLOAT_BLOCK
*
84
*
8
;
// should change
size_t
sz
=
96
+
attr
.
d
/
YMM_FLOAT_BLOCK
*
90
*
4
*
8
;
jitcode0_
.
reset
(
new
gen
::
LSTMJitCode
(
false
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeCtHt
=
jitcode0_
->
getCode
<
void
(
*
)(
lstm_t
*
,
const
lstm_attr_t
*
)
>
();
...
...
@@ -66,7 +66,7 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
LSTMKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
false
;
// not ready yet
gen::LSTMJitCode::init(d);
return
gen
::
LSTMJitCode
::
init
(
d
);
}
#endif
...
...
@@ -82,7 +82,7 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
explicit
PeepholeKernelImpl
(
const
lstm_attr_t
&
attr
)
:
LSTMKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
attr
.
d
))
{
size_t
sz
=
96
+
attr
.
d
/
YMM_FLOAT_BLOCK
*
84
*
8
;
// should change
size_t
sz
=
96
+
attr
.
d
/
YMM_FLOAT_BLOCK
*
96
*
4
*
8
;
jitcode0_
.
reset
(
new
gen
::
LSTMJitCode
(
false
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeCtHt
=
jitcode0_
->
getCode
<
void
(
*
)(
lstm_t
*
,
const
lstm_attr_t
*
)
>
();
...
...
@@ -175,12 +175,42 @@ class GRUKernelImpl : public GRUKernel<T> {
static
inline
bool
useJIT
(
int
d
)
{
return
false
;
}
static
inline
bool
useMKL
(
int
d
)
{
return
false
;
}
explicit
GRUKernelImpl
(
const
gru_attr_t
&
attr
)
:
GRUKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
attr
.
d
))
{
size_t
sz
=
96
+
attr
.
d
/
YMM_FLOAT_BLOCK
*
84
*
8
;
// should change
jitcode0_
.
reset
(
new
gen
::
GRUJitCode
(
0
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeH1
=
jitcode0_
->
getCode
<
void
(
*
)(
gru_t
*
,
const
gru_attr_t
*
)
>
();
jitcode1_
.
reset
(
new
gen
::
GRUJitCode
(
1
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeHtPart1
=
jitcode1_
->
getCode
<
void
(
*
)(
gru_t
*
,
const
gru_attr_t
*
)
>
();
jitcode2_
.
reset
(
new
gen
::
GRUJitCode
(
2
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeHtPart2
=
jitcode1_
->
getCode
<
void
(
*
)(
gru_t
*
,
const
gru_attr_t
*
)
>
();
return
;
}
#endif
this
->
ComputeH1
=
refer
::
GRUH1
<
T
>
;
this
->
ComputeHtPart1
=
refer
::
GRUHtPart1
<
T
>
;
this
->
ComputeHtPart2
=
refer
::
GRUHtPart2
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
GRUJitCode
>
jitcode0_
{
nullptr
},
jitcode1_
{
nullptr
},
jitcode2_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
GRUKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
false
;
// jitcode not ready yet
}
#endif
#define JITKERNEL_DEFINE_NAME_GRU(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(const gru_attr_t& attr) { \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录