Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
8e182170
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
699
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8e182170
编写于
10月 12, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine and replace lstm peephole kernel
上级
7ef2699e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
181 addition
and
297 deletion
+181
-297
paddle/fluid/operators/fusion_lstm_op.cc
paddle/fluid/operators/fusion_lstm_op.cc
+102
-245
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+7
-0
paddle/fluid/operators/math/jit_kernel_lstm.cc
paddle/fluid/operators/math/jit_kernel_lstm.cc
+72
-52
未找到文件。
paddle/fluid/operators/fusion_lstm_op.cc
浏览文件 @
8e182170
...
@@ -15,11 +15,9 @@ limitations under the License. */
...
@@ -15,11 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fusion_lstm_op.h"
#include "paddle/fluid/operators/fusion_lstm_op.h"
#include <string>
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -219,116 +217,55 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
...
@@ -219,116 +217,55 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
template
<
typename
T
>
template
<
typename
T
>
class
FuisonLSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
class
FuisonLSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
#define INIT_VEC_FUNC \
#define INIT_BASE_DEFINES \
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; \
using DeviceContext = paddle::platform::CPUDeviceContext; \
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
auto* x = ctx.Input<LoDTensor>("X"); \
auto& act_cell_str = ctx.Attr<std::string>("cell_activation"); \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation"); \
auto* c0 = ctx.Input<Tensor>("C0"); \
if (platform::jit::MayIUse(platform::jit::avx)) { \
auto* wx = ctx.Input<Tensor>("WeightX"); \
math::VecActivations<T, platform::jit::avx> act_functor; \
auto* wh = ctx.Input<Tensor>("WeightH"); \
act_gate = act_functor(act_gate_str); \
auto* bias = ctx.Input<Tensor>("Bias"); \
act_cell = act_functor(act_cell_str); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
act_cand = act_functor(act_cand_str); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
} else { \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
math::VecActivations<T, platform::jit::isa_any> act_functor; \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
act_gate = act_functor(act_gate_str); \
bool use_peepholes = ctx.Attr<bool>("use_peepholes"); \
act_cell = act_functor(act_cell_str); \
auto x_dims = x->dims();
/* T x M*/
\
act_cand = act_functor(act_cand_str); \
auto wh_dims = wh->dims();
/* D x 4D*/
\
}
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
#define INIT_BASE_INPUT_OUTPUT \
const int D4 = wh_dims[1]
auto* x = ctx.Input<LoDTensor>("X"); \
auto* h0 = ctx.Input<Tensor>("H0"); \
#define INIT_OTHER_DEFINES \
auto* c0 = ctx.Input<Tensor>("C0"); \
const T* x_data = x->data<T>(); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
const T* wx_data = wx->data<T>(); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
const T* wh_data = wh->data<T>(); \
auto* bias = ctx.Input<Tensor>("Bias"); \
/* diagonal weight*/
\
auto* xx = ctx.Output<LoDTensor>("XX"); \
const T* wp_data = bias->data<T>() + D4; \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
/* for peephole only*/
\
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
T* checked_cell_data = nullptr; \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
auto place = ctx.GetPlace(); \
bool use_peepholes = ctx.Attr<bool>("use_peepholes");
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/
\
#define INIT_BASE_SIZES \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
auto x_dims = x->dims();
/* T x M*/
\
checked_cell_data = checked_cell->mutable_data<T>(place); \
auto wh_dims = wh->dims();
/* D x 4D*/
\
} \
const int M = x_dims[1]; \
const auto& ker = \
const int D = wh_dims[0]; \
math::jitkernel::KernelPool::Instance() \
const int D2 = D * 2; \
.template Get<math::jitkernel::LSTMKernel<T>, const std::string&, \
const int D3 = D * 3; \
const std::string&, const std::string&>( \
const int D4 = wh_dims[1];
ctx.Attr<std::string>("gate_activation"), \
ctx.Attr<std::string>("candidate_activation"), \
#define INIT_BASE_INPUT_DATAS \
ctx.Attr<std::string>("cell_activation"), D, use_peepholes)
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
// Wh GEMM
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/
\
const T* wc_data = bias->data<T>() + D4; \
/* for peephole only*/
\
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/
\
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \
}
/// Compute LSTM
#define GEMM_WH_ADDON(bs, prev, out) \
#define GEMM_WH_ADDON(bs, prev, out) \
blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast<T>(1), prev, D, \
blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast<T>(1), prev, D, \
wh_data, D4, static_cast<T>(1), out, D4)
wh_data, D4, static_cast<T>(1), out, D4)
#define GET_Ct(ct_1, gates, ct) \
/* C_t = C_t-1 * fgated + cand_gated * igated*/
\
act_cand(D, gates, gates); \
blas.VMUL(D, gates, gates + D, gates + D); \
blas.VMUL(D, ct_1, gates + D2, gates + D2); \
blas.VADD(D, gates + D, gates + D2, ct)
#define GET_Ht(ct, gates, ht) \
/* H_t = act_cell(C_t) * ogated */
\
act_cell(D, ct, gates + D2); \
blas.VMUL(D, gates + D2, gates + D3, ht)
#define GET_Ct_NOH0C0(gates, ct) \
/* C_t = igated * cgated*/
\
act_gate(D, gates + D, gates + D); \
act_cand(D, gates, gates); \
blas.VMUL(D, gates, gates + D, ct)
#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
/* get outgated, put W_oc * C_t on igated */
\
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht) \
/* get fgated and igated*/
\
blas.VMUL(D, wc_data, ct_1, checked_cell_data); \
blas.VMUL(D, wc_data + D, ct_1, checked_cell_data + D); \
blas.VADD(D2, checked_cell_data, gates + D, gates + D); \
act_gate(D2, gates + D, gates + D); \
GET_Ct(ct_1, gates, ct); \
/* get ogated*/
\
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
void
SeqCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
SeqCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
INIT_BASE_DEFINES
;
INIT_BASE_INPUT_OUTPUT
INIT_OTHER_DEFINES
;
INIT_BASE_SIZES
INIT_VEC_FUNC
INIT_BASE_INPUT_DATAS
auto
x_lod
=
x
->
lod
();
auto
x_lod
=
x
->
lod
();
const
int
total_T
=
x_dims
[
0
];
const
int
total_T
=
x_dims
[
0
];
const
int
N
=
x_lod
[
0
].
size
()
-
1
;
const
int
N
=
x_lod
[
0
].
size
()
-
1
;
...
@@ -352,84 +289,47 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -352,84 +289,47 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
gate_offset
=
-
D
;
gate_offset
=
-
D
;
}
}
#define MOVE_ONE_STEP \
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
prev_h_data = h_out_data; \
int
bid
=
is_reverse
?
N
-
1
-
i
:
i
;
prev_c_data = c_out_data; \
int
seq_len
=
x_lod
[
0
][
bid
+
1
]
-
x_lod
[
0
][
bid
];
xx_data = xx_data + xx_offset; \
const
T
*
prev_c_data
=
nullptr
;
h_out_data = h_out_data + gate_offset; \
const
T
*
prev_h_data
=
nullptr
;
c_out_data = c_out_data + gate_offset
int
tstart
=
0
;
if
(
h0_data
)
{
#define PROCESS_H0C0_DEFINES \
prev_h_data
=
h0_data
+
bid
*
D
;
int bid = is_reverse ? N - 1 - i : i; \
prev_c_data
=
c0_data
+
bid
*
D
;
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \
}
else
{
const T* prev_c_data = nullptr; \
ker
->
ComputeC1H1
(
xx_data
,
c_out_data
,
h_out_data
,
wp_data
);
const T* prev_h_data = nullptr; \
tstart
=
1
;
int tstart = 0
// move one step
prev_h_data
=
h_out_data
;
#define PROCESS_H0C0_PEEPHOLE \
prev_c_data
=
c_out_data
;
PROCESS_H0C0_DEFINES; \
xx_data
=
xx_data
+
xx_offset
;
if (h0_data) { \
h_out_data
=
h_out_data
+
gate_offset
;
prev_h_data = h0_data + bid * D; \
c_out_data
=
c_out_data
+
gate_offset
;
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
}
#define PROCESS_H0C0 \
PROCESS_H0C0_DEFINES; \
if (h0_data) { \
prev_h_data = h0_data + bid * D; \
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
}
if
(
use_peepholes
)
{
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
PROCESS_H0C0_PEEPHOLE
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
COMPUTE_CtHt_PEEPHOLE
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
);
MOVE_ONE_STEP
;
}
}
}
}
else
{
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
const
auto
&
ker
=
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
math
::
jitkernel
::
KernelPool
::
Instance
()
ker
->
ComputeCtHt
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
,
wp_data
,
.
template
Get
<
math
::
jitkernel
::
LSTMKernel
<
T
>,
const
std
::
string
&
,
checked_cell_data
);
const
std
::
string
&
,
const
std
::
string
&>
(
// move one step
act_gate_str
,
act_cand_str
,
act_cell_str
,
D
,
false
);
prev_h_data
=
h_out_data
;
prev_c_data
=
c_out_data
;
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
xx_data
=
xx_data
+
xx_offset
;
PROCESS_H0C0
h_out_data
=
h_out_data
+
gate_offset
;
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
c_out_data
=
c_out_data
+
gate_offset
;
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
ker
->
ComputeCtHt
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
);
MOVE_ONE_STEP
;
}
}
}
}
}
#undef PROCESS_H0C0_DEFINES
#undef PROCESS_H0C0_PEEPHOLE
#undef PROCESS_H0C0
#undef MOVE_ONE_STEP
}
}
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
platform
::
CPUDeviceContext
;
INIT_BASE_DEFINES
;
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
if
(
x
->
lod
()[
0
].
size
()
==
2
)
{
if
(
x
->
lod
()[
0
].
size
()
==
2
)
{
xx
->
Resize
({
x_dims
[
0
],
D4
});
xx
->
Resize
({
x_dims
[
0
],
D4
});
SeqCompute
(
ctx
);
SeqCompute
(
ctx
);
return
;
return
;
}
}
INIT_VEC_FUNC
INIT_OTHER_DEFINES
;
INIT_BASE_INPUT_DATAS
auto
*
reordered_h0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedH0"
);
auto
*
reordered_h0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedH0"
);
auto
*
reordered_c0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedC0"
);
auto
*
reordered_c0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedC0"
);
...
@@ -477,8 +377,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -477,8 +377,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_c_data
=
reordered_c0_data
;
prev_c_data
=
reordered_c0_data
;
size_t
sz
=
sizeof
(
T
)
*
D
;
size_t
sz
=
sizeof
(
T
)
*
D
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
std
::
memcpy
(
reordered_h0_data
,
h0_data
+
seq_order
[
i
]
*
D
,
sz
);
blas
.
VCOPY
(
sz
,
h0_data
+
seq_order
[
i
]
*
D
,
reordered_h0_data
);
std
::
memcpy
(
reordered_c0_data
,
c0_data
+
seq_order
[
i
]
*
D
,
sz
);
blas
.
VCOPY
(
sz
,
c0_data
+
seq_order
[
i
]
*
D
,
reordered_c0_data
);
reordered_h0_data
+=
D
;
reordered_h0_data
+=
D
;
reordered_c0_data
+=
D
;
reordered_c0_data
+=
D
;
}
}
...
@@ -488,13 +388,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -488,13 +388,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
GET_Ct_NOH0C0
(
cur_in_data
,
cur_c_out_data
);
ker
->
ComputeC1H1
(
cur_in_data
,
cur_c_out_data
,
cur_h_out_data
,
wp_data
);
if
(
use_peepholes
)
{
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cur_c_out_data
,
cur_in_data
+
D
);
blas
.
VADD
(
D
,
cur_in_data
+
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
}
act_gate
(
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
GET_Ht
(
cur_c_out_data
,
cur_in_data
,
cur_h_out_data
);
cur_in_data
+=
D4
;
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
cur_h_out_data
+=
D
;
...
@@ -503,66 +397,37 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -503,66 +397,37 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_h_data
=
batched_h_out_data
;
prev_h_data
=
batched_h_out_data
;
prev_c_data
=
batched_c_out_data
;
prev_c_data
=
batched_c_out_data
;
}
}
// compute kernel part
const
auto
&
batch_starts
=
batched_lod
[
0
];
const
auto
&
batch_starts
=
batched_lod
[
0
];
const
int
max_seq_len
=
batch_starts
.
size
()
-
1
;
const
int
max_seq_len
=
batch_starts
.
size
()
-
1
;
const
int
offset
=
tstart
*
max_bs
*
D
;
const
int
offset
=
tstart
*
max_bs
*
D
;
batched_input_data
=
batched_input_data
+
offset
*
4
;
batched_input_data
=
batched_input_data
+
offset
*
4
;
batched_h_out_data
=
batched_h_out_data
+
offset
;
batched_h_out_data
=
batched_h_out_data
+
offset
;
batched_c_out_data
=
batched_c_out_data
+
offset
;
batched_c_out_data
=
batched_c_out_data
+
offset
;
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
#define DEFINE_CUR \
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
T* cur_in_data = batched_input_data; \
GEMM_WH_ADDON
(
cur_bs
,
prev_h_data
,
batched_input_data
);
T* cur_prev_c_data = prev_c_data; \
T
*
cur_in_data
=
batched_input_data
;
T* cur_c_out_data = batched_c_out_data; \
T
*
cur_prev_c_data
=
prev_c_data
;
T* cur_h_out_data = batched_h_out_data
T
*
cur_c_out_data
=
batched_c_out_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
#define MOVE_ONE_BATCH \
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
cur_in_data += D4; \
ker
->
ComputeCtHt
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
cur_prev_c_data += D; \
cur_h_out_data
,
wp_data
,
checked_cell_data
);
cur_c_out_data += D; \
// move one batch
cur_h_out_data += D
cur_in_data
+=
D4
;
cur_prev_c_data
+=
D
;
#define MOVE_ONE_STEP \
cur_c_out_data
+=
D
;
prev_c_data = batched_c_out_data; \
cur_h_out_data
+=
D
;
prev_h_data = batched_h_out_data; \
batched_c_out_data = cur_c_out_data; \
batched_h_out_data = cur_h_out_data; \
batched_input_data = cur_in_data
if
(
use_peepholes
)
{
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
GEMM_WH_ADDON
(
cur_bs
,
prev_h_data
,
batched_input_data
);
DEFINE_CUR
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
COMPUTE_CtHt_PEEPHOLE
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
cur_h_out_data
);
MOVE_ONE_BATCH
;
}
MOVE_ONE_STEP
;
}
}
else
{
const
auto
&
ker
=
math
::
jitkernel
::
KernelPool
::
Instance
()
.
template
Get
<
math
::
jitkernel
::
LSTMKernel
<
T
>,
const
std
::
string
&
,
const
std
::
string
&
,
const
std
::
string
&>
(
act_gate_str
,
act_cand_str
,
act_cell_str
,
D
,
false
);
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
GEMM_WH_ADDON
(
cur_bs
,
prev_h_data
,
batched_input_data
);
DEFINE_CUR
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
ker
->
ComputeCtHt
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
cur_h_out_data
);
MOVE_ONE_BATCH
;
}
MOVE_ONE_STEP
;
}
}
// move one step
prev_c_data
=
batched_c_out_data
;
prev_h_data
=
batched_h_out_data
;
batched_c_out_data
=
cur_c_out_data
;
batched_h_out_data
=
cur_h_out_data
;
batched_input_data
=
cur_in_data
;
}
}
#undef MOVE_ONE_STEP
#undef MOVE_ONE_BATCH
#undef DEFINE_CUR
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batched_h_out
->
set_lod
(
batched_lod
);
batched_h_out
->
set_lod
(
batched_lod
);
...
@@ -579,17 +444,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -579,17 +444,9 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
}
}
}
}
#undef COMPUTE_CtHt_PEEPHOLE
#undef GET_Ct_NOH0C0
#undef COMPUTE_CtHt_NOH0C0
#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
#undef GET_Ht
#undef GET_Ct
#undef GEMM_WH_ADDON
#undef GEMM_WH_ADDON
#undef INIT_BASE_INPUT_DATAS
#undef INIT_OTHER_DEFINES
#undef INIT_BASE_SIZES
#undef INIT_BASE_DEFINES
#undef INIT_BASE_INPUT_OUTPUT
#undef INIT_VEC_FUNC
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
8e182170
...
@@ -126,7 +126,14 @@ template <typename T>
...
@@ -126,7 +126,14 @@ template <typename T>
class
LSTMKernel
:
public
Kernel
{
class
LSTMKernel
:
public
Kernel
{
public:
public:
virtual
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
virtual
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
/* below only used in peephole*/
const
T
*
wp_data
=
nullptr
,
T
*
checked
=
nullptr
)
const
=
0
;
T
*
checked
=
nullptr
)
const
=
0
;
// compute c1 and h1 without c0 or h0
virtual
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
/* below only used in peephole*/
const
T
*
wp_data
=
nullptr
)
const
=
0
;
};
};
}
// namespace jitkernel
}
// namespace jitkernel
...
...
paddle/fluid/operators/math/jit_kernel_lstm.cc
浏览文件 @
8e182170
...
@@ -82,6 +82,26 @@ __m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
...
@@ -82,6 +82,26 @@ __m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
}
}
#endif
#endif
template
<
typename
T
>
static
std
::
shared_ptr
<
const
VActKernel
<
T
>>
GetActKernel
(
const
std
::
string
&
type
,
int
n
)
{
if
(
type
==
"sigmoid"
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VSigmoidKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"relu"
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VReluKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"tanh"
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VTanhKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VIdentityKernel
<
T
>
>
(
n
));
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
return
nullptr
;
}
/* LSTM JitKernel */
/* LSTM JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
LSTMKernelImpl
:
public
LSTMKernel
<
T
>
{
class
LSTMKernelImpl
:
public
LSTMKernel
<
T
>
{
...
@@ -93,26 +113,10 @@ class LSTMKernelImpl : public LSTMKernel<T> {
...
@@ -93,26 +113,10 @@ class LSTMKernelImpl : public LSTMKernel<T> {
d_
=
d
;
d_
=
d
;
d2_
=
d
*
2
;
d2_
=
d
*
2
;
d3_
=
d
*
3
;
d3_
=
d
*
3
;
auto
GetActKernel
=
[
&
](
const
std
::
string
&
type
,
act_gate_d3_
=
GetActKernel
<
T
>
(
act_gate
,
d3_
);
int
n
)
->
std
::
shared_ptr
<
const
VActKernel
<
T
>>
{
act_gate_d_
=
GetActKernel
<
T
>
(
act_gate
,
d
);
if
(
type
==
"sigmoid"
)
{
act_cand_d_
=
GetActKernel
<
T
>
(
act_cand
,
d
);
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
act_cell_d_
=
GetActKernel
<
T
>
(
act_cell
,
d
);
KernelPool
::
Instance
().
template
Get
<
VSigmoidKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"relu"
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VReluKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"tanh"
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VTanhKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VIdentityKernel
<
T
>
>
(
n
));
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
};
act_gate_3d_
=
GetActKernel
(
act_gate
,
d
*
3
);
act_cand_d_
=
GetActKernel
(
act_cand
,
d
);
act_cell_d_
=
GetActKernel
(
act_cell
,
d
);
vmul_d_
=
KernelPool
::
Instance
().
template
Get
<
VMulKernel
<
T
>
>
(
d
);
vmul_d_
=
KernelPool
::
Instance
().
template
Get
<
VMulKernel
<
T
>
>
(
d
);
vadd_d_
=
KernelPool
::
Instance
().
template
Get
<
VAddKernel
<
T
>
>
(
d
);
vadd_d_
=
KernelPool
::
Instance
().
template
Get
<
VAddKernel
<
T
>
>
(
d
);
#ifdef __AVX__
#ifdef __AVX__
...
@@ -134,10 +138,10 @@ class LSTMKernelImpl : public LSTMKernel<T> {
...
@@ -134,10 +138,10 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#endif
#endif
}
}
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
T
*
checked
)
const
override
{
T
*
checked
)
const
override
{
// gates: W_ch, W_ih, W_fh, W_oh
// gates: W_ch, W_ih, W_fh, W_oh
act_gate_
3d
_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
act_gate_
d3
_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
/* C_t = C_t-1 * fgated + cand_gated * igated */
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_
->
Compute
(
gates
,
gates
);
act_cand_d_
->
Compute
(
gates
,
gates
);
...
@@ -149,10 +153,21 @@ class LSTMKernelImpl : public LSTMKernel<T> {
...
@@ -149,10 +153,21 @@ class LSTMKernelImpl : public LSTMKernel<T> {
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
}
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
}
private:
private:
int
d_
,
d2_
,
d3_
;
int
d_
,
d2_
,
d3_
;
std
::
shared_ptr
<
const
VActKernel
<
T
>>
act_gate_3d_
,
act_cand_d_
,
act_cell_d_
;
std
::
shared_ptr
<
const
VActKernel
<
T
>>
act_gate_d3_
,
act_gate_d_
,
act_cand_d_
,
act_cell_d_
;
std
::
shared_ptr
<
const
VMulKernel
<
T
>>
vmul_d_
;
std
::
shared_ptr
<
const
VMulKernel
<
T
>>
vmul_d_
;
std
::
shared_ptr
<
const
VAddKernel
<
T
>>
vadd_d_
;
std
::
shared_ptr
<
const
VAddKernel
<
T
>>
vadd_d_
;
#ifdef __AVX__
#ifdef __AVX__
...
@@ -163,8 +178,8 @@ class LSTMKernelImpl : public LSTMKernel<T> {
...
@@ -163,8 +178,8 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#define INTRI8_FLOAT(isa) \
#define INTRI8_FLOAT(isa) \
template <> \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht,
float* checked)
\
float* gates, const float* ct_1, float* ct, float* ht,
\
const
{
\
const
float* wp_data, float* checked) const {
\
/* gates: W_ch, W_ih, W_fh, W_oh */
\
/* gates: W_ch, W_ih, W_fh, W_oh */
\
__m256 c, i, f, o; \
__m256 c, i, f, o; \
c = _mm256_loadu_ps(gates); \
c = _mm256_loadu_ps(gates); \
...
@@ -205,51 +220,56 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
...
@@ -205,51 +220,56 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
d_
=
d
;
d_
=
d
;
d2_
=
d
*
2
;
d2_
=
d
*
2
;
d3_
=
d
*
3
;
d3_
=
d
*
3
;
auto
GetActKernel
=
[
&
](
const
std
::
string
&
type
,
act_gate_d_
=
GetActKernel
<
T
>
(
act_gate
,
d
);
int
n
)
->
std
::
shared_ptr
<
const
VActKernel
<
T
>>
{
act_cand_d_
=
GetActKernel
<
T
>
(
act_cand
,
d
);
if
(
type
==
"sigmoid"
)
{
act_cell_d_
=
GetActKernel
<
T
>
(
act_cell
,
d
);
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VSigmoidKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"relu"
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VReluKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"tanh"
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VTanhKernel
<
T
>
>
(
n
));
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
std
::
dynamic_pointer_cast
<
const
VActKernel
<
T
>>
(
KernelPool
::
Instance
().
template
Get
<
VIdentityKernel
<
T
>
>
(
n
));
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
};
act_gate_3d_
=
GetActKernel
(
act_gate
,
d
*
3
);
act_cand_d_
=
GetActKernel
(
act_cand
,
d
);
act_cell_d_
=
GetActKernel
(
act_cell
,
d
);
vmul_d_
=
KernelPool
::
Instance
().
template
Get
<
VMulKernel
<
T
>
>
(
d
);
vmul_d_
=
KernelPool
::
Instance
().
template
Get
<
VMulKernel
<
T
>
>
(
d
);
vadd_d_
=
KernelPool
::
Instance
().
template
Get
<
VAddKernel
<
T
>
>
(
d
);
vadd_d_
=
KernelPool
::
Instance
().
template
Get
<
VAddKernel
<
T
>
>
(
d
);
vadd_d2_
=
KernelPool
::
Instance
().
template
Get
<
VAddKernel
<
T
>
>
(
d2_
);
act_gate_d2_
=
GetActKernel
<
T
>
(
act_gate
,
d2_
);
}
}
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
T
*
checked
)
const
override
{
T
*
checked
)
const
override
{
// gates: W_ch, W_ih, W_fh, W_oh
/* get fgated and igated*/
act_gate_3d_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
vmul_d_
->
Compute
(
wp_data
,
ct_1
,
checked
);
vmul_d_
->
Compute
(
wp_data
+
d_
,
ct_1
,
checked
+
d_
);
/* C_t = C_t-1 * fgated + cand_gated * igated */
vadd_d2_
->
Compute
(
checked
,
gates
+
d_
,
gates
+
d_
);
act_gate_d2_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_
->
Compute
(
gates
,
gates
);
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
);
/* get ogated*/
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
);
/* get outgated, put W_oc * C_t on igated */
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
);
/* H_t = act_cell(C_t) * ogated */
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
);
}
}
private:
private:
int
d_
,
d2_
,
d3_
;
int
d_
,
d2_
,
d3_
;
std
::
shared_ptr
<
const
VActKernel
<
T
>>
act_gate_3d_
,
act_cand_d_
,
act_cell_d_
;
std
::
shared_ptr
<
const
VActKernel
<
T
>>
act_gate_d2_
,
act_gate_d_
,
act_cand_d_
,
act_cell_d_
;
std
::
shared_ptr
<
const
VMulKernel
<
T
>>
vmul_d_
;
std
::
shared_ptr
<
const
VMulKernel
<
T
>>
vmul_d_
;
std
::
shared_ptr
<
const
VAddKernel
<
T
>>
vadd_d_
;
std
::
shared_ptr
<
const
VAddKernel
<
T
>>
vadd_d_
,
vadd_d2_
;
};
};
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录