Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
81870723
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
81870723
编写于
2月 22, 2019
作者:
X
xuezhong
提交者:
GitHub
2月 22, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15605 from xuezhong/fix_bug_for_lstmp
Fix bug for lstmp
上级
3b08c9ab
6b83845c
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
259 addition
and
135 deletion
+259
-135
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-1
paddle/fluid/operators/lstm_op.h
paddle/fluid/operators/lstm_op.h
+5
-3
paddle/fluid/operators/lstmp_op.cc
paddle/fluid/operators/lstmp_op.cc
+10
-11
paddle/fluid/operators/lstmp_op.h
paddle/fluid/operators/lstmp_op.h
+67
-35
paddle/fluid/operators/math/detail/lstm_cpu_kernel.h
paddle/fluid/operators/math/detail/lstm_cpu_kernel.h
+20
-18
paddle/fluid/operators/math/detail/lstm_gpu_kernel.h
paddle/fluid/operators/math/detail/lstm_gpu_kernel.h
+16
-14
paddle/fluid/operators/math/detail/lstm_kernel.h
paddle/fluid/operators/math/detail/lstm_kernel.h
+49
-13
paddle/fluid/operators/math/lstm_compute.cc
paddle/fluid/operators/math/lstm_compute.cc
+5
-4
paddle/fluid/operators/math/lstm_compute.cu
paddle/fluid/operators/math/lstm_compute.cu
+6
-6
paddle/fluid/operators/math/lstm_compute.h
paddle/fluid/operators/math/lstm_compute.h
+2
-2
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+39
-8
python/paddle/fluid/tests/unittests/test_lstmp_op.py
python/paddle/fluid/tests/unittests/test_lstmp_op.py
+39
-20
未找到文件。
paddle/fluid/API.spec
浏览文件 @
81870723
...
@@ -71,7 +71,7 @@ paddle.fluid.initializer.NumpyArrayInitializer.__init__ ArgSpec(args=['self', 'v
...
@@ -71,7 +71,7 @@ paddle.fluid.initializer.NumpyArrayInitializer.__init__ ArgSpec(args=['self', 'v
paddle.fluid.layers.fc ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param_attr', 'bias_attr', 'act', 'is_test', 'name'], varargs=None, keywords=None, defaults=(1, None, None, None, False, None))
paddle.fluid.layers.fc ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param_attr', 'bias_attr', 'act', 'is_test', 'name'], varargs=None, keywords=None, defaults=(1, None, None, None, False, None))
paddle.fluid.layers.embedding ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32'))
paddle.fluid.layers.embedding ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32'))
paddle.fluid.layers.dynamic_lstm ArgSpec(args=['input', 'size', 'h_0', 'c_0', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'float32', None))
paddle.fluid.layers.dynamic_lstm ArgSpec(args=['input', 'size', 'h_0', 'c_0', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'float32', None))
paddle.fluid.layers.dynamic_lstmp ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name'
], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32'
, None))
paddle.fluid.layers.dynamic_lstmp ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name'
, 'h_0', 'c_0', 'cell_clip', 'proj_clip'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None, None, None, None
, None))
paddle.fluid.layers.dynamic_gru ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None, False))
paddle.fluid.layers.dynamic_gru ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None, False))
paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid', False))
paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid', False))
paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,))
...
...
paddle/fluid/operators/lstm_op.h
浏览文件 @
81870723
...
@@ -151,9 +151,10 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -151,9 +151,10 @@ class LSTMKernel : public framework::OpKernel<T> {
lstm_value
.
output_value
=
out_t
.
data
<
T
>
();
lstm_value
.
output_value
=
out_t
.
data
<
T
>
();
lstm_value
.
state_value
=
cell_t
.
data
<
T
>
();
lstm_value
.
state_value
=
cell_t
.
data
<
T
>
();
lstm_value
.
state_active_value
=
cell_pre_act_t
.
data
<
T
>
();
lstm_value
.
state_active_value
=
cell_pre_act_t
.
data
<
T
>
();
T
cell_clip
=
0.0
;
math
::
LstmUnitFunctor
<
DeviceContext
,
T
>::
compute
(
math
::
LstmUnitFunctor
<
DeviceContext
,
T
>::
compute
(
device_ctx
,
lstm_value
,
frame_size
,
cur_batch_size
,
gate_act
,
device_ctx
,
lstm_value
,
frame_size
,
cur_batch_size
,
cell_clip
,
cell_act
,
cand_act
);
gate_act
,
cell_act
,
cand_act
);
lstm_value
.
prev_state_value
=
lstm_value
.
state_value
;
lstm_value
.
prev_state_value
=
lstm_value
.
state_value
;
}
}
...
@@ -316,9 +317,10 @@ class LSTMGradKernel : public framework::OpKernel<T> {
...
@@ -316,9 +317,10 @@ class LSTMGradKernel : public framework::OpKernel<T> {
lstm_value
.
output_value
=
nullptr
;
lstm_value
.
output_value
=
nullptr
;
lstm_grad
.
state_active_grad
=
nullptr
;
lstm_grad
.
state_active_grad
=
nullptr
;
int
cur_batch_size
=
bend
-
bstart
;
int
cur_batch_size
=
bend
-
bstart
;
T
cell_clip
=
0.0
;
math
::
LstmUnitGradFunctor
<
DeviceContext
,
T
>::
compute
(
math
::
LstmUnitGradFunctor
<
DeviceContext
,
T
>::
compute
(
device_ctx
,
lstm_value
,
lstm_grad
,
frame_size
,
cur_batch_size
,
device_ctx
,
lstm_value
,
lstm_grad
,
frame_size
,
cur_batch_size
,
gate_act
,
cell_act
,
cand_act
);
cell_clip
,
gate_act
,
cell_act
,
cand_act
);
if
(
n
>
0
)
{
if
(
n
>
0
)
{
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
...
...
paddle/fluid/operators/lstmp_op.cc
浏览文件 @
81870723
...
@@ -73,12 +73,6 @@ class LSTMPOp : public framework::OperatorWithKernel {
...
@@ -73,12 +73,6 @@ class LSTMPOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
"Input(C0) of LSTMP operator should not be null after "
"Input(C0) of LSTMP operator should not be null after "
"Input(H0) provided."
);
"Input(H0) provided."
);
auto
h_dims
=
ctx
->
GetInputDim
(
"H0"
);
auto
c_dims
=
ctx
->
GetInputDim
(
"C0"
);
PADDLE_ENFORCE
(
h_dims
==
c_dims
,
"The dimension of Input(H0) and Input(C0) "
"should be the same."
);
ctx
->
SetOutputDim
(
"OrderedP0"
,
{
h_dims
[
0
],
proj_dims
[
1
]});
}
}
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
...
@@ -180,11 +174,6 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -180,11 +174,6 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
"This LoDTensor is obtained in the forward and used in the "
"This LoDTensor is obtained in the forward and used in the "
"backward."
)
"backward."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
"OrderedP0"
,
"(Tensor) the projection of the initial hidden state "
"H0. This is a tensor with shape (N x P), where N is the "
"batch size and P is the hidden size."
)
.
AsIntermediate
();
AddAttr
<
bool
>
(
"use_peepholes"
,
AddAttr
<
bool
>
(
"use_peepholes"
,
"(bool, defalut: True) "
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections."
)
"whether to enable diagonal/peephole connections."
)
...
@@ -193,6 +182,16 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -193,6 +182,16 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, defalut: False) "
"(bool, defalut: False) "
"whether to compute reversed LSTMP."
)
"whether to compute reversed LSTMP."
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddAttr
<
float
>
(
"cell_clip"
,
"(float, defalut: 0.0) "
"Clip for Tensor for cell state tensor when clip value is "
"greater than 0.0"
)
.
SetDefault
(
0.0
);
AddAttr
<
float
>
(
"proj_clip"
,
"(float, defalut: 0.0) "
"Clip for Tensor for projection tensor when clip value is "
"greater than 0.0"
)
.
SetDefault
(
0.0
);
AddAttr
<
std
::
string
>
(
AddAttr
<
std
::
string
>
(
"gate_activation"
,
"gate_activation"
,
"(string, default: sigmoid)"
"(string, default: sigmoid)"
...
...
paddle/fluid/operators/lstmp_op.h
浏览文件 @
81870723
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <string>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/activation_op.h"
...
@@ -21,17 +22,50 @@ limitations under the License. */
...
@@ -21,17 +22,50 @@ limitations under the License. */
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/transform.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
using
platform
::
Transform
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
>
class
_ClipFunctor
{
public:
explicit
_ClipFunctor
(
const
T
min
,
const
T
max
)
:
min_
(
min
),
max_
(
max
)
{}
HOSTDEVICE
T
operator
()(
const
T
&
x
)
const
{
if
(
x
<
min_
)
return
min_
;
else
if
(
x
>
max_
)
return
max_
;
else
return
x
;
}
private:
T
min_
;
T
max_
;
};
template
<
typename
T
>
class
_ClipGradFunctor
{
public:
explicit
_ClipGradFunctor
(
const
T
min
,
const
T
max
)
:
min_
(
min
),
max_
(
max
)
{}
HOSTDEVICE
T
operator
()(
const
T
&
x
,
const
T
&
y
)
const
{
return
(
y
>
min_
&&
y
<
max_
)
?
x
:
0
;
}
private:
T
min_
;
T
max_
;
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
inline
void
ReorderInitState
(
const
DeviceContext
&
ctx
,
inline
void
ReorderInitState
(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
src
,
const
framework
::
Tensor
&
src
,
...
@@ -67,9 +101,11 @@ class LSTMPKernel : public framework::OpKernel<T> {
...
@@ -67,9 +101,11 @@ class LSTMPKernel : public framework::OpKernel<T> {
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
hidden_t0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
auto
*
hidden_t0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
auto
*
ordered_proj0
=
ctx
.
Output
<
Tensor
>
(
"OrderedP0"
);
auto
*
cell_t0
=
ctx
.
Input
<
Tensor
>
(
"C0"
);
auto
*
cell_t0
=
ctx
.
Input
<
Tensor
>
(
"C0"
);
auto
proj_clip
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"proj_clip"
));
auto
cell_clip
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"cell_clip"
));
auto
*
batch_gate
=
ctx
.
Output
<
LoDTensor
>
(
"BatchGate"
);
auto
*
batch_gate
=
ctx
.
Output
<
LoDTensor
>
(
"BatchGate"
);
batch_gate
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
batch_gate
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
proj_out
=
ctx
.
Output
<
LoDTensor
>
(
"Projection"
);
auto
*
proj_out
=
ctx
.
Output
<
LoDTensor
>
(
"Projection"
);
...
@@ -110,6 +146,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
...
@@ -110,6 +146,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
}
}
lstmp_value
.
prev_state_value
=
nullptr
;
lstmp_value
.
prev_state_value
=
nullptr
;
Tensor
ordered_c0
;
Tensor
ordered_c0
;
Tensor
ordered_h0
;
framework
::
Vector
<
size_t
>
order
(
batch_gate
->
lod
()[
2
]);
framework
::
Vector
<
size_t
>
order
(
batch_gate
->
lod
()[
2
]);
...
@@ -169,18 +206,9 @@ class LSTMPKernel : public framework::OpKernel<T> {
...
@@ -169,18 +206,9 @@ class LSTMPKernel : public framework::OpKernel<T> {
// Since the batch computing for LSTMP reorders the input sequence
// Since the batch computing for LSTMP reorders the input sequence
// according to their length. The initialized hidden state also needs
// according to their length. The initialized hidden state also needs
// to reorder.
// to reorder.
Tensor
ordered_h0
;
ordered_proj0
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ReorderInitState
<
DeviceContext
,
T
>
(
device_ctx
,
*
hidden_t0
,
order
,
ReorderInitState
<
DeviceContext
,
T
>
(
device_ctx
,
*
hidden_t0
,
order
,
&
ordered_h0
,
true
);
&
ordered_h0
,
true
);
blas
.
MatMul
(
ordered_h0
,
false
,
*
proj_weight
,
false
,
static_cast
<
T
>
(
1.0
),
blas
.
MatMul
(
ordered_h0
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
ordered_proj0
,
static_cast
<
T
>
(
0.0
));
if
(
proj_act
!=
math
::
detail
::
ActivationType
::
kIdentity
)
{
auto
proj0_dev
=
EigenMatrix
<
T
>::
From
(
*
ordered_proj0
);
ActCompute
(
cell_act
,
place
,
proj0_dev
,
proj0_dev
);
}
blas
.
MatMul
(
*
ordered_proj0
,
false
,
*
weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
}
...
@@ -189,8 +217,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
...
@@ -189,8 +217,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
lstmp_value
.
state_value
=
cell_t
.
data
<
T
>
();
lstmp_value
.
state_value
=
cell_t
.
data
<
T
>
();
lstmp_value
.
state_active_value
=
cell_pre_act_t
.
data
<
T
>
();
lstmp_value
.
state_active_value
=
cell_pre_act_t
.
data
<
T
>
();
math
::
LstmUnitFunctor
<
DeviceContext
,
T
>::
compute
(
math
::
LstmUnitFunctor
<
DeviceContext
,
T
>::
compute
(
device_ctx
,
lstmp_value
,
frame_size
,
cur_batch_size
,
gate_act
,
device_ctx
,
lstmp_value
,
frame_size
,
cur_batch_size
,
cell_clip
,
cell_act
,
cand_act
);
gate_act
,
cell_act
,
cand_act
);
lstmp_value
.
prev_state_value
=
lstmp_value
.
state_value
;
lstmp_value
.
prev_state_value
=
lstmp_value
.
state_value
;
blas
.
MatMul
(
hidden_t
,
false
,
*
proj_weight
,
false
,
static_cast
<
T
>
(
1.0
),
blas
.
MatMul
(
hidden_t
,
false
,
*
proj_weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
proj_t
,
static_cast
<
T
>
(
0.0
));
&
proj_t
,
static_cast
<
T
>
(
0.0
));
...
@@ -198,6 +226,14 @@ class LSTMPKernel : public framework::OpKernel<T> {
...
@@ -198,6 +226,14 @@ class LSTMPKernel : public framework::OpKernel<T> {
auto
proj_t_dev
=
EigenMatrix
<
T
>::
From
(
proj_t
);
auto
proj_t_dev
=
EigenMatrix
<
T
>::
From
(
proj_t
);
ActCompute
(
cell_act
,
place
,
proj_t_dev
,
proj_t_dev
);
ActCompute
(
cell_act
,
place
,
proj_t_dev
,
proj_t_dev
);
}
}
if
(
proj_clip
&&
proj_clip
>
0.0
)
{
T
*
x_data
=
proj_t
.
data
<
T
>
();
int64_t
numel
=
proj_t
.
numel
();
Transform
<
DeviceContext
>
trans
;
trans
(
ctx
.
template
device_context
<
DeviceContext
>(),
x_data
,
x_data
+
numel
,
x_data
,
_ClipFunctor
<
T
>
(
-
1.0
*
proj_clip
,
proj_clip
));
}
}
}
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
...
@@ -239,6 +275,9 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
...
@@ -239,6 +275,9 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
auto
*
proj_out
=
ctx
.
Input
<
LoDTensor
>
(
"Projection"
);
auto
*
proj_out
=
ctx
.
Input
<
LoDTensor
>
(
"Projection"
);
auto
*
cell_out
=
ctx
.
Input
<
LoDTensor
>
(
"Cell"
);
auto
*
cell_out
=
ctx
.
Input
<
LoDTensor
>
(
"Cell"
);
auto
proj_clip
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"proj_clip"
));
auto
cell_clip
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"cell_clip"
));
auto
*
batch_gate
=
ctx
.
Input
<
LoDTensor
>
(
"BatchGate"
);
auto
*
batch_gate
=
ctx
.
Input
<
LoDTensor
>
(
"BatchGate"
);
auto
*
batch_cell_pre_act
=
ctx
.
Input
<
LoDTensor
>
(
"BatchCellPreAct"
);
auto
*
batch_cell_pre_act
=
ctx
.
Input
<
LoDTensor
>
(
"BatchCellPreAct"
);
auto
*
batch_hidden
=
ctx
.
Input
<
LoDTensor
>
(
"BatchHidden"
);
auto
*
batch_hidden
=
ctx
.
Input
<
LoDTensor
>
(
"BatchHidden"
);
...
@@ -253,7 +292,6 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
...
@@ -253,7 +292,6 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
auto
*
bias_g
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias"
));
auto
*
bias_g
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias"
));
auto
*
h0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
auto
*
h0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
auto
*
ordered_proj0
=
ctx
.
Input
<
Tensor
>
(
"OrderedP0"
);
auto
*
c0
=
ctx
.
Input
<
Tensor
>
(
"C0"
);
auto
*
c0
=
ctx
.
Input
<
Tensor
>
(
"C0"
);
auto
*
h0_g
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"H0"
));
auto
*
h0_g
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"H0"
));
...
@@ -363,6 +401,17 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
...
@@ -363,6 +401,17 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
Tensor
cur_proj
=
batch_proj
.
Slice
(
bstart
,
bend
);
Tensor
cur_proj
=
batch_proj
.
Slice
(
bstart
,
bend
);
Tensor
proj_g
=
batch_proj_g
.
Slice
(
bstart
,
bend
);
Tensor
proj_g
=
batch_proj_g
.
Slice
(
bstart
,
bend
);
if
(
proj_clip
&&
proj_clip
>
0.0
)
{
T
*
dx_data
=
proj_g
.
data
<
T
>
();
T
*
x_data
=
cur_proj
.
data
<
T
>
();
int64_t
numel
=
proj_g
.
numel
();
Transform
<
DeviceContext
>
trans
;
trans
(
ctx
.
template
device_context
<
DeviceContext
>(),
dx_data
,
dx_data
+
numel
,
x_data
,
dx_data
,
_ClipGradFunctor
<
T
>
(
-
1.0
*
proj_clip
,
proj_clip
));
}
if
(
proj_act
!=
math
::
detail
::
ActivationType
::
kIdentity
)
{
if
(
proj_act
!=
math
::
detail
::
ActivationType
::
kIdentity
)
{
auto
cur_proj_dev
=
EigenMatrix
<
T
>::
From
(
cur_proj
);
auto
cur_proj_dev
=
EigenMatrix
<
T
>::
From
(
cur_proj
);
auto
proj_g_dev
=
EigenMatrix
<
T
>::
From
(
proj_g
);
auto
proj_g_dev
=
EigenMatrix
<
T
>::
From
(
proj_g
);
...
@@ -412,7 +461,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
...
@@ -412,7 +461,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
math
::
LstmUnitGradFunctor
<
DeviceContext
,
T
>::
compute
(
math
::
LstmUnitGradFunctor
<
DeviceContext
,
T
>::
compute
(
device_ctx
,
lstmp_value
,
lstmp_grad
,
frame_size
,
cur_batch_size
,
device_ctx
,
lstmp_value
,
lstmp_grad
,
frame_size
,
cur_batch_size
,
gate_act
,
cell_act
,
cand_act
);
cell_clip
,
gate_act
,
cell_act
,
cand_act
);
if
(
n
>
0
)
{
if
(
n
>
0
)
{
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
...
@@ -431,31 +480,14 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
...
@@ -431,31 +480,14 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
ReorderInitState
<
DeviceContext
,
T
>
(
device_ctx
,
*
h0
,
order
,
ReorderInitState
<
DeviceContext
,
T
>
(
device_ctx
,
*
h0
,
order
,
&
ordered_h0
,
true
);
&
ordered_h0
,
true
);
if
(
weight_g
)
{
if
(
weight_g
)
{
blas
.
MatMul
(
*
ordered_proj0
,
true
,
gate_g
,
false
,
blas
.
MatMul
(
ordered_h0
,
true
,
gate_g
,
false
,
static_cast
<
T
>
(
1.0
)
,
static_cast
<
T
>
(
1.0
),
weight_g
,
static_cast
<
T
>
(
1.0
));
weight_g
,
static_cast
<
T
>
(
1.0
));
}
}
}
}
if
(
h0
&&
(
h0_g
||
proj_weight_g
))
{
if
(
h0
&&
(
h0_g
||
proj_weight_g
))
{
ordered_h0_g
.
mutable_data
<
T
>
(
h0_g
->
dims
(),
ctx
.
GetPlace
());
ordered_h0_g
.
mutable_data
<
T
>
(
h0_g
->
dims
(),
ctx
.
GetPlace
());
Tensor
proj0_g
;
proj0_g
.
Resize
({
in_dims
[
0
],
proj_weight
->
dims
()[
1
]});
proj0_g
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
blas
.
MatMul
(
gate_g
,
false
,
*
weight
,
true
,
static_cast
<
T
>
(
1.0
),
blas
.
MatMul
(
gate_g
,
false
,
*
weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
proj0_g
,
static_cast
<
T
>
(
0.0
));
&
ordered_h0_g
,
static_cast
<
T
>
(
0.0
));
if
(
proj_act
!=
math
::
detail
::
ActivationType
::
kIdentity
)
{
auto
proj0_dev
=
EigenMatrix
<
T
>::
From
(
*
ordered_proj0
);
auto
proj0_g_dev
=
EigenMatrix
<
T
>::
From
(
proj0_g
);
ActGradCompute
(
cell_act
,
place
,
proj0_dev
,
proj0_dev
,
proj0_g_dev
,
proj0_g_dev
);
}
if
(
h0_g
)
{
blas
.
MatMul
(
proj0_g
,
false
,
*
proj_weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
ordered_h0_g
,
static_cast
<
T
>
(
0.0
));
}
if
(
proj_weight_g
)
{
blas
.
MatMul
(
ordered_h0
,
true
,
proj0_g
,
false
,
static_cast
<
T
>
(
1.0
),
proj_weight_g
,
static_cast
<
T
>
(
1.0
));
}
}
}
}
}
}
}
...
...
paddle/fluid/operators/math/detail/lstm_cpu_kernel.h
浏览文件 @
81870723
...
@@ -32,7 +32,8 @@ namespace detail {
...
@@ -32,7 +32,8 @@ namespace detail {
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
naive_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
naive_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
ActivationType
active_node
,
int
frame_size
,
T
cell_clip
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
ActivationType
active_state
)
{
T
r_value_in
;
T
r_value_in
;
...
@@ -67,7 +68,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
...
@@ -67,7 +68,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
op
(
&
r_value_in
,
&
r_value_ig
,
&
r_value_fg
,
&
r_value_og
,
&
r_prev_state
,
op
(
&
r_value_in
,
&
r_value_ig
,
&
r_value_fg
,
&
r_value_og
,
&
r_prev_state
,
&
r_state
,
&
r_state_atv
,
&
r_out
,
&
r_checkI
,
&
r_checkF
,
&
r_checkO
,
&
r_state
,
&
r_state_atv
,
&
r_out
,
&
r_checkI
,
&
r_checkF
,
&
r_checkO
,
active_node
,
active_gate
,
active_state
);
&
cell_clip
,
active_node
,
active_gate
,
active_state
);
value_in
[
i
]
=
r_value_in
;
value_in
[
i
]
=
r_value_in
;
value_ig
[
i
]
=
r_value_ig
;
value_ig
[
i
]
=
r_value_ig
;
...
@@ -82,7 +83,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
...
@@ -82,7 +83,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
naive_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
naive_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
ActivationType
active_node
,
T
cell_clip
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
ActivationType
active_state
)
{
T
r_value_in
;
T
r_value_in
;
...
@@ -135,7 +136,7 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
...
@@ -135,7 +136,7 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
&
r_grad_ig
,
&
r_grad_fg
,
&
r_grad_og
,
&
r_prev_state
,
&
r_prev_state_grad
,
&
r_grad_ig
,
&
r_grad_fg
,
&
r_grad_og
,
&
r_prev_state
,
&
r_prev_state_grad
,
&
r_state
,
&
r_state_grad
,
&
r_state_atv
,
&
r_output_grad
,
&
r_checkI
,
&
r_state
,
&
r_state_grad
,
&
r_state_atv
,
&
r_output_grad
,
&
r_checkI
,
&
r_checkF
,
&
r_checkO
,
&
r_checkIGrad
,
&
r_checkFGrad
,
&
r_checkOGrad
,
&
r_checkF
,
&
r_checkO
,
&
r_checkIGrad
,
&
r_checkFGrad
,
&
r_checkOGrad
,
active_node
,
active_gate
,
active_state
);
&
cell_clip
,
active_node
,
active_gate
,
active_state
);
grad_in
[
i
]
=
r_grad_in
;
grad_in
[
i
]
=
r_grad_in
;
grad_ig
[
i
]
=
r_grad_ig
;
grad_ig
[
i
]
=
r_grad_ig
;
...
@@ -154,7 +155,8 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
...
@@ -154,7 +155,8 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
avx_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
avx_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
ActivationType
active_node
,
int
frame_size
,
T
cell_clip
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
ActivationType
active_state
)
{
#ifdef __AVX__
#ifdef __AVX__
...
@@ -194,7 +196,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
...
@@ -194,7 +196,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
op
(
&
r_value_in
,
&
r_value_ig
,
&
r_value_fg
,
&
r_value_og
,
&
r_prev_state
,
op
(
&
r_value_in
,
&
r_value_ig
,
&
r_value_fg
,
&
r_value_og
,
&
r_prev_state
,
&
r_state
,
&
r_state_atv
,
&
r_out
,
&
r_checkI
,
&
r_checkF
,
&
r_checkO
,
&
r_state
,
&
r_state_atv
,
&
r_out
,
&
r_checkI
,
&
r_checkF
,
&
r_checkO
,
active_node
,
active_gate
,
active_state
);
&
cell_clip
,
active_node
,
active_gate
,
active_state
);
value_in
[
i
]
=
r_value_in
;
value_in
[
i
]
=
r_value_in
;
value_ig
[
i
]
=
r_value_ig
;
value_ig
[
i
]
=
r_value_ig
;
...
@@ -210,7 +212,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
...
@@ -210,7 +212,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
avx_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
avx_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
ActivationType
active_node
,
T
cell_clip
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
ActivationType
active_state
)
{
#ifdef __AVX__
#ifdef __AVX__
...
@@ -268,7 +270,7 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
...
@@ -268,7 +270,7 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
&
r_grad_ig
,
&
r_grad_fg
,
&
r_grad_og
,
&
r_prev_state
,
&
r_prev_state_grad
,
&
r_grad_ig
,
&
r_grad_fg
,
&
r_grad_og
,
&
r_prev_state
,
&
r_prev_state_grad
,
&
r_state
,
&
r_state_grad
,
&
r_state_atv
,
&
r_output_grad
,
&
r_checkI
,
&
r_state
,
&
r_state_grad
,
&
r_state_atv
,
&
r_output_grad
,
&
r_checkI
,
&
r_checkF
,
&
r_checkO
,
&
r_checkIGrad
,
&
r_checkFGrad
,
&
r_checkOGrad
,
&
r_checkF
,
&
r_checkO
,
&
r_checkIGrad
,
&
r_checkFGrad
,
&
r_checkOGrad
,
active_node
,
active_gate
,
active_state
);
&
cell_clip
,
active_node
,
active_gate
,
active_state
);
grad_in
[
i
]
=
r_grad_in
;
grad_in
[
i
]
=
r_grad_in
;
grad_ig
[
i
]
=
r_grad_ig
;
grad_ig
[
i
]
=
r_grad_ig
;
...
@@ -292,27 +294,27 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
...
@@ -292,27 +294,27 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
cpu_lstm_forward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
void
cpu_lstm_forward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
ActivationType
active_node
,
ActivationType
active_gat
e
,
T
cell_clip
,
ActivationType
active_nod
e
,
ActivationType
active_state
)
{
ActivationType
active_
gate
,
ActivationType
active_
state
)
{
if
(
Op
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
if
(
Op
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
avx_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame_size
,
active_node
,
avx_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame_size
,
cell_clip
,
active_gate
,
active_state
);
active_
node
,
active_
gate
,
active_state
);
}
else
{
}
else
{
naive_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame_size
,
active_node
,
naive_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame_size
,
cell_clip
,
active_gate
,
active_state
);
active_
node
,
active_
gate
,
active_state
);
}
}
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
cpu_lstm_backward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
void
cpu_lstm_backward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
ActivationType
active_node
,
int
frame_size
,
T
cell_clip
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
ActivationType
active_state
)
{
if
(
Op
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
if
(
Op
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
avx_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame_size
,
active_node
,
avx_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame_size
,
cell_clip
,
active_gate
,
active_state
);
active_
node
,
active_
gate
,
active_state
);
}
else
{
}
else
{
naive_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame_size
,
naive_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame_size
,
cell_clip
,
active_node
,
active_gate
,
active_state
);
active_node
,
active_gate
,
active_state
);
}
}
}
}
...
...
paddle/fluid/operators/math/detail/lstm_gpu_kernel.h
浏览文件 @
81870723
...
@@ -31,7 +31,8 @@ namespace detail {
...
@@ -31,7 +31,8 @@ namespace detail {
*/
*/
template
<
class
T
,
class
Op
,
bool
is_batch
>
template
<
class
T
,
class
Op
,
bool
is_batch
>
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
ActivationType
active_node
,
int
batch_size
,
T
cell_clip
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
ActivationType
active_state
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
@@ -72,7 +73,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
...
@@ -72,7 +73,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
op
(
&
r_value_in
,
&
r_value_ig
,
&
r_value_fg
,
&
r_value_og
,
&
r_prev_state
,
op
(
&
r_value_in
,
&
r_value_ig
,
&
r_value_fg
,
&
r_value_og
,
&
r_prev_state
,
&
r_state
,
&
r_state_atv
,
&
r_out
,
&
r_checkI
,
&
r_checkF
,
&
r_checkO
,
&
r_state
,
&
r_state_atv
,
&
r_out
,
&
r_checkI
,
&
r_checkF
,
&
r_checkO
,
active_node
,
active_gate
,
active_state
);
&
cell_clip
,
active_node
,
active_gate
,
active_state
);
value
.
gate_value
[
frame_idx
]
=
r_value_in
;
value
.
gate_value
[
frame_idx
]
=
r_value_in
;
value
.
gate_value
[
frame_idx
+
frame_size
]
=
r_value_ig
;
value
.
gate_value
[
frame_idx
+
frame_size
]
=
r_value_ig
;
...
@@ -91,7 +92,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
...
@@ -91,7 +92,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
template
<
class
T
,
class
Op
,
bool
is_batch
>
template
<
class
T
,
class
Op
,
bool
is_batch
>
__global__
void
KeLstmBackward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
__global__
void
KeLstmBackward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
ActivationType
active_node
,
int
batch_size
,
T
cell_clip
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
ActivationType
active_state
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
...
@@ -148,8 +150,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
...
@@ -148,8 +150,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
op
(
&
r_value_in
,
&
r_value_ig
,
&
r_value_fg
,
&
r_value_og
,
&
r_grad_in
,
&
r_grad_ig
,
op
(
&
r_value_in
,
&
r_value_ig
,
&
r_value_fg
,
&
r_value_og
,
&
r_grad_in
,
&
r_grad_ig
,
&
r_grad_fg
,
&
r_grad_og
,
&
r_prev_state
,
&
r_prev_state_grad
,
&
r_state
,
&
r_grad_fg
,
&
r_grad_og
,
&
r_prev_state
,
&
r_prev_state_grad
,
&
r_state
,
&
r_state_grad
,
&
r_state_atv
,
&
r_output_grad
,
&
r_checkI
,
&
r_checkF
,
&
r_state_grad
,
&
r_state_atv
,
&
r_output_grad
,
&
r_checkI
,
&
r_checkF
,
&
r_checkO
,
&
r_checkIGrad
,
&
r_checkFGrad
,
&
r_checkOGrad
,
active_node
,
&
r_checkO
,
&
r_checkIGrad
,
&
r_checkFGrad
,
&
r_checkOGrad
,
&
cell_clip
,
active_gate
,
active_state
);
active_
node
,
active_
gate
,
active_state
);
grad
.
gate_grad
[
frame_idx
]
=
r_grad_in
;
grad
.
gate_grad
[
frame_idx
]
=
r_grad_in
;
grad
.
gate_grad
[
frame_idx
+
frame_size
]
=
r_grad_ig
;
grad
.
gate_grad
[
frame_idx
+
frame_size
]
=
r_grad_ig
;
...
@@ -185,8 +187,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
...
@@ -185,8 +187,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
gpu_lstm_forward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
void
gpu_lstm_forward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
ActivationType
active_node
,
ActivationType
active_gat
e
,
T
cell_clip
,
ActivationType
active_nod
e
,
ActivationType
active_state
)
{
ActivationType
active_
gate
,
ActivationType
active_
state
)
{
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
if
(
batch_size
==
1
)
{
if
(
batch_size
==
1
)
{
...
@@ -205,12 +207,12 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
...
@@ -205,12 +207,12 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
if
(
batch_size
==
1
)
{
if
(
batch_size
==
1
)
{
KeLstmForward
<
T
,
Op
,
KeLstmForward
<
T
,
Op
,
/* is_batch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is_batch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frame_size
,
batch_size
,
active_node
,
active_gate
,
op
,
value
,
frame_size
,
batch_size
,
cell_clip
,
active_node
,
active_gate
,
active_state
);
active_state
);
}
else
{
}
else
{
KeLstmForward
<
T
,
Op
,
KeLstmForward
<
T
,
Op
,
/* is_batch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is_batch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frame_size
,
batch_size
,
active_node
,
active_gate
,
op
,
value
,
frame_size
,
batch_size
,
cell_clip
,
active_node
,
active_gate
,
active_state
);
active_state
);
}
}
}
}
...
@@ -218,7 +220,7 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
...
@@ -218,7 +220,7 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
gpu_lstm_backward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
void
gpu_lstm_backward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
T
cell_clip
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
ActivationType
active_state
)
{
dim3
threads
;
dim3
threads
;
...
@@ -239,13 +241,13 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
...
@@ -239,13 +241,13 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
if
(
batch_size
==
1
)
{
if
(
batch_size
==
1
)
{
KeLstmBackward
<
T
,
Op
,
KeLstmBackward
<
T
,
Op
,
/* is_batch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is_batch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frame_size
,
batch_size
,
active_node
,
active_gat
e
,
op
,
value
,
grad
,
frame_size
,
batch_size
,
cell_clip
,
active_nod
e
,
active_state
);
active_
gate
,
active_
state
);
}
else
{
}
else
{
KeLstmBackward
<
T
,
Op
,
KeLstmBackward
<
T
,
Op
,
/* is_batch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is_batch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frame_size
,
batch_size
,
active_node
,
active_gat
e
,
op
,
value
,
grad
,
frame_size
,
batch_size
,
cell_clip
,
active_nod
e
,
active_state
);
active_
gate
,
active_
state
);
}
}
}
}
...
...
paddle/fluid/operators/math/detail/lstm_kernel.h
浏览文件 @
81870723
...
@@ -29,7 +29,7 @@ class lstm {
...
@@ -29,7 +29,7 @@ class lstm {
public:
public:
HOSTDEVICE
void
operator
()(
T
*
value_in
,
T
*
value_ig
,
T
*
value_fg
,
T
*
value_og
,
HOSTDEVICE
void
operator
()(
T
*
value_in
,
T
*
value_ig
,
T
*
value_fg
,
T
*
value_og
,
T
*
prev_state
,
T
*
state
,
T
*
state_atv
,
T
*
output
,
T
*
prev_state
,
T
*
state
,
T
*
state_atv
,
T
*
output
,
T
*
checkI
,
T
*
checkF
,
T
*
checkO
,
T
*
checkI
,
T
*
checkF
,
T
*
checkO
,
T
*
cell_clip
,
ActivationType
active_node
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
ActivationType
active_state
)
{
...
@@ -37,6 +37,15 @@ class lstm {
...
@@ -37,6 +37,15 @@ class lstm {
*
value_ig
=
activation
(
*
value_ig
+
(
*
prev_state
)
*
(
*
checkI
),
active_gate
);
*
value_ig
=
activation
(
*
value_ig
+
(
*
prev_state
)
*
(
*
checkI
),
active_gate
);
*
value_fg
=
activation
(
*
value_fg
+
(
*
prev_state
)
*
(
*
checkF
),
active_gate
);
*
value_fg
=
activation
(
*
value_fg
+
(
*
prev_state
)
*
(
*
checkF
),
active_gate
);
*
state
=
(
*
value_in
)
*
(
*
value_ig
)
+
(
*
prev_state
)
*
(
*
value_fg
);
*
state
=
(
*
value_in
)
*
(
*
value_ig
)
+
(
*
prev_state
)
*
(
*
value_fg
);
if
(
*
cell_clip
>
0.0
)
{
if
(
*
state
<
-
1.0
*
(
*
cell_clip
))
{
*
state
=
-
1.0
*
(
*
cell_clip
);
}
if
(
*
state
>
*
cell_clip
)
{
*
state
=
*
cell_clip
;
}
}
*
value_og
=
activation
(
*
value_og
+
(
*
state
)
*
(
*
checkO
),
active_gate
);
*
value_og
=
activation
(
*
value_og
+
(
*
state
)
*
(
*
checkO
),
active_gate
);
*
state_atv
=
activation
(
*
state
,
active_state
);
*
state_atv
=
activation
(
*
state
,
active_state
);
*
output
=
(
*
value_og
)
*
(
*
state_atv
);
*
output
=
(
*
value_og
)
*
(
*
state_atv
);
...
@@ -52,7 +61,7 @@ class lstm {
...
@@ -52,7 +61,7 @@ class lstm {
__m256
*
value_fg
,
__m256
*
value_og
,
__m256
*
value_fg
,
__m256
*
value_og
,
__m256
*
prev_state
,
__m256
*
state
,
__m256
*
prev_state
,
__m256
*
state
,
__m256
*
state_atv
,
__m256
*
output
,
__m256
*
checkI
,
__m256
*
state_atv
,
__m256
*
output
,
__m256
*
checkI
,
__m256
*
checkF
,
__m256
*
checkO
,
__m256
*
checkF
,
__m256
*
checkO
,
T
*
cell_clip
,
ActivationType
active_node
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
ActivationType
active_state
)
{
...
@@ -65,6 +74,13 @@ class lstm {
...
@@ -65,6 +74,13 @@ class lstm {
active_gate
);
active_gate
);
*
state
=
_mm256_add_ps
(
_mm256_mul_ps
(
*
value_in
,
*
value_ig
),
*
state
=
_mm256_add_ps
(
_mm256_mul_ps
(
*
value_in
,
*
value_ig
),
_mm256_mul_ps
(
*
prev_state
,
*
value_fg
));
_mm256_mul_ps
(
*
prev_state
,
*
value_fg
));
if
(
*
cell_clip
>
0.0
f
)
{
__m256
min
=
_mm256_set1_ps
(
0.0
f
-
*
cell_clip
);
__m256
max
=
_mm256_set1_ps
(
*
cell_clip
);
*
state
=
_mm256_min_ps
(
max
,
*
state
);
*
state
=
_mm256_max_ps
(
min
,
*
state
);
}
*
value_og
=
activation
(
*
value_og
=
activation
(
_mm256_add_ps
(
*
value_og
,
_mm256_mul_ps
(
*
state
,
*
checkO
)),
active_gate
);
_mm256_add_ps
(
*
value_og
,
_mm256_mul_ps
(
*
state
,
*
checkO
)),
active_gate
);
*
state_atv
=
activation
(
*
state
,
active_state
);
*
state_atv
=
activation
(
*
state
,
active_state
);
...
@@ -86,15 +102,26 @@ class lstm {
...
@@ -86,15 +102,26 @@ class lstm {
T
*
prev_state
,
T
*
prev_state_grad
,
T
*
state
,
T
*
prev_state
,
T
*
prev_state_grad
,
T
*
state
,
T
*
state_grad
,
T
*
state_atv
,
T
*
output_grad
,
T
*
state_grad
,
T
*
state_atv
,
T
*
output_grad
,
T
*
checkI
,
T
*
checkF
,
T
*
checkO
,
T
*
checkIGrad
,
T
*
checkI
,
T
*
checkF
,
T
*
checkO
,
T
*
checkIGrad
,
T
*
checkFGrad
,
T
*
checkOGrad
,
T
*
checkFGrad
,
T
*
checkOGrad
,
T
*
cell_clip
,
ActivationType
active_node
,
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
ActivationType
active_state
)
{
*
grad_og
=
*
grad_og
=
activation
((
*
output_grad
)
*
(
*
state_atv
),
*
value_og
,
active_gate
);
activation
((
*
output_grad
)
*
(
*
state_atv
),
*
value_og
,
active_gate
);
*
state_grad
+=
if
(
*
cell_clip
>
0.0
f
)
{
activation
((
*
output_grad
)
*
(
*
value_og
),
*
state_atv
,
active_state
)
+
if
(
*
state
>=
(
*
cell_clip
)
||
*
state
<=
(
0.0
f
-
(
*
cell_clip
)))
{
(
*
grad_og
)
*
(
*
checkO
);
*
state_grad
=
0.0
f
;
}
else
{
*
state_grad
+=
activation
((
*
output_grad
)
*
(
*
value_og
),
*
state_atv
,
active_state
)
+
(
*
grad_og
)
*
(
*
checkO
);
}
}
else
{
*
state_grad
+=
activation
((
*
output_grad
)
*
(
*
value_og
),
*
state_atv
,
active_state
)
+
(
*
grad_og
)
*
(
*
checkO
);
}
*
grad_in
=
activation
((
*
state_grad
)
*
(
*
value_ig
),
*
value_in
,
active_node
);
*
grad_in
=
activation
((
*
state_grad
)
*
(
*
value_ig
),
*
value_in
,
active_node
);
*
grad_ig
=
activation
((
*
state_grad
)
*
(
*
value_in
),
*
value_ig
,
active_gate
);
*
grad_ig
=
activation
((
*
state_grad
)
*
(
*
value_in
),
*
value_ig
,
active_gate
);
*
grad_fg
=
*
grad_fg
=
...
@@ -117,15 +144,24 @@ class lstm {
...
@@ -117,15 +144,24 @@ class lstm {
__m256
*
prev_state
,
__m256
*
prev_state_grad
,
__m256
*
state
,
__m256
*
prev_state
,
__m256
*
prev_state_grad
,
__m256
*
state
,
__m256
*
state_grad
,
__m256
*
state_atv
,
__m256
*
output_grad
,
__m256
*
state_grad
,
__m256
*
state_atv
,
__m256
*
output_grad
,
__m256
*
checkI
,
__m256
*
checkF
,
__m256
*
checkO
,
__m256
*
checkIGrad
,
__m256
*
checkI
,
__m256
*
checkF
,
__m256
*
checkO
,
__m256
*
checkIGrad
,
__m256
*
checkFGrad
,
__m256
*
checkOGrad
,
ActivationType
active_node
,
__m256
*
checkFGrad
,
__m256
*
checkOGrad
,
T
*
cell_clip
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
ActivationType
active_node
,
ActivationType
active_gate
,
ActivationType
active_state
)
{
*
grad_og
=
activation
(
_mm256_mul_ps
(
*
output_grad
,
*
state_atv
),
*
value_og
,
*
grad_og
=
activation
(
_mm256_mul_ps
(
*
output_grad
,
*
state_atv
),
*
value_og
,
active_gate
);
active_gate
);
*
state_grad
=
if
(
*
cell_clip
>
0.0
f
)
{
_mm256_add_ps
(
activation
(
_mm256_mul_ps
(
*
output_grad
,
*
value_og
),
T
*
state_
=
reinterpret_cast
<
T
*>
(
state
);
*
state_atv
,
active_state
),
if
(
*
state_
>=
(
*
cell_clip
)
||
*
state_
<=
(
0.0
f
-
(
*
cell_clip
)))
{
*
state_grad
);
*
state_grad
=
_mm256_set1_ps
(
0.0
f
);
*
state_grad
=
_mm256_add_ps
(
_mm256_mul_ps
(
*
grad_og
,
*
checkO
),
*
state_grad
);
}
else
{
*
state_grad
=
_mm256_add_ps
(
activation
(
_mm256_mul_ps
(
*
output_grad
,
*
value_og
),
*
state_atv
,
active_state
),
*
state_grad
);
*
state_grad
=
_mm256_add_ps
(
_mm256_mul_ps
(
*
grad_og
,
*
checkO
),
*
state_grad
);
}
}
*
grad_in
=
activation
(
_mm256_mul_ps
(
*
state_grad
,
*
value_ig
),
*
value_in
,
*
grad_in
=
activation
(
_mm256_mul_ps
(
*
state_grad
,
*
value_ig
),
*
value_in
,
active_node
);
active_node
);
*
grad_ig
=
activation
(
_mm256_mul_ps
(
*
state_grad
,
*
value_in
),
*
value_ig
,
*
grad_ig
=
activation
(
_mm256_mul_ps
(
*
state_grad
,
*
value_in
),
*
value_ig
,
...
...
paddle/fluid/operators/math/lstm_compute.cc
浏览文件 @
81870723
...
@@ -24,12 +24,12 @@ template <class T>
...
@@ -24,12 +24,12 @@ template <class T>
struct
LstmUnitFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
struct
LstmUnitFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
const
detail
::
ActivationType
&
gate_act
,
T
cell_clip
,
const
detail
::
ActivationType
&
gate_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
)
{
const
detail
::
ActivationType
&
cand_act
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
detail
::
cpu_lstm_forward
(
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
detail
::
cpu_lstm_forward
(
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
cand_act
,
gate_act
,
cell_act
);
c
ell_clip
,
c
and_act
,
gate_act
,
cell_act
);
value
.
gate_value
+=
frame_size
*
4
;
value
.
gate_value
+=
frame_size
*
4
;
value
.
state_value
+=
frame_size
;
value
.
state_value
+=
frame_size
;
value
.
state_active_value
+=
frame_size
;
value
.
state_active_value
+=
frame_size
;
...
@@ -45,13 +45,14 @@ template <class T>
...
@@ -45,13 +45,14 @@ template <class T>
struct
LstmUnitGradFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
struct
LstmUnitGradFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
T
cell_clip
,
const
detail
::
ActivationType
&
gate_act
,
const
detail
::
ActivationType
&
gate_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
)
{
const
detail
::
ActivationType
&
cand_act
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
detail
::
cpu_lstm_backward
(
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
detail
::
cpu_lstm_backward
(
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
frame_size
,
cand_act
,
gate_act
,
cell_act
);
frame_size
,
cell_clip
,
cand_act
,
gate_act
,
cell_act
);
value
.
gate_value
+=
frame_size
*
4
;
value
.
gate_value
+=
frame_size
*
4
;
value
.
state_value
+=
frame_size
;
value
.
state_value
+=
frame_size
;
...
...
paddle/fluid/operators/math/lstm_compute.cu
浏览文件 @
81870723
...
@@ -24,12 +24,12 @@ template <class T>
...
@@ -24,12 +24,12 @@ template <class T>
struct
LstmUnitFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
struct
LstmUnitFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
const
detail
::
ActivationType
&
gate_act
,
T
cell_clip
,
const
detail
::
ActivationType
&
gate_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
)
{
const
detail
::
ActivationType
&
cand_act
)
{
detail
::
gpu_lstm_forward
<
T
>
(
context
,
detail
::
forward
::
lstm
<
T
>
(),
value
,
detail
::
gpu_lstm_forward
<
T
>
(
context
,
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
batch_size
,
c
and_act
,
gate
_act
,
frame_size
,
batch_size
,
c
ell_clip
,
cand
_act
,
cell_act
);
gate_act
,
cell_act
);
}
}
};
};
...
@@ -37,13 +37,13 @@ template <class T>
...
@@ -37,13 +37,13 @@ template <class T>
struct
LstmUnitGradFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
struct
LstmUnitGradFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
T
cell_clip
,
const
detail
::
ActivationType
&
gate_act
,
const
detail
::
ActivationType
&
gate_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
)
{
const
detail
::
ActivationType
&
cand_act
)
{
detail
::
gpu_lstm_backward
(
context
,
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
detail
::
gpu_lstm_backward
(
context
,
detail
::
backward
::
lstm
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
c
and_act
,
gate
_act
,
frame_size
,
batch_size
,
c
ell_clip
,
cand
_act
,
cell_act
);
gate_act
,
cell_act
);
}
}
};
};
...
...
paddle/fluid/operators/math/lstm_compute.h
浏览文件 @
81870723
...
@@ -50,7 +50,7 @@ template <typename DeviceContext, typename T>
...
@@ -50,7 +50,7 @@ template <typename DeviceContext, typename T>
class
LstmUnitFunctor
{
class
LstmUnitFunctor
{
public:
public:
static
void
compute
(
const
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
static
void
compute
(
const
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
T
cell_clip
,
const
detail
::
ActivationType
&
gate_act
,
const
detail
::
ActivationType
&
gate_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
);
const
detail
::
ActivationType
&
cand_act
);
...
@@ -61,7 +61,7 @@ class LstmUnitGradFunctor {
...
@@ -61,7 +61,7 @@ class LstmUnitGradFunctor {
public:
public:
static
void
compute
(
const
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
static
void
compute
(
const
DeviceContext
&
context
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
LstmMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
const
detail
::
ActivationType
&
gate_act
,
T
cell_clip
,
const
detail
::
ActivationType
&
gate_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cell_act
,
const
detail
::
ActivationType
&
cand_act
);
const
detail
::
ActivationType
&
cand_act
);
};
};
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
81870723
...
@@ -668,7 +668,11 @@ def dynamic_lstmp(input,
...
@@ -668,7 +668,11 @@ def dynamic_lstmp(input,
candidate_activation
=
'tanh'
,
candidate_activation
=
'tanh'
,
proj_activation
=
'tanh'
,
proj_activation
=
'tanh'
,
dtype
=
'float32'
,
dtype
=
'float32'
,
name
=
None
):
name
=
None
,
h_0
=
None
,
c_0
=
None
,
cell_clip
=
None
,
proj_clip
=
None
):
"""
"""
**Dynamic LSTMP Layer**
**Dynamic LSTMP Layer**
...
@@ -785,6 +789,17 @@ def dynamic_lstmp(input,
...
@@ -785,6 +789,17 @@ def dynamic_lstmp(input,
dtype(str): Data type. Choices = ["float32", "float64"], default "float32".
dtype(str): Data type. Choices = ["float32", "float64"], default "float32".
name(str|None): A name for this layer(optional). If set None, the layer
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
will be named automatically.
h_0(Variable): The initial hidden state is an optional input, default is zero.
This is a tensor with shape (N x D), where N is the
batch size and D is the projection size.
c_0(Variable): The initial cell state is an optional input, default is zero.
This is a tensor with shape (N x D), where N is the
batch size. `h_0` and `c_0` can be NULL but only at the same time.
cell_clip(float): If provided the cell state is clipped
by this value prior to the cell output activation.
proj_clip(float): If `num_proj > 0` and `proj_clip` is
provided, then the projected values are clipped elementwise to within
`[-proj_clip, proj_clip]`.
Returns:
Returns:
tuple: A tuple of two output variable: the projection of hidden state,
\
tuple: A tuple of two output variable: the projection of hidden state,
\
...
@@ -831,25 +846,41 @@ def dynamic_lstmp(input,
...
@@ -831,25 +846,41 @@ def dynamic_lstmp(input,
batch_hidden
=
helper
.
create_variable_for_type_inference
(
dtype
)
batch_hidden
=
helper
.
create_variable_for_type_inference
(
dtype
)
batch_gate
=
helper
.
create_variable_for_type_inference
(
dtype
)
batch_gate
=
helper
.
create_variable_for_type_inference
(
dtype
)
batch_cell_pre_act
=
helper
.
create_variable_for_type_inference
(
dtype
)
batch_cell_pre_act
=
helper
.
create_variable_for_type_inference
(
dtype
)
inputs
=
{
'Input'
:
input
,
'Weight'
:
weight
,
'ProjWeight'
:
proj_weight
,
'Bias'
:
bias
}
batch_size
=
input
.
shape
[
0
]
if
h_0
:
assert
h_0
.
shape
==
(
batch_size
,
proj_size
),
\
'The shape of h0 should be (batch_size, %d)'
%
proj_size
inputs
[
'H0'
]
=
h_0
if
c_0
:
assert
c_0
.
shape
==
(
batch_size
,
size
),
\
'The shape of c0 should be (batch_size, %d)'
%
size
inputs
[
'C0'
]
=
c_0
if
cell_clip
:
assert
cell_clip
>=
0
,
"cell_clip should not be negtive."
if
proj_clip
:
assert
proj_clip
>=
0
,
"proj_clip should not be negtive."
helper
.
append_op
(
helper
.
append_op
(
type
=
'lstmp'
,
type
=
'lstmp'
,
inputs
=
{
inputs
=
inputs
,
'Input'
:
input
,
'Weight'
:
weight
,
'ProjWeight'
:
proj_weight
,
'Bias'
:
bias
},
outputs
=
{
outputs
=
{
'Projection'
:
projection
,
'Projection'
:
projection
,
'Cell'
:
cell
,
'Cell'
:
cell
,
'OrderedP0'
:
ordered_proj0
,
'BatchHidden'
:
batch_hidden
,
'BatchHidden'
:
batch_hidden
,
'BatchGate'
:
batch_gate
,
'BatchGate'
:
batch_gate
,
'BatchCellPreAct'
:
batch_cell_pre_act
'BatchCellPreAct'
:
batch_cell_pre_act
},
},
attrs
=
{
attrs
=
{
'use_peepholes'
:
use_peepholes
,
'use_peepholes'
:
use_peepholes
,
'cell_clip'
:
cell_clip
,
'proj_clip'
:
proj_clip
,
'is_reverse'
:
is_reverse
,
'is_reverse'
:
is_reverse
,
'gate_activation'
:
gate_activation
,
'gate_activation'
:
gate_activation
,
'cell_activation'
:
cell_activation
,
'cell_activation'
:
cell_activation
,
...
...
python/paddle/fluid/tests/unittests/test_lstmp_op.py
浏览文件 @
81870723
...
@@ -36,12 +36,14 @@ def lstmp(
...
@@ -36,12 +36,14 @@ def lstmp(
w_b
=
None
,
# 1 x 4D
w_b
=
None
,
# 1 x 4D
w_c
=
None
,
# 1 x 3D
w_c
=
None
,
# 1 x 3D
is_reverse
=
False
,
is_reverse
=
False
,
proj_clip
=
0.0
,
cell_clip
=
0.0
,
act_gate
=
None
,
act_gate
=
None
,
act_cell
=
None
,
act_cell
=
None
,
act_cand
=
None
,
act_cand
=
None
,
act_proj
=
None
):
act_proj
=
None
):
def
_step
(
x
,
w_r
,
w_rh
,
w_c
,
r_pre
,
c_pre
,
act_gate
,
act_cell
,
act_cand
,
def
_step
(
x
,
w_r
,
w_rh
,
w_c
,
r_pre
,
c_pre
,
proj_clip
,
cell_clip
,
act_gate
,
act_proj
):
act_
cell
,
act_cand
,
act_
proj
):
g
=
np
.
dot
(
r_pre
,
w_r
)
# 1 x 4D
g
=
np
.
dot
(
r_pre
,
w_r
)
# 1 x 4D
g
=
g
+
x
g
=
g
+
x
g
=
np
.
reshape
(
g
,
(
1
,
g
.
size
))
g
=
np
.
reshape
(
g
,
(
1
,
g
.
size
))
...
@@ -55,6 +57,17 @@ def lstmp(
...
@@ -55,6 +57,17 @@ def lstmp(
g_f
=
act_gate
(
g_f
+
w_fc
*
c_pre
)
# 1 x D
g_f
=
act_gate
(
g_f
+
w_fc
*
c_pre
)
# 1 x D
c
=
g_f
*
c_pre
+
g_i
*
act_cand
(
c
)
# 1 x D
c
=
g_f
*
c_pre
+
g_i
*
act_cand
(
c
)
# 1 x D
def
array_clip
(
a
,
clip
):
size
=
np
.
prod
(
a
.
shape
)
new_a
=
np
.
reshape
(
a
,
(
size
))
for
i
in
range
(
size
):
new_a
[
i
]
=
max
(
new_a
[
i
],
-
1.0
*
clip
)
new_a
[
i
]
=
min
(
new_a
[
i
],
clip
)
new_a
=
np
.
reshape
(
new_a
,
a
.
shape
)
return
new_a
if
cell_clip
>
0.0
:
c
=
array_clip
(
c
,
cell_clip
)
if
w_c
is
None
:
if
w_c
is
None
:
g_o
=
act_gate
(
g_o
)
# 1 x D
g_o
=
act_gate
(
g_o
)
# 1 x D
else
:
else
:
...
@@ -64,6 +77,8 @@ def lstmp(
...
@@ -64,6 +77,8 @@ def lstmp(
# projection
# projection
r
=
np
.
dot
(
h
,
w_rh
)
r
=
np
.
dot
(
h
,
w_rh
)
r
=
act_proj
(
r
)
r
=
act_proj
(
r
)
if
proj_clip
>
0.0
:
r
=
array_clip
(
r
,
proj_clip
)
return
r
,
c
return
r
,
c
def
_reverse
(
x
,
offset
):
def
_reverse
(
x
,
offset
):
...
@@ -87,13 +102,13 @@ def lstmp(
...
@@ -87,13 +102,13 @@ def lstmp(
# compute one sequence
# compute one sequence
seq_len
=
lod
[
0
][
i
]
seq_len
=
lod
[
0
][
i
]
x
=
input
[
offset
[
i
]:
offset
[
i
+
1
],
:]
x
=
input
[
offset
[
i
]:
offset
[
i
+
1
],
:]
r_pre
=
np
.
dot
(
h0
[
i
],
w_rh
)
# 1 x P
r_pre
=
h0
[
i
]
r_pre
=
act_proj
(
r_pre
)
c_pre
=
c0
[
i
]
# 1 x D
c_pre
=
c0
[
i
]
# 1 x D
for
j
in
range
(
seq_len
):
for
j
in
range
(
seq_len
):
# compute one step
# compute one step
r_pre
,
c_pre
=
_step
(
x
[
j
],
w_r
,
w_rh
,
w_c
,
r_pre
,
c_pre
,
act_gate
,
r_pre
,
c_pre
=
_step
(
x
[
j
],
w_r
,
w_rh
,
w_c
,
r_pre
,
c_pre
,
proj_clip
,
act_cell
,
act_cand
,
act_proj
)
cell_clip
,
act_gate
,
act_cell
,
act_cand
,
act_proj
)
projection
.
append
(
r_pre
.
flatten
())
projection
.
append
(
r_pre
.
flatten
())
cell
.
append
(
c_pre
.
flatten
())
cell
.
append
(
c_pre
.
flatten
())
...
@@ -123,13 +138,12 @@ class TestLstmpOp(LstmTest.TestLstmOp):
...
@@ -123,13 +138,12 @@ class TestLstmpOp(LstmTest.TestLstmOp):
T
=
sum
(
self
.
lod
[
0
])
T
=
sum
(
self
.
lod
[
0
])
N
=
len
(
self
.
lod
[
0
])
N
=
len
(
self
.
lod
[
0
])
x
=
np
.
random
.
normal
(
size
=
(
T
,
4
*
self
.
D
)).
astype
(
'float64'
)
x
=
np
.
random
.
normal
(
size
=
(
T
,
4
*
self
.
D
)).
astype
(
'float64'
)
if
self
.
has_initial_state
:
if
self
.
has_initial_state
:
h0
=
np
.
random
.
normal
(
size
=
(
N
,
self
.
D
)).
astype
(
'float64'
)
h0
=
np
.
random
.
normal
(
size
=
(
N
,
self
.
P
)).
astype
(
'float64'
)
c0
=
np
.
random
.
normal
(
size
=
(
N
,
self
.
D
)).
astype
(
'float64'
)
c0
=
np
.
random
.
normal
(
size
=
(
N
,
self
.
D
)).
astype
(
'float64'
)
else
:
else
:
h0
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
h0
=
np
.
zeros
((
N
,
self
.
P
)).
astype
(
'float64'
)
c0
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
c0
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
w
=
np
.
random
.
normal
(
size
=
(
self
.
P
,
4
*
self
.
D
)).
astype
(
'float64'
)
w
=
np
.
random
.
normal
(
size
=
(
self
.
P
,
4
*
self
.
D
)).
astype
(
'float64'
)
if
self
.
use_peepholes
:
if
self
.
use_peepholes
:
...
@@ -140,9 +154,12 @@ class TestLstmpOp(LstmTest.TestLstmOp):
...
@@ -140,9 +154,12 @@ class TestLstmpOp(LstmTest.TestLstmOp):
w_b
=
b
[:,
0
:
4
*
self
.
D
]
w_b
=
b
[:,
0
:
4
*
self
.
D
]
w_c
=
b
[:,
4
*
self
.
D
:]
if
self
.
use_peepholes
else
None
w_c
=
b
[:,
4
*
self
.
D
:]
if
self
.
use_peepholes
else
None
w_rh
=
np
.
random
.
normal
(
size
=
(
self
.
D
,
self
.
P
)).
astype
(
'float64'
)
w_rh
=
np
.
random
.
normal
(
size
=
(
self
.
D
,
self
.
P
)).
astype
(
'float64'
)
proj_clip
=
0.1
cell_clip
=
0.1
r
,
c
=
lstmp
(
x
,
self
.
lod
,
h0
,
c0
,
w
,
w_rh
,
w_b
,
w_c
,
self
.
is_reverse
,
r
,
c
=
lstmp
(
x
,
self
.
lod
,
h0
,
c0
,
w
,
w_rh
,
w_b
,
w_c
,
self
.
is_reverse
,
ACTIVATION
[
self
.
act_gate
],
ACTIVATION
[
self
.
act_cell
],
proj_clip
,
cell_clip
,
ACTIVATION
[
self
.
act_gate
],
ACTIVATION
[
self
.
act_cand
],
ACTIVATION
[
self
.
act_proj
])
ACTIVATION
[
self
.
act_cell
],
ACTIVATION
[
self
.
act_cand
],
ACTIVATION
[
self
.
act_proj
])
self
.
inputs
=
{
'Input'
:
(
x
,
self
.
lod
),
'Weight'
:
w
,
'ProjWeight'
:
w_rh
}
self
.
inputs
=
{
'Input'
:
(
x
,
self
.
lod
),
'Weight'
:
w
,
'ProjWeight'
:
w_rh
}
...
@@ -159,6 +176,8 @@ class TestLstmpOp(LstmTest.TestLstmOp):
...
@@ -159,6 +176,8 @@ class TestLstmpOp(LstmTest.TestLstmOp):
self
.
attrs
=
{
self
.
attrs
=
{
'use_peepholes'
:
self
.
use_peepholes
,
'use_peepholes'
:
self
.
use_peepholes
,
'is_reverse'
:
self
.
is_reverse
,
'is_reverse'
:
self
.
is_reverse
,
'proj_clip'
:
proj_clip
,
'cell_clip'
:
cell_clip
,
'gate_activation'
:
self
.
act_gate
,
'gate_activation'
:
self
.
act_gate
,
'cell_activation'
:
self
.
act_cell
,
'cell_activation'
:
self
.
act_cell
,
'candidate_activation'
:
self
.
act_cand
,
'candidate_activation'
:
self
.
act_cand
,
...
@@ -171,14 +190,14 @@ class TestLstmpOp(LstmTest.TestLstmOp):
...
@@ -171,14 +190,14 @@ class TestLstmpOp(LstmTest.TestLstmOp):
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
# TODO(qingqing) remove folowing lines after the check_grad is refined.
# TODO(qingqing) remove folowing lines after the check_grad is refined.
N
=
len
(
self
.
lod
[
0
])
N
=
len
(
self
.
lod
[
0
])
self
.
outputs
[
'OrderedP0'
]
=
np
.
zeros
((
N
,
self
.
P
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
(
N
,
self
.
D
)).
astype
(
'float64'
)
(
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
check_grad
(
self
.
check_grad
(
[
'Input'
,
'Weight'
,
'ProjWeight'
,
'Bias'
],
[
'Projection'
],
[
'Input'
,
'Weight'
,
'ProjWeight'
,
'Bias'
],
[
'Projection'
],
max_relative_error
=
1e-2
)
max_relative_error
=
1e-2
,
numeric_grad_delta
=
0.0000005
)
class
TestLstmpOpHasInitial
(
TestLstmpOp
):
class
TestLstmpOpHasInitial
(
TestLstmpOp
):
...
@@ -188,7 +207,6 @@ class TestLstmpOpHasInitial(TestLstmpOp):
...
@@ -188,7 +207,6 @@ class TestLstmpOpHasInitial(TestLstmpOp):
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
# TODO(qingqing) remove folowing lines after the check_grad is refined.
# TODO(qingqing) remove folowing lines after the check_grad is refined.
N
=
len
(
self
.
lod
[
0
])
N
=
len
(
self
.
lod
[
0
])
self
.
outputs
[
'OrderedP0'
]
=
np
.
zeros
((
N
,
self
.
P
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
...
@@ -196,11 +214,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
...
@@ -196,11 +214,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
self
.
check_grad
(
self
.
check_grad
(
[
'Input'
,
'Weight'
,
'ProjWeight'
,
'Bias'
,
'H0'
,
'C0'
],
[
'Input'
,
'Weight'
,
'ProjWeight'
,
'Bias'
,
'H0'
,
'C0'
],
[
'Projection'
],
[
'Projection'
],
numeric_grad_delta
=
0.0000005
,
max_relative_error
=
1e-2
)
max_relative_error
=
1e-2
)
def
test_check_grad_ingore_bias
(
self
):
def
test_check_grad_ingore_bias
(
self
):
N
=
len
(
self
.
lod
[
0
])
N
=
len
(
self
.
lod
[
0
])
self
.
outputs
[
'OrderedP0'
]
=
np
.
zeros
((
N
,
self
.
P
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
...
@@ -208,11 +226,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
...
@@ -208,11 +226,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
self
.
check_grad
(
self
.
check_grad
(
[
'Input'
,
'ProjWeight'
,
'Weight'
],
[
'Projection'
],
[
'Input'
,
'ProjWeight'
,
'Weight'
],
[
'Projection'
],
max_relative_error
=
1e-2
,
max_relative_error
=
1e-2
,
numeric_grad_delta
=
0.0000005
,
no_grad_set
=
set
(
'Bias'
))
no_grad_set
=
set
(
'Bias'
))
def
test_check_grad_ingore_weight
(
self
):
def
test_check_grad_ingore_weight
(
self
):
N
=
len
(
self
.
lod
[
0
])
N
=
len
(
self
.
lod
[
0
])
self
.
outputs
[
'OrderedP0'
]
=
np
.
zeros
((
N
,
self
.
P
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
...
@@ -220,11 +238,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
...
@@ -220,11 +238,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
self
.
check_grad
(
self
.
check_grad
(
[
'Input'
,
'ProjWeight'
,
'Bias'
],
[
'Projection'
],
[
'Input'
,
'ProjWeight'
,
'Bias'
],
[
'Projection'
],
max_relative_error
=
1e-2
,
max_relative_error
=
1e-2
,
numeric_grad_delta
=
0.0000005
,
no_grad_set
=
set
(
'Weight'
))
no_grad_set
=
set
(
'Weight'
))
def
test_check_grad_ingore_proj_weight
(
self
):
def
test_check_grad_ingore_proj_weight
(
self
):
N
=
len
(
self
.
lod
[
0
])
N
=
len
(
self
.
lod
[
0
])
self
.
outputs
[
'OrderedP0'
]
=
np
.
zeros
((
N
,
self
.
P
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
...
@@ -232,11 +250,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
...
@@ -232,11 +250,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
self
.
check_grad
(
self
.
check_grad
(
[
'Input'
,
'Weight'
,
'Bias'
],
[
'Projection'
],
[
'Input'
,
'Weight'
,
'Bias'
],
[
'Projection'
],
max_relative_error
=
1e-2
,
max_relative_error
=
1e-2
,
numeric_grad_delta
=
0.0000005
,
no_grad_set
=
set
(
'ProjWeight'
))
no_grad_set
=
set
(
'ProjWeight'
))
def
test_check_grad_ingore_input
(
self
):
def
test_check_grad_ingore_input
(
self
):
N
=
len
(
self
.
lod
[
0
])
N
=
len
(
self
.
lod
[
0
])
self
.
outputs
[
'OrderedP0'
]
=
np
.
zeros
((
N
,
self
.
P
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
...
@@ -244,11 +262,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
...
@@ -244,11 +262,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
self
.
check_grad
(
self
.
check_grad
(
[
'Weight'
,
'ProjWeight'
,
'Bias'
],
[
'Projection'
],
[
'Weight'
,
'ProjWeight'
,
'Bias'
],
[
'Projection'
],
max_relative_error
=
1e-2
,
max_relative_error
=
1e-2
,
numeric_grad_delta
=
0.0000005
,
no_grad_set
=
set
(
'Input'
))
no_grad_set
=
set
(
'Input'
))
def
test_check_grad_ingore_h0
(
self
):
def
test_check_grad_ingore_h0
(
self
):
N
=
len
(
self
.
lod
[
0
])
N
=
len
(
self
.
lod
[
0
])
self
.
outputs
[
'OrderedP0'
]
=
np
.
zeros
((
N
,
self
.
P
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
...
@@ -256,11 +274,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
...
@@ -256,11 +274,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
self
.
check_grad
(
self
.
check_grad
(
[
'Input'
,
'Weight'
,
'ProjWeight'
,
'Bias'
,
'C0'
],
[
'Projection'
],
[
'Input'
,
'Weight'
,
'ProjWeight'
,
'Bias'
,
'C0'
],
[
'Projection'
],
max_relative_error
=
1e-2
,
max_relative_error
=
1e-2
,
numeric_grad_delta
=
0.0000005
,
no_grad_set
=
set
(
'H0'
))
no_grad_set
=
set
(
'H0'
))
def
test_check_grad_ingore_c0
(
self
):
def
test_check_grad_ingore_c0
(
self
):
N
=
len
(
self
.
lod
[
0
])
N
=
len
(
self
.
lod
[
0
])
self
.
outputs
[
'OrderedP0'
]
=
np
.
zeros
((
N
,
self
.
P
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchGate'
]
=
np
.
zeros
((
N
,
4
*
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchHidden'
]
=
np
.
zeros
((
N
,
self
.
D
)).
astype
(
'float64'
)
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
self
.
outputs
[
'BatchCellPreAct'
]
=
np
.
zeros
(
...
@@ -268,6 +286,7 @@ class TestLstmpOpHasInitial(TestLstmpOp):
...
@@ -268,6 +286,7 @@ class TestLstmpOpHasInitial(TestLstmpOp):
self
.
check_grad
(
self
.
check_grad
(
[
'Input'
,
'Weight'
,
'ProjWeight'
,
'Bias'
,
'H0'
],
[
'Projection'
],
[
'Input'
,
'Weight'
,
'ProjWeight'
,
'Bias'
,
'H0'
],
[
'Projection'
],
max_relative_error
=
1e-2
,
max_relative_error
=
1e-2
,
numeric_grad_delta
=
0.0000005
,
no_grad_set
=
set
(
'C0'
))
no_grad_set
=
set
(
'C0'
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录