Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
3d8b6ebc
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
12 个月 前同步成功
通知
692
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3d8b6ebc
编写于
10月 24, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add LSTM backward implenmentation.
上级
3f1062d7
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
237 addition
and
45 deletion
+237
-45
paddle/operators/lstm_op.cc
paddle/operators/lstm_op.cc
+37
-19
paddle/operators/lstm_op.h
paddle/operators/lstm_op.h
+189
-25
paddle/operators/math/sequence2batch.h
paddle/operators/math/sequence2batch.h
+11
-1
未找到文件。
paddle/operators/lstm_op.cc
浏览文件 @
3d8b6ebc
...
...
@@ -21,7 +21,6 @@ class LSTMOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of LSTM should not be null."
);
...
...
@@ -30,8 +29,8 @@ class LSTMOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
"Output(Cell) of LSTM should not be null."
);
auto
x
_dims
=
ctx
->
GetInputDim
(
"Input"
);
PADDLE_ENFORCE_EQ
(
x
_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
auto
in
_dims
=
ctx
->
GetInputDim
(
"Input"
);
PADDLE_ENFORCE_EQ
(
in
_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
if
(
ctx
->
HasInput
(
"H0"
))
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
...
...
@@ -44,7 +43,7 @@ class LSTMOp : public framework::OperatorWithKernel {
"should be the same."
);
}
int
frame_size
=
x
_dims
[
1
]
/
4
;
int
frame_size
=
in
_dims
[
1
]
/
4
;
auto
w_dims
=
ctx
->
GetInputDim
(
"Weight"
);
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
2
,
"The rank of Input(Weight) should be 2."
);
...
...
@@ -71,9 +70,11 @@ class LSTMOp : public framework::OperatorWithKernel {
"4 * %d if disable peepholes connection"
,
frame_size
);
}
ctx
->
SetOutputDim
(
"Hidden"
,
{
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Cell"
,
{
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"BatchGate"
,
x_dims
);
framework
::
DDim
out_dims
({
in_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchGate"
,
in_dims
);
ctx
->
SetOutputDim
(
"BatchCellPreAct"
,
out_dims
);
ctx
->
ShareLoD
(
"Input"
,
"Hidden"
);
ctx
->
ShareLoD
(
"Input"
,
"Cell"
);
}
...
...
@@ -86,7 +87,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"Input"
,
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X 4D), where
,
T is the "
"this LoDTensor is a matrix with shape (T X 4D), where T is the "
"total time steps in this mini-batch, D is the hidden size."
);
AddInput
(
"H0"
,
"(Tensor, optional) the initial hidden state is an optional "
...
...
@@ -110,21 +111,25 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"2. `usePeepholes = True` "
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."
);
AddOutput
(
"Hidden"
,
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`."
);
AddOutput
(
"Cell"
,
"(LoDTensor) the cell state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`."
);
AddOutput
(
"BatchGate"
,
"(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate after the nonlinear computation. This "
"LoDTensor has the same shape with the reorganized input, which "
"
wa
s also be called batch input. The LoD size is 2. The first "
"
i
s also be called batch input. The LoD size is 2. The first "
"LoD is the batch offsets and the second LoD contains the "
"indexes, which denote the position of reorganized sequence "
"in the raw input."
)
.
AsIntermediate
();
AddOutput
(
"Hidden"
,
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`."
);
AddOutput
(
"Cell"
,
"(LoDTensor) the cell state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`."
);
AddOutput
(
"BatchCellPreAct"
,
"(LoDTensor) This LoDTensor is get in the forward and used "
"in the backward."
)
.
AsIntermediate
();
AddAttr
<
bool
>
(
"usePeepholes"
,
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections."
)
...
...
@@ -202,15 +207,28 @@ class LSTMGradOp : public framework::OperatorWithKernel {
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Hidden"
)),
"Input(Hidden@GRAD) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Cell"
)),
"Input(Cell@GRAD) should not be null"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Weight"
),
ctx
->
GetInputDim
(
"Weight"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Bias"
),
ctx
->
GetInputDim
(
"Bias"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Input"
),
ctx
->
GetInputDim
(
"Input"
));
if
(
ctx
->
HasInput
(
"Weight"
))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Weight"
),
ctx
->
GetInputDim
(
"Weight"
));
}
if
(
ctx
->
HasInput
(
"Bias"
))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Bias"
),
ctx
->
GetInputDim
(
"Bias"
));
}
if
(
ctx
->
HasInput
(
"H0"
))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"H0"
),
ctx
->
GetInputDim
(
"H0"
));
}
if
(
ctx
->
HasInput
(
"C0"
))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"C0"
),
ctx
->
GetInputDim
(
"C0"
));
}
}
};
...
...
paddle/operators/lstm_op.h
浏览文件 @
3d8b6ebc
...
...
@@ -21,8 +21,9 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
using
framework
::
LoDTensor
;
using
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
...
...
@@ -31,15 +32,15 @@ template <typename Place, typename T>
class
LSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
);
auto
*
weight
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Weight"
);
auto
*
bias
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Bias"
);
auto
*
input
=
ctx
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
batch_gate
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"BatchGate"
);
auto
*
batch_gate
=
ctx
.
Output
<
LoDTensor
>
(
"BatchGate"
);
batch_gate
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
hidden_out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Hidden"
);
auto
*
hidden_out
=
ctx
.
Output
<
LoDTensor
>
(
"Hidden"
);
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
cell_out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Cell"
);
auto
*
cell_out
=
ctx
.
Output
<
LoDTensor
>
(
"Cell"
);
cell_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Now the function ShareLoD in InferShape is not implemented.
...
...
@@ -49,7 +50,8 @@ class LSTMKernel : public framework::OpKernel<T> {
bool
is_reverse
=
ctx
.
Attr
<
bool
>
(
"isReverse"
);
math
::
LoDTensor2BatchFunctor
<
Place
,
T
>
to_batch
;
to_batch
(
ctx
.
device_context
(),
*
input
,
*
batch_gate
,
is_reverse
);
auto
&
device_ctx
=
ctx
.
device_context
();
to_batch
(
device_ctx
,
*
input
,
*
batch_gate
,
true
,
is_reverse
);
auto
in_dims
=
input
->
dims
();
int
frame_size
=
static_cast
<
int
>
(
in_dims
[
1
]
/
4
);
...
...
@@ -69,15 +71,23 @@ class LSTMKernel : public framework::OpKernel<T> {
}
math
::
LstmMetaValue
<
T
>
lstm_value
;
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
// the code style in LstmMetaValue will be updated later.
lstm_value
.
checkIg
=
bias_data
+
4
*
frame_size
;
lstm_value
.
checkFg
=
lstm_value
.
checkIg
+
frame_size
;
lstm_value
.
checkOg
=
lstm_value
.
checkFg
+
frame_size
;
if
(
bias
)
{
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
// the code style in LstmMetaValue will be updated later.
lstm_value
.
checkIg
=
bias_data
+
4
*
frame_size
;
lstm_value
.
checkFg
=
lstm_value
.
checkIg
+
frame_size
;
lstm_value
.
checkOg
=
lstm_value
.
checkFg
+
frame_size
;
}
else
{
lstm_value
.
checkIg
=
nullptr
;
lstm_value
.
checkFg
=
nullptr
;
lstm_value
.
checkOg
=
nullptr
;
}
lstm_value
.
prevStateValue
=
nullptr
;
framework
::
LoDTensor
batch_out
,
batch_cell
,
batch_cell_pre_act
;
batch_out
.
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
// Use the local variable as here.
LoDTensor
batch_hidden
,
batch_cell
;
auto
batch_cell_pre_act
=
*
(
ctx
.
Output
<
LoDTensor
>
(
"BatchCellPreAct"
));
batch_hidden
.
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
batch_cell
.
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
batch_cell_pre_act
.
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
...
...
@@ -92,7 +102,7 @@ class LSTMKernel : public framework::OpKernel<T> {
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
out_t
=
batch_
out
.
Slice
(
bstart
,
bend
);
Tensor
out_t
=
batch_
hidden
.
Slice
(
bstart
,
bend
);
Tensor
cell_t
=
batch_cell
.
Slice
(
bstart
,
bend
);
Tensor
cell_pre_act_t
=
batch_cell_pre_act
.
Slice
(
bstart
,
bend
);
...
...
@@ -101,9 +111,9 @@ class LSTMKernel : public framework::OpKernel<T> {
if
(
n
!=
0
)
{
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_hidden_t
=
batch_
out
.
Slice
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
Place
,
T
>
(
ctx
.
device_context
(),
pre_hidden_
t
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
auto
pre_hidden_t
=
batch_
hidden
.
Slice
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
Place
,
T
>
(
device_ctx
,
pre_hidden_t
,
false
,
*
weigh
t
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
// else if : FIXME support the initial hidden and cell
...
...
@@ -112,27 +122,181 @@ class LSTMKernel : public framework::OpKernel<T> {
lstm_value
.
outputValue
=
out_t
.
data
<
T
>
();
lstm_value
.
stateValue
=
cell_t
.
data
<
T
>
();
lstm_value
.
stateActiveValue
=
cell_pre_act_t
.
data
<
T
>
();
math
::
LstmUnitFunctor
<
Place
,
T
>::
compute
(
ctx
.
device_context
()
,
lstm_value
,
math
::
LstmUnitFunctor
<
Place
,
T
>::
compute
(
device_ctx
,
lstm_value
,
frame_size
,
cur_batch_size
,
gate_act
,
cell_act
,
cand_act
);
lstm_value
.
prevStateValue
=
lstm_value
.
stateValue
;
}
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
batch_
out
.
set_lod
(
batch_gate
->
lod
());
batch_
hidden
.
set_lod
(
batch_gate
->
lod
());
// restore the output hidden in LoDTensor from the batch hidden
to_seq
(
ctx
.
device_context
(),
batch_out
,
*
hidden_out
);
to_seq
(
device_ctx
,
batch_hidden
,
*
hidden_out
);
batch_cell
.
set_lod
(
batch_gate
->
lod
());
// restore the output cell state in LoDTensor from the batch cell
to_seq
(
ctx
.
device_context
()
,
batch_cell
,
*
cell_out
);
to_seq
(
device_ctx
,
batch_cell
,
*
cell_out
);
}
};
template
<
typename
Place
,
typename
T
>
class
LSTMGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{}
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
hidden_out
=
ctx
.
Input
<
LoDTensor
>
(
"Hidden"
);
auto
*
cell_out
=
ctx
.
Input
<
LoDTensor
>
(
"Cell"
);
auto
*
batch_gate
=
ctx
.
Input
<
LoDTensor
>
(
"BatchGate"
);
auto
*
batch_cell_pre_act
=
ctx
.
Input
<
LoDTensor
>
(
"BatchCellPreAct"
);
auto
*
hidden_g
=
ctx
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Hidden"
));
auto
*
cell_g
=
ctx
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Cell"
));
auto
*
in_g
=
ctx
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"Input"
));
auto
*
weight_g
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Weight"
));
auto
*
bias_g
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias"
));
auto
&
device_ctx
=
ctx
.
device_context
();
if
(
weight_g
)
{
math
::
SetConstant
<
Place
,
T
>
zero
;
zero
(
device_ctx
,
weight_g
,
static_cast
<
T
>
(
0.0
));
}
auto
in_dims
=
input
->
dims
();
auto
out_dims
=
hidden_g
->
dims
();
int
frame_size
=
static_cast
<
int
>
(
in_dims
[
1
]
/
4
);
PADDLE_ENFORCE_EQ
(
frame_size
,
out_dims
[
1
]);
math
::
LstmMetaValue
<
T
>
lstm_value
;
if
(
bias
)
{
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
lstm_value
.
checkIg
=
bias_data
+
4
*
frame_size
;
lstm_value
.
checkFg
=
lstm_value
.
checkIg
+
frame_size
;
lstm_value
.
checkOg
=
lstm_value
.
checkFg
+
frame_size
;
}
else
{
lstm_value
.
checkIg
=
nullptr
;
lstm_value
.
checkFg
=
nullptr
;
lstm_value
.
checkOg
=
nullptr
;
}
math
::
LstmMetaGrad
<
T
>
lstm_grad
;
if
(
bias
&&
bias_g
)
{
T
*
bias_g_data
=
const_cast
<
T
*>
(
bias_g
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
lstm_grad
.
checkIgGrad
=
bias_g_data
+
4
*
frame_size
;
lstm_grad
.
checkFgGrad
=
lstm_grad
.
checkIgGrad
+
frame_size
;
lstm_grad
.
checkOgGrad
=
lstm_grad
.
checkFgGrad
+
frame_size
;
}
else
{
lstm_grad
.
checkIgGrad
=
nullptr
;
lstm_grad
.
checkFgGrad
=
nullptr
;
lstm_grad
.
checkOgGrad
=
nullptr
;
}
math
::
LoDTensor2BatchFunctor
<
Place
,
T
>
to_batch
;
// use the local variable as here.
LoDTensor
batch_hidden
;
batch_hidden
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
batch_hidden
.
set_lod
(
batch_gate
->
lod
());
to_batch
(
device_ctx
,
*
hidden_out
,
batch_hidden
,
false
);
LoDTensor
batch_hidden_g
;
batch_hidden_g
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
batch_hidden_g
.
set_lod
(
batch_gate
->
lod
());
to_batch
(
device_ctx
,
*
hidden_g
,
batch_hidden_g
,
false
);
LoDTensor
batch_cell
;
batch_cell
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
batch_cell
.
set_lod
(
batch_gate
->
lod
());
to_batch
(
device_ctx
,
*
cell_out
,
batch_cell
,
false
);
LoDTensor
batch_cell_g
;
batch_cell_g
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
batch_cell_g
.
set_lod
(
batch_gate
->
lod
());
to_batch
(
device_ctx
,
*
cell_g
,
batch_cell_g
,
false
);
LoDTensor
batch_gate_g
;
batch_gate_g
.
mutable_data
<
T
>
(
batch_gate
->
dims
(),
ctx
.
GetPlace
());
batch_gate_g
.
set_lod
(
batch_gate
->
lod
());
auto
gate_act
=
ctx
.
Attr
<
std
::
string
>
(
"gateActivation"
);
auto
cell_act
=
ctx
.
Attr
<
std
::
string
>
(
"cellActivation"
);
auto
cand_act
=
ctx
.
Attr
<
std
::
string
>
(
"candidateActivation"
);
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
for
(
int
n
=
static_cast
<
int
>
(
num_batch
);
n
>=
0
;
n
--
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
Tensor
gate
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
cell
=
batch_cell
.
Slice
(
bstart
,
bend
);
Tensor
cell_pre_act
=
batch_cell_pre_act
->
Slice
(
bstart
,
bend
);
lstm_value
.
gateValue
=
gate
.
data
<
T
>
();
lstm_value
.
stateValue
=
cell
.
data
<
T
>
();
lstm_value
.
stateActiveValue
=
cell_pre_act
.
data
<
T
>
();
Tensor
out_g
=
batch_hidden_g
.
Slice
(
bstart
,
bend
);
Tensor
gate_g
=
batch_gate_g
.
Slice
(
bstart
,
bend
);
Tensor
cell_g
=
batch_cell_g
.
Slice
(
bstart
,
bend
);
lstm_grad
.
stateGrad
=
cell_g
.
data
<
T
>
();
lstm_grad
.
gateGrad
=
gate_g
.
data
<
T
>
();
lstm_grad
.
outputGrad
=
out_g
.
data
<
T
>
();
if
(
n
!=
0
)
{
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
Tensor
cell_pre
=
batch_cell
.
Slice
(
bstart_pre
,
bstart
);
Tensor
cell_pre_g
=
batch_cell_g
.
Slice
(
bstart_pre
,
bstart
);
lstm_value
.
prevStateValue
=
cell_pre
.
data
<
T
>
();
lstm_grad
.
prevStateGrad
=
cell_pre_g
.
data
<
T
>
();
}
else
{
lstm_value
.
prevStateValue
=
nullptr
;
lstm_grad
.
prevStateGrad
=
nullptr
;
}
int
cur_batch_size
=
bend
-
bstart
;
math
::
LstmUnitGradFunctor
<
Place
,
T
>::
compute
(
device_ctx
,
lstm_value
,
lstm_grad
,
frame_size
,
cur_batch_size
,
gate_act
,
cell_act
,
cand_act
);
if
(
n
!=
0
)
{
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_hidden_g
=
batch_hidden_g
.
Slice
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
Place
,
T
>
(
device_ctx
,
gate_g
,
false
,
*
weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
pre_hidden_g
,
static_cast
<
T
>
(
1.0
));
if
(
weight_g
)
{
/* backward weight */
auto
pre_hidden
=
batch_hidden
.
Slice
(
pre_h_start
,
pre_h_end
);
math
::
matmul
<
Place
,
T
>
(
device_ctx
,
pre_hidden
,
true
,
gate_g
,
false
,
static_cast
<
T
>
(
1.0
),
weight_g
,
static_cast
<
T
>
(
1.0
));
}
}
}
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
if
(
in_g
)
{
/* backward data */
to_seq
(
device_ctx
,
batch_gate_g
,
*
in_g
);
}
if
(
bias
&&
bias_g
)
{
/* backward bias */
bias_g
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
bias_g_e
=
EigenMatrix
<
T
>::
From
(
*
bias_g
);
auto
gate_g_e
=
EigenMatrix
<
T
>::
From
(
batch_gate_g
);
Eigen
::
array
<
int
,
2
>
extents
({{
1
,
4
*
frame_size
}});
Eigen
::
array
<
int
,
2
>
offsets
({{
0
,
0
}});
auto
bg
=
bias_g_e
.
slice
(
offsets
,
extents
)
.
reshape
(
Eigen
::
array
<
int
,
2
>
({{
1
,
frame_size
*
4
}}));
bg
.
device
(
ctx
.
GetEigenDevice
<
Place
>
())
=
gate_g_e
.
sum
(
Eigen
::
array
<
int
,
1
>
({{
0
}}));
}
}
};
}
// namespace operators
...
...
paddle/operators/math/sequence2batch.h
浏览文件 @
3d8b6ebc
...
...
@@ -53,7 +53,17 @@ class LoDTensor2BatchFunctor {
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
lod_tensor
,
framework
::
LoDTensor
&
batch
,
bool
is_reverse
)
const
{
framework
::
LoDTensor
&
batch
,
bool
is_cal_batch_lod
,
bool
is_reverse
=
false
)
const
{
if
(
!
is_cal_batch_lod
)
{
auto
lods
=
batch
.
lod
();
PADDLE_ENFORCE_EQ
(
lods
.
size
(),
2UL
);
PADDLE_ENFORCE_EQ
(
lods
[
1
].
size
(),
lod_tensor
.
dims
()[
1
]);
CopyMatrixRowsFunctor
<
Place
,
T
>
to_batch
;
to_batch
(
context
,
lod_tensor
,
lods
[
1
].
data
(),
batch
,
true
);
return
;
}
auto
lods
=
lod_tensor
.
lod
();
PADDLE_ENFORCE_EQ
(
lods
.
size
(),
1UL
,
"Only support one level sequence now."
);
auto
lod
=
lods
[
0
];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录