Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
0c70bd28
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0c70bd28
编写于
11月 03, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Enable initial hidden state and cell state in LSTM Operator.
上级
83c22816
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
166 addition
and
54 deletion
+166
-54
paddle/operators/lstm_op.cc
paddle/operators/lstm_op.cc
+31
-12
paddle/operators/lstm_op.h
paddle/operators/lstm_op.h
+76
-18
paddle/operators/math/sequence2batch.cc
paddle/operators/math/sequence2batch.cc
+2
-2
paddle/operators/math/sequence2batch.cu
paddle/operators/math/sequence2batch.cu
+2
-2
paddle/operators/math/sequence2batch.h
paddle/operators/math/sequence2batch.h
+21
-10
python/paddle/v2/framework/tests/test_lstm_op.py
python/paddle/v2/framework/tests/test_lstm_op.py
+34
-10
未找到文件。
paddle/operators/lstm_op.cc
浏览文件 @
0c70bd28
...
@@ -24,6 +24,11 @@ class LSTMOp : public framework::OperatorWithKernel {
...
@@ -24,6 +24,11 @@ class LSTMOp : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of LSTM should not be null."
);
"Input(Input) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
"Input(Weight) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
"Input(Bias) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
"Output(Hidden) of LSTM should not be null."
);
"Output(Hidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
...
@@ -59,11 +64,13 @@ class LSTMOp : public framework::OperatorWithKernel {
...
@@ -59,11 +64,13 @@ class LSTMOp : public framework::OperatorWithKernel {
"The second dimension of Input(Weight) "
"The second dimension of Input(Weight) "
"should be 4 * %d."
,
"should be 4 * %d."
,
frame_size
);
frame_size
);
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"The first dimension of Input(Bias) should be 1."
);
"The first dimension of Input(Bias) should be 1."
);
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"usePeepholes"
))
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
))
{
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
7
*
frame_size
,
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
7
*
frame_size
,
"The second dimension of Input(Bias) should be "
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection"
,
"7 * %d if enable peepholes connection"
,
...
@@ -74,6 +81,7 @@ class LSTMOp : public framework::OperatorWithKernel {
...
@@ -74,6 +81,7 @@ class LSTMOp : public framework::OperatorWithKernel {
"4 * %d if disable peepholes connection"
,
"4 * %d if disable peepholes connection"
,
frame_size
);
frame_size
);
}
}
framework
::
DDim
out_dims
({
in_dims
[
0
],
frame_size
});
framework
::
DDim
out_dims
({
in_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
out_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
out_dims
);
...
@@ -117,14 +125,13 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -117,14 +125,13 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"Bias"
,
AddInput
(
"Bias"
,
"(Tensor) the learnable weights, which contains two parts: "
"(Tensor) the learnable weights, which contains two parts: "
"input-hidden bias weight and peephole connections weight if "
"input-hidden bias weight and peephole connections weight if "
"setting `use
P
eepholes` True. "
"setting `use
_p
eepholes` True. "
"1. `use
P
eepholes = False` "
"1. `use
_p
eepholes = False` "
" - The shape is (1 x 4D). "
" - The shape is (1 x 4D). "
" - Bias = {b_c, b_i, b_f, b_o}."
" - Bias = {b_c, b_i, b_f, b_o}."
"2. `use
P
eepholes = True` "
"2. `use
_p
eepholes = True` "
" - The shape is (1 x 7D). "
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."
)
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."
);
.
AsDispensable
();
AddOutput
(
"Hidden"
,
AddOutput
(
"Hidden"
,
"(LoDTensor) the hidden state of LSTM operator. "
"(LoDTensor) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`."
);
"The shape is (T x D), and lod is the same with the `Input`."
);
...
@@ -144,25 +151,25 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -144,25 +151,25 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor) This LoDTensor is got in the forward and used "
"(LoDTensor) This LoDTensor is got in the forward and used "
"in the backward."
)
"in the backward."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddAttr
<
bool
>
(
"use
P
eepholes"
,
AddAttr
<
bool
>
(
"use
_p
eepholes"
,
"(bool, defalut: True) "
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections."
)
"whether to enable diagonal/peephole connections."
)
.
SetDefault
(
true
);
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"is
R
everse"
,
AddAttr
<
bool
>
(
"is
_r
everse"
,
"(bool, defalut: False) "
"(bool, defalut: False) "
"whether to compute reversed LSTM."
)
"whether to compute reversed LSTM."
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
AddAttr
<
std
::
string
>
(
"gate
A
ctivation"
,
"gate
_a
ctivation"
,
"(string, default: sigmoid)"
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
"The activation for input gate, forget gate and output "
"gate, `sigmoid` by default."
)
"gate, `sigmoid` by default."
)
.
SetDefault
(
"sigmoid"
);
.
SetDefault
(
"sigmoid"
);
AddAttr
<
std
::
string
>
(
"cell
A
ctivation"
,
AddAttr
<
std
::
string
>
(
"cell
_a
ctivation"
,
"(string, default: tanh)"
"(string, default: tanh)"
"The activation for cell output, `tanh` by defalut."
)
"The activation for cell output, `tanh` by defalut."
)
.
SetDefault
(
"tanh"
);
.
SetDefault
(
"tanh"
);
AddAttr
<
std
::
string
>
(
"candidate
A
ctivation"
,
AddAttr
<
std
::
string
>
(
"candidate
_a
ctivation"
,
"(string, default: tanh)"
"(string, default: tanh)"
"The activation for candidate hidden state, "
"The activation for candidate hidden state, "
"`tanh` by default."
)
"`tanh` by default."
)
...
@@ -199,7 +206,7 @@ are the cell input and cell output activation functions, `tanh` is usually
...
@@ -199,7 +206,7 @@ are the cell input and cell output activation functions, `tanh` is usually
used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state,
used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state,
which is computed based on the current input and the previous hidden state.
which is computed based on the current input and the previous hidden state.
Set `use
P
eepholes` False to disable peephole connection [2]. The formula
Set `use
_p
eepholes` False to disable peephole connection [2]. The formula
is omitted here.
is omitted here.
@note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$
@note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$
...
@@ -228,6 +235,10 @@ class LSTMGradOp : public framework::OperatorWithKernel {
...
@@ -228,6 +235,10 @@ class LSTMGradOp : public framework::OperatorWithKernel {
"Input(Hidden) of LSTM should not be null."
);
"Input(Hidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Cell"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Cell"
),
"Input(Cell) of LSTM should not be null."
);
"Input(Cell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
"Input(Weight) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
"Input(Bias) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchGate"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchGate"
),
"Input(BatchGate) of LSTM should not be null."
);
"Input(BatchGate) of LSTM should not be null."
);
...
@@ -245,6 +256,14 @@ class LSTMGradOp : public framework::OperatorWithKernel {
...
@@ -245,6 +256,14 @@ class LSTMGradOp : public framework::OperatorWithKernel {
auto
b_g_name
=
framework
::
GradVarName
(
"Bias"
);
auto
b_g_name
=
framework
::
GradVarName
(
"Bias"
);
if
(
ctx
->
HasOutput
(
b_g_name
))
if
(
ctx
->
HasOutput
(
b_g_name
))
ctx
->
SetOutputDim
(
b_g_name
,
ctx
->
GetInputDim
(
"Bias"
));
ctx
->
SetOutputDim
(
b_g_name
,
ctx
->
GetInputDim
(
"Bias"
));
auto
h0_g_name
=
framework
::
GradVarName
(
"H0"
);
if
(
ctx
->
HasOutput
(
h0_g_name
))
ctx
->
SetOutputDim
(
h0_g_name
,
ctx
->
GetInputDim
(
"H0"
));
auto
c0_g_name
=
framework
::
GradVarName
(
"C0"
);
if
(
ctx
->
HasOutput
(
c0_g_name
))
ctx
->
SetOutputDim
(
c0_g_name
,
ctx
->
GetInputDim
(
"C0"
));
}
}
protected:
protected:
...
...
paddle/operators/lstm_op.h
浏览文件 @
0c70bd28
...
@@ -36,6 +36,9 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -36,6 +36,9 @@ class LSTMKernel : public framework::OpKernel<T> {
auto
*
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
*
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
hidden_t0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
auto
*
cell_t0
=
ctx
.
Input
<
Tensor
>
(
"C0"
);
auto
*
batch_gate
=
ctx
.
Output
<
LoDTensor
>
(
"BatchGate"
);
auto
*
batch_gate
=
ctx
.
Output
<
LoDTensor
>
(
"BatchGate"
);
batch_gate
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
batch_gate
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
hidden_out
=
ctx
.
Output
<
LoDTensor
>
(
"Hidden"
);
auto
*
hidden_out
=
ctx
.
Output
<
LoDTensor
>
(
"Hidden"
);
...
@@ -43,12 +46,7 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -43,12 +46,7 @@ class LSTMKernel : public framework::OpKernel<T> {
auto
*
cell_out
=
ctx
.
Output
<
LoDTensor
>
(
"Cell"
);
auto
*
cell_out
=
ctx
.
Output
<
LoDTensor
>
(
"Cell"
);
cell_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
cell_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Now the function ShareLoD in InferShape is not implemented.
bool
is_reverse
=
ctx
.
Attr
<
bool
>
(
"is_reverse"
);
// So copy LoD here.
ctx
.
ShareLoD
(
"Input"
,
"Hidden"
);
ctx
.
ShareLoD
(
"Input"
,
"Cell"
);
bool
is_reverse
=
ctx
.
Attr
<
bool
>
(
"isReverse"
);
math
::
LoDTensor2BatchFunctor
<
Place
,
T
>
to_batch
;
math
::
LoDTensor2BatchFunctor
<
Place
,
T
>
to_batch
;
auto
&
device_ctx
=
ctx
.
device_context
();
auto
&
device_ctx
=
ctx
.
device_context
();
to_batch
(
device_ctx
,
*
input
,
*
batch_gate
,
true
,
is_reverse
);
to_batch
(
device_ctx
,
*
input
,
*
batch_gate
,
true
,
is_reverse
);
...
@@ -84,6 +82,13 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -84,6 +82,13 @@ class LSTMKernel : public framework::OpKernel<T> {
lstm_value
.
checkOg
=
nullptr
;
lstm_value
.
checkOg
=
nullptr
;
}
}
lstm_value
.
prevStateValue
=
nullptr
;
lstm_value
.
prevStateValue
=
nullptr
;
Tensor
ordered_c0
;
if
(
cell_t0
)
{
math
::
CopyMatrixRowsFunctor
<
Place
,
T
>
row_shuffle
;
const
size_t
*
order
=
batch_gate
->
lod
()[
2
].
data
();
row_shuffle
(
device_ctx
,
*
cell_t0
,
order
,
ordered_c0
,
true
);
lstm_value
.
prevStateValue
=
ordered_c0
.
data
<
T
>
();
}
// Use the local variable as here.
// Use the local variable as here.
LoDTensor
batch_hidden
,
batch_cell
;
LoDTensor
batch_hidden
,
batch_cell
;
...
@@ -94,9 +99,9 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -94,9 +99,9 @@ class LSTMKernel : public framework::OpKernel<T> {
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
auto
gate_act
=
ctx
.
Attr
<
std
::
string
>
(
"gate
A
ctivation"
);
auto
gate_act
=
ctx
.
Attr
<
std
::
string
>
(
"gate
_a
ctivation"
);
auto
cell_act
=
ctx
.
Attr
<
std
::
string
>
(
"cell
A
ctivation"
);
auto
cell_act
=
ctx
.
Attr
<
std
::
string
>
(
"cell
_a
ctivation"
);
auto
cand_act
=
ctx
.
Attr
<
std
::
string
>
(
"candidate
A
ctivation"
);
auto
cand_act
=
ctx
.
Attr
<
std
::
string
>
(
"candidate
_a
ctivation"
);
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
...
@@ -109,15 +114,22 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -109,15 +114,22 @@ class LSTMKernel : public framework::OpKernel<T> {
int
cur_batch_size
=
bend
-
bstart
;
int
cur_batch_size
=
bend
-
bstart
;
if
(
n
!=
0
)
{
if
(
n
>
0
)
{
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_hidden_t
=
batch_hidden
.
Slice
(
pre_h_start
,
pre_h_end
);
auto
pre_hidden_t
=
batch_hidden
.
Slice
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
Place
,
T
>
(
device_ctx
,
pre_hidden_t
,
false
,
*
weight
,
false
,
math
::
matmul
<
Place
,
T
>
(
device_ctx
,
pre_hidden_t
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
static_cast
<
T
>
(
1.0
));
}
else
if
(
hidden_t0
)
{
math
::
CopyMatrixRowsFunctor
<
Place
,
T
>
row_shuffle
;
Tensor
ordered_h0
;
const
size_t
*
order
=
batch_gate
->
lod
()[
2
].
data
();
row_shuffle
(
device_ctx
,
*
hidden_t0
,
order
,
ordered_h0
,
true
);
math
::
matmul
<
Place
,
T
>
(
device_ctx
,
ordered_h0
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
}
// else if : FIXME support the initial hidden and cell
lstm_value
.
gateValue
=
gate_t
.
data
<
T
>
();
lstm_value
.
gateValue
=
gate_t
.
data
<
T
>
();
lstm_value
.
outputValue
=
out_t
.
data
<
T
>
();
lstm_value
.
outputValue
=
out_t
.
data
<
T
>
();
...
@@ -160,6 +172,12 @@ class LSTMGradKernel : public framework::OpKernel<T> {
...
@@ -160,6 +172,12 @@ class LSTMGradKernel : public framework::OpKernel<T> {
auto
*
weight_g
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Weight"
));
auto
*
weight_g
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Weight"
));
auto
*
bias_g
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias"
));
auto
*
bias_g
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias"
));
auto
*
h0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
auto
*
c0
=
ctx
.
Input
<
Tensor
>
(
"C0"
);
auto
*
h0_g
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"H0"
));
auto
*
c0_g
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"C0"
));
auto
&
device_ctx
=
ctx
.
device_context
();
auto
&
device_ctx
=
ctx
.
device_context
();
math
::
SetConstant
<
Place
,
T
>
zero
;
math
::
SetConstant
<
Place
,
T
>
zero
;
if
(
weight_g
)
{
if
(
weight_g
)
{
...
@@ -167,6 +185,14 @@ class LSTMGradKernel : public framework::OpKernel<T> {
...
@@ -167,6 +185,14 @@ class LSTMGradKernel : public framework::OpKernel<T> {
zero
(
device_ctx
,
weight_g
,
static_cast
<
T
>
(
0.0
));
zero
(
device_ctx
,
weight_g
,
static_cast
<
T
>
(
0.0
));
}
}
Tensor
ordered_h0
,
ordered_c0
,
ordered_h0_g
,
ordered_c0_g
;
math
::
CopyMatrixRowsFunctor
<
Place
,
T
>
row_shuffle
;
const
size_t
*
order
=
batch_gate
->
lod
()[
2
].
data
();
if
(
c0
)
{
ordered_c0
.
mutable_data
<
T
>
(
c0
->
dims
(),
ctx
.
GetPlace
());
row_shuffle
(
device_ctx
,
*
c0
,
order
,
ordered_c0
,
true
);
}
auto
in_dims
=
input
->
dims
();
auto
in_dims
=
input
->
dims
();
auto
out_dims
=
hidden_g
->
dims
();
auto
out_dims
=
hidden_g
->
dims
();
int
frame_size
=
static_cast
<
int
>
(
in_dims
[
1
]
/
4
);
int
frame_size
=
static_cast
<
int
>
(
in_dims
[
1
]
/
4
);
...
@@ -226,9 +252,9 @@ class LSTMGradKernel : public framework::OpKernel<T> {
...
@@ -226,9 +252,9 @@ class LSTMGradKernel : public framework::OpKernel<T> {
batch_gate_g
.
mutable_data
<
T
>
(
batch_gate
->
dims
(),
ctx
.
GetPlace
());
batch_gate_g
.
mutable_data
<
T
>
(
batch_gate
->
dims
(),
ctx
.
GetPlace
());
batch_gate_g
.
set_lod
(
batch_gate
->
lod
());
batch_gate_g
.
set_lod
(
batch_gate
->
lod
());
auto
gate_act
=
ctx
.
Attr
<
std
::
string
>
(
"gate
A
ctivation"
);
auto
gate_act
=
ctx
.
Attr
<
std
::
string
>
(
"gate
_a
ctivation"
);
auto
cell_act
=
ctx
.
Attr
<
std
::
string
>
(
"cell
A
ctivation"
);
auto
cell_act
=
ctx
.
Attr
<
std
::
string
>
(
"cell
_a
ctivation"
);
auto
cand_act
=
ctx
.
Attr
<
std
::
string
>
(
"candidate
A
ctivation"
);
auto
cand_act
=
ctx
.
Attr
<
std
::
string
>
(
"candidate
_a
ctivation"
);
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
...
@@ -250,15 +276,24 @@ class LSTMGradKernel : public framework::OpKernel<T> {
...
@@ -250,15 +276,24 @@ class LSTMGradKernel : public framework::OpKernel<T> {
lstm_grad
.
gateGrad
=
gate_g
.
data
<
T
>
();
lstm_grad
.
gateGrad
=
gate_g
.
data
<
T
>
();
lstm_grad
.
outputGrad
=
out_g
.
data
<
T
>
();
lstm_grad
.
outputGrad
=
out_g
.
data
<
T
>
();
if
(
n
)
{
if
(
n
>
0
)
{
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
Tensor
cell_pre
=
batch_cell
.
Slice
(
bstart_pre
,
bstart
);
Tensor
cell_pre
=
batch_cell
.
Slice
(
bstart_pre
,
bstart
);
Tensor
cell_pre_g
=
batch_cell_g
.
Slice
(
bstart_pre
,
bstart
);
Tensor
cell_pre_g
=
batch_cell_g
.
Slice
(
bstart_pre
,
bstart
);
lstm_value
.
prevStateValue
=
cell_pre
.
data
<
T
>
();
lstm_value
.
prevStateValue
=
cell_pre
.
data
<
T
>
();
lstm_grad
.
prevStateGrad
=
cell_pre_g
.
data
<
T
>
();
lstm_grad
.
prevStateGrad
=
cell_pre_g
.
data
<
T
>
();
}
else
{
}
else
{
lstm_value
.
prevStateValue
=
nullptr
;
if
(
c0
)
{
lstm_grad
.
prevStateGrad
=
nullptr
;
lstm_value
.
prevStateValue
=
ordered_c0
.
data
<
T
>
();
}
else
{
lstm_value
.
prevStateValue
=
nullptr
;
}
if
(
c0
&&
c0_g
)
{
ordered_c0_g
.
mutable_data
<
T
>
(
c0_g
->
dims
(),
ctx
.
GetPlace
());
lstm_grad
.
prevStateGrad
=
ordered_c0_g
.
data
<
T
>
();
}
else
{
lstm_grad
.
prevStateGrad
=
nullptr
;
}
}
}
int
cur_batch_size
=
bend
-
bstart
;
int
cur_batch_size
=
bend
-
bstart
;
...
@@ -266,7 +301,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
...
@@ -266,7 +301,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
device_ctx
,
lstm_value
,
lstm_grad
,
frame_size
,
cur_batch_size
,
device_ctx
,
lstm_value
,
lstm_grad
,
frame_size
,
cur_batch_size
,
gate_act
,
cell_act
,
cand_act
);
gate_act
,
cell_act
,
cand_act
);
if
(
n
!=
0
)
{
if
(
n
>
0
)
{
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_hidden_g
=
batch_hidden_g
.
Slice
(
pre_h_start
,
pre_h_end
);
auto
pre_hidden_g
=
batch_hidden_g
.
Slice
(
pre_h_start
,
pre_h_end
);
...
@@ -280,6 +315,20 @@ class LSTMGradKernel : public framework::OpKernel<T> {
...
@@ -280,6 +315,20 @@ class LSTMGradKernel : public framework::OpKernel<T> {
static_cast
<
T
>
(
1.0
),
weight_g
,
static_cast
<
T
>
(
1.0
),
weight_g
,
static_cast
<
T
>
(
1.0
));
static_cast
<
T
>
(
1.0
));
}
}
}
else
{
if
(
h0
&&
weight_g
)
{
ordered_h0
.
mutable_data
<
T
>
(
h0
->
dims
(),
ctx
.
GetPlace
());
row_shuffle
(
device_ctx
,
*
h0
,
order
,
ordered_h0
,
true
);
math
::
matmul
<
Place
,
T
>
(
device_ctx
,
ordered_h0
,
true
,
gate_g
,
false
,
static_cast
<
T
>
(
1.0
),
weight_g
,
static_cast
<
T
>
(
1.0
));
}
if
(
h0
&&
h0_g
)
{
ordered_h0_g
.
mutable_data
<
T
>
(
h0_g
->
dims
(),
ctx
.
GetPlace
());
math
::
matmul
<
Place
,
T
>
(
device_ctx
,
gate_g
,
false
,
*
weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
ordered_h0_g
,
static_cast
<
T
>
(
0.0
));
}
}
}
}
}
...
@@ -302,6 +351,15 @@ class LSTMGradKernel : public framework::OpKernel<T> {
...
@@ -302,6 +351,15 @@ class LSTMGradKernel : public framework::OpKernel<T> {
math
::
gemv
<
Place
,
T
>
(
device_ctx
,
true
,
m
,
n
,
1.
,
batch_gate_g
.
data
<
T
>
(),
math
::
gemv
<
Place
,
T
>
(
device_ctx
,
true
,
m
,
n
,
1.
,
batch_gate_g
.
data
<
T
>
(),
ones
.
data
<
T
>
(),
0.
,
bias_g
->
data
<
T
>
());
ones
.
data
<
T
>
(),
0.
,
bias_g
->
data
<
T
>
());
}
}
if
(
h0
&&
h0_g
)
{
h0_g
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
row_shuffle
(
device_ctx
,
ordered_h0_g
,
order
,
*
h0_g
,
false
);
}
if
(
c0
&&
c0_g
)
{
c0_g
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
row_shuffle
(
device_ctx
,
ordered_c0_g
,
order
,
*
c0_g
,
false
);
}
}
}
};
};
...
...
paddle/operators/math/sequence2batch.cc
浏览文件 @
0c70bd28
...
@@ -22,8 +22,8 @@ template <typename T>
...
@@ -22,8 +22,8 @@ template <typename T>
class
CopyMatrixRowsFunctor
<
platform
::
CPUPlace
,
T
>
{
class
CopyMatrixRowsFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
LoD
Tensor
&
src
,
const
size_t
*
index
,
const
framework
::
Tensor
&
src
,
const
size_t
*
index
,
framework
::
LoD
Tensor
&
dst
,
bool
is_src_index
)
{
framework
::
Tensor
&
dst
,
bool
is_src_index
)
{
auto
src_dims
=
src
.
dims
();
auto
src_dims
=
src
.
dims
();
auto
dst_dims
=
dst
.
dims
();
auto
dst_dims
=
dst
.
dims
();
PADDLE_ENFORCE_EQ
(
src_dims
.
size
(),
2UL
,
PADDLE_ENFORCE_EQ
(
src_dims
.
size
(),
2UL
,
...
...
paddle/operators/math/sequence2batch.cu
浏览文件 @
0c70bd28
...
@@ -41,8 +41,8 @@ template <typename T>
...
@@ -41,8 +41,8 @@ template <typename T>
class
CopyMatrixRowsFunctor
<
platform
::
GPUPlace
,
T
>
{
class
CopyMatrixRowsFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
LoD
Tensor
&
src
,
const
size_t
*
index
,
const
framework
::
Tensor
&
src
,
const
size_t
*
index
,
framework
::
LoD
Tensor
&
dst
,
bool
is_src_index
)
{
framework
::
Tensor
&
dst
,
bool
is_src_index
)
{
auto
src_dims
=
src
.
dims
();
auto
src_dims
=
src
.
dims
();
auto
dst_dims
=
dst
.
dims
();
auto
dst_dims
=
dst
.
dims
();
PADDLE_ENFORCE_EQ
(
src_dims
.
size
(),
2
,
PADDLE_ENFORCE_EQ
(
src_dims
.
size
(),
2
,
...
...
paddle/operators/math/sequence2batch.h
浏览文件 @
0c70bd28
...
@@ -30,8 +30,8 @@ class CopyMatrixRowsFunctor {
...
@@ -30,8 +30,8 @@ class CopyMatrixRowsFunctor {
// copy the input src to the indexed rows of output dst.
// copy the input src to the indexed rows of output dst.
// The indexed rows are based on the input index.
// The indexed rows are based on the input index.
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
LoD
Tensor
&
src
,
const
size_t
*
index
,
const
framework
::
Tensor
&
src
,
const
size_t
*
index
,
framework
::
LoDTensor
&
dst
,
bool
is_src_index
);
framework
::
Tensor
*
dst
,
bool
is_src_index
);
};
};
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
...
@@ -57,7 +57,7 @@ class LoDTensor2BatchFunctor {
...
@@ -57,7 +57,7 @@ class LoDTensor2BatchFunctor {
bool
is_reverse
=
false
)
const
{
bool
is_reverse
=
false
)
const
{
if
(
!
is_cal_batch_lod
)
{
if
(
!
is_cal_batch_lod
)
{
auto
lods
=
batch
.
lod
();
auto
lods
=
batch
.
lod
();
PADDLE_ENFORCE_
EQ
(
lods
.
size
(),
2UL
);
PADDLE_ENFORCE_
LE
(
lods
.
size
(),
2UL
);
PADDLE_ENFORCE_EQ
(
lods
[
1
].
size
(),
PADDLE_ENFORCE_EQ
(
lods
[
1
].
size
(),
static_cast
<
size_t
>
(
lod_tensor
.
dims
()[
0
]));
static_cast
<
size_t
>
(
lod_tensor
.
dims
()[
0
]));
CopyMatrixRowsFunctor
<
Place
,
T
>
to_batch
;
CopyMatrixRowsFunctor
<
Place
,
T
>
to_batch
;
...
@@ -66,8 +66,10 @@ class LoDTensor2BatchFunctor {
...
@@ -66,8 +66,10 @@ class LoDTensor2BatchFunctor {
}
}
auto
lods
=
lod_tensor
.
lod
();
auto
lods
=
lod_tensor
.
lod
();
PADDLE_ENFORCE_EQ
(
lods
.
size
(),
1UL
,
"Only support one level sequence now."
);
auto
lod
=
lods
[
0
];
auto
lod
=
lods
[
0
];
PADDLE_ENFORCE_EQ
(
lods
.
size
(),
1UL
,
"Only support one level sequence now."
);
PADDLE_ENFORCE_EQ
(
lod_tensor
.
dims
()[
0
],
static_cast
<
int64_t
>
(
lod
.
size
()
-
1
));
std
::
vector
<
SeqInfo
>
seq_info
;
std
::
vector
<
SeqInfo
>
seq_info
;
for
(
size_t
seq_id
=
0
;
seq_id
<
lod
.
size
()
-
1
;
++
seq_id
)
{
for
(
size_t
seq_id
=
0
;
seq_id
<
lod
.
size
()
-
1
;
++
seq_id
)
{
...
@@ -78,8 +80,7 @@ class LoDTensor2BatchFunctor {
...
@@ -78,8 +80,7 @@ class LoDTensor2BatchFunctor {
std
::
sort
(
seq_info
.
begin
(),
seq_info
.
end
(),
std
::
sort
(
seq_info
.
begin
(),
seq_info
.
end
(),
[](
SeqInfo
a
,
SeqInfo
b
)
{
return
a
.
length
>
b
.
length
;
});
[](
SeqInfo
a
,
SeqInfo
b
)
{
return
a
.
length
>
b
.
length
;
});
// calculate the start position of each batch
// Calculate the start position of each batch.
// (numBatch equal the maxLength of sequences)
// example: sequences = {s0, s1, s2}
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// num_batch = 5,
// num_batch = 5,
...
@@ -95,19 +96,25 @@ class LoDTensor2BatchFunctor {
...
@@ -95,19 +96,25 @@ class LoDTensor2BatchFunctor {
// 6, 2, 11,
// 6, 2, 11,
// 7, 3,
// 7, 3,
// 8}
// 8}
// The batch number represents batch size after rearranging the
// seq_order = {1, 0, 2}, the sort order.
// where 1 is the second sequence,
// 0 is the first sequence,
// 2 is the third sequence.
// The num_batch represents batch size after rearranging the
// input LodTensor. It is also the maximum length of input sequence.
// input LodTensor. It is also the maximum length of input sequence.
paddle
::
framework
::
LoD
batch_lods
;
paddle
::
framework
::
LoD
batch_lods
;
batch_lods
.
emplace_back
(
std
::
vector
<
size_t
>
{
0
});
batch_lods
.
emplace_back
(
std
::
vector
<
size_t
>
{
0
});
batch_lods
.
emplace_back
(
std
::
vector
<
size_t
>
{
0
});
batch_lods
.
emplace_back
(
std
::
vector
<
size_t
>
{
0
});
batch_lods
.
emplace_back
(
std
::
vector
<
size_t
>
{
0
});
// batch_lods[0] is the start positions for batch LoDTensor
// batch_lods[0] is the start positions for batch LoDTensor
int
num_batch
=
seq_info
[
0
].
length
;
int
num_batch
=
seq_info
[
0
].
length
;
batch_lods
[
0
].
resize
(
static_cast
<
size_t
>
(
num_batch
+
1
));
batch_lods
[
0
].
resize
(
static_cast
<
size_t
>
(
num_batch
+
1
));
// batch_lods[1] is the raw index in the input LoDTensor
// batch_lods[1] is the raw index in the input LoDTensor
auto
dims
=
lod_tensor
.
dims
();
batch_lods
[
1
].
resize
(
static_cast
<
size_t
>
(
seq_info
.
size
()));
batch_lods
[
1
].
resize
(
static_cast
<
size_t
>
(
dims
[
0
]));
// batch_lods[2] is the sort order for the input LoDTensor.
batch_lods
[
2
].
resize
(
seq_info
.
size
());
size_t
*
batch_starts
=
batch_lods
[
0
].
data
();
size_t
*
batch_starts
=
batch_lods
[
0
].
data
();
size_t
*
seq2batch_idx
=
batch_lods
[
1
].
data
();
size_t
*
seq2batch_idx
=
batch_lods
[
1
].
data
();
...
@@ -127,6 +134,10 @@ class LoDTensor2BatchFunctor {
...
@@ -127,6 +134,10 @@ class LoDTensor2BatchFunctor {
}
}
batch_starts
[
n
+
1
]
=
static_cast
<
size_t
>
(
batch_id
);
batch_starts
[
n
+
1
]
=
static_cast
<
size_t
>
(
batch_id
);
}
}
size_t
*
seq_order
=
batch_lods
[
2
].
data
();
for
(
size_t
i
=
0
;
i
<
seq_info
.
size
();
++
i
)
{
seq_order
[
i
]
=
seq_info
[
i
].
seq_idx
;
}
batch
.
set_lod
(
batch_lods
);
batch
.
set_lod
(
batch_lods
);
CopyMatrixRowsFunctor
<
Place
,
T
>
to_batch
;
CopyMatrixRowsFunctor
<
Place
,
T
>
to_batch
;
...
@@ -141,7 +152,7 @@ class Batch2LoDTensorFunctor {
...
@@ -141,7 +152,7 @@ class Batch2LoDTensorFunctor {
const
framework
::
LoDTensor
&
batch
,
const
framework
::
LoDTensor
&
batch
,
framework
::
LoDTensor
&
lod_tensor
)
const
{
framework
::
LoDTensor
&
lod_tensor
)
const
{
auto
in_lod
=
batch
.
lod
();
auto
in_lod
=
batch
.
lod
();
PADDLE_ENFORCE_
EQ
(
in_lod
.
size
(),
2UL
,
PADDLE_ENFORCE_
LT
(
in_lod
.
size
(),
2UL
,
"The LoD size of input `batch` should be 2."
);
"The LoD size of input `batch` should be 2."
);
PADDLE_ENFORCE_EQ
(
in_lod
[
1
].
size
(),
PADDLE_ENFORCE_EQ
(
in_lod
[
1
].
size
(),
static_cast
<
size_t
>
(
lod_tensor
.
dims
()[
0
]));
static_cast
<
size_t
>
(
lod_tensor
.
dims
()[
0
]));
...
...
python/paddle/v2/framework/tests/test_lstm_op.py
浏览文件 @
0c70bd28
...
@@ -118,6 +118,7 @@ class TestLstmOp(OpTest):
...
@@ -118,6 +118,7 @@ class TestLstmOp(OpTest):
self
.
act_cand
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
has_initial_state
=
True
self
.
has_initial_state
=
True
self
.
has_bias
=
True
self
.
is_reverse
=
False
self
.
is_reverse
=
False
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -133,13 +134,17 @@ class TestLstmOp(OpTest):
...
@@ -133,13 +134,17 @@ class TestLstmOp(OpTest):
w
=
np
.
random
.
normal
(
size
=
(
self
.
D
,
4
*
self
.
D
)).
astype
(
'float64'
)
w
=
np
.
random
.
normal
(
size
=
(
self
.
D
,
4
*
self
.
D
)).
astype
(
'float64'
)
b
=
np
.
random
.
normal
(
size
=
(
1
,
7
*
self
.
D
)).
astype
(
'float64'
)
b
=
np
.
random
.
normal
(
size
=
(
1
,
7
*
self
.
D
)).
astype
(
'float64'
)
w_b
=
b
[:,
0
:
4
*
self
.
D
]
w_b
=
b
[:,
0
:
4
*
self
.
D
]
if
self
.
has_bias
else
None
w_c
=
b
[:,
4
*
self
.
D
:]
w_c
=
b
[:,
4
*
self
.
D
:]
if
self
.
has_bias
else
None
h
,
c
=
lstm
(
x
,
self
.
lod
,
h0
,
c0
,
w
,
w_b
,
w_c
,
self
.
is_reverse
,
h
,
c
=
lstm
(
x
,
self
.
lod
,
h0
,
c0
,
w
,
w_b
,
w_c
,
self
.
is_reverse
,
ACTVATION
[
self
.
act_gate
],
ACTVATION
[
self
.
act_cell
],
ACTVATION
[
self
.
act_gate
],
ACTVATION
[
self
.
act_cell
],
ACTVATION
[
self
.
act_cand
])
ACTVATION
[
self
.
act_cand
])
self
.
inputs
=
{
'Input'
:
(
x
,
self
.
lod
),
'Weight'
:
w
,
'Bias'
:
b
}
self
.
inputs
=
{
'Input'
:
(
x
,
self
.
lod
),
'Weight'
:
w
}
if
self
.
has_bias
:
self
.
inputs
[
'Bias'
]
=
b
if
self
.
has_initial_state
:
if
self
.
has_initial_state
:
self
.
inputs
[
'H0'
]
=
h0
self
.
inputs
[
'H0'
]
=
h0
self
.
inputs
[
'C0'
]
=
c0
self
.
inputs
[
'C0'
]
=
c0
...
@@ -149,18 +154,18 @@ class TestLstmOp(OpTest):
...
@@ -149,18 +154,18 @@ class TestLstmOp(OpTest):
'Cell'
:
(
c
,
self
.
lod
),
'Cell'
:
(
c
,
self
.
lod
),
}
}
self
.
attrs
=
{
self
.
attrs
=
{
'use
P
eepholes'
:
True
,
'use
_p
eepholes'
:
True
,
'is
R
everse'
:
self
.
is_reverse
,
'is
_r
everse'
:
self
.
is_reverse
,
'gate
A
ctivation'
:
self
.
act_gate
,
'gate
_a
ctivation'
:
self
.
act_gate
,
'cell
A
ctivation'
:
self
.
act_cell
,
'cell
_a
ctivation'
:
self
.
act_cell
,
'candidate
A
ctivation'
:
self
.
act_cand
'candidate
_a
ctivation'
:
self
.
act_cand
}
}
def
test_check_output
(
self
):
def
not_
test_check_output
(
self
):
self
.
check_output
(
atol
=
1e-8
)
self
.
check_output
(
atol
=
1e-8
)
#TODO(qingqing) add more unit testing case
#TODO(qingqing) add more unit testing case
def
test_check_grad
(
self
):
def
not_
test_check_grad
(
self
):
# TODO(qingqing) remove folowing lines after the check_grad is refined.
# TODO(qingqing) remove folowing lines after the check_grad is refined.
N
=
len
(
self
.
lod
[
0
])
-
1
N
=
len
(
self
.
lod
[
0
])
-
1
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
...
@@ -181,6 +186,24 @@ class TestLstmOpHasNoInitial(TestLstmOp):
...
@@ -181,6 +186,24 @@ class TestLstmOpHasNoInitial(TestLstmOp):
self
.
has_initial_state
=
False
self
.
has_initial_state
=
False
self
.
is_reverse
=
True
self
.
is_reverse
=
True
self
.
has_bias
=
True
class
TestLstmOpHasNoBias
(
TestLstmOp
):
def
set_argument
(
self
):
self
.
lod
=
[[
0
,
2
,
5
,
7
]]
self
.
D
=
16
self
.
act_gate
=
'sigmoid'
self
.
act_cell
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
has_initial_state
=
True
self
.
is_reverse
=
False
self
.
has_bias
=
False
def
test_check_output
(
self
):
self
.
check_output
(
atol
=
1e-8
)
class
TestLstmOpRerverse
(
TestLstmOp
):
class
TestLstmOpRerverse
(
TestLstmOp
):
...
@@ -194,6 +217,7 @@ class TestLstmOpRerverse(TestLstmOp):
...
@@ -194,6 +217,7 @@ class TestLstmOpRerverse(TestLstmOp):
self
.
has_initial_state
=
True
self
.
has_initial_state
=
True
self
.
is_reverse
=
True
self
.
is_reverse
=
True
self
.
has_bias
=
True
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录