Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f9138608
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f9138608
编写于
11月 21, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
jitkernel lstm refer support peephole
test=develop
上级
2f9b5f23
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
250 addition
and
263 deletion
+250
-263
paddle/fluid/operators/fused/fusion_lstm_op.cc
paddle/fluid/operators/fused/fusion_lstm_op.cc
+46
-27
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+3
-3
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+30
-12
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+3
-12
paddle/fluid/operators/math/jit_kernel_impl.h
paddle/fluid/operators/math/jit_kernel_impl.h
+10
-4
paddle/fluid/operators/math/jit_kernel_macro.h
paddle/fluid/operators/math/jit_kernel_macro.h
+4
-4
paddle/fluid/operators/math/jit_kernel_refer.h
paddle/fluid/operators/math/jit_kernel_refer.h
+28
-7
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+109
-179
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+17
-15
未找到文件。
paddle/fluid/operators/fused/fusion_lstm_op.cc
浏览文件 @
f9138608
...
...
@@ -236,27 +236,31 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \
const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/
\
const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/
\
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/
\
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
} \
const auto& ker = \
math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::LSTMKernel<T>, const std::string&, \
const std::string&, const std::string&>( \
ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("candidate_activation"), \
ctx.Attr<std::string>("cell_activation"), D, use_peepholes)
#define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/
\
const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/
\
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/
\
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
} \
const math::jitkernel::lstm_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("candidate_activation"), \
ctx.Attr<std::string>("cell_activation"), use_peepholes); \
math::jitkernel::lstm_t one_step; \
one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \
const auto& ker = \
math::jitkernel::KernelPool::Instance() \
.template Get<math::jitkernel::LSTMKernel<T>, \
const math::jitkernel::lstm_attr_t&>(attr)
// Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \
...
...
@@ -299,7 +303,10 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_h_data
=
h0_data
+
bid
*
D
;
prev_c_data
=
c0_data
+
bid
*
D
;
}
else
{
ker
->
ComputeC1H1
(
xx_data
,
c_out_data
,
h_out_data
,
wp_data
);
one_step
.
gates
=
xx_data
;
one_step
.
ct
=
c_out_data
;
one_step
.
ht
=
h_out_data
;
ker
->
ComputeC1H1
(
&
one_step
,
&
attr
);
tstart
=
1
;
// move one step
prev_h_data
=
h_out_data
;
...
...
@@ -310,8 +317,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
}
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
ker
->
ComputeCtHt
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
,
wp_data
,
checked_cell_data
);
one_step
.
gates
=
xx_data
;
one_step
.
ct_1
=
prev_c_data
;
one_step
.
ct
=
c_out_data
;
one_step
.
ht
=
h_out_data
;
ker
->
ComputeCtHt
(
&
one_step
,
&
attr
);
// move one step
prev_h_data
=
h_out_data
;
prev_c_data
=
c_out_data
;
...
...
@@ -388,7 +399,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
ker
->
ComputeC1H1
(
cur_in_data
,
cur_c_out_data
,
cur_h_out_data
,
wp_data
);
one_step
.
gates
=
cur_in_data
;
one_step
.
ct
=
cur_c_out_data
;
one_step
.
ht
=
cur_h_out_data
;
ker
->
ComputeC1H1
(
&
one_step
,
&
attr
);
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
...
...
@@ -413,8 +428,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T
*
cur_c_out_data
=
batched_c_out_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
ker
->
ComputeCtHt
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
cur_h_out_data
,
wp_data
,
checked_cell_data
);
one_step
.
gates
=
cur_in_data
;
one_step
.
ct_1
=
cur_prev_c_data
;
one_step
.
ct
=
cur_c_out_data
;
one_step
.
ht
=
cur_h_out_data
;
ker
->
ComputeCtHt
(
&
one_step
,
&
attr
);
// move one batch
cur_in_data
+=
D4
;
cur_prev_c_data
+=
D
;
...
...
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
f9138608
...
...
@@ -233,7 +233,7 @@ void LSTMJitCode::generate() {
vmovups
(
ymm_src
,
ptr
[
reg_ptr_gates
+
offset
+
num_
]);
act
<
ymm_t
>
(
ymm_i
,
ymm_src
,
act_gate_
);
vmulps
(
ymm_c
,
ymm_c
,
ymm_i
);
if
(
first
_
)
{
if
(
!
compute_c1h1
_
)
{
// f
vmovups
(
ymm_src
,
ptr
[
reg_ptr_gates
+
offset
+
2
*
num_
]);
act
<
ymm_t
>
(
ymm_f
,
ymm_src
,
act_gate_
);
...
...
@@ -242,8 +242,8 @@ void LSTMJitCode::generate() {
vaddps
(
ymm_f
,
ymm_f
,
ymm_c
);
}
/* H_t = act_cell(C_t) * ogated */
ymm_t
ymm_ct
=
first
_
?
ymm_c
:
ymm_f
;
ymm_t
ymm_o
=
first
_
?
ymm_f
:
ymm_c
;
ymm_t
ymm_ct
=
compute_c1h1
_
?
ymm_c
:
ymm_f
;
ymm_t
ymm_o
=
compute_c1h1
_
?
ymm_f
:
ymm_c
;
ymm_t
ymm_tmp
=
ymm_i
;
act
<
ymm_t
>
(
ymm_tmp
,
ymm_ct
,
act_cell_
);
vmovups
(
ymm_src
,
ptr
[
reg_ptr_gates
+
offset
+
3
*
num_
]);
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
f9138608
...
...
@@ -319,6 +319,12 @@ class LSTMJitCode : public VActJitCode {
public:
const
char
*
name
()
const
override
{
std
::
string
base
=
"LSTMJitCode"
;
if
(
use_peephole_
)
{
base
+=
"_Peephole"
;
}
if
(
compute_c1h1_
)
{
base
+=
"_C1H1"
;
}
auto
AddTypeStr
=
[
&
](
operand_type
type
)
{
switch
(
type
)
{
case
operand_type
::
relu
:
...
...
@@ -340,30 +346,42 @@ class LSTMJitCode : public VActJitCode {
break
;
}
};
if
(
first_
)
{
base
+=
"_C1H1"
;
}
AddTypeStr
(
act_gate_
);
AddTypeStr
(
act_cand_
);
AddTypeStr
(
act_cell_
);
return
base
.
c_str
();
}
explicit
LSTMJitCode
(
int
d
,
bool
first
,
operand_type
act_gate
,
operand_type
act_cand
,
operand_type
act_cell
,
explicit
LSTMJitCode
(
bool
compute_c1h1
,
const
lstm_attr_t
&
attr
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
VActJitCode
(
d
,
act_gate
,
code_size
,
code_ptr
),
num_
(
d
),
first_
(
first
),
act_gate_
(
act_gate
),
act_cand_
(
act_cand
),
act_cell_
(
act_cell
)
{}
:
VActJitCode
(
attr
.
d
,
operand_type
::
sigmoid
/* this is bugy*/
,
code_size
,
code_ptr
),
compute_c1h1_
(
compute_c1h1
)
{
auto
typeExchange
=
[](
const
std
::
string
&
type
)
->
gen
::
operand_type
{
if
(
type
==
"sigmoid"
)
{
return
operand_type
::
sigmoid
;
}
else
if
(
type
==
"relu"
)
{
return
operand_type
::
relu
;
}
else
if
(
type
==
"tanh"
)
{
return
operand_type
::
tanh
;
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
operand_type
::
identity
;
}
// else throw error
return
operand_type
::
identity
;
};
num_
=
attr
.
d
;
use_peephole_
=
attr
.
use_peephole
;
act_gate_
=
typeExchange
(
attr
.
act_gate
);
act_cand_
=
typeExchange
(
attr
.
act_cand
);
act_cell_
=
typeExchange
(
attr
.
act_cell
);
}
static
bool
init
(
int
d
);
void
generate
()
override
;
protected:
int
num_
;
bool
first_
;
bool
compute_c1h1_
;
bool
use_peephole_
;
operand_type
act_gate_
;
operand_type
act_cand_
;
operand_type
act_cell_
;
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
f9138608
...
...
@@ -122,18 +122,9 @@ class VTanhKernel : public VActKernel<T> {};
template
<
typename
T
>
class
LSTMKernel
:
public
Kernel
{
public:
virtual
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
/* below only used in peephole*/
const
T
*
wp_data
=
nullptr
,
T
*
checked
=
nullptr
)
const
=
0
;
virtual
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
/* below only used in peephole*/
const
T
*
wp_data
=
nullptr
)
const
=
0
;
// void (*ComputeCtHt)(lstm_t *);
// // compute c1 and h1 without c0 or h0
// void (*ComputeC1H1)(lstm_t *);
void
(
*
ComputeCtHt
)(
lstm_t
*
,
const
lstm_attr_t
*
);
// compute c1 and h1 without c0 or h0
void
(
*
ComputeC1H1
)(
lstm_t
*
,
const
lstm_attr_t
*
);
};
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_impl.h
浏览文件 @
f9138608
...
...
@@ -33,18 +33,24 @@ typedef struct {
const
void
*
ct_1
;
void
*
ct
;
void
*
ht
;
/*
below
only used in peephole*/
const
void
*
wp
_data
{
nullptr
};
/*
weight_peephole and checked data are
only used in peephole*/
const
void
*
wp
{
nullptr
};
void
*
checked
{
nullptr
};
}
lstm_t
;
typedef
struct
lstm_attr_s
{
bool
use_peephole
;
int
d
;
std
::
string
act_gate
,
act_cand
,
act_cell
;
lstm_attr_s
()
=
default
;
lstm_attr_s
(
int
_d
,
const
std
::
string
&
_act_gate
,
const
std
::
string
&
_act_cand
,
const
std
::
string
&
_act_cell
)
:
d
(
_d
),
act_gate
(
_act_gate
),
act_cand
(
_act_cand
),
act_cell
(
_act_cell
)
{}
const
std
::
string
&
_act_cand
,
const
std
::
string
&
_act_cell
,
bool
_use_peephole
=
false
)
:
use_peephole
(
_use_peephole
),
d
(
_d
),
act_gate
(
_act_gate
),
act_cand
(
_act_cand
),
act_cell
(
_act_cell
)
{}
}
lstm_attr_t
;
}
// namespace jitkernel
...
...
paddle/fluid/operators/math/jit_kernel_macro.h
浏览文件 @
f9138608
...
...
@@ -82,10 +82,10 @@ namespace jitkernel {
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \
marco_declare, macro_find_key, macro_impl) \
marco_define_name(ker_key, ker_class); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float,
JITKERNEL_DECLARE,
\
JITKERNEL_FIND_KEY, JITKERNEL_IMPL);
\
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double,
JITKERNEL_DECLARE,
\
JITKERNEL_FIND_KEY, JITKERNEL_IMPL
)
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float,
marco_declare,
\
macro_find_key, macro_impl);
\
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double,
marco_declare,
\
macro_find_key, macro_impl
)
#define REGISTER_JITKERNEL(ker_key, ker_class) \
REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \
...
...
paddle/fluid/operators/math/jit_kernel_refer.h
浏览文件 @
f9138608
...
...
@@ -117,11 +117,13 @@ void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT
}
template
<
typename
T
>
void
LSTMCtHt
(
lstm_t
*
step
,
lstm_attr_t
*
attr
)
{
void
LSTMCtHt
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
const
T
*
ct_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ct_1
);
T
*
ct
=
reinterpret_cast
<
T
*>
(
step
->
ct
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
const
T
*
wp
=
reinterpret_cast
<
const
T
*>
(
step
->
wp
);
T
*
checked
=
reinterpret_cast
<
T
*>
(
step
->
checked
);
auto
act_gate
=
getActFunc
<
T
>
(
attr
->
act_gate
);
auto
act_cand
=
getActFunc
<
T
>
(
attr
->
act_cand
);
auto
act_cell
=
getActFunc
<
T
>
(
attr
->
act_cell
);
...
...
@@ -129,23 +131,36 @@ void LSTMCtHt(lstm_t* step, lstm_attr_t* attr) {
int
d2
=
d
*
2
;
int
d3
=
d
*
3
;
// gates: W_ch, W_ih, W_fh, W_oh
act_gate
(
gates
+
d
,
gates
+
d
,
d3
);
if
(
attr
->
use_peephole
)
{
VMul
(
wp
,
ct_1
,
checked
,
d
);
VMul
(
wp
+
d
,
ct_1
,
checked
+
d
,
d
);
VAdd
(
checked
,
gates
+
d
,
gates
+
d
,
d2
);
act_gate
(
gates
+
d
,
gates
+
d
,
d2
);
}
else
{
act_gate
(
gates
+
d
,
gates
+
d
,
d3
);
}
/
* C_t = C_t-1 * fgated + cand_gated * igated */
/
/ C_t = C_t-1 * fgated + cand_gated * igated
act_cand
(
gates
,
gates
,
d
);
VMul
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
VMul
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
VAdd
(
gates
+
d
,
gates
+
d2
,
ct
,
d
);
/* H_t = act_cell(C_t) * ogated */
if
(
attr
->
use_peephole
)
{
// get ogated
VMul
(
wp
+
d2
,
ct
,
gates
+
d
,
d
);
VAdd
(
gates
+
d
,
gates
+
d3
,
gates
+
d3
,
d
);
act_gate
(
gates
+
d3
,
gates
+
d3
,
d
);
}
// H_t = act_cell(C_t) * ogated
act_cell
(
ct
,
gates
+
d2
,
d
);
VMul
(
gates
+
d2
,
gates
+
d3
,
ht
,
d
);
}
// compute c1 and h1 without c0 or h0
template
<
typename
T
>
void
LSTMC1H1
(
lstm_t
*
step
,
lstm_attr_t
*
attr
)
{
void
LSTMC1H1
(
lstm_t
*
step
,
const
lstm_attr_t
*
attr
)
{
T
*
gates
=
reinterpret_cast
<
T
*>
(
step
->
gates
);
const
T
*
ct_1
=
reinterpret_cast
<
const
T
*>
(
step
->
ct_1
);
T
*
ct
=
reinterpret_cast
<
T
*>
(
step
->
ct
);
T
*
ht
=
reinterpret_cast
<
T
*>
(
step
->
ht
);
auto
act_gate
=
getActFunc
<
T
>
(
attr
->
act_gate
);
...
...
@@ -158,10 +173,16 @@ void LSTMC1H1(lstm_t* step, lstm_attr_t* attr) {
act_gate
(
gates
+
d
,
gates
+
d
,
d
);
act_cand
(
gates
,
gates
,
d
);
VMul
(
gates
,
gates
+
d
,
ct
,
d
);
if
(
attr
->
use_peephole
)
{
// get outgated, put W_oc * C_t on igated
const
T
*
wp
=
reinterpret_cast
<
const
T
*>
(
step
->
wp
);
VMul
(
wp
+
d2
,
ct
,
gates
+
d
,
d
);
VAdd
(
gates
+
d
,
gates
+
d3
,
gates
+
d3
,
d
);
}
/* H_t = act_cell(C_t) * ogated */
act_gate
(
gates
+
d3
,
gates
+
d3
,
d
);
act_cell
(
ct
,
gates
+
d2
,
d
);
V
m
ul
(
gates
+
d2
,
gates
+
d3
,
ht
,
d
);
V
M
ul
(
gates
+
d2
,
gates
+
d3
,
ht
,
d
);
}
}
// namespace refer
...
...
paddle/fluid/operators/math/jit_kernel_rnn.cc
浏览文件 @
f9138608
...
...
@@ -15,9 +15,14 @@ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef __AVX__
#include <immintrin.h>
#endif
...
...
@@ -154,211 +159,136 @@ static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
#endif
/* LSTM JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
>
class
LSTMKernelImpl
:
public
LSTMKernel
<
T
>
{
public:
explicit
LSTMKernelImpl
(
const
std
::
string
&
act_gate
,
const
std
::
string
&
act_cand
,
const
std
::
string
&
act_cell
,
int
d
)
:
LSTMKernel
<
T
>
()
{
d_
=
d
;
d2_
=
d
*
2
;
d3_
=
d
*
3
;
act_gate_d3_
=
GetActKernel
<
T
>
(
act_gate
,
d3_
);
act_gate_d_
=
GetActKernel
<
T
>
(
act_gate
,
d
);
act_cand_d_
=
GetActKernel
<
T
>
(
act_cand
,
d
);
act_cell_d_
=
GetActKernel
<
T
>
(
act_cell
,
d
);
vmul_d_
=
KernelPool
::
Instance
().
template
Get
<
VMulKernel
<
T
>
>
(
d
);
vadd_d_
=
KernelPool
::
Instance
().
template
Get
<
VAddKernel
<
T
>
>
(
d
);
static
inline
std
::
string
name
(
const
lstm_attr_t
&
attr
)
{
PADDLE_THROW
(
"DType should be either float or double"
);
}
static
inline
bool
useJIT
(
int
d
)
{
return
false
;
}
static
inline
bool
useMKL
(
int
d
)
{
return
false
;
}
explicit
LSTMKernelImpl
(
const
lstm_attr_t
&
attr
)
:
LSTMKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
attr
.
d
))
{
size_t
sz
=
96
+
attr
.
d
/
YMM_FLOAT_BLOCK
*
84
*
8
;
// should change
jitcode0_
.
reset
(
new
gen
::
LSTMJitCode
(
false
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeCtHt
=
jitcode0_
->
getCode
<
void
(
*
)(
lstm_t
*
,
const
lstm_attr_t
*
)
>
();
jitcode1_
.
reset
(
new
gen
::
LSTMJitCode
(
true
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeC1H1
=
jitcode1_
->
getCode
<
void
(
*
)(
lstm_t
*
,
const
lstm_attr_t
*
)
>
();
return
;
}
#endif
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
T
*
checked
)
const
override
{
// gates: W_ch, W_ih, W_fh, W_oh
act_gate_d3_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d3_
);
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
,
d_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
this
->
ComputeCtHt
=
refer
::
LSTMCtHt
<
T
>
;
this
->
ComputeC1H1
=
refer
::
LSTMC1H1
<
T
>
;
}
#ifdef PADDLE_WITH_XBYAK
private:
int
d_
,
d2_
,
d3_
;
std
::
shared_ptr
<
const
VActKernel
<
T
>>
act_gate_d3_
,
act_gate_d_
,
act_cand_d_
,
act_cell_d_
;
std
::
shared_ptr
<
const
VMulKernel
<
T
>>
vmul_d_
;
std
::
shared_ptr
<
const
VAddKernel
<
T
>>
vadd_d_
;
#ifdef __AVX__
std
::
unique_ptr
<
const
AVXAct
>
avx_act_gate_
,
avx_act_cand_
,
avx_act_cell_
;
std
::
unique_ptr
<
gen
::
LSTMJitCode
>
jitcode0_
{
nullptr
},
jitcode1_
{
nullptr
};
#endif
};
#define INTRI8_FLOAT(isa) \
template <> \
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d) \
: LSTMKernel<float>() { \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
avx_act_cand_ = GetAVXAct<isa>(act_cand); \
avx_act_cell_ = GetAVXAct<isa>(act_cell); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht, \
const float* wp_data, float* checked) const { \
/* gates: W_ch, W_ih, W_fh, W_oh */
\
__m256 c, i, f, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
f = _mm256_loadu_ps(gates + 16); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/
\
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
i = _mm256_loadu_ps(ct_1); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
f = _mm256_add_ps(c, f); \
_mm256_storeu_ps(ct, f); \
/* H_t = act_cell(C_t) * ogated */
\
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/
\
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */
\
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
}
// TODO(TJ): optimize keq16
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
#endif
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
);
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
LSTMKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
false
;
// not ready yet gen::LSTMJitCode::init(d);
}
#endif
/* Peephole JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
>
class
PeepholeKernelImpl
:
public
LSTMKernel
<
T
>
{
public:
explicit
PeepholeKernelImpl
(
const
std
::
string
&
act_gate
,
const
std
::
string
&
act_cand
,
const
std
::
string
&
act_cell
,
int
d
)
:
LSTMKernel
<
T
>
()
{
d_
=
d
;
d2_
=
d
*
2
;
d3_
=
d
*
3
;
act_gate_d_
=
GetActKernel
<
T
>
(
act_gate
,
d
);
act_cand_d_
=
GetActKernel
<
T
>
(
act_cand
,
d
);
act_cell_d_
=
GetActKernel
<
T
>
(
act_cell
,
d
);
vmul_d_
=
KernelPool
::
Instance
().
template
Get
<
VMulKernel
<
T
>
>
(
d
);
vadd_d_
=
KernelPool
::
Instance
().
template
Get
<
VAddKernel
<
T
>
>
(
d
);
vadd_d2_
=
KernelPool
::
Instance
().
template
Get
<
VAddKernel
<
T
>
>
(
d2_
);
act_gate_d2_
=
GetActKernel
<
T
>
(
act_gate
,
d2_
);
static
inline
std
::
string
name
(
const
lstm_attr_t
&
attr
)
{
PADDLE_THROW
(
"DType should be either float or double"
);
}
static
inline
bool
useJIT
(
int
d
)
{
return
false
;
}
static
inline
bool
useMKL
(
int
d
)
{
return
false
;
}
explicit
PeepholeKernelImpl
(
const
lstm_attr_t
&
attr
)
:
LSTMKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
attr
.
d
))
{
size_t
sz
=
96
+
attr
.
d
/
YMM_FLOAT_BLOCK
*
84
*
8
;
// should change
jitcode0_
.
reset
(
new
gen
::
LSTMJitCode
(
false
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeCtHt
=
jitcode0_
->
getCode
<
void
(
*
)(
lstm_t
*
,
const
lstm_attr_t
*
)
>
();
jitcode1_
.
reset
(
new
gen
::
LSTMJitCode
(
true
,
attr
,
sz
>
4096
?
sz
:
4096
));
this
->
ComputeC1H1
=
jitcode1_
->
getCode
<
void
(
*
)(
lstm_t
*
,
const
lstm_attr_t
*
)
>
();
return
;
}
#endif
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
T
*
checked
)
const
override
{
/* get fgated and igated*/
vmul_d_
->
Compute
(
wp_data
,
ct_1
,
checked
,
d_
);
vmul_d_
->
Compute
(
wp_data
+
d_
,
ct_1
,
checked
+
d_
,
d_
);
vadd_d2_
->
Compute
(
checked
,
gates
+
d_
,
gates
+
d_
,
d2_
);
act_gate_d2_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d2_
);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
/* get ogated*/
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
,
d_
);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
this
->
ComputeCtHt
=
refer
::
LSTMCtHt
<
T
>
;
this
->
ComputeC1H1
=
refer
::
LSTMC1H1
<
T
>
;
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
/* get outgated, put W_oc * C_t on igated */
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
,
d_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
#ifdef PADDLE_WITH_XBYAK
private:
int
d_
,
d2_
,
d3_
;
std
::
shared_ptr
<
const
VActKernel
<
T
>>
act_gate_d2_
,
act_gate_d_
,
act_cand_d_
,
act_cell_d_
;
std
::
shared_ptr
<
const
VMulKernel
<
T
>>
vmul_d_
;
std
::
shared_ptr
<
const
VAddKernel
<
T
>>
vadd_d_
,
vadd_d2_
;
std
::
unique_ptr
<
gen
::
LSTMJitCode
>
jitcode0_
{
nullptr
},
jitcode1_
{
nullptr
};
#endif
};
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const LSTMKernel<ker_dtype>> \
KernelPool::Get<LSTMKernel<ker_dtype>, const std::string&, \
const std::string&, const std::string&, int, bool>( \
const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d, bool use_peephole)
#define JITKERNEL_KEY_LSTM(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d) + act_gate + act_cand + act_cell + \
(use_peephole ? "p" : "n")
#define JITKERNEL_NEW_LSTM_IMPL(ker, dtype, isa, k) \
if (use_peephole) { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<PeepholeKernelImpl<dtype, isa, k>>( \
act_gate, act_cand, act_cell, d)); \
} else { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_cand, \
act_cell, d)); \
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
PeepholeKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
false
;
// peephole jitcode not ready yet
}
#endif
#define JITKERNEL_DEFINE_NAME_LSTM(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(const lstm_attr_t& attr) { \
std::string key(#ker_key "f"); \
key += (attr.act_gate + attr.act_cand + attr.act_cell + \
(attr.use_peephole ? "p" : "n")); \
if (useJIT(attr.d)) { \
/* only jit code need record d*/
\
return key + "jit" + std::to_string(attr.d); \
} else if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
} \
template <> \
std::string ker_class##Impl<double>::name(const lstm_attr_t& attr) { \
std::string key(#ker_key "d"); \
/* jit code do not support double yet*/
\
if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
}
REGISTER_JITKERNEL_ARGS_DEPRECATED
(
lstm
,
LSTMKernel
,
JITKERNEL_DECLARE_LSTM
,
JITKERNEL_KEY_LSTM
,
JITKERNEL_NEW_LSTM_IMPL
);
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const LSTMKernel<ker_dtype>> \
KernelPool::Get<LSTMKernel<ker_dtype>, const lstm_attr_t&>( \
const lstm_attr_t& attr)
#define JITKERNEL_FIND_KEY_LSTM(ker_class, ker_dtype) \
std::string key = ker_class##Impl<ker_dtype>::name(attr)
#define JITKERNEL_LSTM_IMPL(ker, dtype) \
if (attr.use_peephole) { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<PeepholeKernelImpl<dtype>>(attr)); \
} else { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype>>(attr)); \
}
#undef INTRI8_FLOAT
#undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_KEY_LSTM
#undef JITKERNEL_NEW_LSTM_IMPL
REGISTER_JITKERNEL_ARGS
(
lstm
,
LSTMKernel
,
JITKERNEL_DEFINE_NAME_LSTM
,
JITKERNEL_DECLARE_LSTM
,
JITKERNEL_FIND_KEY_LSTM
,
JITKERNEL_LSTM_IMPL
);
/* GRU JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
f9138608
...
...
@@ -341,11 +341,11 @@ TEST(JitKernel, lstm) {
RandomVec
<
float
>
(
d
,
ct_1
.
data
(),
-
2.
f
,
2.
f
);
memcpy
(
xref
.
data
(),
x
.
data
(),
sizeof
(
float
)
*
d4
);
std
::
string
act_gate
=
"sigmoid"
,
act_cand
=
"tanh"
,
act_cell
=
"tanh"
;
const
jit
::
lstm_attr_t
attr
(
d
,
act_gate
,
act_cand
,
act_cell
,
false
);
const
auto
&
ker
=
jit
::
KernelPool
::
Instance
()
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&>
(
act_gate
,
act_cand
,
act_cell
,
d
,
false
);
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
const
jit
::
lstm_attr_t
&>
(
attr
);
// below kernels are used to compute refer
const
auto
&
vsigmoid_3d
=
jit
::
KernelPool
::
Instance
().
template
Get
<
jit
::
VSigmoidKernel
<
float
>
>
(
...
...
@@ -366,14 +366,16 @@ TEST(JitKernel, lstm) {
float
*
ht_ref_data
=
ht_ref
.
data
();
// compute once to check correctness
jit
::
lstm_t
step
;
jit
::
lstm_attr_t
attr
(
d
,
act_gate
,
act_cand
,
act_cell
);
step
.
gates
=
xref_data
;
step
.
ct_1
=
ct_1_data
;
step
.
ct
=
ct_ref_data
;
step
.
ht
=
ht_ref_data
;
refer
::
LSTMCtHt
<
float
>
(
&
step
,
&
attr
);
ker
->
ComputeCtHt
(
x_data
,
ct_1_data
,
ct_tgt_data
,
ht_tgt_data
);
step
.
gates
=
x_data
;
step
.
ct
=
ct_tgt_data
;
step
.
ht
=
ht_tgt_data
;
ker
->
ComputeCtHt
(
&
step
,
&
attr
);
for
(
int
i
=
0
;
i
<
d
;
++
i
)
{
EXPECT_NEAR
(
ct_tgt_data
[
i
],
ct_ref_data
[
i
],
1e-3
);
EXPECT_NEAR
(
ht_tgt_data
[
i
],
ht_ref_data
[
i
],
1e-3
);
...
...
@@ -392,7 +394,7 @@ TEST(JitKernel, lstm) {
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
ComputeCtHt
(
x_data
,
ct_1_data
,
ct_tgt_data
,
ht_tgt_data
);
ker
->
ComputeCtHt
(
&
step
,
&
attr
);
}
auto
ttgte
=
GetCurrentUS
();
VLOG
(
30
)
<<
"Vec size "
<<
d
...
...
@@ -710,21 +712,21 @@ TEST(JitKernel, pool) {
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
const
int
frame_size
=
4
;
std
::
string
act_gate
=
"sigmoid"
,
act_cand
=
"tanh"
,
act_cell
=
"tanh"
;
jit
::
lstm_attr_t
attr
(
frame_size
,
act_gate
,
act_cand
,
act_cell
,
false
);
const
auto
&
plstm1
=
jit
::
KernelPool
::
Instance
()
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&>
(
act_gate
,
act_cand
,
act_cell
,
frame_size
,
false
);
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
const
jit
::
lstm_attr_t
&>
(
attr
);
const
auto
&
plstm2
=
jit
::
KernelPool
::
Instance
()
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&>
(
act_gate
,
act_cand
,
act_cell
,
frame_size
,
false
);
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
const
jit
::
lstm_attr_t
&>
(
attr
);
EXPECT_EQ
(
plstm1
,
plstm2
);
const
auto
&
peephole
=
jit
::
KernelPool
::
Instance
()
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&>
(
act_gate
,
act_cand
,
act_cell
,
frame_size
,
true
);
.
template
Get
<
jit
::
LSTMKernel
<
float
>,
const
jit
::
lstm_attr_t
&>
(
jit
::
lstm_attr_t
(
frame_size
,
act_gate
,
act_cand
,
act_cell
,
true
));
EXPECT_TRUE
(
plstm1
!=
peephole
);
const
auto
&
pvmul_f
=
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录