Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
508548f8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
508548f8
编写于
8月 22, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implement attention lstm cpu forward
上级
9affc36c
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
278 addition
and
194 deletion
+278
-194
paddle/fluid/operators/attention_lstm_op.cc
paddle/fluid/operators/attention_lstm_op.cc
+276
-190
paddle/fluid/operators/attention_lstm_op.h
paddle/fluid/operators/attention_lstm_op.h
+2
-3
paddle/fluid/operators/fusion_lstm_op.h
paddle/fluid/operators/fusion_lstm_op.h
+0
-1
未找到文件。
paddle/fluid/operators/attention_lstm_op.cc
浏览文件 @
508548f8
...
@@ -20,10 +20,12 @@ limitations under the License. */
...
@@ -20,10 +20,12 @@ limitations under the License. */
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
void
Fus
ionLSTMOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
void
Attent
ionLSTMOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"WeightX"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"WeightX"
),
"Input(WeightX) of LSTM should not be null."
);
"Input(WeightX) of LSTM should not be null."
);
...
@@ -57,6 +59,9 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -57,6 +59,9 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"should be the same."
);
"should be the same."
);
}
}
// fc_out , shape (maxseqlen,1)
int
max_seq_len
=
0
;
auto
wx_dims
=
ctx
->
GetInputDim
(
"WeightX"
);
auto
wx_dims
=
ctx
->
GetInputDim
(
"WeightX"
);
PADDLE_ENFORCE_EQ
(
wx_dims
.
size
(),
2
,
PADDLE_ENFORCE_EQ
(
wx_dims
.
size
(),
2
,
"The rank of Input(WeightX) should be 2."
);
"The rank of Input(WeightX) should be 2."
);
...
@@ -103,241 +108,321 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -103,241 +108,321 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
ctx
->
ShareLoD
(
"X"
,
"XX"
);
ctx
->
ShareLoD
(
"X"
,
"XX"
);
}
}
framework
::
OpKernelType
Fus
ionLSTMOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
Attent
ionLSTMOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
framework
::
ExecutionContext
&
ctx
)
const
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
void
Fus
ionLSTMOpMaker
::
Make
()
{
void
Attent
ionLSTMOpMaker
::
Make
()
{
AddInput
(
"X"
,
AddInput
(
"X"
,
"(LoDTensor) the input is a LodTensor, which support "
"(LoDTensor) the input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X M), where T is the "
"this LoDTensor is a matrix with shape (T X M), where T is the "
"total time steps in this mini-batch, M is the dim size of x."
);
"total time steps in this mini-batch, M is the dim size of x."
);
AddInput
(
"WeightX"
,
AddInput
(
"C0"
,
"(Tensor) the learnable weights of X."
"(Tensor) LSTM C0"
" - The shape is (M x 4D), where M is the dim size of x, D is the "
"This is a tensor with shape (N x D), where N is the batch size, D "
"hidden size. "
"is the gate size."
" - Weight = {W_cx, W_ix, W_fx, W_ox}"
);
"C0 is necessary because of attention."
);
AddInput
(
"WeightH"
,
"(Tensor) same as LSTMOp, the learnable hidden-hidden weights."
" - The shape is (D x 4D), where D is the hidden size. "
" - Weight = {W_ch, W_ih, W_fh, W_oh}"
);
AddInput
(
"Bias"
,
"(Tensor) the learnable weights. Almost same as LSTMOp"
"Note: we should add the fc bias into this (1x4D) in bias."
"input-hidden bias weight and peephole connections weight if "
"setting `use_peepholes` True. "
"1. `use_peepholes = False` "
" - The shape is (1 x 4D). "
" - Bias = {b_c, b_i, b_f, b_o}."
"2. `use_peepholes = True` "
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."
);
AddInput
(
"H0"
,
AddInput
(
"H0"
,
"(Tensor, optional) (same as LSTMOp) the initial hidden state is an "
"(Tensor, optional) LSTM H0"
"optional "
"This is a tensor with shape (N x D), where N is the "
"input. This is a tensor with shape (N x D), where N is the "
"batch size and D is the gate size."
)
"batch size and D is the hidden size."
)
.
AsDispensable
();
.
AsDispensable
();
AddInput
(
"C0"
,
AddInput
(
"AttentionWeight"
,
"(Tensor, optional) (same as LSTMOp) (the initial cell state is an "
"(Tensor) the weights of attention fc. Always relu the fc result."
"optional "
"The shape is ((M+D) x 1), where M is the dim size of x, D is the "
"input. This is a tensor with shape (N x D), where N is the "
"gate size of LSTM."
);
"batch size. `H0` and `C0` can be NULL but only at the same time."
)
AddInput
(
"AttentionBias, optional"
,
"(Tensor) the bias of attention fc."
"The shape is (1 x 1)"
)
.
AsDispensable
();
AddInput
(
"AttentionScalar"
,
"(Tensor, optional) the scalar on the result of attentioned fc. "
"Always relu the Scalar."
"The shape is (1 x 1)"
)
.
AsDispensable
();
AddInput
(
"AttentionScalarBias"
,
"(Tensor, optional) the scalar bias of attention fc."
"The shape is (1 x 1)"
)
.
AsDispensable
();
.
AsDispensable
();
AddInput
(
"LSTMWeight"
,
"(Tensor) the combined weight of LSTM"
" - The shape is ((D+M) x 4D), where D is the hidden gate size, M "
"is the dim size of x"
" - Weight = {W_forget, W_input, W_output, W_cell}"
);
AddInput
(
"LSTMBias"
,
"(Tensor) the combined bias of LSTM, shape (1x4D)."
"Note: we should add the bias of hidden and context accorindg to "
"the same gate: "
"{B_forget, B_input, B_output, B_cell}"
);
AddOutput
(
"Hidden"
,
AddOutput
(
"Hidden"
,
"(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
"(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`."
);
"The shape is (T x D), and lod is the same with the `Input`."
);
AddOutput
(
"Cell"
,
AddOutput
(
"Cell"
,
"(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
"(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`."
);
"The shape is (T x D), and lod is the same with the `Input`."
);
AddOutput
(
"XX"
,
AddOutput
(
"(LoDTensor) the result after X * WeightX (size is T x 4D)"
"AttentionedX"
,
" or batched_X (size is T x M), this will be automatically chosen
,"
"(LodTensor) shape is (T x 1), the result after X * AttentionWeight
,"
" where T is the total time steps in this mini-batch,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size, M is the dim size of x input
."
)
" D is the hidden size
."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
"
BatchedGate"
,
"(LoDTensor) (same as LSTMOp)."
).
AsIntermediate
();
AddOutput
(
"
AttentionFCOut"
,
AddOutput
(
"BatchCellPreAct"
,
"(LoDTensor) (same as LSTMOp)
."
)
"(Tensor) (max_seq_len, 1), compute at each step
."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddAttr
<
bool
>
(
"use_peepholes"
,
AddOutput
(
"LSTMX"
,
"(bool, defalut: True) "
"(Tensor) the input X of LSTM for each step."
"whether to enable diagonal/peephole connections."
)
"Shape is (1 x M), where M is the x frame size"
)
.
SetDefault
(
true
);
.
AsIntermediate
();
AddAttr
<
bool
>
(
"is_reverse"
,
AddOutput
(
"(bool, defalut: False) "
"LSTMOUT"
,
"whether to compute reversed LSTM."
)
"(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step."
.
SetDefault
(
false
);
"Shape is (1 x 4D), where M is the x frame size"
)
.
AsIntermediate
();
// TODO(TJ): InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr
<
std
::
string
>
(
"gate_activation"
,
AddAttr
<
std
::
string
>
(
"gate_activation"
,
"(string, default: sigmoid)"
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
"The activation for input gate, forget gate and output "
"gate, `sigmoid` by default."
)
"gate, `sigmoid` by default."
)
.
SetDefault
(
"sigmoid"
)
.
SetDefault
(
"sigmoid"
)
.
InEnum
({
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
});
.
InEnum
({
"sigmoid"
});
AddAttr
<
std
::
string
>
(
"cell_activation"
,
AddAttr
<
std
::
string
>
(
"cell_activation"
,
"(string, default: tanh)"
"(string, default: tanh)"
"The activation for cell output, `tanh` by defalut."
)
"The activation for cell output, `tanh` by defalut."
)
.
SetDefault
(
"tanh"
)
.
SetDefault
(
"tanh"
)
.
InEnum
({
"
sigmoid"
,
"tanh"
,
"relu"
,
"identity
"
});
.
InEnum
({
"
tanh
"
});
AddAttr
<
std
::
string
>
(
"candidate_activation"
,
AddAttr
<
std
::
string
>
(
"candidate_activation"
,
"(string, default: tanh)"
"(string, default: tanh)"
"The activation for candidate hidden state, "
"The activation for candidate hidden state, "
"`tanh` by default."
)
"`tanh` by default."
)
.
SetDefault
(
"tanh"
)
.
SetDefault
(
"tanh"
)
.
InEnum
({
"
sigmoid"
,
"tanh"
,
"relu"
,
"identity
"
});
.
InEnum
({
"
tanh
"
});
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Fusion Long-Short Term Memory (LSTM) Operator.
Attention Long-Short Term Memory (LSTM) Operator.
This operator fuse the X into LSTM, more details can refer to LSTM op.
Attention part:
concat( x(seqlen * M), expand( cell_t-1(1,D) ) ) => tmp(seqlen*(M+D))
tmp(seqlen*(M+D)) * fc((M+D)*1) => fcout(seqlen*1) with bias, relu
fcout(seqlen*1) * scalar => fcout(seqlen*1) with bias, relu
dotmul and sum pool ( fcout(seqlen*1), x(seqlen * M) ) => lstm_x_t(1, M)
LSTM part:
use lstm_x_t as input and compute as standard LSTM.
)DOC"
);
)DOC"
);
}
}
// y[i] = (x[i] + bias[0]) > 0 ? (x[i] + bias[0]) : 0;
template
<
typename
T
>
inline
void
bias_relu
(
const
int
n
,
const
T
*
x
,
const
T
*
bias
,
T
*
y
)
{
if
(
bias
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
x
[
i
]
+
bias
[
0
];
}
vec_relu
(
n
,
y
,
y
);
}
else
{
vec_relu
(
n
,
x
,
y
);
}
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
inline
void
ReorderInitState
(
const
DeviceContext
&
ctx
,
inline
void
vec_softmax
(
const
BlasT
<
DeviceContext
,
T
>&
blas
,
const
int
n
,
const
framework
::
Tensor
&
src
,
const
T
*
x
,
T
*
y
)
{
framework
::
Vector
<
size_t
>
index_lod
,
T
scalar
=
x
[
0
];
framework
::
Tensor
*
dst
,
bool
indexed_src
)
{
// max
math
::
CopyMatrixRowsFunctor
<
DeviceContext
,
T
>
row_shuffle
;
for
(
int
i
=
1
;
i
<
n
;
++
i
)
{
dst
->
mutable_data
<
T
>
(
src
.
dims
(),
ctx
.
GetPlace
());
scalar
=
scalar
<
x
[
i
]
?
x
[
i
]
:
scalar
;
// TODO(TJ): check mem copy perf
}
row_shuffle
(
ctx
,
src
,
index_lod
,
dst
,
indexed_src
);
// sub
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
c
]
=
x
[
c
]
-
alpha
;
}
// exp
blas
.
VEXP
(
n
,
y
,
y
);
// sum
scalar
=
T
(
0
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
scalar
+=
y
[
i
];
}
// scale
blas
.
VSCAL
(
n
,
static_cast
<
T
>
(
1
)
/
scalar
,
y
);
}
__m256
exp
(
__m256
a
)
{
return
exp256_ps
(
a
);
}
__m256
log
(
__m256
a
)
{
return
log256_ps
(
a
);
}
__m256
sin
(
__m256
a
)
{
return
sin256_ps
(
a
);
}
__m256
cos
(
__m256
a
)
{
return
cos256_ps
(
a
);
}
__m256
relu
(
const
__m256
a
)
{
__m256
tmp
=
_mm256_set1_ps
(
0.0
f
);
return
_mm256_max_ps
(
a
,
tmp
);
}
__m256
sigmoid
(
const
__m256
a
)
{
__m256
max
=
_mm256_set1_ps
(
SIGMOID_THRESHOLD_MAX
);
__m256
min
=
_mm256_set1_ps
(
SIGMOID_THRESHOLD_MIN
);
__m256
tmp
=
_mm256_max_ps
(
a
,
min
);
tmp
=
_mm256_min_ps
(
tmp
,
max
);
tmp
=
_mm256_sub_ps
(
_mm256_set1_ps
(
0.0
f
),
tmp
);
tmp
=
exp
(
tmp
);
tmp
=
_mm256_add_ps
(
_mm256_set1_ps
(
1.0
f
),
tmp
);
tmp
=
_mm256_div_ps
(
_mm256_set1_ps
(
1.0
f
),
tmp
);
return
tmp
;
}
__m256
tanh
(
const
__m256
a
)
{
__m256
max
=
_mm256_set1_ps
(
EXP_MAX_INPUT
);
__m256
tmp
=
_mm256_mul_ps
(
_mm256_set1_ps
(
-
2.0
f
),
a
);
tmp
=
_mm256_min_ps
(
tmp
,
max
);
tmp
=
exp
(
tmp
);
return
_mm256_sub_ps
(
_mm256_div_ps
(
_mm256_set1_ps
(
2.0
f
),
_mm256_add_ps
(
_mm256_set1_ps
(
1.0
f
),
tmp
)),
_mm256_set1_ps
(
1.0
f
));
}
__m256
linear
(
const
__m256
a
)
{
return
a
;
}
inline
void
vec_sigmoid
(
const
T
*
x
,
T
*
y
)
{
const
real
min
=
SIGMOID_THRESHOLD_MIN
;
const
real
max
=
SIGMOID_THRESHOLD_MAX
;
real
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
1.0
/
(
1.0
+
exp
(
-
tmp
));
}
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
Fuis
onLSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
class
Attenti
onLSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
x
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
// T x M
auto
*
wx
=
ctx
.
Input
<
Tensor
>
(
"WeightX"
);
auto
*
h0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
// N x D
auto
*
wh
=
ctx
.
Input
<
Tensor
>
(
"WeightH"
);
auto
*
c0
=
ctx
.
Input
<
Tensor
>
(
"C0"
);
// N x D
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
atten_w
=
ctx
.
Input
<
Tensor
>
(
"AttentionWeight"
);
// (M+D) x 1
auto
*
hidden_t0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
auto
*
atten_b
=
ctx
.
Input
<
Tensor
>
(
"AttentionBias"
);
// 1x1
auto
*
cell_t0
=
ctx
.
Input
<
Tensor
>
(
"C0"
);
auto
*
atten_scalar
=
ctx
.
Input
<
Tensor
>
(
"AttentionScalar"
);
// 1x1
auto
*
atten_scalar_bias
=
ctx
.
Input
<
Tensor
>
(
"AttentionScalar"
);
// 1x1
auto
*
xx
=
ctx
.
Output
<
LoDTensor
>
(
"XX"
);
auto
*
lstm_w
=
ctx
.
Input
<
Tensor
>
(
"LSTMWeight"
);
// (D+M) x D*4
auto
*
batched_gate
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedGate"
);
auto
*
lstm_b
=
ctx
.
Input
<
Tensor
>
(
"LSTMBias"
);
// 1 x D*4
auto
*
hidden_out
=
ctx
.
Output
<
LoDTensor
>
(
"Hidden"
);
auto
*
cell_out
=
ctx
.
Output
<
LoDTensor
>
(
"Cell"
);
auto
*
hidden_out
=
ctx
.
Output
<
LoDTensor
>
(
"Hidden"
);
// TxD
bool
is_reverse
=
ctx
.
Attr
<
bool
>
(
"is_reverse"
);
auto
*
cell_out
=
ctx
.
Output
<
LoDTensor
>
(
"Cell"
);
// TxD
auto
*
atted_x
=
ctx
.
Output
<
LoDTensor
>
(
"AttentionedX"
);
// T x 1
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
fc_out
=
ctx
.
Output
<
Tensor
>
(
'
AttentionFCOut
'
);
// max_seq_len x 1
T
*
batched_gate_data
=
batched_gate
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
lstm_x
=
ctx
.
Output
<
Tensor
>
(
"LSTMX"
);
// 1 x M
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
lstm_out
=
ctx
.
Output
<
Tensor
>
(
"LSTMOUT"
);
// 1 x 4D
cell_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
h0_data
=
h0
->
data
<
T
>
();
auto
x_dims
=
x
->
dims
();
const
T
*
c0_data
=
c0
->
data
<
T
>
();
auto
wx_dims
=
wx
->
dims
();
const
T
*
lstm_w_data
=
lstm_w
->
data
<
T
>
();
const
T
*
lstm_b_data
=
lstm_b
->
data
<
T
>
();
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
const
T
*
atten_w_data
=
atten_w
->
data
<
T
>
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
const
T
*
atten_b_data
=
atten_b
?
atten_b
->
data
<
T
>
()
:
NULL
;
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
const
T
*
atten_scalar_data
=
atten_scalar
?
atten_scalar
->
data
<
T
>
()
:
NULL
;
if
(
x_dims
[
1
]
>
wx_dims
[
1
])
{
const
T
*
atten_scalar_bias_data
=
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
wx_dims
[
1
],
x_dims
[
1
],
atten_scalar_bias
?
atten_scalar_bias
->
data
<
T
>
()
:
NULL
;
x_data
,
wx_data
,
xx_data
,
bias
->
data
<
T
>
());
T
*
hidden_out_data
=
hidden_out
->
mutable_data
<
T
>
();
to_batch
(
dev_ctx
,
*
xx
,
batched_gate
,
true
,
is_reverse
);
T
*
cell_out_data
=
cell_out
->
mutable_data
<
T
>
();
}
else
{
T
*
atted_x_data
=
atted_x
->
mutable_data
<
T
>
();
to_batch
(
dev_ctx
,
*
x
,
xx
,
true
,
is_reverse
);
T
*
fc_out_data
=
fc_out
->
mutable_data
<
T
>
();
batched_gate
->
set_lod
(
xx
->
lod
());
T
*
lstm_x_data
=
lstm_x
->
mutable_data
<
T
>
();
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
wx_dims
[
1
],
x_dims
[
1
],
T
*
lstm_out_data
=
lstm_out
->
mutable_data
<
T
>
();
xx_data
,
wx_data
,
batched_gate_data
,
bias
->
data
<
T
>
());
auto
x_lod
=
x
->
lod
();
auto
x_dims
=
x
->
dims
();
// T x M
auto
w_dims
=
w
->
dims
();
// (D+M) x 4D
const
int
M
=
x_dims
[
1
];
// x frame size
const
int
D
=
w_dims
[
1
]
/
4
;
// gate frame size
const
int
D2
=
D
*
2
;
const
int
D3
=
D
*
3
;
const
int
D4
=
w_dims
[
1
];
const
int
batch_size
=
x_lod
[
0
].
size
()
-
1
;
// assert lod.size() == 1
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
T
,
1
,
M
,
x_data
,
atten_w_data
,
atted_x_data
,
atten_b_data
);
const
T
*
cur_x_data
=
x_data
;
const
T
*
prev_cell_data
=
NULL
;
const
T
*
prev_hidden_data
=
NULL
;
T
*
cur_cell_out_data
=
cell_out_data
;
T
*
cur_hidden_out_data
=
hidden_out_data
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
seq_len
=
x_lod
[
0
][
i
+
1
];
prev_cell_data
=
c0_data
+
i
*
D
;
prev_hidden_data
=
h0
?
h0_data
+
i
*
D
:
NULL
;
for
(
int
step
=
0
;
step
<
seq_len
;
++
step
)
{
/// compute attention vector
// prev_cell(1xD) * fc(D) rest part of atten_wgt
// T = cblas_dot();
T
prev_cell_bias
=
blas
.
VDOT
(
D
,
prev_cell_data
,
atten_w_data
+
M
);
// add cell bias and relu
bias_relu
<
T
>
(
seq_len
,
atted_x_data
,
&
prev_cell_bias
,
fc_out_data
);
// fc2: scalar
if
(
atten_scalar_data
)
{
// x = a*x
blas
.
VSCAL
(
seq_len
,
atten_scalar_data
,
fc_out_data
);
bias_relu
<
T
>
(
seq_len
,
fc_out_data
,
atten_scalar_bias_data
,
fc_out_data
);
}
}
vec_softmax
<
DeviceContext
,
T
>
(
blas
,
seq_len
,
fc_out_data
,
fc_out_data
);
int
frame_size
=
static_cast
<
int
>
(
wx_dims
[
1
]
/
4
);
// mul x(seq_len*M) and sum pool
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
1
,
M
,
seq_len
,
fc_out_data
,
math
::
LstmMetaValue
<
T
>
lstm_value
;
cur_x_data
,
lstm_x_data
);
// no peephole
lstm_value
.
check_ig
=
nullptr
;
/// compute LSTM step
lstm_value
.
check_fg
=
nullptr
;
// lstm weight : concat[forget , input , output , tilde]
lstm_value
.
check_og
=
nullptr
;
// shape : (D + M) x (4 * D)
lstm_value
.
prev_state_value
=
nullptr
;
// fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D
Tensor
ordered_c0
;
blas
.
MatMul
(
1
,
D4
,
M
,
lstm_x_data
,
lstm_w_data
+
D
*
D4
,
lstm_out_data
);
if
(
prev_hidden_data
)
{
framework
::
Vector
<
size_t
>
order
(
batched_gate
->
lod
()[
2
]);
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D4
,
D
,
static_cast
<
T
>
(
1
),
prev_hidden_data
,
D
,
lstm_w_data
,
D4
,
static_cast
<
T
>
(
1
),
if
(
cell_t0
)
{
lstm_out_data
,
D4
);
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState
<
DeviceContext
,
T
>
(
dev_ctx
,
*
cell_t0
,
order
,
&
ordered_c0
,
true
);
lstm_value
.
prev_state_value
=
ordered_c0
.
data
<
T
>
();
}
}
// since input is 1xM, so can use add bias
blas
.
VADD
(
D4
,
lstm_b_data
,
lstm_out_data
,
lstm_out_data
);
// Use the local variable as here.
// gate act: sigmoid
LoDTensor
batch_hidden
,
batch_cell
;
vec_sigmoid
(
D3
,
lstm_out_data
,
lstm_out_data
);
auto
*
batch_cell_pre_act
=
ctx
.
Output
<
LoDTensor
>
(
"BatchCellPreAct"
);
// candicate act: tanh
batch_hidden
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
vec_tanh
(
D
,
lstm_out_data
+
D3
,
lstm_out_data
+
D3
);
batch_cell
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
batch_cell_pre_act
->
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
auto
batch_starts
=
batched_gate
->
lod
()[
0
];
size_t
max_seq_len
=
batch_starts
.
size
()
-
1
;
auto
gate_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"gate_activation"
));
auto
cell_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
));
auto
cand_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
));
for
(
size_t
n
=
0
;
n
<
max_seq_len
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
Tensor
gate_t
=
batched_gate
->
Slice
(
bstart
,
bend
);
Tensor
out_t
=
batch_hidden
.
Slice
(
bstart
,
bend
);
Tensor
cell_t
=
batch_cell
.
Slice
(
bstart
,
bend
);
Tensor
cell_pre_act_t
=
batch_cell_pre_act
->
Slice
(
bstart
,
bend
);
int
cur_batch_size
=
bend
-
bstart
;
if
(
n
>
0
)
{
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_hidden_t
=
batch_hidden
.
Slice
(
pre_h_start
,
pre_h_end
);
// TODO(TJ): use gemm directly
blas
.
MatMul
(
pre_hidden_t
,
false
,
*
wh
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
else
if
(
hidden_t0
)
{
// TODO(TJ): move h0 outside for
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
// If n == 0 and there is initialized hidden state, calculate W_h * H0.
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized hidden state also needs
// to reorder.
Tensor
ordered_h0
;
ReorderInitState
<
DeviceContext
,
T
>
(
dev_ctx
,
*
hidden_t0
,
order
,
&
ordered_h0
,
true
);
// TODO(TJ): use gemm directly
blas
.
MatMul
(
ordered_h0
,
false
,
*
wh
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
lstm_value
.
gate_value
=
gate_t
.
data
<
T
>
();
// a = forget * prev_cell
lstm_value
.
output_value
=
out_t
.
data
<
T
>
();
blas
.
VMUL
(
D
,
lstm_out_data
,
prev_cell_data
,
lstm_out_data
);
lstm_value
.
state_value
=
cell_t
.
data
<
T
>
();
lstm_value
.
state_active_value
=
cell_pre_act_t
.
data
<
T
>
();
// b = input * tilde
math
::
LstmUnitFunctor
<
DeviceContext
,
T
>::
compute
(
blas
.
VMUL
(
D
,
lstm_out_data
+
D
,
lstm_out
+
D3
,
lstm_out_data
+
D
);
dev_ctx
,
lstm_value
,
frame_size
,
cur_batch_size
,
gate_act
,
cell_act
,
cand_act
);
lstm_value
.
prev_state_value
=
lstm_value
.
state_value
;
}
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
// cell_out = a + b
batch_hidden
.
set_lod
(
batched_gate
->
lod
());
blas
.
VADD
(
D
,
lstm_out_data
,
lstm_out_data
+
D
,
cur_cell_out_data
);
// restore the output hidden in LoDTensor from the batch hidden
to_seq
(
dev_ctx
,
batch_hidden
,
hidden_out
);
batch_cell
.
set_lod
(
batched_gate
->
lod
());
// state act tanh(cell_out) * output_gate
// restore the output cell state in LoDTensor from the batch cell
vec_tanh
(
D
,
cur_cell_out_data
,
lstm_out_data
);
to_seq
(
dev_ctx
,
batch_cell
,
cell_out
);
blas
.
VMUL
(
D
,
lstm_out_data
,
lstm_out
+
D2
,
cur_hidden_out_data
);
prev_hidden_data
=
hidden_out
+
i
*
gate_size
;
prev_cell_data
=
cur_cell_out_data
;
cur_cell_out_data
=
cur_cell_out_data
+
D
;
cur_hidden_out_data
=
cur_hidden_out_data
+
D
;
}
cur_x_data
=
cur_x_data
+
seq_len
*
M
;
}
}
}
};
};
...
@@ -345,10 +430,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -345,10 +430,11 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
fusion_lstm
,
ops
::
FusionLSTMOp
,
ops
::
FusionLSTMOpMaker
,
REGISTER_OPERATOR
(
attention_lstm
,
ops
::
AttentionLSTMOp
,
ops
::
AttentionLSTMOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
fus
ion_lstm
,
attent
ion_lstm
,
ops
::
Fuis
onLSTMKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
Attenti
onLSTMKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
Fuis
onLSTMKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
ops
::
Attenti
onLSTMKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/attention_lstm_op.h
浏览文件 @
508548f8
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
// #include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -22,7 +21,7 @@ namespace operators {
...
@@ -22,7 +21,7 @@ namespace operators {
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
class
Fus
ionLSTMOp
:
public
framework
::
OperatorWithKernel
{
class
Attent
ionLSTMOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
@@ -33,7 +32,7 @@ class FusionLSTMOp : public framework::OperatorWithKernel {
...
@@ -33,7 +32,7 @@ class FusionLSTMOp : public framework::OperatorWithKernel {
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
};
class
Fus
ionLSTMOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
Attent
ionLSTMOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
override
;
void
Make
()
override
;
};
};
...
...
paddle/fluid/operators/fusion_lstm_op.h
浏览文件 @
508548f8
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
// #include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录