Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1013d6d0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1013d6d0
编写于
11月 30, 2018
作者:
C
chengduozh
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'add_cudnn_lstm' of
https://github.com/PaddlePaddle/Paddle
into add_cudnn_lstm
上级
5d334ff0
6ce42501
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
939 addition
and
1 deletion
+939
-1
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/operators/lstm_cudnn_op.cc
paddle/fluid/operators/lstm_cudnn_op.cc
+218
-0
paddle/fluid/operators/lstm_cudnn_op.cu.cc
paddle/fluid/operators/lstm_cudnn_op.cu.cc
+494
-0
paddle/fluid/operators/lstm_cudnn_op.h
paddle/fluid/operators/lstm_cudnn_op.h
+45
-0
paddle/fluid/platform/dynload/cudnn.h
paddle/fluid/platform/dynload/cudnn.h
+17
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+164
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
1013d6d0
...
...
@@ -187,6 +187,7 @@ paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=Non
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'dropout_prob', 'input_size', 'hidden_size', 'num_layers', 'is_bidirec', 'dtype', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(False, 'float32', False, None, None, False, 0))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
...
...
paddle/fluid/operators/lstm_cudnn_op.cc
0 → 100644
浏览文件 @
1013d6d0
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/lstm_cudnn_op.h"
#include <string>
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
class
CudnnLSTMOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"W"
),
"Input(Weight) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"InitH"
),
"Input(init_h) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"InitC"
),
"Input(init_c) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Cache"
),
"Input(Cache) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"last_h"
),
"Output(last_h) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"last_c"
),
"Output(last_c) of LSTM should not be null."
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
3
,
"Input(X)'s rank must be 3."
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"Input"
));
ctx
->
SetOutputDim
(
"last_h"
,
ctx
->
GetInputDim
(
"InitH"
));
ctx
->
SetOutputDim
(
"last_c"
,
ctx
->
GetInputDim
(
"InitC"
));
}
};
class
CudnnLSTMOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Input"
,
"(Tensor) RNN input tensor, which support variable-time length input "
"sequence."
"The shape of the Tensor MUST be ( seq_len * batch_size * input_size)"
"seq_len is the total time step in this mini-batch (CAN be change in "
"different batch)"
"batch_size is the instance number of this batch"
"input_size is the hidden size of the input."
"input_hidden_size and the hidden_size in the next may not be same"
);
AddInput
(
"InitH"
,
"(Tensor) the initial hidden state of the LSTM"
"input. This is a tensor with shape (num_layers x batch_size x "
"hidden_size)"
"and When is_bidirec is True, the shape will be (num_layers*2 x "
"batch_size x hidden_size)"
);
AddInput
(
"InitC"
,
"(Tensor) the initial cell state of the LSTm "
"input. This is a tensor with shape (num_layers x batch_size x "
"hidden_size)"
"and When is_bidirec is True, the shape will be (num_layers*2 x "
"batch_size x hidden_size)"
);
AddInput
(
"W"
,
"(Tensor) the learnable hidden-hidden weights."
" The shape is (N), where N is total weight size of the LSTM. "
" cudnn concatenate all the weight to one Tensor"
);
AddInput
(
"Cache"
,
"The cache of dropout op, a RAW type variable including random "
"number generator states and some descriptors, which is used in "
"cudnn kernel."
)
.
AsDispensable
();
AddOutput
(
"Out"
,
"(Tensor) the hidden state of LSTM operator. "
"The shape is ( seq_len x batch_size x hidden_size) if "
"is_bidirec is False"
"and When is_bidirec is True, the shape will be ( seq_len x "
"batch_size x hidden_size * 2) "
);
AddOutput
(
"last_h"
,
"(Tensor) the hidden state of the last step. "
"The shape is ( num_layers x batch_size x hidden_size) if "
"is_bidirec is False"
"and When is_bidirec is True, the shape will be (num_layers*2 x "
"batch_size x hidden_size)"
);
AddOutput
(
"last_c"
,
"(Tensor) the cell state of the last step"
"The shape is ( num_layers x batch_size x hidden_size) if "
"is_bidirec is False"
"and When is_bidirect is True, the shape will be (num_layers*2 x "
"batch_size x hidden_size*2)"
);
AddAttr
<
int
>
(
"max_len"
,
"max length of the LSTM op"
"the first dim of the Input can NOT be greater than max_len"
)
.
SetDefault
(
20
);
AddAttr
<
float
>
(
"dropout_prob"
,
"dropout prob of the dropout op"
"the dropout ONLY work between lstm layers, not between time steps"
"There is no dropout work on the Out tensor"
)
.
SetDefault
(
0.0
);
AddAttr
<
bool
>
(
"is_bidirec"
,
"is_bidirec"
"if it is bidirection rnn"
"The will affect the shape of the Out, last_h, and last_c"
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"input_size"
,
"input size ot the Input Tensor"
).
SetDefault
(
10
);
AddAttr
<
int
>
(
"hidden_size"
,
"hidden size of the LSTM"
).
SetDefault
(
100
);
AddAttr
<
int
>
(
"num_layers"
,
"the total layer number of the LSTM"
)
.
SetDefault
(
1
);
AddAttr
<
bool
>
(
"is_test"
,
"True if in test phase."
).
SetDefault
(
false
);
AddAttr
<
int
>
(
"seed"
,
"seed to used if fix_seed is True"
).
SetDefault
(
-
1
);
AddComment
(
R"DOC(
CUDNN LSTM implementation
A four-gate Long Short-Term Memory network with no peephole connections.
In the forward pass the output ht and cell output ct for a given iteration can be computed from the recurrent input ht-1,
the cell input ct-1 and the previous layer input xt given matrices W, R and biases bW, bR from the following equations:
$$ i_t = sigmoid(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i) $$
$$ f_t = sigmoid(W_{fx}x_{t} + W_{fh}h_{t-1} + bx_f + bh_f) $$
$$ o_t = sigmoid(W_{ox}x_{t} + W_{oh}h_{t-1} + bx_o + bh_o) $$
$$ \\tilde{c_t} = tanh(W_{cx}x_t + W_{ch}h_{t-1} + bx_c + bh_c) $$
$$ c_t = f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} $$
$$ h_t = o_t \\odot tanh(c_t) $$
- W terms denote weight matrices (e.g. $W_{ix}$ is the matrix
of weights from the input gate to the input)
- The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector).
- sigmoid is the logistic sigmoid function.
- $i, f, o$ and $c$ are the input gate, forget gate, output gate,
and cell activation vectors, respectively, all of which have the same size as
the cell output activation vector $h$.
- The $\odot$ is the element-wise product of the vectors.
- `tanh` is the activation functions.
- $\tilde{c_t}$ is also called candidate hidden state,
which is computed based on the current input and the previous hidden state.
Where sigmoid is the sigmoid operator: sigmoid(x) = 1 / (1 + e^-x), * represents a point-wise multiplication,
X represensts a matrix multiplication
)DOC"
);
}
};
class
CudnnLSTMGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"W"
),
"Input(W) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"last_h"
),
"Input(last_h) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"last_c"
),
"Input(last_c) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Cache"
),
"Input(last_c) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"InitH"
),
"Input(init_h) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"InitC"
),
"Input(init_c) of LSTM should not be null."
);
auto
SetOutGradDim
=
[
&
ctx
](
const
std
::
string
&
name
)
{
auto
g_name
=
framework
::
GradVarName
(
name
);
if
(
ctx
->
HasOutput
(
g_name
))
{
ctx
->
SetOutputDim
(
g_name
,
ctx
->
GetInputDim
(
name
));
}
};
SetOutGradDim
(
"Input"
);
SetOutGradDim
(
"W"
);
SetOutGradDim
(
"InitH"
);
SetOutGradDim
(
"InitC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
cudnn_lstm
,
ops
::
CudnnLSTMOp
,
ops
::
CudnnLSTMOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
cudnn_lstm_grad
,
ops
::
CudnnLSTMGradOp
);
REGISTER_OP_CPU_KERNEL
(
cudnn_lstm
,
ops
::
CudnnLSTMKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
cudnn_lstm_grad
,
ops
::
CudnnLSTMGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
paddle/fluid/operators/lstm_cudnn_op.cu.cc
0 → 100644
浏览文件 @
1013d6d0
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/lstm_cudnn_op.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace
paddle
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
using
Tensor
=
framework
::
Tensor
;
struct
CudnnRNNCache
{
CudnnRNNCache
()
{
x_desc_
=
NULL
;
y_desc_
=
NULL
;
dx_desc_
=
NULL
;
dy_desc_
=
NULL
;
}
~
CudnnRNNCache
()
{
release
();
}
cudnnRNNDescriptor_t
rnn_desc_
;
cudnnTensorDescriptor_t
*
x_desc_
;
cudnnTensorDescriptor_t
*
y_desc_
;
cudnnTensorDescriptor_t
*
dx_desc_
;
cudnnTensorDescriptor_t
*
dy_desc_
;
cudnnTensorDescriptor_t
hx_desc_
;
cudnnTensorDescriptor_t
cx_desc_
;
cudnnTensorDescriptor_t
hy_desc_
;
cudnnTensorDescriptor_t
cy_desc_
;
cudnnTensorDescriptor_t
dhx_desc_
;
cudnnTensorDescriptor_t
dcx_desc_
;
cudnnTensorDescriptor_t
dhy_desc_
;
cudnnTensorDescriptor_t
dcy_desc_
;
cudnnTensorDescriptor_t
output_x_desc_
;
cudnnTensorDescriptor_t
output_y_desc_
;
cudnnDropoutDescriptor_t
dropout_desc_
;
size_t
weights_size_
;
cudnnFilterDescriptor_t
w_desc_
;
cudnnFilterDescriptor_t
dw_desc_
;
size_t
workspace_size_
;
size_t
reserve_size_
;
Tensor
reserve_data_
;
Tensor
workspace_data_
;
Tensor
dropout_state_
;
size_t
max_length_
;
float
dropout_prob_
;
bool
is_bidirec_
;
int
batch_size_
;
int
input_size_
;
int
hidden_size_
;
int
num_layers_
;
int
seed_
;
void
init
(
cudnnHandle_t
handle
,
const
framework
::
ExecutionContext
&
ctx
,
size_t
max_len
,
int
batch_size
,
int
input_size
,
int
hidden_size
,
int
num_layers
,
float
dropout_prob
,
bool
is_bidirec
,
int
seed
,
int
weight_numel
)
{
max_length_
=
max_len
;
batch_size_
=
batch_size
;
input_size_
=
input_size
;
hidden_size_
=
hidden_size
;
num_layers_
=
num_layers
;
dropout_prob_
=
dropout_prob
;
is_bidirec_
=
is_bidirec
;
seed_
=
seed
;
x_desc_
=
new
cudnnTensorDescriptor_t
[
max_length_
];
y_desc_
=
new
cudnnTensorDescriptor_t
[
max_length_
];
dx_desc_
=
new
cudnnTensorDescriptor_t
[
max_length_
];
dy_desc_
=
new
cudnnTensorDescriptor_t
[
max_length_
];
int
dim_a
[
3
];
int
stride_a
[
3
];
for
(
size_t
i
=
0
;
i
<
max_length_
;
++
i
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
x_desc_
[
i
]));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
y_desc_
[
i
]));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
dx_desc_
[
i
]));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
dy_desc_
[
i
]));
dim_a
[
0
]
=
batch_size_
;
dim_a
[
1
]
=
input_size_
;
dim_a
[
2
]
=
1
;
stride_a
[
0
]
=
dim_a
[
2
]
*
dim_a
[
1
];
stride_a
[
1
]
=
dim_a
[
2
];
stride_a
[
2
]
=
1
;
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
x_desc_
[
i
],
CUDNN_DATA_FLOAT
,
3
,
dim_a
,
stride_a
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
dx_desc_
[
i
],
CUDNN_DATA_FLOAT
,
3
,
dim_a
,
stride_a
));
dim_a
[
0
]
=
batch_size_
;
dim_a
[
1
]
=
is_bidirec_
?
hidden_size_
*
2
:
hidden_size_
;
dim_a
[
2
]
=
1
;
stride_a
[
0
]
=
dim_a
[
2
]
*
dim_a
[
1
];
stride_a
[
1
]
=
dim_a
[
2
];
stride_a
[
2
]
=
1
;
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
y_desc_
[
i
],
CUDNN_DATA_FLOAT
,
3
,
dim_a
,
stride_a
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
dy_desc_
[
i
],
CUDNN_DATA_FLOAT
,
3
,
dim_a
,
stride_a
));
}
dim_a
[
0
]
=
num_layers_
*
(
is_bidirec_
?
2
:
1
);
dim_a
[
1
]
=
batch_size_
;
dim_a
[
2
]
=
hidden_size_
;
stride_a
[
0
]
=
dim_a
[
2
]
*
dim_a
[
1
];
stride_a
[
1
]
=
dim_a
[
2
];
stride_a
[
2
]
=
1
;
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
hx_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
cx_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
hy_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
cy_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
dhx_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
dcx_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
dhy_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
dcy_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
hx_desc_
,
CUDNN_DATA_FLOAT
,
3
,
dim_a
,
stride_a
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
cx_desc_
,
CUDNN_DATA_FLOAT
,
3
,
dim_a
,
stride_a
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
hy_desc_
,
CUDNN_DATA_FLOAT
,
3
,
dim_a
,
stride_a
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
cy_desc_
,
CUDNN_DATA_FLOAT
,
3
,
dim_a
,
stride_a
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
dhx_desc_
,
CUDNN_DATA_FLOAT
,
3
,
dim_a
,
stride_a
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
dcx_desc_
,
CUDNN_DATA_FLOAT
,
3
,
dim_a
,
stride_a
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
dhy_desc_
,
CUDNN_DATA_FLOAT
,
3
,
dim_a
,
stride_a
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
dcy_desc_
,
CUDNN_DATA_FLOAT
,
3
,
dim_a
,
stride_a
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateDropoutDescriptor
(
&
dropout_desc_
));
size_t
state_size
;
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDropoutGetStatesSize
(
handle
,
&
state_size
);
dropout_state_
.
Resize
({
static_cast
<
int64_t
>
(
state_size
)}));
auto
*
dropout_state_data
=
dropout_state_
.
mutable_data
<
uint8_t
>
(
ctx
.
GetPlace
());
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetDropoutDescriptor
(
dropout_desc_
,
handle
,
dropout_prob_
,
dropout_state_data
,
state_size
,
seed_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateRNNDescriptor
(
&
rnn_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetRNNDescriptor_v6
(
handle
,
rnn_desc_
,
hidden_size_
,
num_layers_
,
dropout_desc_
,
CUDNN_LINEAR_INPUT
,
is_bidirec_
?
CUDNN_BIDIRECTIONAL
:
CUDNN_UNIDIRECTIONAL
,
CUDNN_LSTM
,
CUDNN_RNN_ALGO_STANDARD
,
CUDNN_DATA_FLOAT
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateFilterDescriptor
(
&
w_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnCreateFilterDescriptor
(
&
dw_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetRNNParamsSize
(
handle
,
rnn_desc_
,
x_desc_
[
0
],
&
weights_size_
,
CUDNN_DATA_FLOAT
));
PADDLE_ENFORCE_EQ
(
weights_size_
,
sizeof
(
float
)
*
weight_numel
,
"cudnn lstm weight size should be SAME"
);
int
dim_w
[
3
];
dim_w
[
0
]
=
weights_size_
/
sizeof
(
float
);
dim_w
[
1
]
=
1
;
dim_w
[
2
]
=
1
;
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetFilterNdDescriptor
(
w_desc_
,
CUDNN_DATA_FLOAT
,
CUDNN_TENSOR_NCHW
,
3
,
dim_w
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetFilterNdDescriptor
(
dw_desc_
,
CUDNN_DATA_FLOAT
,
CUDNN_TENSOR_NCHW
,
3
,
dim_w
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetRNNWorkspaceSize
(
handle
,
rnn_desc_
,
max_length_
,
x_desc_
,
&
workspace_size_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetRNNTrainingReserveSize
(
handle
,
rnn_desc_
,
max_length_
,
x_desc_
,
&
reserve_size_
));
reserve_data_
.
Resize
({
static_cast
<
int64_t
>
(
reserve_size_
)});
reserve_data_
.
mutable_data
<
uint8_t
>
(
ctx
.
GetPlace
());
workspace_data_
.
Resize
({
static_cast
<
int64_t
>
(
workspace_size_
)});
workspace_data_
.
mutable_data
<
uint8_t
>
(
ctx
.
GetPlace
());
}
void
release
()
{
for
(
size_t
i
=
0
;
i
<
max_length_
;
++
i
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
x_desc_
[
i
]));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
y_desc_
[
i
]));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
dx_desc_
[
i
]));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
dy_desc_
[
i
]));
}
delete
[]
x_desc_
;
delete
[]
y_desc_
;
delete
[]
dx_desc_
;
delete
[]
dy_desc_
;
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
hx_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
cx_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
hy_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
cy_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
dhx_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
dcx_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
dhy_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
dcy_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyDropoutDescriptor
(
dropout_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyRNNDescriptor
(
rnn_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyFilterDescriptor
(
w_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyFilterDescriptor
(
dw_desc_
));
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
CudnnLSTMGPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
const
Tensor
*
init_h
=
ctx
.
Input
<
Tensor
>
(
"InitH"
);
const
Tensor
*
init_c
=
ctx
.
Input
<
Tensor
>
(
"InitC"
);
auto
w
=
ctx
.
Input
<
Tensor
>
(
"W"
);
Tensor
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
Tensor
*
last_h
=
ctx
.
Output
<
Tensor
>
(
"last_h"
);
Tensor
*
last_c
=
ctx
.
Output
<
Tensor
>
(
"last_c"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
init_h_data
=
init_h
->
data
<
T
>
();
const
T
*
init_c_data
=
init_c
->
data
<
T
>
();
const
T
*
w_data
=
w
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
last_h_data
=
last_h
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
last_c_data
=
last_c
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
size_t
max_len
=
ctx
.
Attr
<
int
>
(
"max_len"
);
float
dropout_prob
=
ctx
.
Attr
<
float
>
(
"dropout_prob"
);
bool
is_bidirec
=
ctx
.
Attr
<
bool
>
(
"is_bidirec"
);
int
input_size
=
ctx
.
Attr
<
int
>
(
"input_size"
);
int
hidden_size
=
ctx
.
Attr
<
int
>
(
"hidden_size"
);
int
num_layers
=
ctx
.
Attr
<
int
>
(
"num_layers"
);
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
/*
if (is_test) {
TensorCopy(*x, ctx.GetPlace(), out);
return;
}*/
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
*
cache_var
=
ctx
.
InputVar
(
"Cache"
);
if
(
!
cache_var
)
{
// The RAW type cache variable wouldn't be created and broadcasted on
// multi-devices before the first running.
// use parent scope to make cache persistable
auto
*
scope
=
const_cast
<
framework
::
Scope
*>
(
ctx
.
scope
().
parent
());
auto
cache_var_name
=
ctx
.
Inputs
(
"Cache"
)[
0
];
cache_var
=
scope
->
Var
(
cache_var_name
);
}
CudnnRNNCache
*
cudnn_rnn_cache
=
nullptr
;
if
(
cache_var
->
IsInitialized
())
{
cudnn_rnn_cache
=
const_cast
<
framework
::
Variable
*>
(
cache_var
)
->
GetMutable
<
CudnnRNNCache
>
();
}
else
{
cudnn_rnn_cache
=
const_cast
<
framework
::
Variable
*>
(
cache_var
)
->
GetMutable
<
CudnnRNNCache
>
();
std
::
random_device
rnd
;
int
seed
=
ctx
.
Attr
<
int
>
(
"seed"
);
if
(
seed
==
-
1
)
{
seed
=
rnd
();
}
auto
input_w_numel
=
w
->
numel
();
auto
batch_size
=
x
->
dims
()[
1
];
cudnn_rnn_cache
->
init
(
handle
,
ctx
,
max_len
,
batch_size
,
input_size
,
hidden_size
,
num_layers
,
dropout_prob
,
is_bidirec
,
seed
,
input_w_numel
);
}
auto
run_seq_len
=
x
->
dims
()[
0
];
if
(
is_test
)
{
// for inference
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnRNNForwardInference
(
handle
,
cudnn_rnn_cache
->
rnn_desc_
,
run_seq_len
,
cudnn_rnn_cache
->
x_desc_
,
x_data
,
cudnn_rnn_cache
->
hx_desc_
,
init_h_data
,
cudnn_rnn_cache
->
cx_desc_
,
init_c_data
,
cudnn_rnn_cache
->
w_desc_
,
w_data
,
cudnn_rnn_cache
->
y_desc_
,
out_data
,
cudnn_rnn_cache
->
hy_desc_
,
last_h_data
,
cudnn_rnn_cache
->
cy_desc_
,
last_c_data
,
cudnn_rnn_cache
->
workspace_data_
.
data
<
uint8_t
>
(),
cudnn_rnn_cache
->
workspace_size_
));
}
else
{
// for train
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnRNNForwardTraining
(
handle
,
cudnn_rnn_cache
->
rnn_desc_
,
run_seq_len
,
cudnn_rnn_cache
->
x_desc_
,
x_data
,
cudnn_rnn_cache
->
hx_desc_
,
init_h_data
,
cudnn_rnn_cache
->
cx_desc_
,
init_c_data
,
cudnn_rnn_cache
->
w_desc_
,
w_data
,
cudnn_rnn_cache
->
y_desc_
,
out_data
,
cudnn_rnn_cache
->
hy_desc_
,
last_h_data
,
cudnn_rnn_cache
->
cy_desc_
,
last_c_data
,
cudnn_rnn_cache
->
workspace_data_
.
data
<
uint8_t
>
(),
cudnn_rnn_cache
->
workspace_size_
,
cudnn_rnn_cache
->
reserve_data_
.
data
<
uint8_t
>
(),
cudnn_rnn_cache
->
reserve_size_
));
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
CudnnLSTMGPUGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
auto
*
weight
=
ctx
.
Input
<
Tensor
>
(
"W"
);
auto
*
init_h
=
ctx
.
Input
<
Tensor
>
(
"InitH"
);
auto
*
init_c
=
ctx
.
Input
<
Tensor
>
(
"InitC"
);
// auto * last_h = ctx.Input<Tensor>("last_h");
// auto * last_c = ctx.Input<Tensor>("last_c");
auto
*
out
=
ctx
.
Input
<
Tensor
>
(
"Out"
);
auto
*
out_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
last_h_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"last_h"
));
auto
*
last_c_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"last_c"
));
// auto* init_h = ctx.Input<Tensor>("init_h");
// auto* init_c = ctx.Input<Tensor>("init_c");
auto
*
in_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Input"
));
auto
*
weight_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"W"
));
auto
*
init_h_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"InitH"
));
auto
*
init_c_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"InitC"
));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
handle
=
dev_ctx
.
cudnn_handle
();
auto
*
cache_var
=
ctx
.
InputVar
(
"Cache"
);
PADDLE_ENFORCE
(
cache_var
->
IsInitialized
());
CudnnRNNCache
*
cudnn_rnn_cache
=
const_cast
<
framework
::
Variable
*>
(
cache_var
)
->
GetMutable
<
CudnnRNNCache
>
();
auto
input_dims
=
input
->
dims
();
auto
weight_dims
=
weight
->
dims
();
auto
init_h_dims
=
init_h
->
dims
();
auto
init_c_dims
=
init_c
->
dims
();
in_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
weight_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
zero
;
zero
(
dev_ctx
,
in_grad
,
static_cast
<
T
>
(
0.0
));
zero
(
dev_ctx
,
weight_grad
,
static_cast
<
T
>
(
0.0
));
T
*
init_h_grad_data
=
NULL
;
if
(
init_h_grad
==
nullptr
)
{
Tensor
init_h_grad_temp
;
init_h_grad_temp
.
mutable_data
<
T
>
(
init_h_dims
,
ctx
.
GetPlace
());
zero
(
dev_ctx
,
&
init_h_grad_temp
,
static_cast
<
T
>
(
0.0
));
init_h_grad_data
=
init_h_grad_temp
.
data
<
T
>
();
}
else
{
init_h_grad
->
mutable_data
<
T
>
(
init_h_dims
,
ctx
.
GetPlace
());
zero
(
dev_ctx
,
init_h_grad
,
static_cast
<
T
>
(
0.0
));
init_h_grad_data
=
init_h_grad
->
data
<
T
>
();
}
T
*
init_c_grad_data
=
NULL
;
if
(
init_c_grad
==
nullptr
)
{
Tensor
init_c_grad_temp
;
init_c_grad_temp
.
mutable_data
<
T
>
(
init_c_dims
,
ctx
.
GetPlace
());
zero
(
dev_ctx
,
&
init_c_grad_temp
,
static_cast
<
T
>
(
0.0
));
init_c_grad_data
=
init_c_grad_temp
.
data
<
T
>
();
}
else
{
init_c_grad
->
mutable_data
<
T
>
(
init_c_dims
,
ctx
.
GetPlace
());
zero
(
dev_ctx
,
init_c_grad
,
static_cast
<
T
>
(
0.0
));
init_c_grad_data
=
init_c_grad
->
data
<
T
>
();
}
const
T
*
last_h_grad_data
=
NULL
;
if
(
last_h_grad
==
nullptr
)
{
Tensor
last_h_grad_temp
;
last_h_grad_temp
.
mutable_data
<
T
>
(
init_h_dims
,
ctx
.
GetPlace
());
zero
(
dev_ctx
,
&
last_h_grad_temp
,
static_cast
<
T
>
(
0.0
));
last_h_grad_data
=
(
const
T
*
)
last_h_grad_temp
.
data
<
T
>
();
}
else
{
last_h_grad_data
=
last_h_grad
->
data
<
T
>
();
}
const
T
*
last_c_grad_data
=
NULL
;
if
(
last_c_grad
==
nullptr
)
{
Tensor
last_c_grad_temp
;
last_c_grad_temp
.
mutable_data
<
T
>
(
init_c_dims
,
ctx
.
GetPlace
());
zero
(
dev_ctx
,
&
last_c_grad_temp
,
static_cast
<
T
>
(
0.0
));
last_c_grad_data
=
(
const
T
*
)
last_c_grad_temp
.
data
<
T
>
();
}
else
{
last_c_grad_data
=
last_c_grad
->
data
<
T
>
();
}
const
T
*
out_grad_data
=
NULL
;
if
(
out_grad
==
nullptr
)
{
Tensor
out_grad_temp
;
out_grad_temp
.
mutable_data
<
T
>
(
out
->
dims
(),
ctx
.
GetPlace
());
zero
(
dev_ctx
,
&
out_grad_temp
,
static_cast
<
T
>
(
0.0
));
out_grad_data
=
(
const
T
*
)
out_grad_temp
.
data
<
T
>
();
}
else
{
out_grad_data
=
out_grad
->
data
<
T
>
();
}
// zero( dev_ctx, last_h_grad, static_cast<T>(0.0));
// zero( dev_ctx, last_c_grad, static_cast<T>(0.0));
auto
out_data
=
out
->
data
<
T
>
();
// auto out_grad_data = out_grad->data<T>();
auto
weight_data
=
weight
->
data
<
T
>
();
auto
init_h_data
=
init_h
->
data
<
T
>
();
auto
init_c_data
=
init_c
->
data
<
T
>
();
auto
in_grad_data
=
in_grad
->
data
<
T
>
();
auto
work_data
=
cudnn_rnn_cache
->
workspace_data_
.
data
<
uint8_t
>
();
auto
reserve_data
=
cudnn_rnn_cache
->
reserve_data_
.
data
<
uint8_t
>
();
auto
run_seq_len
=
input_dims
[
0
];
PADDLE_ENFORCE_LE
((
size_t
)
run_seq_len
,
cudnn_rnn_cache
->
max_length_
,
"cudnn running seq_len CAN not greater max_lengh"
);
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnRNNBackwardData
(
handle
,
cudnn_rnn_cache
->
rnn_desc_
,
run_seq_len
,
cudnn_rnn_cache
->
y_desc_
,
out_data
,
cudnn_rnn_cache
->
dy_desc_
,
out_grad_data
,
cudnn_rnn_cache
->
dhy_desc_
,
last_h_grad_data
,
cudnn_rnn_cache
->
dcy_desc_
,
last_c_grad_data
,
cudnn_rnn_cache
->
w_desc_
,
weight_data
,
cudnn_rnn_cache
->
hx_desc_
,
init_h_data
,
cudnn_rnn_cache
->
cx_desc_
,
init_c_data
,
cudnn_rnn_cache
->
dx_desc_
,
in_grad_data
,
cudnn_rnn_cache
->
dhx_desc_
,
init_h_grad_data
,
cudnn_rnn_cache
->
dcx_desc_
,
init_c_grad_data
,
work_data
,
cudnn_rnn_cache
->
workspace_size_
,
reserve_data
,
cudnn_rnn_cache
->
reserve_size_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnRNNBackwardWeights
(
handle
,
cudnn_rnn_cache
->
rnn_desc_
,
run_seq_len
,
cudnn_rnn_cache
->
x_desc_
,
input
->
data
<
T
>
(),
cudnn_rnn_cache
->
hx_desc_
,
init_h
->
data
<
T
>
(),
cudnn_rnn_cache
->
y_desc_
,
out
->
data
<
T
>
(),
cudnn_rnn_cache
->
workspace_data_
.
data
<
uint8_t
>
(),
cudnn_rnn_cache
->
workspace_size_
,
cudnn_rnn_cache
->
dw_desc_
,
weight_grad
->
data
<
T
>
(),
cudnn_rnn_cache
->
reserve_data_
.
data
<
uint8_t
>
(),
cudnn_rnn_cache
->
reserve_size_
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
cudnn_lstm
,
ops
::
CudnnLSTMGPUKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
cudnn_lstm_grad
,
ops
::
CudnnLSTMGPUGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
paddle/fluid/operators/lstm_cudnn_op.h
0 → 100644
浏览文件 @
1013d6d0
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
namespace
paddle
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
T
>
class
CudnnLSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_THROW
(
"CPU is not support for this kernel now. Will be add in the future"
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
CudnnLSTMGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/platform/dynload/cudnn.h
浏览文件 @
1013d6d0
...
...
@@ -111,7 +111,23 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnFindConvolutionForwardAlgorithmEx); \
__macro(cudnnFindConvolutionBackwardFilterAlgorithmEx); \
__macro(cudnnFindConvolutionBackwardDataAlgorithmEx); \
__macro(cudnnGetErrorString);
__macro(cudnnGetErrorString); \
__macro(cudnnCreateDropoutDescriptor); \
__macro(cudnnDropoutGetStatesSize); \
__macro(cudnnSetDropoutDescriptor); \
__macro(cudnnCreateRNNDescriptor); \
__macro(cudnnSetRNNDescriptor); \
__macro(cudnnGetRNNParamsSize); \
__macro(cudnnGetRNNWorkspaceSize); \
__macro(cudnnGetRNNTrainingReserveSize); \
__macro(cudnnRNNForwardTraining); \
__macro(cudnnRNNBackwardData); \
__macro(cudnnRNNBackwardWeights); \
__macro(cudnnRNNForwardInference); \
__macro(cudnnDestroyDropoutDescriptor); \
__macro(cudnnDestroyRNNDescriptor); \
__macro(cudnnSetRNNDescriptor_v6);
CUDNN_DNN_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_CUDNN_WRAP
)
#define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
1013d6d0
...
...
@@ -169,6 +169,7 @@ __all__ = [
'log_loss'
,
'add_position_encoding'
,
'bilinear_tensor_product'
,
'lstm'
,
]
...
...
@@ -472,6 +473,169 @@ def dynamic_lstm(input,
return
hidden
,
cell
def
lstm
(
input
,
init_h
,
init_c
,
max_len
,
dropout_prob
,
input_size
,
hidden_size
,
num_layers
,
is_bidirec
=
False
,
dtype
=
'float32'
,
is_test
=
False
,
name
=
None
,
default_initializer
=
None
,
seed
=-
1
):
"""
If Device is GPU, This op will use cudnn LSTM implementation
A four-gate Long Short-Term Memory network with no peephole connections.
In the forward pass the output ht and cell output ct for a given iteration can be computed from the recurrent input ht-1,
the cell input ct-1 and the previous layer input xt given matrices W, R and biases bW, bR from the following equations:
$$ i_t =
\\
sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i) $$
$$ f_t =
\\
sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + bx_f + bh_f) $$
$$ o_t =
\\
sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + bx_o + bh_o) $$
$$
\\
tilde{c_t} = tanh(W_{cx}x_t + W_{ch}h_{t-1} + bx_c + bh_c) $$
$$ c_t = f_t
\\
odot c_{t-1} + i_t
\\
odot
\\
tilde{c_t} $$
$$ h_t = o_t
\\
odot tanh(c_t) $$
- W terms denote weight matrices (e.g. $W_{ix}$ is the matrix
of weights from the input gate to the input)
- The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector).
- sigmoid is the logistic sigmoid function.
- $i, f, o$ and $c$ are the input gate, forget gate, output gate,
and cell activation vectors, respectively, all of which have the same size as
the cell output activation vector $h$.
- The $\odot$ is the element-wise product of the vectors.
- `tanh` is the activation functions.
- $
\t
ilde{c_t}$ is also called candidate hidden state,
which is computed based on the current input and the previous hidden state.
Where sigmoid is the sigmoid operator: sigmoid(x) = 1 / (1 + e^-x), * represents a point-wise multiplication,
X represensts a matrix multiplication
Args:
input (Variable): LSTM input tensor, shape MUST be ( seq_len x batch_size x input_size )
init_h(Variable): The initial hidden state of the LSTM
This is a tensor with shape ( num_layers x batch_size x hidden_size)
if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size)
init_c(Variable): The initial cell state of the LSTM.
This is a tensor with shape ( num_layers x batch_size x hidden_size )
if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size)
max_len (int): max length of LSTM. the first dim of input tensor CAN NOT greater than max_len
dropout_prob(float): dropout prob, dropout ONLY work between rnn layers, NOT between time steps
There is NO dropout work on rnn output of the last RNN layers
input_size (int): hidden size of the input tensor
hidden_size (int): hidden size of the LSTM
num_layers (int): total layers number of the LSTM
is_bidirec (bool): If it is bidirectional
dtype (str): Data type. Choices = ["float32", "float64"], default "float32".
is_test (bool): If it is in test phrase
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
default_initializer(Initialize|None): Where use initializer to initialize the Weight
If set None, defaule initializer will be used
seed(int): Seed for dropout in LSTM, If it's -1, dropout will use random seed
Returns:
rnn_out(Tensor): result of LSTM hidden, shape is (seq_len x batch_size x hidden_size)
if is_bidirec set to True, shape will be ( seq_len x batch_sze x hidden_size*2)
last_h(Tensor): the hidden state of the last step of LSTM
shape is ( num_layers x batch_size x hidden_size )
if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size)
last_c(Tensor): the cell state of the last step of LSTM
shape is ( num_layers x batch_size x hidden_size )
if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size)
Examples:
.. code-block:: python
input = embedding
batch_size = 20
max_len = 100
dropout_prob = 0.2
input_size = 100
hidden_size = 150
num_layers = 1
init_hidden1 = layers.fill_constant( [num_layers, batch_size, hidden_size], 'float32', 0.0, stop_grad=False)
init_cell1 = layers.fill_constant( [num_layers, batch_size, hidden_size], 'float32', 0.0, stop_grad=False)
rnn_out, last_h, last_c = layers.lstm( input, init_h, init_c,
\
max_len, dropout_prob, input_size, hidden_size,
\
num_layers)
"""
helper
=
LayerHelper
(
'cudnn_lstm'
,
**
locals
())
weight_size
=
0
for
i
in
range
(
num_layers
):
if
i
==
0
:
input_weight_size
=
(
input_size
*
hidden_size
)
*
4
else
:
if
is_bidirec
:
input_weight_size
=
(
hidden_size
*
2
*
hidden_size
)
*
4
else
:
input_weight_size
=
(
hidden_size
*
hidden_size
)
*
4
hidden_weight_size
=
(
hidden_size
*
hidden_size
)
*
4
if
is_bidirec
:
weight_size
+=
(
input_weight_size
+
hidden_weight_size
)
*
2
weight_size
+=
hidden_size
*
8
*
2
else
:
weight_size
+=
input_weight_size
+
hidden_weight_size
weight_size
+=
hidden_size
*
8
weight
=
helper
.
create_parameter
(
attr
=
helper
.
param_attr
,
shape
=
[
weight_size
],
dtype
=
dtype
,
default_initializer
=
default_initializer
)
out
=
helper
.
create_variable_for_type_inference
(
dtype
)
last_h
=
helper
.
create_variable_for_type_inference
(
dtype
)
last_c
=
helper
.
create_variable_for_type_inference
(
dtype
)
cache
=
helper
.
create_variable
(
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
,
stop_gradient
=
True
)
helper
.
append_op
(
type
=
'cudnn_lstm'
,
inputs
=
{
'Input'
:
input
,
'InitH'
:
init_h
,
'InitC'
:
init_c
,
'W'
:
weight
,
'Cache'
:
cache
,
},
outputs
=
{
'Out'
:
out
,
'last_h'
:
last_h
,
'last_c'
:
last_c
,
},
attrs
=
{
'max_len'
:
max_len
,
'is_bidirec'
:
is_bidirec
,
'input_size'
:
input_size
,
'hidden_size'
:
hidden_size
,
'num_layers'
:
num_layers
,
'is_test'
:
is_test
,
'dropout_prob'
:
dropout_prob
,
'seed'
:
seed
,
})
return
out
,
last_h
,
last_c
def
dynamic_lstmp
(
input
,
size
,
proj_size
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录