Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
76beff86
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
76beff86
编写于
1月 24, 2018
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make the projection activation configurable
上级
db1f6a59
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
66 addition
and
65 deletion
+66
-65
paddle/operators/lstmp_op.cc
paddle/operators/lstmp_op.cc
+38
-38
paddle/operators/lstmp_op.h
paddle/operators/lstmp_op.h
+8
-6
python/paddle/v2/fluid/tests/test_lstmp_op.py
python/paddle/v2/fluid/tests/test_lstmp_op.py
+20
-21
未找到文件。
paddle/operators/lstmp_op.cc
浏览文件 @
76beff86
...
@@ -23,27 +23,29 @@ class LSTMPOp : public framework::OperatorWithKernel {
...
@@ -23,27 +23,29 @@ class LSTMPOp : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of LSTMP should not be null."
);
"Input(Input) of LSTMP
operator
should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
"Input(Weight) of LSTMP should not be null."
);
"Input(Weight) of LSTMP
operator
should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"ProjWeight"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"ProjWeight"
),
"Input(ProjWeight) of LSTMP should not be null."
);
"Input(ProjWeight) of LSTMP
operator
should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
"Input(Bias) of LSTMP should not be null."
);
"Input(Bias) of LSTMP
operator
should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Projection"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Projection"
),
"Output(Projection) of LSTMP should not be null."
);
"Output(Projection) of LSTMP
operator
should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
"Output(Cell) of LSTMP should not be null."
);
"Output(Cell) of LSTMP
operator
should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchGate"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchGate"
),
"Output(BatchGate) of LSTMP should not be null."
);
"Output(BatchGate) of LSTMP
operator
should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchCellPreAct"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchCellPreAct"
),
"Output(BatchGate) of LSTMP should not be null."
);
"Output(BatchCellPreAct) of LSTMP operator should not be "
"null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchHidden"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchHidden"
),
"Output(BatchHidden) of LSTMP should not be null."
);
"Output(BatchHidden) of LSTMP
operator
should not be null."
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
2
,
"Input(X)'s rank of LSTMP operator must be 2."
);
int
frame_size
=
in_dims
[
1
]
/
4
;
int
frame_size
=
in_dims
[
1
]
/
4
;
auto
w_dims
=
ctx
->
GetInputDim
(
"Weight"
);
auto
w_dims
=
ctx
->
GetInputDim
(
"Weight"
);
...
@@ -68,8 +70,8 @@ class LSTMPOp : public framework::OperatorWithKernel {
...
@@ -68,8 +70,8 @@ class LSTMPOp : public framework::OperatorWithKernel {
if
(
ctx
->
HasInput
(
"H0"
))
{
if
(
ctx
->
HasInput
(
"H0"
))
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
"Input(C0)
and Input(H0) of LSTMP should not
"
"Input(C0)
of LSTMP operator should not be null after
"
"
be null at the same time
."
);
"
Input(H0) provided
."
);
auto
h_dims
=
ctx
->
GetInputDim
(
"H0"
);
auto
h_dims
=
ctx
->
GetInputDim
(
"H0"
);
auto
c_dims
=
ctx
->
GetInputDim
(
"C0"
);
auto
c_dims
=
ctx
->
GetInputDim
(
"C0"
);
PADDLE_ENFORCE
(
h_dims
==
c_dims
,
PADDLE_ENFORCE
(
h_dims
==
c_dims
,
...
@@ -132,8 +134,7 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -132,8 +134,7 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"C0"
,
AddInput
(
"C0"
,
"(Tensor, optional) the initial cell state is an optional "
"(Tensor, optional) the initial cell state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"input. This is a tensor with shape (N x D), where N is the "
"batch size. Only one of `H0` and `C0` can be NULL at the same "
"batch size. `C0` should not be null if `H0` provided."
)
"time."
)
.
AsDispensable
();
.
AsDispensable
();
AddInput
(
"Weight"
,
AddInput
(
"Weight"
,
"(Tensor) the learnable hidden-hidden weights."
"(Tensor) the learnable hidden-hidden weights."
...
@@ -211,13 +212,12 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -211,13 +212,12 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
"`tanh` by default."
)
"`tanh` by default."
)
.
SetDefault
(
"tanh"
)
.
SetDefault
(
"tanh"
)
.
InEnum
({
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
});
.
InEnum
({
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
});
AddAttr
<
bool
>
(
"share_cell_act"
,
AddAttr
<
std
::
string
>
(
"proj_activation"
,
"(bool, defalut: True) "
"(string, default: tanh)"
"whether to share the activation of cell output with the "
"The activation for projection output, "
"projection layer. When set to `False`, the projection "
"`tanh` by defalut."
)
"is simple linear, otherwise it will go through an "
.
SetDefault
(
"tanh"
)
"activation function same as `cell_activation`."
)
.
InEnum
({
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
});
.
SetDefault
(
true
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Long-Short Term Memory with recurrent Projection layer (LSTMP) Operator.
Long-Short Term Memory with recurrent Projection layer (LSTMP) Operator.
...
@@ -226,20 +226,21 @@ original hidden state to a lower-dimensional one, which is proposed to reduce
...
@@ -226,20 +226,21 @@ original hidden state to a lower-dimensional one, which is proposed to reduce
the number of total parameters and furthermore computational complexity for
the number of total parameters and furthermore computational complexity for
the LSTM, espeacially for the case that the size of output units is relative
the LSTM, espeacially for the case that the size of output units is relative
large (https://research.google.com/pubs/archive/43905.pdf).
large (https://research.google.com/pubs/archive/43905.pdf).
The formula is as follows:
The formula is as follows:
$$
$$
i_t = \sigma(W_{ix}x_{t} + W_{i
h
}r_{t-1} + W_{ic}c_{t-1} + b_i) \\
i_t = \sigma(W_{ix}x_{t} + W_{i
r
}r_{t-1} + W_{ic}c_{t-1} + b_i) \\
f_t = \sigma(W_{fx}x_{t} + W_{f
h
}r_{t-1} + W_{fc}c_{t-1} + b_f) \\
f_t = \sigma(W_{fx}x_{t} + W_{f
r
}r_{t-1} + W_{fc}c_{t-1} + b_f) \\
\tilde{c_t} = act_g(W_{cx}x_t + W_{c
h
}r_{t-1} + b_c) \\
\tilde{c_t} = act_g(W_{cx}x_t + W_{c
r
}r_{t-1} + b_c) \\
o_t = \sigma(W_{ox}x_{t} + W_{o
h
}r_{t-1} + W_{oc}c_t + b_o) \\
o_t = \sigma(W_{ox}x_{t} + W_{o
r
}r_{t-1} + W_{oc}c_t + b_o) \\
c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c_t}
c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c_t}
\\
h_t = o_t \odot act_h(c_t)
h_t = o_t \odot act_h(c_t)
\\
r_t = \overline{act_h}(W_{rh}h_t)
r_t = \overline{act_h}(W_{rh}h_t)
$$
$$
...
@@ -259,9 +260,8 @@ input and previous hidden state.
...
@@ -259,9 +260,8 @@ input and previous hidden state.
The $\odot$ is the element-wise product of the vectors. $act_g$ and $act_h$
The $\odot$ is the element-wise product of the vectors. $act_g$ and $act_h$
are the cell input and cell output activation functions and `tanh` is usually
are the cell input and cell output activation functions and `tanh` is usually
used for them. $\overline{act_h}$ is the activation function for the projection
used for them. $\overline{act_h}$ is the activation function for the
layer. When `share_cell_act` set to `False`, $\overline{act_h}$ is an
projection output, usually using `identity` or same as $act_h$.
identity activation, otherwise it will be same as $act_h$.
Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$
Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$
operations on the input $x_{t}$ are NOT included in this operator.
operations on the input $x_{t}$ are NOT included in this operator.
...
@@ -277,22 +277,22 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
...
@@ -277,22 +277,22 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of LSTMP should not be null."
);
"Input(Input) of LSTMP
operator
should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Projection"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Projection"
),
"Input(Projection) of LSTMP should not be null."
);
"Input(Projection) of LSTMP
operator
should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Cell"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Cell"
),
"Input(Cell) of LSTMP should not be null."
);
"Input(Cell) of LSTMP
operator
should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
"Input(Weight) of LSTMP should not be null."
);
"Input(Weight) of LSTMP
operator
should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"ProjWeight"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"ProjWeight"
),
"Input(ProjWeight) of LSTMP should not be null."
);
"Input(ProjWeight) of LSTMP
operator
should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
"Input(Bias) of LSTMP should not be null."
);
"Input(Bias) of LSTMP
operator
should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchGate"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchGate"
),
"Input(BatchGate) of LSTMP should not be null."
);
"Input(BatchGate) of LSTMP
operator
should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchCellPreAct"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchCellPreAct"
),
"Input(BatchGate) of LSTMP should not be null."
);
"Input(BatchGate) of LSTMP
operator
should not be null."
);
auto
SetOutGradDim
=
[
&
ctx
](
const
std
::
string
&
name
)
{
auto
SetOutGradDim
=
[
&
ctx
](
const
std
::
string
&
name
)
{
auto
g_name
=
framework
::
GradVarName
(
name
);
auto
g_name
=
framework
::
GradVarName
(
name
);
...
...
paddle/operators/lstmp_op.h
浏览文件 @
76beff86
...
@@ -136,7 +136,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
...
@@ -136,7 +136,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
));
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
));
auto
cand_act
=
math
::
detail
::
GetActivationType
(
auto
cand_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
));
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
));
auto
share_cell_act
=
ctx
.
Attr
<
bool
>
(
"share_cell_act"
);
auto
proj_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"proj_activation"
));
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
...
@@ -174,7 +175,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
...
@@ -174,7 +175,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
ordered_h0
,
false
,
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
ordered_h0
,
false
,
*
proj_weight
,
false
,
static_cast
<
T
>
(
1.0
),
*
proj_weight
,
false
,
static_cast
<
T
>
(
1.0
),
ordered_proj0
,
static_cast
<
T
>
(
0.0
));
ordered_proj0
,
static_cast
<
T
>
(
0.0
));
if
(
share_cell_act
)
{
if
(
proj_act
!=
math
::
detail
::
ActivationType
::
kIdentity
)
{
auto
proj0_dev
=
EigenMatrix
<
T
>::
From
(
*
ordered_proj0
);
auto
proj0_dev
=
EigenMatrix
<
T
>::
From
(
*
ordered_proj0
);
ActCompute
(
cell_act
,
place
,
proj0_dev
,
proj0_dev
);
ActCompute
(
cell_act
,
place
,
proj0_dev
,
proj0_dev
);
}
}
...
@@ -194,7 +195,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
...
@@ -194,7 +195,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
hidden_t
,
false
,
*
proj_weight
,
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
hidden_t
,
false
,
*
proj_weight
,
false
,
static_cast
<
T
>
(
1.0
),
&
proj_t
,
false
,
static_cast
<
T
>
(
1.0
),
&
proj_t
,
static_cast
<
T
>
(
0.0
));
static_cast
<
T
>
(
0.0
));
if
(
share_cell_act
)
{
if
(
proj_act
!=
math
::
detail
::
ActivationType
::
kIdentity
)
{
auto
proj_t_dev
=
EigenMatrix
<
T
>::
From
(
proj_t
);
auto
proj_t_dev
=
EigenMatrix
<
T
>::
From
(
proj_t
);
ActCompute
(
cell_act
,
place
,
proj_t_dev
,
proj_t_dev
);
ActCompute
(
cell_act
,
place
,
proj_t_dev
,
proj_t_dev
);
}
}
...
@@ -348,7 +349,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
...
@@ -348,7 +349,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
));
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
));
auto
cand_act
=
math
::
detail
::
GetActivationType
(
auto
cand_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
));
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
));
auto
share_cell_act
=
ctx
.
Attr
<
bool
>
(
"share_cell_act"
);
auto
proj_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"proj_activation"
));
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
...
@@ -359,7 +361,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
...
@@ -359,7 +361,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
Tensor
cur_proj
=
batch_proj
.
Slice
(
bstart
,
bend
);
Tensor
cur_proj
=
batch_proj
.
Slice
(
bstart
,
bend
);
Tensor
proj_g
=
batch_proj_g
.
Slice
(
bstart
,
bend
);
Tensor
proj_g
=
batch_proj_g
.
Slice
(
bstart
,
bend
);
if
(
share_cell_act
)
{
if
(
proj_act
!=
math
::
detail
::
ActivationType
::
kIdentity
)
{
auto
cur_proj_dev
=
EigenMatrix
<
T
>::
From
(
cur_proj
);
auto
cur_proj_dev
=
EigenMatrix
<
T
>::
From
(
cur_proj
);
auto
proj_g_dev
=
EigenMatrix
<
T
>::
From
(
proj_g
);
auto
proj_g_dev
=
EigenMatrix
<
T
>::
From
(
proj_g
);
ActGradCompute
(
cell_act
,
place
,
cur_proj_dev
,
cur_proj_dev
,
proj_g_dev
,
ActGradCompute
(
cell_act
,
place
,
cur_proj_dev
,
cur_proj_dev
,
proj_g_dev
,
...
@@ -439,7 +441,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
...
@@ -439,7 +441,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
gate_g
,
false
,
*
weight
,
math
::
matmul
<
DeviceContext
,
T
>
(
device_ctx
,
gate_g
,
false
,
*
weight
,
true
,
static_cast
<
T
>
(
1.0
),
&
proj0_g
,
true
,
static_cast
<
T
>
(
1.0
),
&
proj0_g
,
static_cast
<
T
>
(
0.0
));
static_cast
<
T
>
(
0.0
));
if
(
share_cell_act
)
{
if
(
proj_act
!=
math
::
detail
::
ActivationType
::
kIdentity
)
{
auto
proj0_dev
=
EigenMatrix
<
T
>::
From
(
*
ordered_proj0
);
auto
proj0_dev
=
EigenMatrix
<
T
>::
From
(
*
ordered_proj0
);
auto
proj0_g_dev
=
EigenMatrix
<
T
>::
From
(
proj0_g
);
auto
proj0_g_dev
=
EigenMatrix
<
T
>::
From
(
proj0_g
);
ActGradCompute
(
cell_act
,
place
,
proj0_dev
,
proj0_dev
,
proj0_g_dev
,
ActGradCompute
(
cell_act
,
place
,
proj0_dev
,
proj0_dev
,
proj0_g_dev
,
...
...
python/paddle/v2/fluid/tests/test_lstmp_op.py
浏览文件 @
76beff86
...
@@ -41,7 +41,7 @@ def relu(x):
...
@@ -41,7 +41,7 @@ def relu(x):
return
np
.
maximum
(
x
,
0
)
return
np
.
maximum
(
x
,
0
)
ACTVATION
=
{
ACT
I
VATION
=
{
'identity'
:
identity
,
'identity'
:
identity
,
'sigmoid'
:
sigmoid
,
'sigmoid'
:
sigmoid
,
'tanh'
:
tanh
,
'tanh'
:
tanh
,
...
@@ -63,8 +63,9 @@ def lstmp(
...
@@ -63,8 +63,9 @@ def lstmp(
act_gate
=
None
,
act_gate
=
None
,
act_cell
=
None
,
act_cell
=
None
,
act_cand
=
None
,
act_cand
=
None
,
share_cell_act
=
True
):
act_proj
=
None
):
def
_step
(
x
,
w_r
,
w_rh
,
w_c
,
r_pre
,
c_pre
,
act_gate
,
act_cell
,
act_cand
):
def
_step
(
x
,
w_r
,
w_rh
,
w_c
,
r_pre
,
c_pre
,
act_gate
,
act_cell
,
act_cand
,
act_proj
):
g
=
np
.
dot
(
r_pre
,
w_r
)
# 1 x 4D
g
=
np
.
dot
(
r_pre
,
w_r
)
# 1 x 4D
g
=
g
+
x
g
=
g
+
x
g
=
np
.
reshape
(
g
,
(
1
,
g
.
size
))
g
=
np
.
reshape
(
g
,
(
1
,
g
.
size
))
...
@@ -86,8 +87,7 @@ def lstmp(
...
@@ -86,8 +87,7 @@ def lstmp(
h
=
g_o
*
act_cell
(
c
)
h
=
g_o
*
act_cell
(
c
)
# projection
# projection
r
=
np
.
dot
(
h
,
w_rh
)
r
=
np
.
dot
(
h
,
w_rh
)
if
share_cell_act
:
r
=
act_proj
(
r
)
r
=
act_cell
(
r
)
return
r
,
c
return
r
,
c
def
_reverse
(
x
,
lod
):
def
_reverse
(
x
,
lod
):
...
@@ -110,13 +110,12 @@ def lstmp(
...
@@ -110,13 +110,12 @@ def lstmp(
seq_len
=
offset
[
i
+
1
]
-
offset
[
i
]
seq_len
=
offset
[
i
+
1
]
-
offset
[
i
]
x
=
input
[
offset
[
i
]:
offset
[
i
+
1
],
:]
x
=
input
[
offset
[
i
]:
offset
[
i
+
1
],
:]
r_pre
=
np
.
dot
(
h0
[
i
],
w_rh
)
# 1 x P
r_pre
=
np
.
dot
(
h0
[
i
],
w_rh
)
# 1 x P
if
share_cell_act
:
r_pre
=
act_proj
(
r_pre
)
r_pre
=
act_cell
(
r_pre
)
c_pre
=
c0
[
i
]
# 1 x D
c_pre
=
c0
[
i
]
# 1 x D
for
j
in
range
(
seq_len
):
for
j
in
range
(
seq_len
):
# compute one step
# compute one step
r_pre
,
c_pre
=
_step
(
x
[
j
],
w_r
,
w_rh
,
w_c
,
r_pre
,
c_pre
,
act_gate
,
r_pre
,
c_pre
=
_step
(
x
[
j
],
w_r
,
w_rh
,
w_c
,
r_pre
,
c_pre
,
act_gate
,
act_cell
,
act_cand
)
act_cell
,
act_cand
,
act_proj
)
projection
.
append
(
r_pre
.
flatten
())
projection
.
append
(
r_pre
.
flatten
())
cell
.
append
(
c_pre
.
flatten
())
cell
.
append
(
c_pre
.
flatten
())
...
@@ -131,7 +130,7 @@ def lstmp(
...
@@ -131,7 +130,7 @@ def lstmp(
return
projection
,
cell
return
projection
,
cell
class
TestLstmOp
(
OpTest
):
class
TestLstm
p
Op
(
OpTest
):
def
set_argument
(
self
):
def
set_argument
(
self
):
self
.
lod
=
[[
0
,
2
,
5
,
7
]]
self
.
lod
=
[[
0
,
2
,
5
,
7
]]
# hidden size
# hidden size
...
@@ -142,8 +141,8 @@ class TestLstmOp(OpTest):
...
@@ -142,8 +141,8 @@ class TestLstmOp(OpTest):
self
.
act_gate
=
'sigmoid'
self
.
act_gate
=
'sigmoid'
self
.
act_cell
=
'tanh'
self
.
act_cell
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
act_proj
=
self
.
act_cell
self
.
share_cell_act
=
True
self
.
has_initial_state
=
False
self
.
has_initial_state
=
False
self
.
is_reverse
=
False
self
.
is_reverse
=
False
self
.
use_peepholes
=
True
self
.
use_peepholes
=
True
...
@@ -172,8 +171,8 @@ class TestLstmOp(OpTest):
...
@@ -172,8 +171,8 @@ class TestLstmOp(OpTest):
w_c
=
b
[:,
4
*
self
.
D
:]
if
self
.
use_peepholes
else
None
w_c
=
b
[:,
4
*
self
.
D
:]
if
self
.
use_peepholes
else
None
w_rh
=
np
.
random
.
normal
(
size
=
(
self
.
D
,
self
.
P
)).
astype
(
'float64'
)
w_rh
=
np
.
random
.
normal
(
size
=
(
self
.
D
,
self
.
P
)).
astype
(
'float64'
)
r
,
c
=
lstmp
(
x
,
self
.
lod
,
h0
,
c0
,
w
,
w_rh
,
w_b
,
w_c
,
self
.
is_reverse
,
r
,
c
=
lstmp
(
x
,
self
.
lod
,
h0
,
c0
,
w
,
w_rh
,
w_b
,
w_c
,
self
.
is_reverse
,
ACT
VATION
[
self
.
act_gate
],
ACT
VATION
[
self
.
act_cell
],
ACT
IVATION
[
self
.
act_gate
],
ACTI
VATION
[
self
.
act_cell
],
ACT
VATION
[
self
.
act_cand
],
self
.
share_cell_act
)
ACT
IVATION
[
self
.
act_cand
],
ACTIVATION
[
self
.
act_proj
]
)
self
.
inputs
=
{
'Input'
:
(
x
,
self
.
lod
),
'Weight'
:
w
,
'ProjWeight'
:
w_rh
}
self
.
inputs
=
{
'Input'
:
(
x
,
self
.
lod
),
'Weight'
:
w
,
'ProjWeight'
:
w_rh
}
...
@@ -193,7 +192,7 @@ class TestLstmOp(OpTest):
...
@@ -193,7 +192,7 @@ class TestLstmOp(OpTest):
'gate_activation'
:
self
.
act_gate
,
'gate_activation'
:
self
.
act_gate
,
'cell_activation'
:
self
.
act_cell
,
'cell_activation'
:
self
.
act_cell
,
'candidate_activation'
:
self
.
act_cand
,
'candidate_activation'
:
self
.
act_cand
,
'
share_cell_act'
:
self
.
share_cell_act
'
proj_activation'
:
self
.
act_proj
}
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
...
@@ -212,7 +211,7 @@ class TestLstmOp(OpTest):
...
@@ -212,7 +211,7 @@ class TestLstmOp(OpTest):
max_relative_error
=
1e-2
)
max_relative_error
=
1e-2
)
class
TestLstm
OpHasInitial
(
TestLstm
Op
):
class
TestLstm
pOpHasInitial
(
TestLstmp
Op
):
def
set_argument
(
self
):
def
set_argument
(
self
):
self
.
lod
=
[[
0
,
2
,
5
,
7
]]
self
.
lod
=
[[
0
,
2
,
5
,
7
]]
self
.
D
=
16
self
.
D
=
16
...
@@ -221,8 +220,8 @@ class TestLstmOpHasInitial(TestLstmOp):
...
@@ -221,8 +220,8 @@ class TestLstmOpHasInitial(TestLstmOp):
self
.
act_gate
=
'sigmoid'
self
.
act_gate
=
'sigmoid'
self
.
act_cell
=
'tanh'
self
.
act_cell
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
act_proj
=
self
.
act_cell
self
.
share_cell_act
=
True
self
.
has_initial_state
=
True
self
.
has_initial_state
=
True
self
.
is_reverse
=
True
self
.
is_reverse
=
True
self
.
use_peepholes
=
True
self
.
use_peepholes
=
True
...
@@ -313,7 +312,7 @@ class TestLstmOpHasInitial(TestLstmOp):
...
@@ -313,7 +312,7 @@ class TestLstmOpHasInitial(TestLstmOp):
no_grad_set
=
set
(
'C0'
))
no_grad_set
=
set
(
'C0'
))
class
TestLstm
OpRerverse
(
TestLstm
Op
):
class
TestLstm
pOpRerverse
(
TestLstmp
Op
):
def
set_argument
(
self
):
def
set_argument
(
self
):
self
.
lod
=
[[
0
,
2
,
5
,
7
]]
self
.
lod
=
[[
0
,
2
,
5
,
7
]]
self
.
D
=
16
self
.
D
=
16
...
@@ -322,14 +321,14 @@ class TestLstmOpRerverse(TestLstmOp):
...
@@ -322,14 +321,14 @@ class TestLstmOpRerverse(TestLstmOp):
self
.
act_gate
=
'sigmoid'
self
.
act_gate
=
'sigmoid'
self
.
act_cell
=
'tanh'
self
.
act_cell
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
act_proj
=
self
.
act_cell
self
.
share_cell_act
=
True
self
.
has_initial_state
=
False
self
.
has_initial_state
=
False
self
.
is_reverse
=
True
self
.
is_reverse
=
True
self
.
use_peepholes
=
True
self
.
use_peepholes
=
True
class
TestLstm
OpNotUsePeepholes
(
TestLstm
Op
):
class
TestLstm
pOpNotUsePeepholes
(
TestLstmp
Op
):
def
set_argument
(
self
):
def
set_argument
(
self
):
self
.
lod
=
[[
0
,
2
,
5
,
7
]]
self
.
lod
=
[[
0
,
2
,
5
,
7
]]
self
.
D
=
16
self
.
D
=
16
...
@@ -338,14 +337,14 @@ class TestLstmOpNotUsePeepholes(TestLstmOp):
...
@@ -338,14 +337,14 @@ class TestLstmOpNotUsePeepholes(TestLstmOp):
self
.
act_gate
=
'sigmoid'
self
.
act_gate
=
'sigmoid'
self
.
act_cell
=
'tanh'
self
.
act_cell
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
act_proj
=
self
.
act_cell
self
.
share_cell_act
=
True
self
.
has_initial_state
=
False
self
.
has_initial_state
=
False
self
.
is_reverse
=
False
self
.
is_reverse
=
False
self
.
use_peepholes
=
False
self
.
use_peepholes
=
False
class
TestLstm
OpNotShareCellAct
(
TestLstm
Op
):
class
TestLstm
pOpLinearProjection
(
TestLstmp
Op
):
def
set_argument
(
self
):
def
set_argument
(
self
):
self
.
lod
=
[[
0
,
2
,
5
,
7
]]
self
.
lod
=
[[
0
,
2
,
5
,
7
]]
self
.
D
=
16
self
.
D
=
16
...
@@ -354,8 +353,8 @@ class TestLstmOpNotShareCellAct(TestLstmOp):
...
@@ -354,8 +353,8 @@ class TestLstmOpNotShareCellAct(TestLstmOp):
self
.
act_gate
=
'sigmoid'
self
.
act_gate
=
'sigmoid'
self
.
act_cell
=
'tanh'
self
.
act_cell
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
act_proj
=
'identity'
self
.
share_cell_act
=
False
self
.
has_initial_state
=
False
self
.
has_initial_state
=
False
self
.
is_reverse
=
False
self
.
is_reverse
=
False
self
.
use_peepholes
=
True
self
.
use_peepholes
=
True
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录