Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
8e182170
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8e182170
编写于
10月 12, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine and replace lstm peephole kernel
上级
7ef2699e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
181 addition
and
297 deletion
+181
-297
paddle/fluid/operators/fusion_lstm_op.cc
paddle/fluid/operators/fusion_lstm_op.cc
+102
-245
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+7
-0
paddle/fluid/operators/math/jit_kernel_lstm.cc
paddle/fluid/operators/math/jit_kernel_lstm.cc
+72
-52
未找到文件。
paddle/fluid/operators/fusion_lstm_op.cc
浏览文件 @
8e182170
...
@@ -15,11 +15,9 @@ limitations under the License. */
...
@@ -15,11 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fusion_lstm_op.h"
#include "paddle/fluid/operators/fusion_lstm_op.h"
#include <string>
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -219,116 +217,55 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
...
@@ -219,116 +217,55 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
template
<
typename
T
>
template
<
typename
T
>
class
FuisonLSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
class
FuisonLSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
#define INIT_VEC_FUNC \
#define INIT_BASE_DEFINES \
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; \
using DeviceContext = paddle::platform::CPUDeviceContext; \
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
auto* x = ctx.Input<LoDTensor>("X"); \
auto& act_cell_str = ctx.Attr<std::string>("cell_activation"); \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation"); \
auto* c0 = ctx.Input<Tensor>("C0"); \
if (platform::jit::MayIUse(platform::jit::avx)) { \
auto* wx = ctx.Input<Tensor>("WeightX"); \
math::VecActivations<T, platform::jit::avx> act_functor; \
auto* wh = ctx.Input<Tensor>("WeightH"); \
act_gate = act_functor(act_gate_str); \
auto* bias = ctx.Input<Tensor>("Bias"); \
act_cell = act_functor(act_cell_str); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
act_cand = act_functor(act_cand_str); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
} else { \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
math::VecActivations<T, platform::jit::isa_any> act_functor; \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
act_gate = act_functor(act_gate_str); \
bool use_peepholes = ctx.Attr<bool>("use_peepholes"); \
act_cell = act_functor(act_cell_str); \
auto x_dims = x->dims();
/* T x M*/
\
act_cand = act_functor(act_cand_str); \
auto wh_dims = wh->dims();
/* D x 4D*/
\
}
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
#define INIT_BASE_INPUT_OUTPUT \
const int D4 = wh_dims[1]
auto* x = ctx.Input<LoDTensor>("X"); \
auto* h0 = ctx.Input<Tensor>("H0"); \
#define INIT_OTHER_DEFINES \
auto* c0 = ctx.Input<Tensor>("C0"); \
const T* x_data = x->data<T>(); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
const T* wx_data = wx->data<T>(); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
const T* wh_data = wh->data<T>(); \
auto* bias = ctx.Input<Tensor>("Bias"); \
/* diagonal weight*/
\
auto* xx = ctx.Output<LoDTensor>("XX"); \
const T* wp_data = bias->data<T>() + D4; \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
/* for peephole only*/
\
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
T* checked_cell_data = nullptr; \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
auto place = ctx.GetPlace(); \
bool use_peepholes = ctx.Attr<bool>("use_peepholes");
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/
\
#define INIT_BASE_SIZES \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
auto x_dims = x->dims();
/* T x M*/
\
checked_cell_data = checked_cell->mutable_data<T>(place); \
auto wh_dims = wh->dims();
/* D x 4D*/
\
} \
const int M = x_dims[1]; \
const auto& ker = \
const int D = wh_dims[0]; \
math::jitkernel::KernelPool::Instance() \
const int D2 = D * 2; \
.template Get<math::jitkernel::LSTMKernel<T>, const std::string&, \
const int D3 = D * 3; \
const std::string&, const std::string&>( \
const int D4 = wh_dims[1];
ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("candidate_activation"), \
#define INIT_BASE_INPUT_DATAS \
ctx.Attr<std::string>("cell_activation"), D, use_peepholes)
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
// Wh GEMM
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/
\
const T* wc_data = bias->data<T>() + D4; \
/* for peephole only*/
\
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/
\
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
}
/// Compute LSTM
#define GEMM_WH_ADDON(bs, prev, out) \
#define GEMM_WH_ADDON(bs, prev, out) \
blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast<T>(1), prev, D, \
blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast<T>(1), prev, D, \
wh_data, D4, static_cast<T>(1), out, D4)
wh_data, D4, static_cast<T>(1), out, D4)
#define GET_Ct(ct_1, gates, ct) \
/* C_t = C_t-1 * fgated + cand_gated * igated*/
\
act_cand(D, gates, gates); \
blas.VMUL(D, gates, gates + D, gates + D); \
blas.VMUL(D, ct_1, gates + D2, gates + D2); \
blas.VADD(D, gates + D, gates + D2, ct)
#define GET_Ht(ct, gates, ht) \
/* H_t = act_cell(C_t) * ogated */
\
act_cell(D, ct, gates + D2); \
blas.VMUL(D, gates + D2, gates + D3, ht)
#define GET_Ct_NOH0C0(gates, ct) \
/* C_t = igated * cgated*/
\
act_gate(D, gates + D, gates + D); \
act_cand(D, gates, gates); \
blas.VMUL(D, gates, gates + D, ct)
#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
/* get outgated, put W_oc * C_t on igated */
\
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht) \
/* get fgated and igated*/
\
blas.VMUL(D, wc_data, ct_1, checked_cell_data); \
blas.VMUL(D, wc_data + D, ct_1, checked_cell_data + D); \
blas.VADD(D2, checked_cell_data, gates + D, gates + D); \
act_gate(D2, gates + D, gates + D); \
GET_Ct(ct_1, gates, ct); \
/* get ogated*/
\
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
void
SeqCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
SeqCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
INIT_BASE_DEFINES
;
INIT_BASE_INPUT_OUTPUT
INIT_OTHER_DEFINES
;
INIT_BASE_SIZES
INIT_VEC_FUNC
INIT_BASE_INPUT_DATAS
auto
x_lod
=
x
->
lod
();
auto
x_lod
=
x
->
lod
();
const
int
total_T
=
x_dims
[
0
];
const
int
total_T
=
x_dims
[
0
];
const
int
N
=
x_lod
[
0
].
size
()
-
1
;
const
int
N
=
x_lod
[
0
].
size
()
-
1
;
...
@@ -352,84 +289,47 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -352,84 +289,47 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
gate_offset
=
-
D
;
gate_offset
=
-
D
;
}
}
#define MOVE_ONE_STEP \
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
prev_h_data = h_out_data; \
int
bid
=
is_reverse
?
N
-
1
-
i
:
i
;
prev_c_data = c_out_data; \
int
seq_len
=
x_lod
[
0
][
bid
+
1
]
-
x_lod
[
0
][
bid
];
xx_data = xx_data + xx_offset; \
const
T
*
prev_c_data
=
nullptr
;
h_out_data = h_out_data + gate_offset; \
const
T
*
prev_h_data
=
nullptr
;
c_out_data = c_out_data + gate_offset
int
tstart
=
0
;
if
(
h0_data
)
{
#define PROCESS_H0C0_DEFINES \
prev_h_data
=
h0_data
+
bid
*
D
;
int bid = is_reverse ? N - 1 - i : i; \
prev_c_data
=
c0_data
+
bid
*
D
;
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \
}
else
{
const T* prev_c_data = nullptr; \
ker
->
ComputeC1H1
(
xx_data
,
c_out_data
,
h_out_data
,
wp_data
);
const T* prev_h_data = nullptr; \
tstart
=
1
;
int tstart = 0
// move one step
prev_h_data
=
h_out_data
;
#define PROCESS_H0C0_PEEPHOLE \
prev_c_data
=
c_out_data
;
PROCESS_H0C0_DEFINES; \
xx_data
=
xx_data
+
xx_offset
;
if (h0_data) { \
h_out_data
=
h_out_data
+
gate_offset
;
prev_h_data = h0_data + bid * D; \
c_out_data
=
c_out_data
+
gate_offset
;
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
}
#define PROCESS_H0C0 \
PROCESS_H0C0_DEFINES; \
if (h0_data) { \
prev_h_data = h0_data + bid * D; \
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
}
if
(
use_peepholes
)
{
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
PROCESS_H0C0_PEEPHOLE
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
COMPUTE_CtHt_PEEPHOLE
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
);
MOVE_ONE_STEP
;
}
}
}
}
else
{
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
const
auto
&
ker
=
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
math
::
jitkernel
::
KernelPool
::
Instance
()
ker
->
ComputeCtHt
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
,
wp_data
,
.
template
Get
<
math
::
jitkernel
::
LSTMKernel
<
T
>,
const
std
::
string
&
,
checked_cell_data
);
const
std
::
string
&
,
const
std
::
string
&>
(
// move one step
act_gate_str
,
act_cand_str
,
act_cell_str
,
D
,
false
);
prev_h_data
=
h_out_data
;
prev_c_data
=
c_out_data
;
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
xx_data
=
xx_data
+
xx_offset
;
PROCESS_H0C0
h_out_data
=
h_out_data
+
gate_offset
;
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
c_out_data
=
c_out_data
+
gate_offset
;
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
ker
->
ComputeCtHt
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
);
MOVE_ONE_STEP
;
}
}
}
}
}
#undef PROCESS_H0C0_DEFINES
#undef PROCESS_H0C0_PEEPHOLE
#undef PROCESS_H0C0
#undef MOVE_ONE_STEP
}
}
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
platform
::
CPUDeviceContext
;
INIT_BASE_DEFINES
;
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
if
(
x
->
lod
()[
0
].
size
()
==
2
)
{
if
(
x
->
lod
()[
0
].
size
()
==
2
)
{
xx
->
Resize
({
x_dims
[
0
],
D4
});
xx
->
Resize
({
x_dims
[
0
],
D4
});
SeqCompute
(
ctx
);
SeqCompute
(
ctx
);
return
;
return
;
}
}
INIT_VEC_FUNC
INIT_OTHER_DEFINES
;
INIT_BASE_INPUT_DATAS
auto
*
reordered_h0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedH0"
);
auto
*
reordered_h0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedH0"
);
auto
*
reordered_c0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedC0"
);
auto
*
reordered_c0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedC0"
);
...
@@ -477,8 +377,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -477,8 +377,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_c_data
=
reordered_c0_data
;
prev_c_data
=
reordered_c0_data
;
size_t
sz
=
sizeof
(
T
)
*
D
;
size_t
sz
=
sizeof
(
T
)
*
D
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
std
::
memcpy
(
reordered_h0_data
,
h0_data
+
seq_order
[
i
]
*
D
,
sz
);
blas
.
VCOPY
(
sz
,
h0_data
+
seq_order
[
i
]
*
D
,
reordered_h0_data
);
std
::
memcpy
(
reordered_c0_data
,
c0_data
+
seq_order
[
i
]
*
D
,
sz
);
blas
.
VCOPY
(
sz
,
c0_data
+
seq_order
[
i
]
*
D
,
reordered_c0_data
);
reordered_h0_data
+=
D
;
reordered_h0_data
+=
D
;
reordered_c0_data
+=
D
;
reordered_c0_data
+=
D
;
}
}
...
@@ -488,13 +388,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -488,13 +388,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
GET_Ct_NOH0C0
(
cur_in_data
,
cur_c_out_data
);
ker
->
ComputeC1H1
(
cur_in_data
,
cur_c_out_data
,
cur_h_out_data
,
wp_data
);
if
(
use_peepholes
)
{
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cur_c_out_data
,
cur_in_data
+
D
);
blas
.
VADD
(
D
,
cur_in_data
+
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
}
act_gate
(
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
GET_Ht
(
cur_c_out_data
,
cur_in_data
,
cur_h_out_data
);
cur_in_data
+=
D4
;
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
cur_h_out_data
+=
D
;
...
@@ -503,66 +397,37 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -503,66 +397,37 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_h_data
=
batched_h_out_data
;
prev_h_data
=
batched_h_out_data
;
prev_c_data
=
batched_c_out_data
;
prev_c_data
=
batched_c_out_data
;
}
}
// compute kernel part
const
auto
&
batch_starts
=
batched_lod
[
0
];
const
auto
&
batch_starts
=
batched_lod
[
0
];
const
int
max_seq_len
=
batch_starts
.
size
()
-
1
;
const
int
max_seq_len
=
batch_starts
.
size
()
-
1
;
const
int
offset
=
tstart
*
max_bs
*
D
;
const
int
offset
=
tstart
*
max_bs
*
D
;
batched_input_data
=
batched_input_data
+
offset
*
4
;
batched_input_data
=
batched_input_data
+
offset
*
4
;
batched_h_out_data
=
batched_h_out_data
+
offset
;
batched_h_out_data
=
batched_h_out_data
+
offset
;
batched_c_out_data
=
batched_c_out_data
+
offset
;
batched_c_out_data
=
batched_c_out_data
+
offset
;
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
#define DEFINE_CUR \
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
T* cur_in_data = batched_input_data; \
GEMM_WH_ADDON
(
cur_bs
,
prev_h_data
,
batched_input_data
);
T* cur_prev_c_data = prev_c_data; \
T
*
cur_in_data
=
batched_input_data
;
T* cur_c_out_data = batched_c_out_data; \
T
*
cur_prev_c_data
=
prev_c_data
;
T* cur_h_out_data = batched_h_out_data
T
*
cur_c_out_data
=
batched_c_out_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
#define MOVE_ONE_BATCH \
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
cur_in_data += D4; \
ker
->
ComputeCtHt
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
cur_prev_c_data += D; \
cur_h_out_data
,
wp_data
,
checked_cell_data
);
cur_c_out_data += D; \
// move one batch
cur_h_out_data += D
cur_in_data
+=
D4
;
cur_prev_c_data
+=
D
;
#define MOVE_ONE_STEP \
cur_c_out_data
+=
D
;
prev_c_data = batched_c_out_data; \
cur_h_out_data
+=
D
;
prev_h_data = batched_h_out_data; \
batched_c_out_data = cur_c_out_data; \
batched_h_out_data = cur_h_out_data; \
batched_input_data = cur_in_data
if
(
use_peepholes
)
{
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
GEMM_WH_ADDON
(
cur_bs
,
prev_h_data
,
batched_input_data
);
DEFINE_CUR
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
COMPUTE_CtHt_PEEPHOLE
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
cur_h_out_data
);
MOVE_ONE_BATCH
;
}
MOVE_ONE_STEP
;
}
}
else
{
const
auto
&
ker
=
math
::
jitkernel
::
KernelPool
::
Instance
()
.
template
Get
<
math
::
jitkernel
::
LSTMKernel
<
T
>,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&>
(
act_gate_str
,
act_cand_str
,
act_cell_str
,
D
,
false
);
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
GEMM_WH_ADDON
(
cur_bs
,
prev_h_data
,
batched_input_data
);
DEFINE_CUR
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
ker
->
ComputeCtHt
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
cur_h_out_data
);
MOVE_ONE_BATCH
;
}
MOVE_ONE_STEP
;
}
}
// move one step
prev_c_data
=
batched_c_out_data
;
prev_h_data
=
batched_h_out_data
;
batched_c_out_data
=
cur_c_out_data
;
batched_h_out_data
=
cur_h_out_data
;
batched_input_data
=
cur_in_data
;
}
}
#undef MOVE_ONE_STEP
#undef MOVE_ONE_BATCH
#undef DEFINE_CUR
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batched_h_out
->
set_lod
(
batched_lod
);
batched_h_out
->
set_lod
(
batched_lod
);
...
@@ -579,17 +444,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -579,17 +444,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
}
}
}
}
#undef COMPUTE_CtHt_PEEPHOLE
#undef GET_Ct_NOH0C0
#undef COMPUTE_CtHt_NOH0C0
#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
#undef GET_Ht
#undef GET_Ct
#undef GEMM_WH_ADDON
#undef GEMM_WH_ADDON
#undef INIT_BASE_INPUT_DATAS
#undef INIT_OTHER_DEFINES
#undef INIT_BASE_SIZES
#undef INIT_BASE_DEFINES
#undef INIT_BASE_INPUT_OUTPUT
#undef INIT_VEC_FUNC
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
8e182170
...
@@ -126,7 +126,14 @@ template <typename T>
...
@@ -126,7 +126,14 @@ template <typename T>
class
LSTMKernel
:
public
Kernel
{
class
LSTMKernel
:
public
Kernel
{
public:
public:
virtual
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
virtual
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
/* below only used in peephole*/
const
T
*
wp_data
=
nullptr
,
T
*
checked
=
nullptr
)
const
=
0
;
T
*
checked
=
nullptr
)
const
=
0
;
// compute c1 and h1 without c0 or h0
virtual
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
/* below only used in peephole*/
const
T
*
wp_data
=
nullptr
)
const
=
0
;
};
};
}
// namespace jitkernel
}
// namespace jitkernel
...
...
paddle/fluid/operators/math/jit_kernel_lstm.cc
浏览文件 @
8e182170
...
@@ -82,6 +82,26 @@ __m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
...
@@ -82,6 +82,26 @@ __m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
}
}
#endif
#endif
template
<
typename
T
>
static
std
::
shared_ptr
<
const
VActKernel
<
T
>>
GetActKernel
(
const
std
::
string
&
type
,
int
n
)
{
if
(
type
==
"sigmoid"
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VSigmoidKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"relu"
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VReluKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"tanh"
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VTanhKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VIdentityKernel
<
T
>
>
(
n
));
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
return
nullptr
;
}
/* LSTM JitKernel */
/* LSTM JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
LSTMKernelImpl
:
public
LSTMKernel
<
T
>
{
class
LSTMKernelImpl
:
public
LSTMKernel
<
T
>
{
...
@@ -93,26 +113,10 @@ class LSTMKernelImpl : public LSTMKernel<T> {
...
@@ -93,26 +113,10 @@ class LSTMKernelImpl : public LSTMKernel<T> {
d_
=
d
;
d_
=
d
;
d2_
=
d
*
2
;
d2_
=
d
*
2
;
d3_
=
d
*
3
;
d3_
=
d
*
3
;
auto
GetActKernel
=
[
&
](
const
std
::
string
&
type
,
act_gate_d3_
=
GetActKernel
<
T
>
(
act_gate
,
d3_
);
int
n
)
->
std
::
shared_ptr
<
const
VActKernel
<
T
>>
{
act_gate_d_
=
GetActKernel
<
T
>
(
act_gate
,
d
);
if
(
type
==
"sigmoid"
)
{
act_cand_d_
=
GetActKernel
<
T
>
(
act_cand
,
d
);
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
act_cell_d_
=
GetActKernel
<
T
>
(
act_cell
,
d
);
KernelPool
::
Instance
().
template
Get
<
VSigmoidKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"relu"
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VReluKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"tanh"
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VTanhKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VIdentityKernel
<
T
>
>
(
n
));
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
};
act_gate_3d_
=
GetActKernel
(
act_gate
,
d
*
3
);
act_cand_d_
=
GetActKernel
(
act_cand
,
d
);
act_cell_d_
=
GetActKernel
(
act_cell
,
d
);
vmul_d_
=
KernelPool
::
Instance
().
template
Get
<
VMulKernel
<
T
>
>
(
d
);
vmul_d_
=
KernelPool
::
Instance
().
template
Get
<
VMulKernel
<
T
>
>
(
d
);
vadd_d_
=
KernelPool
::
Instance
().
template
Get
<
VAddKernel
<
T
>
>
(
d
);
vadd_d_
=
KernelPool
::
Instance
().
template
Get
<
VAddKernel
<
T
>
>
(
d
);
#ifdef __AVX__
#ifdef __AVX__
...
@@ -134,10 +138,10 @@ class LSTMKernelImpl : public LSTMKernel<T> {
...
@@ -134,10 +138,10 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#endif
#endif
}
}
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
T
*
checked
)
const
override
{
T
*
checked
)
const
override
{
// gates: W_ch, W_ih, W_fh, W_oh
// gates: W_ch, W_ih, W_fh, W_oh
act_gate_
3d
_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
act_gate_
d3
_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
/* C_t = C_t-1 * fgated + cand_gated * igated */
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_
->
Compute
(
gates
,
gates
);
act_cand_d_
->
Compute
(
gates
,
gates
);
...
@@ -149,10 +153,21 @@ class LSTMKernelImpl : public LSTMKernel<T> {
...
@@ -149,10 +153,21 @@ class LSTMKernelImpl : public LSTMKernel<T> {
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
}
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
}
private:
private:
int
d_
,
d2_
,
d3_
;
int
d_
,
d2_
,
d3_
;
std
::
shared_ptr
<
const
VActKernel
<
T
>>
act_gate_3d_
,
act_cand_d_
,
act_cell_d_
;
std
::
shared_ptr
<
const
VActKernel
<
T
>>
act_gate_d3_
,
act_gate_d_
,
act_cand_d_
,
act_cell_d_
;
std
::
shared_ptr
<
const
VMulKernel
<
T
>>
vmul_d_
;
std
::
shared_ptr
<
const
VMulKernel
<
T
>>
vmul_d_
;
std
::
shared_ptr
<
const
VAddKernel
<
T
>>
vadd_d_
;
std
::
shared_ptr
<
const
VAddKernel
<
T
>>
vadd_d_
;
#ifdef __AVX__
#ifdef __AVX__
...
@@ -163,8 +178,8 @@ class LSTMKernelImpl : public LSTMKernel<T> {
...
@@ -163,8 +178,8 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#define INTRI8_FLOAT(isa) \
#define INTRI8_FLOAT(isa) \
template <> \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht,
float* checked)
\
float* gates, const float* ct_1, float* ct, float* ht,
\
const
{
\
const
float* wp_data, float* checked) const {
\
/* gates: W_ch, W_ih, W_fh, W_oh */
\
/* gates: W_ch, W_ih, W_fh, W_oh */
\
__m256 c, i, f, o; \
__m256 c, i, f, o; \
c = _mm256_loadu_ps(gates); \
c = _mm256_loadu_ps(gates); \
...
@@ -205,51 +220,56 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
...
@@ -205,51 +220,56 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
d_
=
d
;
d_
=
d
;
d2_
=
d
*
2
;
d2_
=
d
*
2
;
d3_
=
d
*
3
;
d3_
=
d
*
3
;
auto
GetActKernel
=
[
&
](
const
std
::
string
&
type
,
act_gate_d_
=
GetActKernel
<
T
>
(
act_gate
,
d
);
int
n
)
->
std
::
shared_ptr
<
const
VActKernel
<
T
>>
{
act_cand_d_
=
GetActKernel
<
T
>
(
act_cand
,
d
);
if
(
type
==
"sigmoid"
)
{
act_cell_d_
=
GetActKernel
<
T
>
(
act_cell
,
d
);
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VSigmoidKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"relu"
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VReluKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"tanh"
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VTanhKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VIdentityKernel
<
T
>
>
(
n
));
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
};
act_gate_3d_
=
GetActKernel
(
act_gate
,
d
*
3
);
act_cand_d_
=
GetActKernel
(
act_cand
,
d
);
act_cell_d_
=
GetActKernel
(
act_cell
,
d
);
vmul_d_
=
KernelPool
::
Instance
().
template
Get
<
VMulKernel
<
T
>
>
(
d
);
vmul_d_
=
KernelPool
::
Instance
().
template
Get
<
VMulKernel
<
T
>
>
(
d
);
vadd_d_
=
KernelPool
::
Instance
().
template
Get
<
VAddKernel
<
T
>
>
(
d
);
vadd_d_
=
KernelPool
::
Instance
().
template
Get
<
VAddKernel
<
T
>
>
(
d
);
vadd_d2_
=
KernelPool
::
Instance
().
template
Get
<
VAddKernel
<
T
>
>
(
d2_
);
act_gate_d2_
=
GetActKernel
<
T
>
(
act_gate
,
d2_
);
}
}
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
T
*
checked
)
const
override
{
T
*
checked
)
const
override
{
// gates: W_ch, W_ih, W_fh, W_oh
/* get fgated and igated*/
act_gate_3d_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
vmul_d_
->
Compute
(
wp_data
,
ct_1
,
checked
);
vmul_d_
->
Compute
(
wp_data
+
d_
,
ct_1
,
checked
+
d_
);
/* C_t = C_t-1 * fgated + cand_gated * igated */
vadd_d2_
->
Compute
(
checked
,
gates
+
d_
,
gates
+
d_
);
act_gate_d2_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_
->
Compute
(
gates
,
gates
);
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
);
/* get ogated*/
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
);
/* get outgated, put W_oc * C_t on igated */
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
);
/* H_t = act_cell(C_t) * ogated */
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
}
}
private:
private:
int
d_
,
d2_
,
d3_
;
int
d_
,
d2_
,
d3_
;
std
::
shared_ptr
<
const
VActKernel
<
T
>>
act_gate_3d_
,
act_cand_d_
,
act_cell_d_
;
std
::
shared_ptr
<
const
VActKernel
<
T
>>
act_gate_d2_
,
act_gate_d_
,
act_cand_d_
,
act_cell_d_
;
std
::
shared_ptr
<
const
VMulKernel
<
T
>>
vmul_d_
;
std
::
shared_ptr
<
const
VMulKernel
<
T
>>
vmul_d_
;
std
::
shared_ptr
<
const
VAddKernel
<
T
>>
vadd_d_
;
std
::
shared_ptr
<
const
VAddKernel
<
T
>>
vadd_d_
,
vadd_d2_
;
};
};
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录