Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
cc441ee1
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cc441ee1
编写于
12月 03, 2018
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add cudnn lstm
test=release/1.2
上级
fa6c2b53
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
207 addition
and
4 deletion
+207
-4
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/platform/dynload/cudnn.h
paddle/fluid/platform/dynload/cudnn.h
+17
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+163
-0
python/paddle/fluid/tests/unittests/op_test.py
python/paddle/fluid/tests/unittests/op_test.py
+20
-2
python/paddle/fluid/tests/unittests/testsuite.py
python/paddle/fluid/tests/unittests/testsuite.py
+6
-1
未找到文件。
paddle/fluid/API.spec
浏览文件 @
cc441ee1
...
...
@@ -194,6 +194,7 @@ paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=Non
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
...
...
paddle/fluid/platform/dynload/cudnn.h
浏览文件 @
cc441ee1
...
...
@@ -111,7 +111,23 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnFindConvolutionForwardAlgorithmEx); \
__macro(cudnnFindConvolutionBackwardFilterAlgorithmEx); \
__macro(cudnnFindConvolutionBackwardDataAlgorithmEx); \
__macro(cudnnGetErrorString);
__macro(cudnnGetErrorString); \
__macro(cudnnCreateDropoutDescriptor); \
__macro(cudnnDropoutGetStatesSize); \
__macro(cudnnSetDropoutDescriptor); \
__macro(cudnnCreateRNNDescriptor); \
__macro(cudnnSetRNNDescriptor); \
__macro(cudnnGetRNNParamsSize); \
__macro(cudnnGetRNNWorkspaceSize); \
__macro(cudnnGetRNNTrainingReserveSize); \
__macro(cudnnRNNForwardTraining); \
__macro(cudnnRNNBackwardData); \
__macro(cudnnRNNBackwardWeights); \
__macro(cudnnRNNForwardInference); \
__macro(cudnnDestroyDropoutDescriptor); \
__macro(cudnnDestroyRNNDescriptor); \
__macro(cudnnSetRNNDescriptor_v6);
CUDNN_DNN_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_CUDNN_WRAP
)
#define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
cc441ee1
...
...
@@ -169,6 +169,7 @@ __all__ = [
'log_loss'
,
'add_position_encoding'
,
'bilinear_tensor_product'
,
'lstm'
,
]
...
...
@@ -472,6 +473,168 @@ def dynamic_lstm(input,
return
hidden
,
cell
def
lstm
(
input
,
init_h
,
init_c
,
max_len
,
hidden_size
,
num_layers
,
dropout_prob
=
0.0
,
is_bidirec
=
False
,
is_test
=
False
,
name
=
None
,
default_initializer
=
None
,
seed
=-
1
):
"""
If Device is GPU, This op will use cudnn LSTM implementation
A four-gate Long Short-Term Memory network with no peephole connections.
In the forward pass the output ht and cell output ct for a given iteration can be computed from the recurrent input ht-1,
the cell input ct-1 and the previous layer input xt given matrices W, R and biases bW, bR from the following equations:
$$ i_t =
\\
sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i) $$
$$ f_t =
\\
sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + bx_f + bh_f) $$
$$ o_t =
\\
sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + bx_o + bh_o) $$
$$
\\
tilde{c_t} = tanh(W_{cx}x_t + W_{ch}h_{t-1} + bx_c + bh_c) $$
$$ c_t = f_t
\\
odot c_{t-1} + i_t
\\
odot
\\
tilde{c_t} $$
$$ h_t = o_t
\\
odot tanh(c_t) $$
- W terms denote weight matrices (e.g. $W_{ix}$ is the matrix
of weights from the input gate to the input)
- The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector).
- sigmoid is the logistic sigmoid function.
- $i, f, o$ and $c$ are the input gate, forget gate, output gate,
and cell activation vectors, respectively, all of which have the same size as
the cell output activation vector $h$.
- The $\odot$ is the element-wise product of the vectors.
- `tanh` is the activation functions.
- $
\t
ilde{c_t}$ is also called candidate hidden state,
which is computed based on the current input and the previous hidden state.
Where sigmoid is the sigmoid operator: sigmoid(x) = 1 / (1 + e^-x), * represents a point-wise multiplication,
X represensts a matrix multiplication
Args:
input (Variable): LSTM input tensor, shape MUST be ( seq_len x batch_size x input_size )
init_h(Variable): The initial hidden state of the LSTM
This is a tensor with shape ( num_layers x batch_size x hidden_size)
if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size)
init_c(Variable): The initial cell state of the LSTM.
This is a tensor with shape ( num_layers x batch_size x hidden_size )
if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size)
max_len (int): max length of LSTM. the first dim of input tensor CAN NOT greater than max_len
hidden_size (int): hidden size of the LSTM
num_layers (int): total layers number of the LSTM
dropout_prob(float|0.0): dropout prob, dropout ONLY work between rnn layers, NOT between time steps
There is NO dropout work on rnn output of the last RNN layers
is_bidirec (bool): If it is bidirectional
is_test (bool): If it is in test phrase
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
default_initializer(Initialize|None): Where use initializer to initialize the Weight
If set None, defaule initializer will be used
seed(int): Seed for dropout in LSTM, If it's -1, dropout will use random seed
Returns:
rnn_out(Tensor): result of LSTM hidden, shape is (seq_len x batch_size x hidden_size)
if is_bidirec set to True, shape will be ( seq_len x batch_sze x hidden_size*2)
last_h(Tensor): the hidden state of the last step of LSTM
shape is ( num_layers x batch_size x hidden_size )
if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size)
last_c(Tensor): the cell state of the last step of LSTM
shape is ( num_layers x batch_size x hidden_size )
if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size)
Examples:
.. code-block:: python
input = embedding
batch_size = 20
max_len = 100
dropout_prob = 0.2
input_size = 100
hidden_size = 150
num_layers = 1
init_hidden1 = layers.fill_constant( [num_layers, batch_size, hidden_size], 'float32', 0.0, stop_grad=False)
init_cell1 = layers.fill_constant( [num_layers, batch_size, hidden_size], 'float32', 0.0, stop_grad=False)
rnn_out, last_h, last_c = layers.lstm( input, init_h, init_c,
\
max_len, dropout_prob, input_size, hidden_size,
\
num_layers)
"""
helper
=
LayerHelper
(
'cudnn_lstm'
,
**
locals
())
dtype
=
input
.
dtype
input_shape
=
list
(
input
.
shape
)
input_size
=
input_shape
[
-
1
]
weight_size
=
0
for
i
in
range
(
num_layers
):
if
i
==
0
:
input_weight_size
=
(
input_size
*
hidden_size
)
*
4
else
:
if
is_bidirec
:
input_weight_size
=
(
hidden_size
*
2
*
hidden_size
)
*
4
else
:
input_weight_size
=
(
hidden_size
*
hidden_size
)
*
4
hidden_weight_size
=
(
hidden_size
*
hidden_size
)
*
4
if
is_bidirec
:
weight_size
+=
(
input_weight_size
+
hidden_weight_size
)
*
2
weight_size
+=
hidden_size
*
8
*
2
else
:
weight_size
+=
input_weight_size
+
hidden_weight_size
weight_size
+=
hidden_size
*
8
weight
=
helper
.
create_parameter
(
attr
=
helper
.
param_attr
,
shape
=
[
weight_size
],
dtype
=
dtype
,
default_initializer
=
default_initializer
)
out
=
helper
.
create_variable_for_type_inference
(
dtype
)
last_h
=
helper
.
create_variable_for_type_inference
(
dtype
)
last_c
=
helper
.
create_variable_for_type_inference
(
dtype
)
cache
=
helper
.
create_variable
(
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
,
stop_gradient
=
True
)
helper
.
append_op
(
type
=
'cudnn_lstm'
,
inputs
=
{
'Input'
:
input
,
'InitH'
:
init_h
,
'InitC'
:
init_c
,
'W'
:
weight
,
'Cache'
:
cache
,
},
outputs
=
{
'Out'
:
out
,
'last_h'
:
last_h
,
'last_c'
:
last_c
,
},
attrs
=
{
'max_len'
:
max_len
,
'is_bidirec'
:
is_bidirec
,
'input_size'
:
input_size
,
'hidden_size'
:
hidden_size
,
'num_layers'
:
num_layers
,
'is_test'
:
is_test
,
'dropout_prob'
:
dropout_prob
,
'seed'
:
seed
,
})
return
out
,
last_h
,
last_c
def
dynamic_lstmp
(
input
,
size
,
proj_size
,
...
...
python/paddle/fluid/tests/unittests/op_test.py
浏览文件 @
cc441ee1
...
...
@@ -216,6 +216,15 @@ class OpTest(unittest.TestCase):
self
.
dtype
)
outputs
=
append_input_output
(
block
,
op_proto
,
self
.
outputs
,
False
,
self
.
dtype
)
if
hasattr
(
self
,
"cache_name_list"
):
for
name
in
self
.
cache_name_list
:
inputs
[
name
]
=
block
.
create_var
(
name
=
name
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
,
stop_gradient
=
True
)
op
=
block
.
append_op
(
type
=
self
.
op_type
,
inputs
=
inputs
,
...
...
@@ -428,8 +437,17 @@ class OpTest(unittest.TestCase):
op_inputs
=
self
.
inputs
if
hasattr
(
self
,
"inputs"
)
else
dict
()
op_outputs
=
self
.
outputs
if
hasattr
(
self
,
"outputs"
)
else
dict
()
op_attrs
=
self
.
attrs
if
hasattr
(
self
,
"attrs"
)
else
dict
()
self
.
op
=
create_op
(
self
.
scope
,
self
.
op_type
,
op_inputs
,
op_outputs
,
op_attrs
)
cache_list
=
None
if
hasattr
(
self
,
"cache_name_list"
):
cache_list
=
self
.
cache_name_list
self
.
op
=
create_op
(
self
.
scope
,
self
.
op_type
,
op_inputs
,
op_outputs
,
op_attrs
,
cache_list
=
cache_list
)
if
no_grad_set
is
None
:
no_grad_set
=
set
()
...
...
python/paddle/fluid/tests/unittests/testsuite.py
浏览文件 @
cc441ee1
...
...
@@ -20,7 +20,7 @@ import paddle.fluid.core as core
from
paddle.fluid.op
import
Operator
def
create_op
(
scope
,
op_type
,
inputs
,
outputs
,
attrs
):
def
create_op
(
scope
,
op_type
,
inputs
,
outputs
,
attrs
,
cache_list
=
None
):
kwargs
=
dict
()
op_maker
=
core
.
op_proto_and_checker_maker
...
...
@@ -43,6 +43,11 @@ def create_op(scope, op_type, inputs, outputs, attrs):
__create_var__
(
in_name
,
sub_in_name
)
else
:
__create_var__
(
in_name
,
in_name
)
if
cache_list
!=
None
and
isinstance
(
cache_list
,
list
):
for
name
in
cache_list
:
kwargs
[
name
]
=
[]
scope
.
var
(
name
)
kwargs
[
name
].
append
(
name
)
for
out_name
,
out_dup
in
Operator
.
get_op_outputs
(
op_type
):
if
out_name
in
outputs
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录