Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1d95173c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
1d95173c
编写于
11月 16, 2017
作者:
W
wanghaox
浏览文件
操作
浏览文件
下载
差异文件
change offset and length's rank to 2, dim[0] for batch size
上级
40a6c488
a76b6144
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
125 addition
and
113 deletion
+125
-113
paddle/operators/sequence_slice_op.cc
paddle/operators/sequence_slice_op.cc
+9
-1
paddle/operators/sequence_slice_op.h
paddle/operators/sequence_slice_op.h
+23
-23
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+85
-81
python/paddle/v2/fluid/tests/test_beam_search_op.py
python/paddle/v2/fluid/tests/test_beam_search_op.py
+2
-2
python/paddle/v2/fluid/tests/test_sequence_slice_op.py
python/paddle/v2/fluid/tests/test_sequence_slice_op.py
+6
-6
未找到文件。
paddle/operators/sequence_slice_op.cc
浏览文件 @
1d95173c
...
...
@@ -32,6 +32,14 @@ class SequenceSliceOp : public framework::OperatorWithKernel {
"Output(Out) of SequenceSliceOp should not be null."
);
auto
input_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
offset_dim
=
ctx
->
GetInputDim
(
"Offset"
);
auto
length_dim
=
ctx
->
GetInputDim
(
"Length"
);
PADDLE_ENFORCE_EQ
(
offset_dim
.
size
(),
2UL
,
"Only support one level sequence now."
);
PADDLE_ENFORCE_EQ
(
length_dim
.
size
(),
2UL
,
"Only support one level sequence now."
);
ctx
->
SetOutputDim
(
"Out"
,
input_dims
);
}
...
...
@@ -95,7 +103,7 @@ It only supports sequence (LoD Tensor with level number is 1).
[d1, d2;
e1, e2]]
LoD(X) = {{0, 3, 5}}; Dims(X) = (5, 2)
Offset = [
0, 1]; Length = [2, 1
]
Offset = [
[0], [1]]; Length = [[2], [1]
]
Out = [[a1, a2;
b1, b2]
...
...
paddle/operators/sequence_slice_op.h
浏览文件 @
1d95173c
...
...
@@ -48,42 +48,42 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
auto
*
length
=
ctx
.
Input
<
Tensor
>
(
"Length"
);
auto
*
out
=
ctx
.
Output
<
LoDTensor
>
(
"Out"
);
auto
lod
=
in
->
lod
();
auto
n
=
lod
[
0
].
size
()
-
1
;
PADDLE_ENFORCE_EQ
(
lod
.
size
(),
1UL
,
"Only support one level sequence now."
);
PADDLE_ENFORCE_EQ
(
n
,
length
->
dims
()[
0
],
"The size of input-sequence and length-array should be the same"
)
PADDLE_ENFORCE_EQ
(
n
,
offset
->
dims
()[
0
],
"The size of input-sequence and offset-array should be the same"
)
const
int64_t
*
offset_data
=
offset
->
data
<
int64_t
>
();
const
int64_t
*
length_data
=
length
->
data
<
int64_t
>
();
framework
::
Tensor
offset_cpu
;
framework
::
Tensor
length_cpu
;
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
framework
::
Tensor
offset_cpu
;
offset_cpu
.
mutable_data
<
T
>
(
offset
->
dims
(),
platform
::
CPUPlace
());
offset_cpu
.
CopyFrom
(
*
offset
,
platform
::
CPUPlace
(),
ctx
.
device_context
());
offset_data
=
offset_cpu
.
data
<
int64_t
>
();
framework
::
Tensor
length_cpu
;
length_cpu
.
mutable_data
<
T
>
(
length
->
dims
(),
platform
::
CPUPlace
());
length_cpu
.
CopyFrom
(
*
length
,
platform
::
CPUPlace
(),
ctx
.
device_context
());
length_data
=
length_cpu
.
data
<
int64_t
>
();
}
auto
lod
=
in
->
lod
();
auto
n
=
lod
[
0
].
size
()
-
1
;
PADDLE_ENFORCE_EQ
(
lod
.
size
(),
1UL
,
"Only support one level sequence now."
);
PADDLE_ENFORCE_EQ
(
offset
->
dims
().
size
(),
1UL
,
"Only support one level sequence now."
);
PADDLE_ENFORCE_EQ
(
length
->
dims
().
size
(),
1UL
,
"Only support one level sequence now."
);
PADDLE_ENFORCE_EQ
(
n
,
length
->
dims
()[
0
],
"The size of input-sequence and length-array should be the same"
)
PADDLE_ENFORCE_EQ
(
n
,
offset
->
dims
()[
0
],
"The size of input-sequence and offset-array should be the same"
)
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
PADDLE_ENFORCE_LT
(
0
,
offset_data
[
i
],
"The offset must greater than zero"
)
PADDLE_ENFORCE_LT
(
0
,
length_data
[
i
],
"The length must greater than zero"
)
PADDLE_ENFORCE_LT
(
lod
[
0
][
i
]
+
offset_data
[
i
]
+
length_data
[
i
],
lod
[
0
][
i
+
1
],
"The target tensor's length overflow"
)
}
PADDLE_ENFORCE_LT
(
0
,
offset_data
[
i
],
"The offset must greater than zero"
)
PADDLE_ENFORCE_LT
(
0
,
length_data
[
i
],
"The length must greater than zero"
)
PADDLE_ENFORCE_LT
(
lod
[
0
][
i
]
+
offset_data
[
i
]
+
length_data
[
i
],
lod
[
0
][
i
+
1
],
"The target tensor's length overflow"
)}
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
out_lod
=
SequenceSliceLoD
(
*
in
,
offset_data
,
length_data
);
...
...
@@ -100,7 +100,7 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
Tensor
in_t
=
in
->
Slice
(
static_cast
<
int
>
(
lod
[
0
][
i
]
+
offset_data
[
i
]),
static_cast
<
int
>
(
lod
[
0
][
i
]
+
offset_data
[
i
]
+
length_data
[
i
]));
length_data
[
i
]));
StridedMemcpy
<
T
>
(
ctx
.
device_context
(),
in_t
.
data
<
T
>
(),
in_stride
,
in_t
.
dims
(),
out_stride
,
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
1d95173c
...
...
@@ -2987,8 +2987,10 @@ def img_cmrnorm_layer(input,
layer_attr
=
None
):
"""
Response normalization across feature maps.
The details please refer to
`Alex's paper <http://www.cs.toronto.edu/~fritz/absps/imagenet.pdf>`_.
Reference:
ImageNet Classification with Deep Convolutional Neural Networks
http://www.cs.toronto.edu/~fritz/absps/imagenet.pdf
The example usage is:
...
...
@@ -2997,7 +2999,7 @@ def img_cmrnorm_layer(input,
norm = img_cmrnorm_layer(input=net, size=5)
:param name: The name of this layer. It is optional.
:type name:
None |
basestring
:type name: basestring
:param input: The input of this layer.
:type input: LayerOutput
:param size: Normalize in number of :math:`size` feature maps.
...
...
@@ -3006,9 +3008,11 @@ def img_cmrnorm_layer(input,
:type scale: float
:param power: The hyper-parameter.
:type power: float
:param num_channels: input layer's filers number or channels. If
num_channels is None, it will be set automatically.
:param layer_attr: Extra Layer Attribute.
:param num_channels: The number of input channels. If the parameter is not set or
set to None, its actual value will be automatically set to
the channels number of the input.
:param layer_attr: The extra layer attributes. See ExtraLayerAttribute for
details.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
...
...
@@ -3036,7 +3040,7 @@ def batch_norm_layer(input,
use_global_stats
=
None
,
mean_var_names
=
None
):
"""
Batch Normalization Layer. The notation of this layer
as follow
.
Batch Normalization Layer. The notation of this layer
is as follows
.
:math:`x` is the input features over a mini-batch.
...
...
@@ -3050,8 +3054,10 @@ def batch_norm_layer(input,
\\
sigma_{
\\
beta}^{2} +
\\
epsilon}}
\\
qquad &//\ normalize
\\\\
y_i &
\\
gets
\\
gamma
\\
hat{x_i} +
\\
beta
\\
qquad &//\ scale\ and\ shift
The details of batch normalization please refer to this
`paper <http://arxiv.org/abs/1502.03167>`_.
Reference:
Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift
http://arxiv.org/abs/1502.03167
The example usage is:
...
...
@@ -3061,48 +3067,47 @@ def batch_norm_layer(input,
:param name: The name of this layer. It is optional.
:type name: basestring
:param input: batch normalization input. Better be linear activation.
Because there is an activation inside batch_normalization.
:param input: This layer's input which is to be performed batch normalization on.
:type input: LayerOutput
:param batch_norm_type: We have batch_norm, mkldnn_batch_norm and cudnn_batch_norm.
batch_norm supports CPU, MKLDNN and GPU. cudnn_batch_norm
requires cuDNN version greater or equal to v4 (>=v4).
But cudnn_batch_norm is faster and needs less
memory than batch_norm. mkldnn_batch_norm requires
enable use_mkldnn
. By default (None), we will
automaticly select cudnn_batch_norm for GPU,
use_mkldnn is enabled
. By default (None), we will
automatic
al
ly select cudnn_batch_norm for GPU,
mkldnn_batch_norm for MKLDNN and batch_norm for CPU.
Otherwise, select batch norm type based on th
e
specified type. If you use cudnn_batch_norm
,
we suggested you use latest version,
such as v5.1.
Users can specify the batch norm type. If you us
e
cudnn_batch_norm, we suggested you use latest version
,
such as v5.1.
:type batch_norm_type: None | string, None or "batch_norm" or "cudnn_batch_norm"
or "mkldnn_batch_norm"
:param act: Activation Type. Better be relu. Because batch
normalization will normalize input near zero.
:param act: Activation type. ReluActivation is the default activation.
:type act: BaseActivation
:param num_channels:
num of image channels or previous layer's number of
filters. None will automatically get from layer's
input.
:param num_channels:
The number of input channels. If the parameter is not set or
set to None, its actual value will be automatically set to
the channels number of the
input.
:type num_channels: int
:param bias_attr: :math:`
\\
beta`, better be zero when initialize. So the
initial_std=0, initial_mean=1 is best practice.
:param bias_attr: :math:`
\\
beta`. The bias attribute. If the parameter is set to
False or an object whose type is not ParameterAttribute, no
bias is defined. If the parameter is set to True, the bias is
initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param param_attr: :math:`
\\
gamma`
, better be one when initialize. So th
e
initial_std=0, initial_mean=1 is best practice
.
:param param_attr: :math:`
\\
gamma`
. The parameter attribute. See ParameterAttribut
e
for details
.
:type param_attr: ParameterAttribute
:param layer_attr: Extra Layer Attribute.
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
:type layer_attr: ExtraLayerAttribute
:param use_global_stats:
whether use moving mean/variance statistics
during testing peroid. If None or True,
it will use moving mean/variance statistics during
testing. If False, it will use the mean
and variance of current batch of test data for
testing
.
:param use_global_stats:
Whether use moving mean/variance statistics during
testing peroid. If the parameter is set to None or
True, it will use moving mean/variance statistics
during testing. If the parameter is set to False, it
will use the mean and variance of the current batch
of test data
.
:type use_global_stats: bool | None.
:param moving_average_fraction: Factor used in the moving average
computation, referred to as facotr,
:math:`runningMean = newMean*(1-factor)
+ runningMean*factor`
:param moving_average_fraction: Factor used in the moving average computation.
:math:`runningMean = newMean*(1-factor) + runningMean*factor`
:type moving_average_fraction: float.
:param mean_var_names: [mean name, variance name]
:type mean_var_names: string list
...
...
@@ -3164,8 +3169,9 @@ def sum_to_one_norm_layer(input, name=None, layer_attr=None):
:type input: LayerOutput
:param name: The name of this layer. It is optional.
:type name: basestring
:param layer_attr: extra layer attributes.
:type layer_attr: ExtraLayerAttribute.
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute
for details.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
"""
...
...
@@ -3200,7 +3206,8 @@ def row_l2_norm_layer(input, name=None, layer_attr=None):
:type input: LayerOutput
:param name: The name of this layer. It is optional.
:type name: basestring
:param layer_attr: extra layer attributes.
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute
for details.
:type layer_attr: ExtraLayerAttribute.
:return: LayerOutput object.
:rtype: LayerOutput
...
...
@@ -3237,22 +3244,17 @@ def addto_layer(input, act=None, name=None, bias_attr=None, layer_attr=None):
act=ReluActivation(),
bias_attr=False)
This layer just simply add
all input layers together, then activate the sum
inputs. Each input of this layer should be the same size, which is also the
o
utput size of this layer
.
This layer just simply add
s all input layers together, then activates the
sum. All inputs should share the same dimension, which is also the dimension
o
f this layer's output
.
There is no weight matrix for each input, because it just a simple add
operation. If you want a complicated operation before add, please use
mixed_layer.
It is a very good way to set dropout outside the layers. Since not all
PaddlePaddle layer support dropout, you can add an add_to layer, set
dropout here.
Please refer to dropout_layer for details.
:param name: The name of this layer. It is optional.
:type name: basestring
:param input:
I
nput layers. It could be a LayerOutput or list/tuple of
:param input:
The i
nput layers. It could be a LayerOutput or list/tuple of
LayerOutput.
:type input: LayerOutput | list | tuple
:param act: Activation Type. LinearActivation is the default activation.
...
...
@@ -3261,7 +3263,8 @@ def addto_layer(input, act=None, name=None, bias_attr=None, layer_attr=None):
whose type is not ParameterAttribute, no bias is defined. If the
parameter is set to True, the bias is initialized to zero.
:type bias_attr: ParameterAttribute | None | bool | Any
:param layer_attr: Extra Layer attribute.
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
...
...
@@ -3300,8 +3303,8 @@ def addto_layer(input, act=None, name=None, bias_attr=None, layer_attr=None):
@
layer_support
(
DROPOUT
,
ERROR_CLIPPING
)
def
concat_layer
(
input
,
act
=
None
,
name
=
None
,
layer_attr
=
None
,
bias_attr
=
None
):
"""
Concat
all input vector into one hug
e vector.
Inputs can be
list of LayerOutput or
list of projection.
Concat
enate all input vectors to on
e vector.
Inputs can be
a list of LayerOutput or a
list of projection.
The example usage is:
...
...
@@ -3311,11 +3314,12 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None):
:param name: The name of this layer. It is optional.
:type name: basestring
:param input: input layers or projections
:param input:
The
input layers or projections
:type input: list | tuple | collections.Sequence
:param act: Activation type. IdentityActivation is the default activation.
:type act: BaseActivation
:param layer_attr: Extra Layer Attribute.
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
...
...
@@ -3385,7 +3389,7 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None):
def
seq_concat_layer
(
a
,
b
,
act
=
None
,
name
=
None
,
layer_attr
=
None
,
bias_attr
=
None
):
"""
Concat
sequence a with
sequence b.
Concat
enate sequence a and
sequence b.
Inputs:
- a = [a1, a2, ..., am]
...
...
@@ -3404,13 +3408,14 @@ def seq_concat_layer(a, b, act=None, name=None, layer_attr=None,
:param name: The name of this layer. It is optional.
:type name: basestring
:param a: input sequence layer
:param a:
The first
input sequence layer
:type a: LayerOutput
:param b: input sequence layer
:param b:
The second
input sequence layer
:type b: LayerOutput
:param act: Activation type. IdentityActivation is the default activation.
:type act: BaseActivation
:param layer_attr: Extra Layer Attribute.
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
:type layer_attr: ExtraLayerAttribute
:param bias_attr: The bias attribute. If the parameter is set to False or an object
whose type is not ParameterAttribute, no bias is defined. If the
...
...
@@ -3447,31 +3452,25 @@ def memory(name,
boot_bias_active_type
=
None
,
boot_with_const_id
=
None
):
"""
The memory layers is a layer cross each time step. Reference this output
as previous time step layer :code:`name` 's output.
The memory takes a layer's output at previous time step as its own output.
The default memory is zero in first time step, previous time step's
output in the rest time steps.
If boot_bias, the activation of the bias is the initial value of the memory.
If boot_
bias, the first time step value is this bias and
with activation
.
If boot_
with_const_id is set, then the memory's output at the first time step
is a IndexSlot, the Arguments.ids()[0] is this :code:`cost_id`
.
If boot_
with_const_id, then the first time stop is a IndexSlot, the
Arguments.ids()[0] is this :code:`cost_id`
.
If boot_
layer is specified, the memory's output at the first time step will
be the boot_layer's output
.
If boot_layer is not null, the memory is just the boot_layer's output.
Set :code:`is_seq` is true boot layer is sequence.
The same name layer in recurrent group will set memory on each time
step.
In other case, the default memory's output at the first time step is zero.
.. code-block:: python
mem = memory(size=256, name='state')
state = fc_layer(input=mem, size=256, name='state')
If you do not want to specify the name, you can
equivalently
use set_input()
to specify the layer
needs
to be remembered as the following:
If you do not want to specify the name, you can
also
use set_input()
to specify the layer to be remembered as the following:
.. code-block:: python
...
...
@@ -3479,26 +3478,31 @@ def memory(name,
state = fc_layer(input=mem, size=256)
mem.set_input(mem)
:param name:
t
he name of the layer which this memory remembers.
:param name:
T
he name of the layer which this memory remembers.
If name is None, user should call set_input() to specify the
name of the layer which this memory remembers.
:type name: basestring
:param size:
size
of memory.
:param size:
The dimensionality
of memory.
:type size: int
:param memory_name: the name of the memory.
It is ignored when name is provided.
:param memory_name: The name of the memory. It is ignored when name is provided.
:type memory_name: basestring
:param is_seq: DEPRECATED. is sequence for boot_layer
:type is_seq: bool
:param boot_layer: boot layer of memory.
:param boot_layer: This parameter specifies memory's output at the first time
step and the output is boot_layer's output.
:type boot_layer: LayerOutput | None
:param boot_bias: boot layer's bias
:param boot_bias: The bias attribute of memory's output at the first time step.
If the parameter is set to False or an object whose type is not
ParameterAttribute, no bias is defined. If the parameter is set
to True, the bias is initialized to zero.
:type boot_bias: ParameterAttribute | None
:param boot_bias_active_type: boot layer's active type.
:param boot_bias_active_type: Activation type for memory's bias at the first time
step. LinearActivation is the default activation.
:type boot_bias_active_type: BaseActivation
:param boot_with_const_id: boot layer's id.
:param boot_with_const_id: This parameter specifies memory's output at the first
time step and the output is an index.
:type boot_with_const_id: int
:return: LayerOutput object
which is a memory
.
:return: LayerOutput object.
:rtype: LayerOutput
"""
if
boot_bias_active_type
is
None
:
...
...
python/paddle/v2/f
ramework
/tests/test_beam_search_op.py
→
python/paddle/v2/f
luid
/tests/test_beam_search_op.py
浏览文件 @
1d95173c
import
logging
from
paddle.v2.f
ramework
.op
import
Operator
,
DynamicRecurrentOp
import
paddle.v2.f
ramework
.core
as
core
from
paddle.v2.f
luid
.op
import
Operator
,
DynamicRecurrentOp
import
paddle.v2.f
luid
.core
as
core
import
unittest
import
numpy
as
np
...
...
python/paddle/v2/fluid/tests/test_sequence_slice_op.py
浏览文件 @
1d95173c
...
...
@@ -9,16 +9,16 @@ class TestSequenceSliceOp(OpTest):
# only supprot one level LoD
x
=
np
.
random
.
random
(
self
.
x_dim
).
astype
(
'float32'
)
lod
=
self
.
x_lod
offset
=
np
.
array
(
self
.
offset
).
flatten
().
astype
(
"int64"
)
length
=
np
.
array
(
self
.
length
).
flatten
().
astype
(
"int64"
)
offset
=
np
.
array
(
self
.
offset
).
astype
(
"int64"
)
length
=
np
.
array
(
self
.
length
).
astype
(
"int64"
)
self
.
inputs
=
{
'X'
:
(
x
,
lod
),
'Offset'
:
offset
,
'Length'
:
length
}
outs
=
[]
#np.zeros((100, 3, 2)).astype('float32')
out_lod
=
[[
0
]]
out_lod_offset
=
0
for
i
in
range
(
len
(
offset
)):
sub_x
=
x
[
lod
[
0
][
i
]
+
offset
[
i
]:
lod
[
0
]
[
i
]
+
offset
[
i
]
+
length
[
i
],
:]
sub_x
=
x
[
lod
[
0
][
i
]
+
offset
[
i
,
0
]:
lod
[
0
]
[
i
]
+
offset
[
i
,
0
]
+
length
[
i
,
0
],
:]
out_lod_offset
=
out_lod_offset
+
len
(
sub_x
)
outs
.
append
(
sub_x
)
out_lod
[
0
].
append
(
out_lod_offset
)
...
...
@@ -28,8 +28,8 @@ class TestSequenceSliceOp(OpTest):
def
init_test_case
(
self
):
self
.
x_dim
=
(
100
,
3
,
2
)
self
.
x_lod
=
[[
0
,
20
,
40
,
60
,
80
,
100
]]
self
.
offset
=
[
1
,
2
,
3
,
4
,
5
]
self
.
length
=
[
10
,
8
,
6
,
4
,
2
]
self
.
offset
=
[
[
1
],
[
2
],
[
3
],
[
4
],
[
5
]
]
self
.
length
=
[
[
10
],
[
8
],
[
6
],
[
4
],
[
2
]
]
def
setUp
(
self
):
self
.
op_type
=
"sequence_slice"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录