Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
36d20609
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
36d20609
编写于
11月 01, 2017
作者:
Q
qingqing01
提交者:
GitHub
11月 01, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5115 from qingqing01/lstm_bp
Add backward implementation for LSTM operator.
上级
3d567864
7061e013
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
579 addition
and
165 deletion
+579
-165
paddle/framework/lod_tensor_test.cu
paddle/framework/lod_tensor_test.cu
+4
-4
paddle/operators/lstm_op.cc
paddle/operators/lstm_op.cc
+67
-26
paddle/operators/lstm_op.h
paddle/operators/lstm_op.h
+197
-27
paddle/operators/math/detail/lstm_cpu_kernel.h
paddle/operators/math/detail/lstm_cpu_kernel.h
+6
-17
paddle/operators/math/detail/lstm_gpu_kernel.h
paddle/operators/math/detail/lstm_gpu_kernel.h
+8
-20
paddle/operators/math/detail/lstm_kernel.h
paddle/operators/math/detail/lstm_kernel.h
+51
-8
paddle/operators/math/math_function.cc
paddle/operators/math/math_function.cc
+20
-0
paddle/operators/math/math_function.cu
paddle/operators/math/math_function.cu
+27
-0
paddle/operators/math/math_function.h
paddle/operators/math/math_function.h
+5
-0
paddle/operators/math/math_function_test.cc
paddle/operators/math/math_function_test.cc
+50
-0
paddle/operators/math/math_function_test.cu
paddle/operators/math/math_function_test.cu
+62
-0
paddle/operators/math/sequence2batch.h
paddle/operators/math/sequence2batch.h
+16
-8
python/paddle/v2/framework/tests/test_lstm_op.py
python/paddle/v2/framework/tests/test_lstm_op.py
+66
-55
未找到文件。
paddle/framework/lod_tensor_test.cu
浏览文件 @
36d20609
...
@@ -36,8 +36,8 @@ TEST(LoDTensor, LoDInGPU) {
...
@@ -36,8 +36,8 @@ TEST(LoDTensor, LoDInGPU) {
lod_tensor
.
mutable_data
<
float
>
(
place
);
lod_tensor
.
mutable_data
<
float
>
(
place
);
lod_tensor
.
set_lod
(
src_lod
);
lod_tensor
.
set_lod
(
src_lod
);
CHECK
_EQ
(
lod_tensor
.
lod_element
(
0
,
2
).
first
,
4UL
);
EXPECT
_EQ
(
lod_tensor
.
lod_element
(
0
,
2
).
first
,
4UL
);
CHECK
_EQ
(
lod_tensor
.
lod_element
(
0
,
4
).
first
,
8UL
);
EXPECT
_EQ
(
lod_tensor
.
lod_element
(
0
,
4
).
first
,
8UL
);
auto
lod
=
lod_tensor
.
lod
();
auto
lod
=
lod_tensor
.
lod
();
...
@@ -45,6 +45,6 @@ TEST(LoDTensor, LoDInGPU) {
...
@@ -45,6 +45,6 @@ TEST(LoDTensor, LoDInGPU) {
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
for
(
size_t
i
=
0
;
i
<
src_lod
[
0
].
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
src_lod
[
0
].
size
();
++
i
)
{
CHECK
_EQ
(
lod
[
0
].
data
()[
i
],
src_lod
[
0
].
data
()[
i
]
*
2
);
EXPECT
_EQ
(
lod
[
0
].
data
()[
i
],
src_lod
[
0
].
data
()[
i
]
*
2
);
}
}
}
}
\ No newline at end of file
paddle/operators/lstm_op.cc
浏览文件 @
36d20609
...
@@ -21,7 +21,6 @@ class LSTMOp : public framework::OperatorWithKernel {
...
@@ -21,7 +21,6 @@ class LSTMOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of LSTM should not be null."
);
"Input(Input) of LSTM should not be null."
);
...
@@ -29,9 +28,13 @@ class LSTMOp : public framework::OperatorWithKernel {
...
@@ -29,9 +28,13 @@ class LSTMOp : public framework::OperatorWithKernel {
"Output(Hidden) of LSTM should not be null."
);
"Output(Hidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
"Output(Cell) of LSTM should not be null."
);
"Output(Cell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchGate"
),
"Output(BatchGate) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchCellPreAct"
),
"Output(BatchGate) of LSTM should not be null."
);
auto
x
_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
in
_dims
=
ctx
->
GetInputDim
(
"Input"
);
PADDLE_ENFORCE_EQ
(
x
_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
in
_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
if
(
ctx
->
HasInput
(
"H0"
))
{
if
(
ctx
->
HasInput
(
"H0"
))
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
...
@@ -44,7 +47,7 @@ class LSTMOp : public framework::OperatorWithKernel {
...
@@ -44,7 +47,7 @@ class LSTMOp : public framework::OperatorWithKernel {
"should be the same."
);
"should be the same."
);
}
}
int
frame_size
=
x
_dims
[
1
]
/
4
;
int
frame_size
=
in
_dims
[
1
]
/
4
;
auto
w_dims
=
ctx
->
GetInputDim
(
"Weight"
);
auto
w_dims
=
ctx
->
GetInputDim
(
"Weight"
);
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
2
,
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
2
,
"The rank of Input(Weight) should be 2."
);
"The rank of Input(Weight) should be 2."
);
...
@@ -71,12 +74,21 @@ class LSTMOp : public framework::OperatorWithKernel {
...
@@ -71,12 +74,21 @@ class LSTMOp : public framework::OperatorWithKernel {
"4 * %d if disable peepholes connection"
,
"4 * %d if disable peepholes connection"
,
frame_size
);
frame_size
);
}
}
ctx
->
SetOutputDim
(
"Hidden"
,
{
x_dims
[
0
],
frame_size
});
framework
::
DDim
out_dims
({
in_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Cell"
,
{
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchGate"
,
x_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchGate"
,
in_dims
);
ctx
->
SetOutputDim
(
"BatchCellPreAct"
,
out_dims
);
ctx
->
ShareLoD
(
"Input"
,
"Hidden"
);
ctx
->
ShareLoD
(
"Input"
,
"Hidden"
);
ctx
->
ShareLoD
(
"Input"
,
"Cell"
);
ctx
->
ShareLoD
(
"Input"
,
"Cell"
);
}
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
());
}
};
};
class
LSTMOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
LSTMOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
@@ -86,16 +98,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -86,16 +98,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"Input"
,
AddInput
(
"Input"
,
"(LoDTensor) the first input is a LodTensor, which support "
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X 4D), where
,
T is the "
"this LoDTensor is a matrix with shape (T X 4D), where T is the "
"total time steps in this mini-batch, D is the hidden size."
);
"total time steps in this mini-batch, D is the hidden size."
);
AddInput
(
"H0"
,
AddInput
(
"H0"
,
"(Tensor, optional) the initial hidden state is an optional "
"(Tensor, optional) the initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"input. This is a tensor with shape (N x D), where N is the "
"batch size, D is the hidden size."
);
"batch size, D is the hidden size."
)
.
AsDispensable
();
AddInput
(
"C0"
,
AddInput
(
"C0"
,
"(Tensor, optional) the initial cell state is an optional "
"(Tensor, optional) the initial cell state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"input. This is a tensor with shape (N x D), where N is the "
"batch size. `H0` and `C0` can be NULL but only at the same time"
);
"batch size. `H0` and `C0` can be NULL but only at the same time"
)
.
AsDispensable
();
AddInput
(
"Weight"
,
AddInput
(
"Weight"
,
"(Tensor) the learnable hidden-hidden weights."
"(Tensor) the learnable hidden-hidden weights."
" - The shape is (D x 4D), where D is the hidden size. "
" - The shape is (D x 4D), where D is the hidden size. "
...
@@ -109,22 +123,27 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -109,22 +123,27 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
" - Bias = {b_c, b_i, b_f, b_o}."
" - Bias = {b_c, b_i, b_f, b_o}."
"2. `usePeepholes = True` "
"2. `usePeepholes = True` "
" - The shape is (1 x 7D). "
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."
);
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."
)
.
AsDispensable
();
AddOutput
(
"Hidden"
,
"(LoDTensor) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`."
);
AddOutput
(
"Cell"
,
"(LoDTensor) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`."
);
AddOutput
(
"BatchGate"
,
AddOutput
(
"BatchGate"
,
"(LoDTensor) This LoDTensor contains input gate, forget gate "
"(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate after the nonlinear computation. This "
"and output gate after the nonlinear computation. This "
"LoDTensor has the same shape with the reorganized input, which "
"LoDTensor has the same shape with the reorganized input, which "
"
wa
s also be called batch input. The LoD size is 2. The first "
"
i
s also be called batch input. The LoD size is 2. The first "
"LoD is the batch offsets and the second LoD contains the "
"LoD is the batch offsets and the second LoD contains the "
"indexes, which denote the position of reorganized sequence "
"indexes, which denote the position of reorganized sequence "
"in the raw input."
)
"in the raw input."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
"Hidden"
,
AddOutput
(
"BatchCellPreAct"
,
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
"(LoDTensor) This LoDTensor is got in the forward and used "
"The shape and lod is the same with the `Input`."
);
"in the backward."
)
AddOutput
(
"Cell"
,
.
AsIntermediate
();
"(LoDTensor) the cell state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`."
);
AddAttr
<
bool
>
(
"usePeepholes"
,
AddAttr
<
bool
>
(
"usePeepholes"
,
"(bool, defalut: True) "
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections."
)
"whether to enable diagonal/peephole connections."
)
...
@@ -202,15 +221,37 @@ class LSTMGradOp : public framework::OperatorWithKernel {
...
@@ -202,15 +221,37 @@ class LSTMGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Hidden"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Hidden@GRAD) should not be null"
);
"Input(Input) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Cell"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Hidden"
),
"Input(Cell@GRAD) should not be null"
);
"Input(Hidden) of LSTM should not be null."
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Weight"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Cell"
),
ctx
->
GetInputDim
(
"Weight"
));
"Input(Cell) of LSTM should not be null."
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Bias"
),
ctx
->
GetInputDim
(
"Bias"
));
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchGate"
),
"Input(BatchGate) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchCellPreAct"
),
"Input(BatchGate) of LSTM should not be null."
);
auto
in_g_name
=
framework
::
GradVarName
(
"Input"
);
if
(
ctx
->
HasOutput
(
in_g_name
))
ctx
->
SetOutputDim
(
in_g_name
,
ctx
->
GetInputDim
(
"Input"
));
auto
w_g_name
=
framework
::
GradVarName
(
"Weight"
);
if
(
ctx
->
HasOutput
(
w_g_name
))
ctx
->
SetOutputDim
(
w_g_name
,
ctx
->
GetInputDim
(
"Weight"
));
auto
b_g_name
=
framework
::
GradVarName
(
"Bias"
);
if
(
ctx
->
HasOutput
(
b_g_name
))
ctx
->
SetOutputDim
(
b_g_name
,
ctx
->
GetInputDim
(
"Bias"
));
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
());
}
}
};
};
...
...
paddle/operators/lstm_op.h
浏览文件 @
36d20609
...
@@ -21,8 +21,9 @@ limitations under the License. */
...
@@ -21,8 +21,9 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
framework
::
LoDTensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
...
@@ -31,15 +32,15 @@ template <typename Place, typename T>
...
@@ -31,15 +32,15 @@ template <typename Place, typename T>
class
LSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
class
LSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
);
auto
*
input
=
ctx
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
weight
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Weight"
);
auto
*
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
*
bias
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Bias"
);
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
batch_gate
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"BatchGate"
);
auto
*
batch_gate
=
ctx
.
Output
<
LoDTensor
>
(
"BatchGate"
);
batch_gate
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
batch_gate
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
hidden_out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Hidden"
);
auto
*
hidden_out
=
ctx
.
Output
<
LoDTensor
>
(
"Hidden"
);
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
cell_out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Cell"
);
auto
*
cell_out
=
ctx
.
Output
<
LoDTensor
>
(
"Cell"
);
cell_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
cell_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Now the function ShareLoD in InferShape is not implemented.
// Now the function ShareLoD in InferShape is not implemented.
...
@@ -49,7 +50,8 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -49,7 +50,8 @@ class LSTMKernel : public framework::OpKernel<T> {
bool
is_reverse
=
ctx
.
Attr
<
bool
>
(
"isReverse"
);
bool
is_reverse
=
ctx
.
Attr
<
bool
>
(
"isReverse"
);
math
::
LoDTensor2BatchFunctor
<
Place
,
T
>
to_batch
;
math
::
LoDTensor2BatchFunctor
<
Place
,
T
>
to_batch
;
to_batch
(
ctx
.
device_context
(),
*
input
,
*
batch_gate
,
is_reverse
);
auto
&
device_ctx
=
ctx
.
device_context
();
to_batch
(
device_ctx
,
*
input
,
*
batch_gate
,
true
,
is_reverse
);
auto
in_dims
=
input
->
dims
();
auto
in_dims
=
input
->
dims
();
int
frame_size
=
static_cast
<
int
>
(
in_dims
[
1
]
/
4
);
int
frame_size
=
static_cast
<
int
>
(
in_dims
[
1
]
/
4
);
...
@@ -69,17 +71,26 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -69,17 +71,26 @@ class LSTMKernel : public framework::OpKernel<T> {
}
}
math
::
LstmMetaValue
<
T
>
lstm_value
;
math
::
LstmMetaValue
<
T
>
lstm_value
;
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
if
(
bias
)
{
// the code style in LstmMetaValue will be updated later.
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
lstm_value
.
checkIg
=
bias_data
+
4
*
frame_size
;
// the code style in LstmMetaValue will be updated later.
lstm_value
.
checkFg
=
lstm_value
.
checkIg
+
frame_size
;
lstm_value
.
checkOg
=
lstm_value
.
checkFg
+
frame_size
;
lstm_value
.
checkIg
=
bias_data
+
4
*
frame_size
;
lstm_value
.
checkFg
=
lstm_value
.
checkIg
+
frame_size
;
lstm_value
.
checkOg
=
lstm_value
.
checkFg
+
frame_size
;
}
else
{
lstm_value
.
checkIg
=
nullptr
;
lstm_value
.
checkFg
=
nullptr
;
lstm_value
.
checkOg
=
nullptr
;
}
lstm_value
.
prevStateValue
=
nullptr
;
lstm_value
.
prevStateValue
=
nullptr
;
framework
::
LoDTensor
batch_out
,
batch_cell
,
batch_cell_pre_act
;
// Use the local variable as here.
batch_out
.
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
LoDTensor
batch_hidden
,
batch_cell
;
auto
*
batch_cell_pre_act
=
ctx
.
Output
<
LoDTensor
>
(
"BatchCellPreAct"
);
batch_hidden
.
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
batch_cell
.
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
batch_cell
.
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
batch_cell_pre_act
.
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
batch_cell_pre_act
->
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
...
@@ -92,18 +103,18 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -92,18 +103,18 @@ class LSTMKernel : public framework::OpKernel<T> {
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
out_t
=
batch_
out
.
Slice
(
bstart
,
bend
);
Tensor
out_t
=
batch_
hidden
.
Slice
(
bstart
,
bend
);
Tensor
cell_t
=
batch_cell
.
Slice
(
bstart
,
bend
);
Tensor
cell_t
=
batch_cell
.
Slice
(
bstart
,
bend
);
Tensor
cell_pre_act_t
=
batch_cell_pre_act
.
Slice
(
bstart
,
bend
);
Tensor
cell_pre_act_t
=
batch_cell_pre_act
->
Slice
(
bstart
,
bend
);
int
cur_batch_size
=
bend
-
bstart
;
int
cur_batch_size
=
bend
-
bstart
;
if
(
n
!=
0
)
{
if
(
n
!=
0
)
{
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_hidden_t
=
batch_
out
.
Slice
(
pre_h_start
,
pre_h_end
);
auto
pre_hidden_t
=
batch_
hidden
.
Slice
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
Place
,
T
>
(
ctx
.
device_context
(),
pre_hidden_
t
,
false
,
math
::
matmul
<
Place
,
T
>
(
device_ctx
,
pre_hidden_t
,
false
,
*
weigh
t
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
static_cast
<
T
>
(
1.0
));
}
}
// else if : FIXME support the initial hidden and cell
// else if : FIXME support the initial hidden and cell
...
@@ -112,27 +123,186 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -112,27 +123,186 @@ class LSTMKernel : public framework::OpKernel<T> {
lstm_value
.
outputValue
=
out_t
.
data
<
T
>
();
lstm_value
.
outputValue
=
out_t
.
data
<
T
>
();
lstm_value
.
stateValue
=
cell_t
.
data
<
T
>
();
lstm_value
.
stateValue
=
cell_t
.
data
<
T
>
();
lstm_value
.
stateActiveValue
=
cell_pre_act_t
.
data
<
T
>
();
lstm_value
.
stateActiveValue
=
cell_pre_act_t
.
data
<
T
>
();
math
::
LstmUnitFunctor
<
Place
,
T
>::
compute
(
ctx
.
device_context
()
,
lstm_value
,
math
::
LstmUnitFunctor
<
Place
,
T
>::
compute
(
device_ctx
,
lstm_value
,
frame_size
,
cur_batch_size
,
frame_size
,
cur_batch_size
,
gate_act
,
cell_act
,
cand_act
);
gate_act
,
cell_act
,
cand_act
);
lstm_value
.
prevStateValue
=
lstm_value
.
stateValue
;
lstm_value
.
prevStateValue
=
lstm_value
.
stateValue
;
}
}
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
batch_
out
.
set_lod
(
batch_gate
->
lod
());
batch_
hidden
.
set_lod
(
batch_gate
->
lod
());
// restore the output hidden in LoDTensor from the batch hidden
// restore the output hidden in LoDTensor from the batch hidden
to_seq
(
ctx
.
device_context
(),
batch_out
,
*
hidden_out
);
to_seq
(
device_ctx
,
batch_hidden
,
*
hidden_out
);
batch_cell
.
set_lod
(
batch_gate
->
lod
());
batch_cell
.
set_lod
(
batch_gate
->
lod
());
// restore the output cell state in LoDTensor from the batch cell
// restore the output cell state in LoDTensor from the batch cell
to_seq
(
ctx
.
device_context
()
,
batch_cell
,
*
cell_out
);
to_seq
(
device_ctx
,
batch_cell
,
*
cell_out
);
}
}
};
};
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
LSTMGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
LSTMGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{}
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
hidden_out
=
ctx
.
Input
<
LoDTensor
>
(
"Hidden"
);
auto
*
cell_out
=
ctx
.
Input
<
LoDTensor
>
(
"Cell"
);
auto
*
batch_gate
=
ctx
.
Input
<
LoDTensor
>
(
"BatchGate"
);
auto
*
batch_cell_pre_act
=
ctx
.
Input
<
LoDTensor
>
(
"BatchCellPreAct"
);
auto
*
hidden_g
=
ctx
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Hidden"
));
auto
*
in_g
=
ctx
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"Input"
));
auto
*
weight_g
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Weight"
));
auto
*
bias_g
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias"
));
auto
&
device_ctx
=
ctx
.
device_context
();
math
::
SetConstant
<
Place
,
T
>
zero
;
if
(
weight_g
)
{
weight_g
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
zero
(
device_ctx
,
weight_g
,
static_cast
<
T
>
(
0.0
));
}
auto
in_dims
=
input
->
dims
();
auto
out_dims
=
hidden_g
->
dims
();
int
frame_size
=
static_cast
<
int
>
(
in_dims
[
1
]
/
4
);
PADDLE_ENFORCE_EQ
(
frame_size
,
out_dims
[
1
]);
math
::
LstmMetaValue
<
T
>
lstm_value
;
if
(
bias
)
{
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
lstm_value
.
checkIg
=
bias_data
+
4
*
frame_size
;
lstm_value
.
checkFg
=
lstm_value
.
checkIg
+
frame_size
;
lstm_value
.
checkOg
=
lstm_value
.
checkFg
+
frame_size
;
}
else
{
lstm_value
.
checkIg
=
nullptr
;
lstm_value
.
checkFg
=
nullptr
;
lstm_value
.
checkOg
=
nullptr
;
}
math
::
LstmMetaGrad
<
T
>
lstm_grad
;
if
(
bias
&&
bias_g
)
{
T
*
bias_g_data
=
const_cast
<
T
*>
(
bias_g
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
zero
(
device_ctx
,
bias_g
,
static_cast
<
T
>
(
0.0
));
lstm_grad
.
checkIgGrad
=
bias_g_data
+
4
*
frame_size
;
lstm_grad
.
checkFgGrad
=
lstm_grad
.
checkIgGrad
+
frame_size
;
lstm_grad
.
checkOgGrad
=
lstm_grad
.
checkFgGrad
+
frame_size
;
}
else
{
lstm_grad
.
checkIgGrad
=
nullptr
;
lstm_grad
.
checkFgGrad
=
nullptr
;
lstm_grad
.
checkOgGrad
=
nullptr
;
}
math
::
LoDTensor2BatchFunctor
<
Place
,
T
>
to_batch
;
// use the local variable as here.
LoDTensor
batch_hidden
;
batch_hidden
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
batch_hidden
.
set_lod
(
batch_gate
->
lod
());
to_batch
(
device_ctx
,
*
hidden_out
,
batch_hidden
,
false
);
LoDTensor
batch_hidden_g
;
batch_hidden_g
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
batch_hidden_g
.
set_lod
(
batch_gate
->
lod
());
to_batch
(
device_ctx
,
*
hidden_g
,
batch_hidden_g
,
false
);
LoDTensor
batch_cell
;
batch_cell
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
batch_cell
.
set_lod
(
batch_gate
->
lod
());
to_batch
(
device_ctx
,
*
cell_out
,
batch_cell
,
false
);
LoDTensor
batch_cell_g
;
batch_cell_g
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
batch_cell_g
.
set_lod
(
batch_gate
->
lod
());
// TODO(qingqing) support the case output cell has gradient.
// to_batch(device_ctx, *cell_g, batch_cell_g, false);
zero
(
device_ctx
,
&
batch_cell_g
,
static_cast
<
T
>
(
0.0
));
LoDTensor
batch_gate_g
;
batch_gate_g
.
mutable_data
<
T
>
(
batch_gate
->
dims
(),
ctx
.
GetPlace
());
batch_gate_g
.
set_lod
(
batch_gate
->
lod
());
auto
gate_act
=
ctx
.
Attr
<
std
::
string
>
(
"gateActivation"
);
auto
cell_act
=
ctx
.
Attr
<
std
::
string
>
(
"cellActivation"
);
auto
cand_act
=
ctx
.
Attr
<
std
::
string
>
(
"candidateActivation"
);
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
for
(
int
n
=
static_cast
<
int
>
(
num_batch
)
-
1
;
n
>=
0
;
n
--
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
Tensor
gate
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
cell
=
batch_cell
.
Slice
(
bstart
,
bend
);
Tensor
cell_pre_act
=
batch_cell_pre_act
->
Slice
(
bstart
,
bend
);
lstm_value
.
gateValue
=
gate
.
data
<
T
>
();
lstm_value
.
stateValue
=
cell
.
data
<
T
>
();
lstm_value
.
stateActiveValue
=
cell_pre_act
.
data
<
T
>
();
Tensor
out_g
=
batch_hidden_g
.
Slice
(
bstart
,
bend
);
Tensor
gate_g
=
batch_gate_g
.
Slice
(
bstart
,
bend
);
Tensor
cell_g
=
batch_cell_g
.
Slice
(
bstart
,
bend
);
lstm_grad
.
stateGrad
=
cell_g
.
data
<
T
>
();
lstm_grad
.
gateGrad
=
gate_g
.
data
<
T
>
();
lstm_grad
.
outputGrad
=
out_g
.
data
<
T
>
();
if
(
n
)
{
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
Tensor
cell_pre
=
batch_cell
.
Slice
(
bstart_pre
,
bstart
);
Tensor
cell_pre_g
=
batch_cell_g
.
Slice
(
bstart_pre
,
bstart
);
lstm_value
.
prevStateValue
=
cell_pre
.
data
<
T
>
();
lstm_grad
.
prevStateGrad
=
cell_pre_g
.
data
<
T
>
();
}
else
{
lstm_value
.
prevStateValue
=
nullptr
;
lstm_grad
.
prevStateGrad
=
nullptr
;
}
int
cur_batch_size
=
bend
-
bstart
;
math
::
LstmUnitGradFunctor
<
Place
,
T
>::
compute
(
device_ctx
,
lstm_value
,
lstm_grad
,
frame_size
,
cur_batch_size
,
gate_act
,
cell_act
,
cand_act
);
if
(
n
!=
0
)
{
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_hidden_g
=
batch_hidden_g
.
Slice
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
Place
,
T
>
(
device_ctx
,
gate_g
,
false
,
*
weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
pre_hidden_g
,
static_cast
<
T
>
(
1.0
));
if
(
weight_g
)
{
/* backward weight */
auto
pre_hidden
=
batch_hidden
.
Slice
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
Place
,
T
>
(
device_ctx
,
pre_hidden
,
true
,
gate_g
,
false
,
static_cast
<
T
>
(
1.0
),
weight_g
,
static_cast
<
T
>
(
1.0
));
}
}
}
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
if
(
in_g
)
{
/* backward data */
in_g
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
to_seq
(
device_ctx
,
batch_gate_g
,
*
in_g
);
}
if
(
bias
&&
bias_g
)
{
/* backward bias */
int
m
=
static_cast
<
int
>
(
batch_gate_g
.
dims
()[
0
]);
int
n
=
static_cast
<
int
>
(
batch_gate_g
.
dims
()[
1
]);
Tensor
ones
;
ones
.
mutable_data
<
T
>
({
m
},
ctx
.
GetPlace
());
math
::
SetConstant
<
Place
,
T
>
set
;
set
(
device_ctx
,
&
ones
,
static_cast
<
T
>
(
1.0
));
math
::
gemv
<
Place
,
T
>
(
device_ctx
,
true
,
m
,
n
,
1.
,
batch_gate_g
.
data
<
T
>
(),
ones
.
data
<
T
>
(),
0.
,
bias_g
->
data
<
T
>
());
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/operators/math/detail/lstm_cpu_kernel.h
浏览文件 @
36d20609
...
@@ -26,10 +26,7 @@ namespace detail {
...
@@ -26,10 +26,7 @@ namespace detail {
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
naive_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
naive_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frameSize
,
int
frameSize
)
{
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
T
rValueIn
;
T
rValueIn
;
T
rValueIg
;
T
rValueIg
;
T
rValueFg
;
T
rValueFg
;
...
@@ -60,10 +57,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
...
@@ -60,10 +57,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
rPrevState
=
value
.
prevStateValue
[
i
];
rPrevState
=
value
.
prevStateValue
[
i
];
}
}
hppl
::
cpu
::
ForwardAct
<
T
>
act
;
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
act
(
active_node
),
act
(
active_gate
),
rOut
,
rCheckI
,
rCheckF
,
rCheckO
);
act
(
active_state
));
valueIn
[
i
]
=
rValueIn
;
valueIn
[
i
]
=
rValueIn
;
valueIg
[
i
]
=
rValueIg
;
valueIg
[
i
]
=
rValueIg
;
...
@@ -77,10 +72,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
...
@@ -77,10 +72,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
naive_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
naive_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frameSize
,
LstmMetaGrad
<
T
>
grad
,
int
frameSize
)
{
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
T
rValueIn
;
T
rValueIn
;
T
rValueIg
;
T
rValueIg
;
T
rValueFg
;
T
rValueFg
;
...
@@ -127,11 +119,10 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
...
@@ -127,11 +119,10 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
rPrevState
=
value
.
prevStateValue
[
i
];
rPrevState
=
value
.
prevStateValue
[
i
];
}
}
hppl
::
cpu
::
BackwardAct
<
T
>
act
;
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
,
act
(
active_node
),
act
(
active_gate
),
act
(
active_state
)
);
rCheckOGrad
);
gradIn
[
i
]
=
rGradIn
;
gradIn
[
i
]
=
rGradIn
;
gradIg
[
i
]
=
rGradIg
;
gradIg
[
i
]
=
rGradIg
;
...
@@ -283,8 +274,7 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frameSize,
...
@@ -283,8 +274,7 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frameSize,
avx_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frameSize
,
active_node
,
avx_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frameSize
,
active_node
,
active_gate
,
active_state
);
active_gate
,
active_state
);
}
else
{
}
else
{
naive_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frameSize
,
active_node
,
naive_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frameSize
);
active_gate
,
active_state
);
}
}
}
}
...
@@ -297,8 +287,7 @@ void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
...
@@ -297,8 +287,7 @@ void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
avx_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frameSize
,
active_node
,
avx_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frameSize
,
active_node
,
active_gate
,
active_state
);
active_gate
,
active_state
);
}
else
{
}
else
{
naive_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frameSize
,
active_node
,
naive_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frameSize
);
active_gate
,
active_state
);
}
}
}
}
...
...
paddle/operators/math/detail/lstm_gpu_kernel.h
浏览文件 @
36d20609
...
@@ -32,9 +32,7 @@ namespace detail {
...
@@ -32,9 +32,7 @@ namespace detail {
*/
*/
template
<
class
T
,
class
Op
,
bool
isBatch
>
template
<
class
T
,
class
Op
,
bool
isBatch
>
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frameSize
,
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
int
batchSize
)
{
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
const
int
frameIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frameIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frameIdx
>=
frameSize
)
return
;
if
(
frameIdx
>=
frameSize
)
return
;
...
@@ -70,10 +68,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
...
@@ -70,10 +68,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
rPrevState
=
value
.
prevStateValue
[
frameIdx
];
rPrevState
=
value
.
prevStateValue
[
frameIdx
];
}
}
hppl
::
gpu
::
ForwardAct
<
T
>
act
;
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
act
(
active_node
),
act
(
active_gate
),
rOut
,
rCheckI
,
rCheckF
,
rCheckO
);
act
(
active_state
));
value
.
gateValue
[
frameIdx
]
=
rValueIn
;
value
.
gateValue
[
frameIdx
]
=
rValueIn
;
value
.
gateValue
[
frameIdx
+
frameSize
]
=
rValueIg
;
value
.
gateValue
[
frameIdx
+
frameSize
]
=
rValueIg
;
...
@@ -92,9 +88,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
...
@@ -92,9 +88,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
template
<
class
T
,
class
Op
,
bool
isBatch
>
template
<
class
T
,
class
Op
,
bool
isBatch
>
__global__
void
KeLstmBackward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
__global__
void
KeLstmBackward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frameSize
,
LstmMetaGrad
<
T
>
grad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
int
batchSize
)
{
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
const
int
frameIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frameIdx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frameIdx
>=
frameSize
)
return
;
if
(
frameIdx
>=
frameSize
)
return
;
...
@@ -145,11 +139,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
...
@@ -145,11 +139,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
rPrevState
=
value
.
prevStateValue
[
frameIdx
];
rPrevState
=
value
.
prevStateValue
[
frameIdx
];
}
}
hppl
::
gpu
::
BackwardAct
<
T
>
act
;
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
rGradOg
,
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rOutputGrad
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
);
act
(
active_node
),
act
(
active_gate
),
act
(
active_state
));
grad
.
gateGrad
[
frameIdx
]
=
rGradIn
;
grad
.
gateGrad
[
frameIdx
]
=
rGradIn
;
grad
.
gateGrad
[
frameIdx
+
frameSize
]
=
rGradIg
;
grad
.
gateGrad
[
frameIdx
+
frameSize
]
=
rGradIg
;
...
@@ -205,13 +197,11 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
...
@@ -205,13 +197,11 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
if
(
batchSize
==
1
)
{
if
(
batchSize
==
1
)
{
KeLstmForward
<
T
,
Op
,
KeLstmForward
<
T
,
Op
,
/* isBatch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* isBatch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frameSize
,
batchSize
,
active_node
,
active_gate
,
op
,
value
,
frameSize
,
batchSize
);
active_state
);
}
else
{
}
else
{
KeLstmForward
<
T
,
Op
,
KeLstmForward
<
T
,
Op
,
/* isBatch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* isBatch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frameSize
,
batchSize
,
active_node
,
active_gate
,
op
,
value
,
frameSize
,
batchSize
);
active_state
);
}
}
}
}
...
@@ -240,13 +230,11 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
...
@@ -240,13 +230,11 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
if
(
batchSize
==
1
)
{
if
(
batchSize
==
1
)
{
KeLstmBackward
<
T
,
Op
,
KeLstmBackward
<
T
,
Op
,
/* isBatch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* isBatch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frameSize
,
batchSize
,
active_node
,
active_gate
,
op
,
value
,
grad
,
frameSize
,
batchSize
);
active_state
);
}
else
{
}
else
{
KeLstmBackward
<
T
,
Op
,
KeLstmBackward
<
T
,
Op
,
/* isBatch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* isBatch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frameSize
,
batchSize
,
active_node
,
active_gate
,
op
,
value
,
grad
,
frameSize
,
batchSize
);
active_state
);
}
}
}
}
...
...
paddle/operators/math/detail/lstm_kernel.h
浏览文件 @
36d20609
...
@@ -24,15 +24,29 @@ namespace detail {
...
@@ -24,15 +24,29 @@ namespace detail {
namespace
forward
{
namespace
forward
{
template
<
typename
T
>
DEVICE
inline
T
sigmoid
(
const
T
a
)
{
const
T
min
=
SIGMOID_THRESHOLD_MIN
;
const
T
max
=
SIGMOID_THRESHOLD_MAX
;
T
tmp
=
(
a
<
min
)
?
min
:
((
a
>
max
)
?
max
:
a
);
return
static_cast
<
T
>
(
1.0
)
/
(
static_cast
<
T
>
(
1.0
)
+
exp
(
-
tmp
));
}
template
<
typename
T
>
DEVICE
inline
T
tanh
(
const
T
a
)
{
T
tmp
=
-
2.0
*
a
;
tmp
=
(
tmp
>
EXP_MAX_INPUT
)
?
EXP_MAX_INPUT
:
tmp
;
return
(
2.0
/
(
1.0
+
exp
(
tmp
)))
-
1.0
;
}
template
<
class
T
>
template
<
class
T
>
class
lstm
{
class
lstm
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
valueIn
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueOg
,
HOSTDEVICE
void
operator
()(
T
&
valueIn
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueOg
,
T
&
prevState
,
T
&
state
,
T
&
stateAtv
,
T
&
output
,
T
&
prevState
,
T
&
state
,
T
&
stateAtv
,
T
&
output
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
)
{
typename
hppl
::
ForwardActType
<
T
>::
type
actInput
,
#if 0
typename
hppl
::
ForwardActType
<
T
>::
type
actGate
,
// TODO(qingqing) support to activation speficed by users
typename
hppl
::
ForwardActType
<
T
>::
type
actState
)
{
valueIn = actInput(valueIn);
valueIn = actInput(valueIn);
valueIg = actGate(valueIg + prevState * checkI);
valueIg = actGate(valueIg + prevState * checkI);
valueFg = actGate(valueFg + prevState * checkF);
valueFg = actGate(valueFg + prevState * checkF);
...
@@ -40,6 +54,15 @@ class lstm {
...
@@ -40,6 +54,15 @@ class lstm {
valueOg = actGate(valueOg + state * checkO);
valueOg = actGate(valueOg + state * checkO);
stateAtv = actState(state);
stateAtv = actState(state);
output = valueOg * stateAtv;
output = valueOg * stateAtv;
#else
valueIn
=
tanh
<
T
>
(
valueIn
);
valueIg
=
sigmoid
<
T
>
(
valueIg
+
prevState
*
checkI
);
valueFg
=
sigmoid
<
T
>
(
valueFg
+
prevState
*
checkF
);
state
=
valueIn
*
valueIg
+
prevState
*
valueFg
;
valueOg
=
sigmoid
<
T
>
(
valueOg
+
state
*
checkO
);
stateAtv
=
tanh
<
T
>
(
state
);
output
=
valueOg
*
stateAtv
;
#endif
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...
@@ -72,6 +95,16 @@ class lstm {
...
@@ -72,6 +95,16 @@ class lstm {
namespace
backward
{
namespace
backward
{
template
<
typename
T
>
DEVICE
inline
T
sigmoid
(
const
T
a
,
const
T
b
)
{
return
a
*
b
*
(
1.0
-
b
);
}
template
<
typename
T
>
DEVICE
inline
T
tanh
(
const
T
a
,
const
T
b
)
{
return
a
*
(
1.0
-
b
*
b
);
}
template
<
class
T
>
template
<
class
T
>
class
lstm
{
class
lstm
{
public:
public:
...
@@ -80,10 +113,9 @@ class lstm {
...
@@ -80,10 +113,9 @@ class lstm {
T
&
prevState
,
T
&
prevStateGrad
,
T
&
state
,
T
&
prevState
,
T
&
prevStateGrad
,
T
&
state
,
T
&
stateGrad
,
T
&
stateAtv
,
T
&
outputGrad
,
T
&
stateGrad
,
T
&
stateAtv
,
T
&
outputGrad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
)
{
typename
hppl
::
BackwardActType
<
T
>::
type
actInput
,
#if 0
typename
hppl
::
BackwardActType
<
T
>::
type
actGate
,
// TODO(qingqing) support to activation speficed by users
typename
hppl
::
BackwardActType
<
T
>::
type
actState
)
{
gradOg = actGate(outputGrad * stateAtv, valueOg);
gradOg = actGate(outputGrad * stateAtv, valueOg);
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
gradIn = actInput(stateGrad * valueIg, valueIn);
gradIn = actInput(stateGrad * valueIg, valueIn);
...
@@ -93,6 +125,17 @@ class lstm {
...
@@ -93,6 +125,17 @@ class lstm {
checkIGrad = gradIg * prevState;
checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState;
checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state;
checkOGrad = gradOg * state;
#else
gradOg
=
sigmoid
<
T
>
(
outputGrad
*
stateAtv
,
valueOg
);
stateGrad
+=
tanh
<
T
>
(
outputGrad
*
valueOg
,
stateAtv
)
+
gradOg
*
checkO
;
gradIn
=
tanh
<
T
>
(
stateGrad
*
valueIg
,
valueIn
);
gradIg
=
sigmoid
<
T
>
(
stateGrad
*
valueIn
,
valueIg
);
gradFg
=
sigmoid
<
T
>
(
stateGrad
*
prevState
,
valueFg
);
prevStateGrad
=
gradIg
*
checkI
+
gradFg
*
checkF
+
stateGrad
*
valueFg
;
checkIGrad
=
gradIg
*
prevState
;
checkFGrad
=
gradFg
*
prevState
;
checkOGrad
=
gradOg
*
state
;
#endif
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...
...
paddle/operators/math/math_function.cc
浏览文件 @
36d20609
...
@@ -211,6 +211,26 @@ void batched_gemm<platform::CPUPlace, double>(
...
@@ -211,6 +211,26 @@ void batched_gemm<platform::CPUPlace, double>(
}
}
#endif
#endif
template
<
>
void
gemv
<
platform
::
CPUPlace
,
float
>
(
const
platform
::
DeviceContext
&
context
,
const
bool
trans_a
,
const
int
M
,
const
int
N
,
const
float
alpha
,
const
float
*
A
,
const
float
*
B
,
const
float
beta
,
float
*
C
)
{
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
cblas_sgemv
(
CblasRowMajor
,
transA
,
M
,
N
,
alpha
,
A
,
N
,
B
,
1
,
beta
,
C
,
1
);
}
template
<
>
void
gemv
<
platform
::
CPUPlace
,
double
>
(
const
platform
::
DeviceContext
&
context
,
const
bool
trans_a
,
const
int
M
,
const
int
N
,
const
double
alpha
,
const
double
*
A
,
const
double
*
B
,
const
double
beta
,
double
*
C
)
{
CBLAS_TRANSPOSE
transA
=
(
trans_a
==
false
)
?
CblasNoTrans
:
CblasTrans
;
cblas_dgemv
(
CblasRowMajor
,
transA
,
M
,
N
,
alpha
,
A
,
N
,
B
,
1
,
beta
,
C
,
1
);
}
template
struct
SetConstant
<
platform
::
CPUPlace
,
float
>;
template
struct
SetConstant
<
platform
::
CPUPlace
,
float
>;
}
// namespace math
}
// namespace math
...
...
paddle/operators/math/math_function.cu
浏览文件 @
36d20609
...
@@ -203,6 +203,33 @@ void batched_gemm<platform::GPUPlace, double>(
...
@@ -203,6 +203,33 @@ void batched_gemm<platform::GPUPlace, double>(
&
beta
,
C
,
ldc
,
strideC
,
batchCount
));
&
beta
,
C
,
ldc
,
strideC
,
batchCount
));
}
}
template
<
>
void
gemv
<
platform
::
GPUPlace
,
float
>
(
const
platform
::
DeviceContext
&
context
,
const
bool
trans_a
,
const
int
M
,
const
int
N
,
const
float
alpha
,
const
float
*
A
,
const
float
*
B
,
const
float
beta
,
float
*
C
)
{
cublasOperation_t
cuTransA
=
(
trans_a
==
false
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemv
(
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
cublas_handle
(),
cuTransA
,
N
,
M
,
&
alpha
,
A
,
N
,
B
,
1
,
&
beta
,
C
,
1
));
}
template
<
>
void
gemv
<
platform
::
GPUPlace
,
double
>
(
const
platform
::
DeviceContext
&
context
,
const
bool
trans_a
,
const
int
M
,
const
int
N
,
const
double
alpha
,
const
double
*
A
,
const
double
*
B
,
const
double
beta
,
double
*
C
)
{
cublasOperation_t
cuTransA
=
(
trans_a
==
false
)
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasDgemv
(
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
cublas_handle
(),
cuTransA
,
N
,
M
,
&
alpha
,
A
,
N
,
B
,
1
,
&
beta
,
C
,
1
));
}
template
struct
SetConstant
<
platform
::
GPUPlace
,
float
>;
template
struct
SetConstant
<
platform
::
GPUPlace
,
float
>;
}
// namespace math
}
// namespace math
...
...
paddle/operators/math/math_function.h
浏览文件 @
36d20609
...
@@ -93,6 +93,11 @@ void batched_gemm(const platform::DeviceContext& context,
...
@@ -93,6 +93,11 @@ void batched_gemm(const platform::DeviceContext& context,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
,
const
int
batchCount
,
const
int
strideA
,
const
int
strideB
);
const
int
batchCount
,
const
int
strideA
,
const
int
strideB
);
template
<
typename
Place
,
typename
T
>
void
gemv
(
const
platform
::
DeviceContext
&
context
,
const
bool
trans_a
,
const
int
M
,
const
int
N
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
);
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
struct
SetConstant
{
struct
SetConstant
{
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
...
...
paddle/operators/math/math_function_test.cc
浏览文件 @
36d20609
...
@@ -89,3 +89,53 @@ TEST(math_function, zero) {
...
@@ -89,3 +89,53 @@ TEST(math_function, zero) {
EXPECT_EQ
(
t
[
2
],
1
);
EXPECT_EQ
(
t
[
2
],
1
);
EXPECT_EQ
(
t
[
3
],
1
);
EXPECT_EQ
(
t
[
3
],
1
);
}
}
template
<
typename
T
>
void
GemvTest
(
int
m
,
int
n
,
bool
trans
)
{
paddle
::
framework
::
Tensor
mat_a
;
paddle
::
framework
::
Tensor
vec_b
;
paddle
::
framework
::
Tensor
vec_c
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
int
b_num
=
trans
?
m
:
n
;
int
c_num
=
trans
?
n
:
m
;
T
*
data_a
=
mat_a
.
mutable_data
<
T
>
({
m
,
n
},
*
cpu_place
);
T
*
data_b
=
vec_b
.
mutable_data
<
T
>
({
b_num
},
*
cpu_place
);
T
*
data_c
=
vec_c
.
mutable_data
<
T
>
({
c_num
},
*
cpu_place
);
for
(
int
i
=
0
;
i
<
mat_a
.
numel
();
++
i
)
{
data_a
[
i
]
=
static_cast
<
T
>
(
i
);
}
for
(
int
i
=
0
;
i
<
vec_b
.
numel
();
++
i
)
{
data_b
[
i
]
=
static_cast
<
T
>
(
i
);
}
paddle
::
platform
::
CPUDeviceContext
context
(
*
cpu_place
);
paddle
::
operators
::
math
::
gemv
<
paddle
::
platform
::
CPUPlace
,
T
>
(
context
,
trans
,
static_cast
<
int
>
(
m
),
static_cast
<
int
>
(
n
),
1.
,
data_a
,
data_b
,
0.
,
data_c
);
if
(
!
trans
)
{
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
T
sum
=
0.0
;
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
sum
+=
data_a
[
i
*
n
+
j
]
*
data_b
[
j
];
}
ASSERT_FLOAT_EQ
(
data_c
[
i
],
sum
);
}
}
else
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
T
sum
=
0.0
;
for
(
int
j
=
0
;
j
<
m
;
++
j
)
{
sum
+=
data_a
[
j
*
n
+
i
]
*
data_b
[
j
];
}
ASSERT_FLOAT_EQ
(
data_c
[
i
],
sum
);
}
}
}
TEST
(
math_function
,
gemv
)
{
GemvTest
<
float
>
(
3
,
13
,
false
);
GemvTest
<
double
>
(
4
,
5
,
false
);
GemvTest
<
float
>
(
12
,
7
,
true
);
GemvTest
<
double
>
(
7
,
9
,
true
);
}
paddle/operators/math/math_function_test.cu
浏览文件 @
36d20609
...
@@ -177,3 +177,65 @@ TEST(math_function, gemm_trans_cublas) {
...
@@ -177,3 +177,65 @@ TEST(math_function, gemm_trans_cublas) {
EXPECT_EQ
(
input3_ptr
[
7
],
99
);
EXPECT_EQ
(
input3_ptr
[
7
],
99
);
delete
gpu_place
;
delete
gpu_place
;
}
}
template
<
typename
T
>
void
GemvTest
(
int
m
,
int
n
,
bool
trans
)
{
paddle
::
framework
::
Tensor
mat_a
;
paddle
::
framework
::
Tensor
vec_b
;
paddle
::
framework
::
Tensor
vec_c
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
T
*
data_a
=
mat_a
.
mutable_data
<
T
>
({
m
,
n
},
*
cpu_place
);
T
*
data_b
=
vec_b
.
mutable_data
<
T
>
({
trans
?
m
:
n
},
*
cpu_place
);
T
*
data_c
=
vec_c
.
mutable_data
<
T
>
({
trans
?
n
:
m
},
*
cpu_place
);
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
framework
::
Tensor
g_mat_a
;
paddle
::
framework
::
Tensor
g_vec_b
;
paddle
::
framework
::
Tensor
g_vec_c
;
T
*
g_data_a
=
g_mat_a
.
mutable_data
<
T
>
(
mat_a
.
dims
(),
*
gpu_place
);
T
*
g_data_b
=
g_vec_b
.
mutable_data
<
T
>
(
vec_b
.
dims
(),
*
gpu_place
);
T
*
g_data_c
=
g_vec_c
.
mutable_data
<
T
>
(
vec_c
.
dims
(),
*
gpu_place
);
for
(
int
i
=
0
;
i
<
mat_a
.
numel
();
++
i
)
{
data_a
[
i
]
=
static_cast
<
T
>
(
i
);
}
for
(
int
i
=
0
;
i
<
vec_b
.
numel
();
++
i
)
{
data_b
[
i
]
=
static_cast
<
T
>
(
i
);
}
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
g_mat_a
.
CopyFrom
(
mat_a
,
*
gpu_place
,
context
);
g_vec_b
.
CopyFrom
(
vec_b
,
*
gpu_place
,
context
);
paddle
::
operators
::
math
::
gemv
<
paddle
::
platform
::
GPUPlace
,
T
>
(
context
,
trans
,
static_cast
<
int
>
(
m
),
static_cast
<
int
>
(
n
),
1.
,
g_data_a
,
g_data_b
,
0.
,
g_data_c
);
vec_c
.
CopyFrom
(
g_vec_c
,
paddle
::
platform
::
CPUPlace
(),
context
);
if
(
!
trans
)
{
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
T
sum
=
0.0
;
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
sum
+=
data_a
[
i
*
n
+
j
]
*
data_b
[
j
];
}
ASSERT_FLOAT_EQ
(
data_c
[
i
],
sum
);
}
}
else
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
T
sum
=
0.0
;
for
(
int
j
=
0
;
j
<
m
;
++
j
)
{
sum
+=
data_a
[
j
*
n
+
i
]
*
data_b
[
j
];
}
ASSERT_FLOAT_EQ
(
data_c
[
i
],
sum
);
}
}
}
TEST
(
math_function
,
gemv
)
{
GemvTest
<
float
>
(
3
,
13
,
false
);
GemvTest
<
double
>
(
3
,
13
,
false
);
GemvTest
<
float
>
(
3
,
13
,
true
);
GemvTest
<
double
>
(
3
,
13
,
true
);
}
paddle/operators/math/sequence2batch.h
浏览文件 @
36d20609
...
@@ -53,7 +53,18 @@ class LoDTensor2BatchFunctor {
...
@@ -53,7 +53,18 @@ class LoDTensor2BatchFunctor {
public:
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
lod_tensor
,
const
framework
::
LoDTensor
&
lod_tensor
,
framework
::
LoDTensor
&
batch
,
bool
is_reverse
)
const
{
framework
::
LoDTensor
&
batch
,
bool
is_cal_batch_lod
,
bool
is_reverse
=
false
)
const
{
if
(
!
is_cal_batch_lod
)
{
auto
lods
=
batch
.
lod
();
PADDLE_ENFORCE_EQ
(
lods
.
size
(),
2UL
);
PADDLE_ENFORCE_EQ
(
lods
[
1
].
size
(),
static_cast
<
size_t
>
(
lod_tensor
.
dims
()[
0
]));
CopyMatrixRowsFunctor
<
Place
,
T
>
to_batch
;
to_batch
(
context
,
lod_tensor
,
lods
[
1
].
data
(),
batch
,
true
);
return
;
}
auto
lods
=
lod_tensor
.
lod
();
auto
lods
=
lod_tensor
.
lod
();
PADDLE_ENFORCE_EQ
(
lods
.
size
(),
1UL
,
"Only support one level sequence now."
);
PADDLE_ENFORCE_EQ
(
lods
.
size
(),
1UL
,
"Only support one level sequence now."
);
auto
lod
=
lods
[
0
];
auto
lod
=
lods
[
0
];
...
@@ -101,10 +112,10 @@ class LoDTensor2BatchFunctor {
...
@@ -101,10 +112,10 @@ class LoDTensor2BatchFunctor {
size_t
*
batch_starts
=
batch_lods
[
0
].
data
();
size_t
*
batch_starts
=
batch_lods
[
0
].
data
();
size_t
*
seq2batch_idx
=
batch_lods
[
1
].
data
();
size_t
*
seq2batch_idx
=
batch_lods
[
1
].
data
();
batch_starts
[
0
]
=
0
;
batch_starts
[
0
]
=
0
;
for
(
size_
t
n
=
0
;
n
<
num_batch
;
n
++
)
{
for
(
in
t
n
=
0
;
n
<
num_batch
;
n
++
)
{
auto
batch_id
=
static_cast
<
int
>
(
batch_starts
[
n
]);
auto
batch_id
=
static_cast
<
int
>
(
batch_starts
[
n
]);
for
(
size_t
i
=
0
;
i
<
seq_info
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
seq_info
.
size
();
++
i
)
{
size_
t
seq_len
=
seq_info
[
i
].
length
;
in
t
seq_len
=
seq_info
[
i
].
length
;
int
start
=
seq_info
[
i
].
start
;
int
start
=
seq_info
[
i
].
start
;
if
(
n
<
seq_len
)
{
if
(
n
<
seq_len
)
{
seq2batch_idx
[
batch_id
]
=
seq2batch_idx
[
batch_id
]
=
...
@@ -132,11 +143,8 @@ class Batch2LoDTensorFunctor {
...
@@ -132,11 +143,8 @@ class Batch2LoDTensorFunctor {
auto
in_lod
=
batch
.
lod
();
auto
in_lod
=
batch
.
lod
();
PADDLE_ENFORCE_EQ
(
in_lod
.
size
(),
2UL
,
PADDLE_ENFORCE_EQ
(
in_lod
.
size
(),
2UL
,
"The LoD size of input `batch` should be 2."
);
"The LoD size of input `batch` should be 2."
);
auto
out_lod
=
lod_tensor
.
lod
()[
0
];
PADDLE_ENFORCE_EQ
(
in_lod
[
1
].
size
(),
auto
num
=
out_lod
[
out_lod
.
size
()
-
1
];
static_cast
<
size_t
>
(
lod_tensor
.
dims
()[
0
]));
PADDLE_ENFORCE_EQ
(
num
,
lod_tensor
.
dims
()[
0
]);
PADDLE_ENFORCE_EQ
(
num
,
in_lod
[
1
].
size
());
PADDLE_ENFORCE_EQ
(
num
,
batch
.
dims
()[
0
]);
CopyMatrixRowsFunctor
<
Place
,
T
>
to_seq
;
CopyMatrixRowsFunctor
<
Place
,
T
>
to_seq
;
size_t
*
index
=
in_lod
[
1
].
data
();
size_t
*
index
=
in_lod
[
1
].
data
();
to_seq
(
context
,
batch
,
index
,
lod_tensor
,
false
);
to_seq
(
context
,
batch
,
index
,
lod_tensor
,
false
);
...
...
python/paddle/v2/framework/tests/test_lstm_op.py
浏览文件 @
36d20609
...
@@ -52,7 +52,7 @@ def lstm(
...
@@ -52,7 +52,7 @@ def lstm(
g
=
np
.
dot
(
h_pre
,
w_h
)
# 1 x 4D
g
=
np
.
dot
(
h_pre
,
w_h
)
# 1 x 4D
g
=
g
+
x
g
=
g
+
x
g
=
np
.
reshape
(
g
,
(
1
,
g
.
size
))
g
=
np
.
reshape
(
g
,
(
1
,
g
.
size
))
c
_tmp
,
g_i
,
g_f
,
g_o
=
np
.
split
(
g
,
4
,
axis
=
1
)
c
,
g_i
,
g_f
,
g_o
=
np
.
split
(
g
,
4
,
axis
=
1
)
if
w_c
is
None
:
if
w_c
is
None
:
g_i
=
act_gate
(
g_i
)
# 1 x D
g_i
=
act_gate
(
g_i
)
# 1 x D
g_f
=
act_gate
(
g_f
)
# 1 x D
g_f
=
act_gate
(
g_f
)
# 1 x D
...
@@ -60,7 +60,7 @@ def lstm(
...
@@ -60,7 +60,7 @@ def lstm(
w_ic
,
w_fc
,
w_oc
=
np
.
split
(
w_c
,
3
,
axis
=
1
)
w_ic
,
w_fc
,
w_oc
=
np
.
split
(
w_c
,
3
,
axis
=
1
)
g_i
=
act_gate
(
g_i
+
w_ic
*
c_pre
)
# 1 x D
g_i
=
act_gate
(
g_i
+
w_ic
*
c_pre
)
# 1 x D
g_f
=
act_gate
(
g_f
+
w_fc
*
c_pre
)
# 1 x D
g_f
=
act_gate
(
g_f
+
w_fc
*
c_pre
)
# 1 x D
c
=
g_f
*
c_pre
+
g_i
*
act_cand
(
c
_tmp
)
# 1 x D
c
=
g_f
*
c_pre
+
g_i
*
act_cand
(
c
)
# 1 x D
if
w_c
is
None
:
if
w_c
is
None
:
g_o
=
act_gate
(
g_o
)
# 1 x D
g_o
=
act_gate
(
g_o
)
# 1 x D
...
@@ -68,8 +68,7 @@ def lstm(
...
@@ -68,8 +68,7 @@ def lstm(
_
,
_
,
w_oc
=
np
.
split
(
w_c
,
3
,
axis
=
1
)
_
,
_
,
w_oc
=
np
.
split
(
w_c
,
3
,
axis
=
1
)
g_o
=
act_gate
(
g_o
+
w_oc
*
c
)
# 1 x D
g_o
=
act_gate
(
g_o
+
w_oc
*
c
)
# 1 x D
h
=
g_o
*
act_cell
(
c
)
h
=
g_o
*
act_cell
(
c
)
bg
=
np
.
concatenate
((
act_cand
(
c_tmp
),
g_i
,
g_f
,
g_o
),
axis
=
1
)
return
h
,
c
return
h
,
c
,
bg
def
_reverse
(
x
,
lod
):
def
_reverse
(
x
,
lod
):
y
=
np
.
zeros_like
(
x
)
y
=
np
.
zeros_like
(
x
)
...
@@ -82,7 +81,6 @@ def lstm(
...
@@ -82,7 +81,6 @@ def lstm(
batch_size
=
len
(
offset
)
-
1
batch_size
=
len
(
offset
)
-
1
hidden
=
[]
hidden
=
[]
cell
=
[]
cell
=
[]
gate
=
[]
input
=
_reverse
(
input
,
offset
)
if
is_reverse
else
input
input
=
_reverse
(
input
,
offset
)
if
is_reverse
else
input
if
w_b
is
not
None
:
if
w_b
is
not
None
:
input
=
input
+
np
.
tile
(
w_b
,
(
offset
[
-
1
],
1
))
input
=
input
+
np
.
tile
(
w_b
,
(
offset
[
-
1
],
1
))
...
@@ -94,96 +92,109 @@ def lstm(
...
@@ -94,96 +92,109 @@ def lstm(
c_pre
=
c0
[
i
]
# 1 x D
c_pre
=
c0
[
i
]
# 1 x D
for
j
in
range
(
seq_len
):
for
j
in
range
(
seq_len
):
# compute one step
# compute one step
h_pre
,
c_pre
,
g_pre
=
_step
(
x
[
j
],
w_h
,
w_c
,
h_pre
,
c_pre
,
act_gate
,
h_pre
,
c_pre
=
_step
(
x
[
j
],
w_h
,
w_c
,
h_pre
,
c_pre
,
act_gate
,
act_cell
,
act_cand
)
act_cell
,
act_cand
)
hidden
.
append
(
h_pre
.
flatten
())
hidden
.
append
(
h_pre
.
flatten
())
cell
.
append
(
c_pre
.
flatten
())
cell
.
append
(
c_pre
.
flatten
())
gate
.
append
(
g_pre
.
flatten
())
hidden
=
np
.
array
(
hidden
).
astype
(
"float64"
)
hidden
=
np
.
array
(
hidden
).
astype
(
'float64'
)
cell
=
np
.
array
(
cell
).
astype
(
"float64"
)
cell
=
np
.
array
(
cell
).
astype
(
'float64'
)
gate
=
np
.
array
(
gate
).
astype
(
"float64"
)
hidden
=
_reverse
(
hidden
,
offset
)
if
is_reverse
else
hidden
hidden
=
_reverse
(
hidden
,
offset
)
if
is_reverse
else
hidden
cell
=
_reverse
(
cell
,
offset
)
if
is_reverse
else
cell
cell
=
_reverse
(
cell
,
offset
)
if
is_reverse
else
cell
assert
gate
.
shape
==
input
.
shape
assert
hidden
.
shape
==
(
input
.
shape
[
0
],
input
.
shape
[
1
]
/
4
)
assert
hidden
.
shape
==
(
input
.
shape
[
0
],
input
.
shape
[
1
]
/
4
)
assert
cell
.
shape
==
(
input
.
shape
[
0
],
input
.
shape
[
1
]
/
4
)
assert
cell
.
shape
==
(
input
.
shape
[
0
],
input
.
shape
[
1
]
/
4
)
return
hidden
,
cell
,
gate
return
hidden
,
cell
class
TestLstmOp
(
OpTest
):
class
TestLstmOp
(
OpTest
):
def
set_data
(
self
):
def
set_argument
(
self
):
self
.
lod
=
[[
0
,
2
,
6
,
9
]]
self
.
lod
=
[[
0
,
2
,
5
,
7
]]
self
.
D
=
64
self
.
D
=
16
self
.
sort_idx
=
[
2
,
6
,
0
,
3
,
7
,
1
,
4
,
8
,
5
]
self
.
act_gate
=
"sigmoid"
self
.
act_gate
=
'sigmoid'
self
.
act_cell
=
"tanh"
self
.
act_cell
=
'tanh'
self
.
act_cand
=
"tanh"
self
.
act_cand
=
'tanh'
self
.
has_initial_state
=
True
self
.
is_reverse
=
False
self
.
is_reverse
=
False
def
setUp
(
self
):
def
setUp
(
self
):
self
.
set_
data
()
self
.
set_
argument
()
self
.
op_type
=
"lstm"
self
.
op_type
=
'lstm'
T
=
self
.
lod
[
0
][
-
1
]
T
=
self
.
lod
[
0
][
-
1
]
N
=
len
(
self
.
lod
[
0
])
-
1
N
=
len
(
self
.
lod
[
0
])
-
1
x
=
np
.
random
.
normal
(
size
=
(
T
,
4
*
self
.
D
)).
astype
(
"float64"
)
x
=
np
.
random
.
normal
(
size
=
(
T
,
4
*
self
.
D
)).
astype
(
'float64'
)
h0
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
"float64"
)
h0
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
c0
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
"float64"
)
c0
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
w
=
np
.
random
.
normal
(
size
=
(
self
.
D
,
4
*
self
.
D
)).
astype
(
"float64"
)
w
=
np
.
random
.
normal
(
size
=
(
self
.
D
,
4
*
self
.
D
)).
astype
(
'float64'
)
b
=
np
.
random
.
normal
(
size
=
(
1
,
7
*
self
.
D
)).
astype
(
"float64"
)
b
=
np
.
random
.
normal
(
size
=
(
1
,
7
*
self
.
D
)).
astype
(
'float64'
)
w_b
=
b
[:,
0
:
4
*
self
.
D
]
w_b
=
b
[:,
0
:
4
*
self
.
D
]
w_c
=
b
[:,
4
*
self
.
D
:]
w_c
=
b
[:,
4
*
self
.
D
:]
h
,
c
,
g
=
lstm
(
x
,
self
.
lod
,
h0
,
c0
,
w
,
w_b
,
w_c
,
self
.
is_reverse
,
h
,
c
=
lstm
(
x
,
self
.
lod
,
h0
,
c0
,
w
,
w_b
,
w_c
,
self
.
is_reverse
,
ACTVATION
[
self
.
act_gate
],
ACTVATION
[
self
.
act_cell
],
ACTVATION
[
self
.
act_gate
],
ACTVATION
[
self
.
act_cell
],
ACTVATION
[
self
.
act_cand
])
ACTVATION
[
self
.
act_cand
])
g_sort
=
np
.
zeros_like
(
x
)
self
.
inputs
=
{
'Input'
:
(
x
,
self
.
lod
),
'Weight'
:
w
,
'Bias'
:
b
}
for
i
,
j
in
enumerate
(
self
.
sort_idx
):
if
self
.
has_initial_state
:
g_sort
[
i
,
:]
=
g
[
j
,
:]
self
.
inputs
[
'H0'
]
=
h0
self
.
inputs
[
'C0'
]
=
c0
self
.
inputs
=
{
'Input'
:
(
x
,
self
.
lod
),
'H0'
:
h0
,
'C0'
:
c0
,
'Weight'
:
w
,
'Bias'
:
b
}
self
.
outputs
=
{
self
.
outputs
=
{
'Hidden'
:
(
h
,
self
.
lod
),
'Hidden'
:
(
h
,
self
.
lod
),
'Cell'
:
(
c
,
self
.
lod
),
'Cell'
:
(
c
,
self
.
lod
),
'BatchGate'
:
g_sort
}
}
self
.
attrs
=
{
self
.
attrs
=
{
'usePeepholes'
:
True
,
'usePeepholes'
:
True
,
'isReverse'
:
self
.
is_reverse
,
'isReverse'
:
self
.
is_reverse
,
'gateActivation'
:
'sigmoid'
,
'gateActivation'
:
self
.
act_gate
,
'cellActivation'
:
'tanh'
,
'cellActivation'
:
self
.
act_cell
,
'candidateActivation'
:
'tanh'
'candidateActivation'
:
self
.
act_cand
}
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
(
atol
=
1e-8
)
#TODO(qingqing) add more unit testing case
def
test_check_grad
(
self
):
# TODO(qingqing) remove folowing lines after the check_grad is refined.
N
=
len
(
self
.
lod
[
0
])
-
1
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
(
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
check_grad
(
[
'Input'
,
'Weight'
,
'Bias'
],
[
'Hidden'
],
max_relative_error
=
5e-4
)
class
TestLstmOpHasNoInitial
(
TestLstmOp
):
def
set_argument
(
self
):
self
.
lod
=
[[
0
,
2
,
5
,
7
]]
self
.
D
=
16
self
.
act_gate
=
'sigmoid'
self
.
act_cell
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
has_initial_state
=
False
self
.
is_reverse
=
True
class
TestLstmOpRerverse
(
TestLstmOp
):
class
TestLstmOpRerverse
(
TestLstmOp
):
def
set_data
(
self
):
def
set_argument
(
self
):
self
.
lod
=
[[
0
,
2
,
6
,
9
]]
self
.
lod
=
[[
0
,
2
,
5
,
7
]]
self
.
D
=
64
self
.
D
=
16
self
.
sort_idx
=
[
2
,
6
,
0
,
3
,
7
,
1
,
4
,
8
,
5
]
self
.
act_gate
=
"sigmoid"
self
.
act_gate
=
'sigmoid'
self
.
act_cell
=
"tanh"
self
.
act_cell
=
'tanh'
self
.
act_cand
=
"tanh"
self
.
act_cand
=
'tanh'
self
.
has_initial_state
=
True
self
.
is_reverse
=
True
self
.
is_reverse
=
True
if
__name__
==
"__main__"
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录