Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
7b84c580
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7b84c580
编写于
8月 28, 2018
作者:
F
fengjiayi
提交者:
GitHub
8月 28, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12824 from JiayiFeng/dev_sequence_padding_op
Sequence pad op
上级
0ee6fed0
7e0c9f50
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
743 addition
and
271 deletion
+743
-271
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+1
-0
paddle/fluid/operators/math/sequence_padding.cc
paddle/fluid/operators/math/sequence_padding.cc
+97
-103
paddle/fluid/operators/math/sequence_padding.cu
paddle/fluid/operators/math/sequence_padding.cu
+103
-144
paddle/fluid/operators/math/sequence_padding.h
paddle/fluid/operators/math/sequence_padding.h
+38
-14
paddle/fluid/operators/math/sequence_padding_test.cc
paddle/fluid/operators/math/sequence_padding_test.cc
+19
-4
paddle/fluid/operators/sequence_pad_op.cc
paddle/fluid/operators/sequence_pad_op.cc
+194
-0
paddle/fluid/operators/sequence_pad_op.cu
paddle/fluid/operators/sequence_pad_op.cu
+29
-0
paddle/fluid/operators/sequence_pad_op.h
paddle/fluid/operators/sequence_pad_op.h
+66
-0
paddle/fluid/operators/warpctc_op.h
paddle/fluid/operators/warpctc_op.h
+18
-6
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+46
-0
python/paddle/fluid/tests/unittests/test_sequence_pad_op.py
python/paddle/fluid/tests/unittests/test_sequence_pad_op.py
+131
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
7b84c580
...
...
@@ -113,6 +113,7 @@ paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
paddle.fluid.layers.conv3d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
paddle.fluid.layers.sequence_expand ArgSpec(args=['x', 'y', 'ref_level', 'name'], varargs=None, keywords=None, defaults=(-1, None))
paddle.fluid.layers.sequence_pad ArgSpec(args=['x', 'pad_value', 'maxlen'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm_unit ArgSpec(args=['x_t', 'hidden_t_prev', 'cell_t_prev', 'forget_bias', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(0.0, None, None, None))
paddle.fluid.layers.reduce_sum ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None))
paddle.fluid.layers.reduce_mean ArgSpec(args=['input', 'dim', 'keep_dim', 'name'], varargs=None, keywords=None, defaults=(None, False, None))
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
7b84c580
...
...
@@ -291,6 +291,7 @@ op_library(unsqueeze_op DEPS reshape_op)
op_library
(
squeeze_op DEPS reshape_op
)
op_library
(
extract_rows_op DEPS memory
)
op_library
(
flatten_op DEPS reshape_op
)
op_library
(
sequence_pad_op DEPS sequence_padding
)
op_library
(
unstack_op DEPS stack_op
)
if
(
WITH_GPU
)
...
...
paddle/fluid/operators/math/sequence_padding.cc
浏览文件 @
7b84c580
...
...
@@ -18,65 +18,86 @@ namespace paddle {
namespace
operators
{
namespace
math
{
template
<
typename
T
>
void
CopyValidData
(
framework
::
Tensor
*
dst_tensor
,
const
framework
::
Tensor
*
src_tensor
,
const
framework
::
Vector
<
size_t
>&
seq_offsets
,
int
pad_seq_len
,
int
step_width
,
bool
norm_by_len
,
CopyType
type
,
PadLayout
layout
)
{
int
seq_num
=
seq_offsets
.
size
()
-
1
;
const
T
*
src_data
=
src_tensor
->
data
<
T
>
();
T
*
dst_data
=
dst_tensor
->
data
<
T
>
();
int
seq_cpy_gap
=
step_width
;
int
pad_cpy_gap
=
layout
==
kBatchLengthWidth
?
step_width
:
seq_num
*
step_width
;
for
(
int
seq_idx
=
0
;
seq_idx
<
seq_num
;
++
seq_idx
)
{
int
valid_seq_len
=
seq_offsets
[
seq_idx
+
1
]
-
seq_offsets
[
seq_idx
];
PADDLE_ENFORCE_GE
(
pad_seq_len
,
valid_seq_len
,
"The padded sequence length can not be less than its original length."
);
int
seq_data_offset
=
seq_offsets
[
seq_idx
]
*
step_width
;
int
pad_data_offset
=
layout
==
kBatchLengthWidth
?
seq_idx
*
pad_seq_len
*
step_width
:
seq_idx
*
step_width
;
float
scale
=
1.0
f
/
static_cast
<
float
>
(
valid_seq_len
);
for
(
int
step_idx
=
0
;
step_idx
<
valid_seq_len
;
++
step_idx
)
{
const
T
*
src
=
src_data
+
(
type
==
kSeqToPad
?
seq_data_offset
:
pad_data_offset
);
T
*
dst
=
dst_data
+
(
type
==
kSeqToPad
?
pad_data_offset
:
seq_data_offset
);
memcpy
(
dst
,
src
,
step_width
*
sizeof
(
T
));
if
(
norm_by_len
)
{
for
(
int
i
=
0
;
i
<
step_width
;
++
i
)
{
*
(
dst
+
i
)
*=
scale
;
}
}
seq_data_offset
+=
seq_cpy_gap
;
pad_data_offset
+=
pad_cpy_gap
;
}
}
}
template
<
typename
T
>
class
PaddingLoDTensorFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
LoDTensor
&
seq
,
framework
::
Tensor
*
padding
,
bool
norm_by_times
)
{
auto
lod
=
seq
.
lod
();
PADDLE_ENFORCE_GT
(
lod
.
size
(),
0UL
,
"The LoD of LoDTensor seq should not be null."
);
const
size_t
level
=
0
;
framework
::
LoD
abs_offset_lod
=
framework
::
ToAbsOffset
(
lod
);
auto
seq_dims
=
seq
.
dims
();
PADDLE_ENFORCE_EQ
(
seq_dims
[
0
],
static_cast
<
int64_t
>
(
abs_offset_lod
[
level
].
back
()),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length."
);
auto
padding_dims
=
padding
->
dims
();
PADDLE_ENFORCE_EQ
(
padding_dims
.
size
(),
3UL
,
"The input padding should be a 3-D Tensor of shape "
"[max_sequence_length, num_sequences, sequence_width]."
);
const
int64_t
max_sequence_length
=
MaximumSequenceLength
(
lod
,
level
);
PADDLE_ENFORCE_EQ
(
padding_dims
[
0
],
max_sequence_length
,
"The first dimension of Tensor padding should be the "
"maximum length of all sequences in LoDTensor seq."
);
const
int64_t
num_sequences
=
abs_offset_lod
[
level
].
size
()
-
1
;
PADDLE_ENFORCE_EQ
(
padding_dims
[
1
],
num_sequences
,
"The second dimension of Tensor padding should be the "
"number of sequences in LoDTensor seq."
);
const
int64_t
sequence_width
=
seq
.
numel
()
/
seq_dims
[
0
];
PADDLE_ENFORCE_EQ
(
padding_dims
[
2
],
sequence_width
,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq."
);
const
T
*
seq_data
=
seq
.
data
<
T
>
();
T
*
padding_data
=
padding
->
data
<
T
>
();
for
(
int64_t
i
=
0
;
i
<
max_sequence_length
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
num_sequences
;
++
j
)
{
int64_t
start_pos
=
abs_offset_lod
[
level
][
j
];
int64_t
sequence_length
=
abs_offset_lod
[
level
][
j
+
1
]
-
start_pos
;
if
(
i
<
sequence_length
)
{
// i > 0 => sequence_length > 0
T
scale
=
norm_by_times
?
(
1.0
f
/
static_cast
<
T
>
(
sequence_length
))
:
1.0
f
;
for
(
int64_t
k
=
0
;
k
<
sequence_width
;
++
k
)
{
padding_data
[(
i
*
num_sequences
+
j
)
*
sequence_width
+
k
]
=
seq_data
[(
start_pos
+
i
)
*
sequence_width
+
k
]
*
scale
;
const
framework
::
LoDTensor
&
seq_tensor
,
framework
::
LoDTensor
*
pad_tensor
,
const
framework
::
LoDTensor
&
pad_value
,
int
pad_seq_len
=
-
1
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
const
PadLayout
layout
=
kBatchLengthWidth
)
{
auto
seq_lod
=
seq_tensor
.
lod
();
const
auto
seq_offsets
=
framework
::
ToAbsOffset
(
seq_lod
)[
lod_level
];
const
auto
&
seq_tensor_dims
=
seq_tensor
.
dims
();
const
auto
&
pad_tensor_dims
=
pad_tensor
->
dims
();
if
(
pad_seq_len
==
-
1
)
{
pad_seq_len
=
MaximumSequenceLength
(
seq_offsets
);
}
}
else
{
memset
(
padding_data
+
(
i
*
num_sequences
+
j
)
*
sequence_width
,
0
,
sequence_width
*
sizeof
(
T
));
int
step_width
=
seq_tensor
.
numel
()
/
seq_tensor_dims
[
0
];
CheckDims
(
seq_tensor_dims
,
pad_tensor_dims
,
seq_offsets
,
pad_seq_len
,
step_width
,
layout
);
PADDLE_ENFORCE
(
pad_value
.
numel
()
==
1
||
pad_value
.
numel
()
==
step_width
,
"The numel of 'pad_value' can only be 1 or be equal to the "
"'step_width'."
);
// fill padding value
T
*
pad_data
=
pad_tensor
->
data
<
T
>
();
const
T
*
pad_value_data
=
pad_value
.
data
<
T
>
();
if
(
pad_value
.
numel
()
==
1
)
{
for
(
int
i
=
0
;
i
<
pad_tensor
->
numel
();
++
i
)
{
pad_data
[
i
]
=
*
pad_value_data
;
}
}
else
{
for
(
int
i
=
0
;
i
<
pad_tensor
->
numel
();
i
+=
step_width
)
{
memcpy
(
pad_data
+
i
,
pad_value_data
,
step_width
*
sizeof
(
T
));
}
}
CopyValidData
<
T
>
(
pad_tensor
,
&
seq_tensor
,
seq_offsets
,
pad_seq_len
,
step_width
,
norm_by_times
,
kSeqToPad
,
layout
);
}
};
...
...
@@ -84,62 +105,35 @@ template <typename T>
class
UnpaddingLoDTensorFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
framework
::
LoDTensor
*
seq
,
const
framework
::
Tensor
&
padding
,
bool
norm_by_times
)
{
auto
lod
=
seq
->
lod
();
PADDLE_ENFORCE_GT
(
lod
.
size
(),
0UL
,
"The LoD of LoDTensor seq should not be null."
);
const
size_t
level
=
0
;
framework
::
LoD
abs_offset_lod
=
framework
::
ToAbsOffset
(
lod
);
auto
seq_dims
=
seq
->
dims
();
PADDLE_ENFORCE_EQ
(
seq_dims
[
0
],
static_cast
<
int64_t
>
(
abs_offset_lod
[
level
].
back
()),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length."
);
auto
padding_dims
=
padding
.
dims
();
PADDLE_ENFORCE_EQ
(
padding_dims
.
size
(),
3UL
,
"The input padding should be a 3-D Tensor of shape "
"[max_sequnece_length, num_sequences, sequence_width]."
);
const
int64_t
max_sequence_length
=
MaximumSequenceLength
(
lod
,
level
);
PADDLE_ENFORCE_EQ
(
padding_dims
[
0
],
max_sequence_length
,
"The first dimension of Tensor padding should be "
"the maximum length of all sequences in LoDTensor seq."
);
const
int64_t
num_sequences
=
abs_offset_lod
[
level
].
size
()
-
1
;
PADDLE_ENFORCE_EQ
(
padding_dims
[
1
],
num_sequences
,
"The second dimension of Tensor padding should be "
"the number of sequences in LoDTensor seq."
);
const
int64_t
sequence_width
=
seq
->
numel
()
/
seq_dims
[
0
];
PADDLE_ENFORCE_EQ
(
padding_dims
[
2
],
sequence_width
,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq."
);
const
T
*
padding_data
=
padding
.
data
<
T
>
();
T
*
seq_data
=
seq
->
data
<
T
>
();
for
(
int64_t
i
=
0
;
i
<
num_sequences
;
++
i
)
{
int64_t
start_pos
=
abs_offset_lod
[
level
][
i
];
int64_t
sequence_length
=
abs_offset_lod
[
level
][
i
+
1
]
-
start_pos
;
for
(
int64_t
j
=
0
;
j
<
sequence_length
;
++
j
)
{
// sequence_width > j > 0
T
scale
=
norm_by_times
?
(
1.0
f
/
static_cast
<
T
>
(
sequence_length
))
:
1.0
f
;
for
(
int64_t
k
=
0
;
k
<
sequence_width
;
++
k
)
{
seq_data
[(
start_pos
+
j
)
*
sequence_width
+
k
]
=
padding_data
[(
j
*
num_sequences
+
i
)
*
sequence_width
+
k
]
*
scale
;
}
}
const
framework
::
LoDTensor
&
pad_tensor
,
framework
::
LoDTensor
*
seq_tensor
,
int
pad_seq_len
=
-
1
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
const
PadLayout
layout
=
kBatchLengthWidth
)
{
auto
seq_offsets
=
framework
::
ToAbsOffset
(
seq_tensor
->
lod
())[
lod_level
];
const
auto
&
seq_tensor_dims
=
seq_tensor
->
dims
();
const
auto
&
pad_tensor_dims
=
pad_tensor
.
dims
();
if
(
pad_seq_len
==
-
1
)
{
pad_seq_len
=
MaximumSequenceLength
(
seq_offsets
);
}
int
step_width
=
seq_tensor
->
numel
()
/
seq_tensor_dims
[
0
];
CheckDims
(
seq_tensor_dims
,
pad_tensor_dims
,
seq_offsets
,
pad_seq_len
,
step_width
,
layout
);
CopyValidData
<
T
>
(
seq_tensor
,
&
pad_tensor
,
seq_offsets
,
pad_seq_len
,
step_width
,
norm_by_times
,
kPadToSeq
,
layout
);
}
};
template
class
PaddingLoDTensorFunctor
<
platform
::
CPUDeviceContext
,
int
>;
template
class
PaddingLoDTensorFunctor
<
platform
::
CPUDeviceContext
,
int64_t
>;
template
class
PaddingLoDTensorFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
class
PaddingLoDTensorFunctor
<
platform
::
CPUDeviceContext
,
double
>;
template
class
UnpaddingLoDTensorFunctor
<
platform
::
CPUDeviceContext
,
int
>;
template
class
UnpaddingLoDTensorFunctor
<
platform
::
CPUDeviceContext
,
int64_t
>;
template
class
UnpaddingLoDTensorFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
class
UnpaddingLoDTensorFunctor
<
platform
::
CPUDeviceContext
,
double
>;
}
// namespace math
}
// namespace operators
...
...
paddle/fluid/operators/math/sequence_padding.cu
浏览文件 @
7b84c580
...
...
@@ -19,41 +19,32 @@ namespace paddle {
namespace
operators
{
namespace
math
{
template
<
typename
T
,
bool
NormByTimes
,
bool
Padding
>
__global__
void
SequencePaddingKernel
(
T
*
padding
,
T
*
sequence
,
const
size_t
*
sequence_start_positions
,
const
size_t
sequence_width
,
const
size_t
max_sequence_length
,
const
size_t
num_sequences
)
{
size_t
padding_idx
=
blockIdx
.
y
;
size_t
start_pos
=
sequence_start_positions
[
padding_idx
];
size_t
sequence_length
=
sequence_start_positions
[
padding_idx
+
1
]
-
start_pos
;
size_t
sequence_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
size_t
padding_base_idx
=
(
sequence_idx
*
num_sequences
+
padding_idx
)
*
sequence_width
;
size_t
sequence_base_idx
=
(
start_pos
+
sequence_idx
)
*
sequence_width
;
if
(
sequence_idx
<
sequence_length
)
{
T
scale
=
NormByTimes
?
(
1.0
f
/
static_cast
<
T
>
(
sequence_length
))
:
1.0
f
;
if
(
Padding
)
{
/* sequence -> padding */
for
(
size_t
i
=
threadIdx
.
x
;
i
<
sequence_width
;
i
+=
blockDim
.
x
)
{
padding
[
padding_base_idx
+
i
]
=
scale
*
sequence
[
sequence_base_idx
+
i
];
}
}
else
{
/* padding -> sequence */
for
(
size_t
i
=
threadIdx
.
x
;
i
<
sequence_width
;
i
+=
blockDim
.
x
)
{
sequence
[
sequence_base_idx
+
i
]
=
scale
*
padding
[
padding_base_idx
+
i
];
}
}
}
else
if
(
sequence_idx
<
max_sequence_length
)
{
if
(
Padding
)
{
/* sequence -> padding */
for
(
size_t
i
=
threadIdx
.
x
;
i
<
sequence_width
;
i
+=
blockDim
.
x
)
{
padding
[
padding_base_idx
+
i
]
=
0
;
template
<
typename
T
,
CopyType
Type
>
__global__
void
SequencePaddingKernel
(
T
*
dst
,
const
T
*
src
,
const
T
*
pad_value
,
bool
is_constant_pad
,
const
size_t
*
seq_offsets
,
const
size_t
seq_num
,
const
size_t
pad_seq_len
,
const
size_t
step_width
,
bool
norm_by_len
,
const
PadLayout
layout
)
{
size_t
seq_idx
=
blockIdx
.
y
;
size_t
seq_len
=
seq_offsets
[
seq_idx
+
1
]
-
seq_offsets
[
seq_idx
];
size_t
step_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
size_t
seq_data_offset
=
(
seq_offsets
[
seq_idx
]
+
step_idx
)
*
step_width
;
size_t
pad_data_offset
=
layout
==
kBatchLengthWidth
?
(
seq_idx
*
pad_seq_len
+
step_idx
)
*
step_width
:
(
step_idx
*
seq_num
+
seq_idx
)
*
step_width
;
T
*
dst_data
=
dst
+
(
Type
==
kSeqToPad
?
pad_data_offset
:
seq_data_offset
);
const
T
*
src_data
=
src
+
(
Type
==
kSeqToPad
?
seq_data_offset
:
pad_data_offset
);
if
(
step_idx
<
seq_len
)
{
float
scale
=
norm_by_len
?
(
1.0
f
/
static_cast
<
float
>
(
seq_len
))
:
1.0
f
;
for
(
size_t
i
=
threadIdx
.
x
;
i
<
step_width
;
i
+=
blockDim
.
x
)
{
dst_data
[
i
]
=
scale
*
src_data
[
i
];
}
}
else
if
(
step_idx
<
pad_seq_len
&&
Type
==
kSeqToPad
)
{
for
(
size_t
i
=
threadIdx
.
x
;
i
<
step_width
;
i
+=
blockDim
.
x
)
{
dst_data
[
i
]
=
is_constant_pad
?
pad_value
[
0
]
:
pad_value
[
i
];
}
}
}
...
...
@@ -62,74 +53,59 @@ template <typename T>
class
PaddingLoDTensorFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
LoDTensor
&
seq
,
framework
::
Tensor
*
padding
,
bool
norm_by_times
)
{
auto
lod
=
seq
.
lod
();
PADDLE_ENFORCE_GT
(
lod
.
size
(),
0UL
,
"The lod of LoDTensor seq should not be null."
);
const
size_t
level
=
0
;
framework
::
LoD
abs_offset_lod
=
framework
::
ToAbsOffset
(
lod
);
auto
seq_dims
=
seq
.
dims
();
PADDLE_ENFORCE_EQ
(
seq_dims
[
0
],
static_cast
<
int64_t
>
(
abs_offset_lod
[
level
].
back
()),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length."
);
auto
padding_dims
=
padding
->
dims
();
PADDLE_ENFORCE_EQ
(
padding_dims
.
size
(),
3UL
,
"The input padding should be a 3-D Tensor of shape "
"[max_sequence_length, num_sequences, sequence_width]."
);
int64_t
max_sequence_length
=
MaximumSequenceLength
(
lod
,
level
);
PADDLE_ENFORCE_EQ
(
padding_dims
[
0
],
max_sequence_length
,
"The first dimension of Tensor padding should be the "
"maximum length of all sequences in LoDTensor seq."
);
const
int64_t
num_sequences
=
abs_offset_lod
[
level
].
size
()
-
1
;
PADDLE_ENFORCE_EQ
(
padding_dims
[
1
],
num_sequences
,
"The second dimension of Tensor padding should be the "
"number of sequences in LoDTensor seq."
);
const
int64_t
sequence_width
=
seq
.
numel
()
/
seq_dims
[
0
];
PADDLE_ENFORCE_EQ
(
padding_dims
[
2
],
sequence_width
,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq."
);
if
(
!
norm_by_times
&&
num_sequences
==
1UL
)
{
TensorCopy
(
seq
,
context
.
GetPlace
(),
context
,
padding
);
padding
->
Resize
(
padding_dims
);
const
framework
::
LoDTensor
&
seq_tensor
,
framework
::
LoDTensor
*
pad_tensor
,
const
framework
::
LoDTensor
&
pad_value
,
int
pad_seq_len
=
-
1
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
const
PadLayout
layout
=
kBatchLengthWidth
)
{
auto
seq_lod
=
seq_tensor
.
lod
();
const
auto
seq_offsets
=
framework
::
ToAbsOffset
(
seq_lod
)[
lod_level
];
const
auto
&
seq_tensor_dims
=
seq_tensor
.
dims
();
const
auto
&
pad_tensor_dims
=
pad_tensor
->
dims
();
int
max_seq_len
=
MaximumSequenceLength
(
seq_offsets
);
if
(
pad_seq_len
==
-
1
)
{
pad_seq_len
=
max_seq_len
;
}
PADDLE_ENFORCE_GE
(
pad_seq_len
,
max_seq_len
,
"The pad_seq_len must be equal to or greater than the "
"original max sequence length."
);
int
step_width
=
seq_tensor
.
numel
()
/
seq_tensor_dims
[
0
];
int
seq_num
=
seq_offsets
.
size
()
-
1
;
CheckDims
(
seq_tensor_dims
,
pad_tensor_dims
,
seq_offsets
,
pad_seq_len
,
step_width
,
layout
);
PADDLE_ENFORCE
(
pad_value
.
numel
()
==
1
||
pad_value
.
numel
()
==
step_width
,
"The numel of 'pad_value' can only be 1 or be equal to the "
"'step_width'."
);
if
(
!
norm_by_times
&&
seq_num
==
1UL
&&
pad_seq_len
==
max_seq_len
)
{
TensorCopy
(
seq_tensor
,
context
.
GetPlace
(),
context
,
pad_tensor
);
pad_tensor
->
Resize
(
pad_tensor_dims
);
return
;
}
const
int
64_t
kBlockSize
=
512
;
const
int
kBlockSize
=
512
;
/* At least use 32 threads to copy sequence_width elements,
* and at least 8 elements for each thread.
*/
size_t
block_dim_x
=
std
::
min
(((((
s
equence
_width
+
7
)
>>
3
)
+
31
)
>>
5
)
<<
5
,
kBlockSize
);
std
::
min
(((((
s
tep
_width
+
7
)
>>
3
)
+
31
)
>>
5
)
<<
5
,
kBlockSize
);
size_t
block_dim_y
=
kBlockSize
/
block_dim_x
;
dim3
threads
(
block_dim_x
,
block_dim_y
);
size_t
grid_dim_x
=
(
max_sequence_length
+
block_dim_y
-
1
)
/
block_dim_y
;
size_t
grid_dim_y
=
num_sequences
;
size_t
grid_dim_x
=
(
pad_seq_len
+
block_dim_y
-
1
)
/
block_dim_y
;
size_t
grid_dim_y
=
seq_num
;
dim3
grid
(
grid_dim_x
,
grid_dim_y
);
const
T
*
seq_data
=
seq
.
data
<
T
>
();
T
*
padding_data
=
padding
->
data
<
T
>
();
if
(
norm_by_times
)
{
SequencePaddingKernel
<
T
,
1
,
1
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
padding_data
,
const_cast
<
T
*>
(
seq_data
),
abs_offset_lod
[
level
].
CUDAData
(
context
.
GetPlace
()),
sequence_width
,
max_sequence_length
,
num_sequences
);
}
else
{
SequencePaddingKernel
<
T
,
0
,
1
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
padding_data
,
const_cast
<
T
*>
(
seq_data
),
abs_offset_lod
[
level
].
CUDAData
(
context
.
GetPlace
()),
sequence_width
,
max_sequence_length
,
num_sequences
);
}
const
T
*
seq_data
=
seq_tensor
.
data
<
T
>
();
T
*
pad_data
=
pad_tensor
->
data
<
T
>
();
const
T
*
pad_value_data
=
pad_value
.
data
<
T
>
();
SequencePaddingKernel
<
T
,
kSeqToPad
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
pad_data
,
seq_data
,
pad_value_data
,
pad_value
.
numel
()
==
1
,
seq_offsets
.
CUDAData
(
context
.
GetPlace
()),
seq_num
,
pad_seq_len
,
step_width
,
norm_by_times
,
layout
);
}
};
...
...
@@ -137,79 +113,62 @@ template <typename T>
class
UnpaddingLoDTensorFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
framework
::
LoDTensor
*
seq
,
const
framework
::
Tensor
&
padding
,
bool
norm_by_times
)
{
auto
lod
=
seq
->
lod
();
PADDLE_ENFORCE_GT
(
lod
.
size
(),
0UL
,
"The lod of LoDTensor seq should not be null."
);
const
size_t
level
=
0
;
framework
::
LoD
abs_offset_lod
=
framework
::
ToAbsOffset
(
lod
);
auto
seq_dims
=
seq
->
dims
();
PADDLE_ENFORCE_EQ
(
seq_dims
[
0
],
static_cast
<
int64_t
>
(
abs_offset_lod
[
level
].
back
()),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length."
);
auto
padding_dims
=
padding
.
dims
();
PADDLE_ENFORCE_EQ
(
padding_dims
.
size
(),
3UL
,
"The input padding should be a 3-D Tensor of shape "
"[max_sequnece_length, num_sequences, sequence_width]."
);
int64_t
max_sequence_length
=
MaximumSequenceLength
(
lod
,
level
);
PADDLE_ENFORCE_EQ
(
padding_dims
[
0
],
max_sequence_length
,
"The first dimension of Tensor padding should be "
"the maximum length of all sequences in LoDTensor seq."
);
const
int64_t
num_sequences
=
abs_offset_lod
[
level
].
size
()
-
1
;
PADDLE_ENFORCE_EQ
(
padding_dims
[
1
],
num_sequences
,
"The second dimension of Tensor padding should be "
"the number of sequences in LoDTensor seq."
);
const
int64_t
sequence_width
=
seq
->
numel
()
/
seq_dims
[
0
];
PADDLE_ENFORCE_EQ
(
padding_dims
[
2
],
sequence_width
,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq."
);
if
(
!
norm_by_times
&&
num_sequences
==
1UL
)
{
TensorCopy
(
padding
,
context
.
GetPlace
(),
context
,
seq
);
seq
->
Resize
(
seq_dims
);
const
framework
::
LoDTensor
&
pad_tensor
,
framework
::
LoDTensor
*
seq_tensor
,
int
pad_seq_len
=
-
1
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
const
PadLayout
layout
=
kBatchLengthWidth
)
{
auto
seq_offsets
=
framework
::
ToAbsOffset
(
seq_tensor
->
lod
())[
lod_level
];
const
auto
&
seq_tensor_dims
=
seq_tensor
->
dims
();
const
auto
&
pad_tensor_dims
=
pad_tensor
.
dims
();
int
max_seq_len
=
MaximumSequenceLength
(
seq_offsets
);
if
(
pad_seq_len
==
-
1
)
{
pad_seq_len
=
max_seq_len
;
}
int
step_width
=
seq_tensor
->
numel
()
/
seq_tensor_dims
[
0
];
int
seq_num
=
seq_offsets
.
size
()
-
1
;
CheckDims
(
seq_tensor_dims
,
pad_tensor_dims
,
seq_offsets
,
pad_seq_len
,
step_width
,
layout
);
if
(
!
norm_by_times
&&
seq_num
==
1UL
&&
pad_seq_len
==
max_seq_len
)
{
TensorCopy
(
pad_tensor
,
context
.
GetPlace
(),
context
,
seq_tensor
);
seq_tensor
->
Resize
(
seq_tensor_dims
);
return
;
}
const
int
64_t
kBlockSize
=
512
;
const
int
kBlockSize
=
512
;
/* At least use 32 threads to copy sequence_width elements,
* and at least 8 elements for each thread.
*/
size_t
block_dim_x
=
std
::
min
(((((
s
equence
_width
+
7
)
>>
3
)
+
31
)
>>
5
)
<<
5
,
kBlockSize
);
std
::
min
(((((
s
tep
_width
+
7
)
>>
3
)
+
31
)
>>
5
)
<<
5
,
kBlockSize
);
size_t
block_dim_y
=
kBlockSize
/
block_dim_x
;
dim3
threads
(
block_dim_x
,
block_dim_y
);
size_t
grid_dim_x
=
(
max_sequence_length
+
block_dim_y
-
1
)
/
block_dim_y
;
size_t
grid_dim_y
=
num_sequences
;
size_t
grid_dim_x
=
(
pad_seq_len
+
block_dim_y
-
1
)
/
block_dim_y
;
size_t
grid_dim_y
=
seq_num
;
dim3
grid
(
grid_dim_x
,
grid_dim_y
);
const
T
*
padding_data
=
padding
.
data
<
T
>
();
T
*
seq_data
=
seq
->
data
<
T
>
();
if
(
norm_by_times
)
{
SequencePaddingKernel
<
T
,
1
,
0
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
const_cast
<
T
*>
(
padding_data
),
seq_data
,
abs_offset_lod
[
level
].
CUDAData
(
context
.
GetPlace
()),
sequence_width
,
max_sequence_length
,
num_sequences
);
}
else
{
SequencePaddingKernel
<
T
,
0
,
0
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
const_cast
<
T
*>
(
padding_data
),
seq_data
,
abs_offset_lod
[
level
].
CUDAData
(
context
.
GetPlace
()),
sequence_width
,
max_sequence_length
,
num_sequences
);
}
const
T
*
pad_data
=
pad_tensor
.
data
<
T
>
();
T
*
seq_data
=
seq_tensor
->
data
<
T
>
();
SequencePaddingKernel
<
T
,
kPadToSeq
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
seq_data
,
pad_data
,
nullptr
,
false
,
seq_offsets
.
CUDAData
(
context
.
GetPlace
()),
seq_num
,
pad_seq_len
,
step_width
,
norm_by_times
,
layout
);
}
};
template
class
PaddingLoDTensorFunctor
<
platform
::
CUDADeviceContext
,
int
>;
template
class
PaddingLoDTensorFunctor
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
class
PaddingLoDTensorFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
PaddingLoDTensorFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
UnpaddingLoDTensorFunctor
<
platform
::
CUDADeviceContext
,
int
>;
template
class
UnpaddingLoDTensorFunctor
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
class
UnpaddingLoDTensorFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
UnpaddingLoDTensorFunctor
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace math
}
// namespace operators
...
...
paddle/fluid/operators/math/sequence_padding.h
浏览文件 @
7b84c580
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/device_context.h"
...
...
@@ -22,17 +23,33 @@ namespace paddle {
namespace
operators
{
namespace
math
{
inline
static
size_t
MaximumSequenceLength
(
const
framework
::
LoD
&
lod
,
const
size_t
level
)
{
const
size_t
num_sequences
=
lod
[
level
].
size
()
-
1
;
size_t
max_sequence_length
=
0
;
framework
::
LoD
abs_offset_lod
=
framework
::
ToAbsOffset
(
lod
);
for
(
size_t
i
=
0
;
i
<
num_sequences
;
++
i
)
{
max_sequence_length
=
std
::
max
(
max_sequence_length
,
abs_offset_lod
[
level
][
i
+
1
]
-
abs_offset_lod
[
level
][
i
]);
enum
PadLayout
{
kBatchLengthWidth
=
0
,
kLengthBatchWidth
};
enum
CopyType
{
kSeqToPad
,
kPadToSeq
};
inline
static
size_t
MaximumSequenceLength
(
const
framework
::
Vector
<
size_t
>&
seq_offset
)
{
size_t
seq_num
=
seq_offset
.
size
()
-
1
;
size_t
max_seq_len
=
0
;
for
(
size_t
i
=
0
;
i
<
seq_num
;
++
i
)
{
max_seq_len
=
std
::
max
(
max_seq_len
,
seq_offset
[
i
+
1
]
-
seq_offset
[
i
]);
}
return
max_sequence_length
;
return
max_seq_len
;
}
inline
static
void
CheckDims
(
const
framework
::
DDim
&
seq_tensor_dims
,
const
framework
::
DDim
&
pad_tensor_dims
,
const
framework
::
Vector
<
size_t
>&
seq_offset
,
int64_t
padded_seq_len
,
int64_t
step_width
,
const
PadLayout
&
layout
)
{
PADDLE_ENFORCE_EQ
(
static_cast
<
size_t
>
(
seq_tensor_dims
[
0
]),
seq_offset
.
back
(),
"Value of 1st dimension of the sequence tensor should be "
"equal to sum of lengths of all sequences."
);
PADDLE_ENFORCE
(
seq_tensor_dims
.
size
()
+
1
==
pad_tensor_dims
.
size
()
||
seq_tensor_dims
.
size
()
==
pad_tensor_dims
.
size
(),
"pad_tensor's rank should be 1 greater than seq_tensor's "
"rank, or be equal with it."
);
}
/*
...
...
@@ -64,15 +81,22 @@ inline static size_t MaximumSequenceLength(const framework::LoD& lod,
template
<
typename
DeviceContext
,
typename
T
>
class
PaddingLoDTensorFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
seq
,
framework
::
Tensor
*
padding
,
bool
norm_by_times
);
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
seq_tensor
,
framework
::
LoDTensor
*
pad_tensor
,
const
framework
::
LoDTensor
&
pad_value
,
int
pad_seq_len
=
-
1
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
const
PadLayout
layout
=
kBatchLengthWidth
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
UnpaddingLoDTensorFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
framework
::
LoDTensor
*
seq
,
const
framework
::
Tensor
&
padding
,
bool
norm_by_times
);
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
pad_tensor
,
framework
::
LoDTensor
*
seq_tensor
,
int
pad_seq_len
=
-
1
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
const
PadLayout
layout
=
kBatchLengthWidth
);
};
}
// namespace math
...
...
paddle/fluid/operators/math/sequence_padding_test.cc
浏览文件 @
7b84c580
...
...
@@ -23,7 +23,9 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
paddle
::
framework
::
LoDTensor
cpu_seq_back
;
paddle
::
framework
::
LoDTensor
seq
;
paddle
::
framework
::
LoDTensor
seq_back
;
paddle
::
framework
::
Tensor
padding
;
paddle
::
framework
::
LoDTensor
padding
;
paddle
::
framework
::
LoDTensor
cpu_pad_value
;
paddle
::
framework
::
LoDTensor
pad_value
;
const
size_t
level
=
lod
.
size
()
-
1
;
auto
seq_dims
=
...
...
@@ -46,20 +48,33 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
}
const
size_t
max_sequence_length
=
paddle
::
operators
::
math
::
MaximumSequenceLength
(
lod
,
level
);
paddle
::
operators
::
math
::
MaximumSequenceLength
(
lod
[
level
]
);
const
size_t
num_sequences
=
lod
[
level
].
size
()
-
1
;
auto
padding_dims
=
paddle
::
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
max_sequence_length
),
static_cast
<
int64_t
>
(
num_sequences
),
static_cast
<
int64_t
>
(
sequence_width
)});
padding
.
mutable_data
<
T
>
(
padding_dims
,
*
place
);
T
*
pad_value_data
=
cpu_pad_value
.
mutable_data
<
T
>
({
1
},
paddle
::
platform
::
CPUPlace
());
*
pad_value_data
=
static_cast
<
T
>
(
0
);
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
pad_value
=
cpu_pad_value
;
}
else
{
TensorCopySync
(
cpu_pad_value
,
*
place
,
&
pad_value
);
}
paddle
::
operators
::
math
::
PaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
*
context
,
seq
,
&
padding
,
false
);
*
context
,
seq
,
&
padding
,
pad_value
,
-
1
,
0
,
false
,
paddle
::
operators
::
math
::
kLengthBatchWidth
);
seq_back
.
set_lod
(
lod
);
seq_back
.
mutable_data
<
T
>
(
seq_dims
,
*
place
);
paddle
::
operators
::
math
::
UnpaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
*
context
,
&
seq_back
,
padding
,
false
);
*
context
,
padding
,
&
seq_back
,
-
1
,
0
,
false
,
paddle
::
operators
::
math
::
kLengthBatchWidth
);
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
cpu_seq_back
=
seq_back
;
...
...
paddle/fluid/operators/sequence_pad_op.cc
0 → 100644
浏览文件 @
7b84c580
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_pad_op.h"
namespace
paddle
{
namespace
operators
{
class
SequencePadOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequencePadOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"PadValue"
),
"Input(PadValue) of SequencePadOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of SequencePadOp should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_GE
(
x_dims
.
size
(),
2
,
"The rank of Input(x) can't be less than 2."
);
auto
time_step_dims
=
framework
::
slice_ddim
(
x_dims
,
1
,
x_dims
.
size
());
auto
pad_value_dims
=
ctx
->
GetInputDim
(
"PadValue"
);
PADDLE_ENFORCE
(
pad_value_dims
==
framework
::
make_ddim
({
1
})
||
pad_value_dims
==
time_step_dims
,
"The Input(PadValue) must be a scalar or a tensor whose "
"shape equals to time steps in sequences"
);
int
out_dim_0
=
-
1
;
int
out_dim_1
=
-
1
;
if
(
ctx
->
IsRuntime
())
{
// run time
framework
::
Variable
*
x_var
=
boost
::
get
<
framework
::
Variable
*>
(
ctx
->
GetInputVarPtrs
(
"X"
)[
0
]);
const
auto
&
x_lod
=
x_var
->
Get
<
LoDTensor
>
().
lod
();
PADDLE_ENFORCE
(
!
x_lod
.
empty
(),
"The Input(X) must hold lod info."
);
const
auto
&
x_lod_0
=
x_lod
[
0
];
PADDLE_ENFORCE_GE
(
x_lod_0
.
size
(),
2
,
"The Input(X)'s lod info is corrupted."
);
PADDLE_ENFORCE_EQ
(
x_dims
[
0
],
static_cast
<
int64_t
>
(
x_lod_0
.
back
()),
"The Input(X)'s lod info mismatches the actual tensor shape."
);
int
seq_num
=
x_lod_0
.
size
()
-
1
;
int
max_seq_len
=
math
::
MaximumSequenceLength
(
x_lod_0
);
int
padded_length
=
ctx
->
Attrs
().
Get
<
int
>
(
"padded_length"
);
if
(
padded_length
==
-
1
)
{
padded_length
=
max_seq_len
;
}
PADDLE_ENFORCE_GE
(
padded_length
,
max_seq_len
,
"The Attr(padded_length) must be -1 or an int greater "
"than the length of the longest original sequence."
);
out_dim_0
=
seq_num
;
out_dim_1
=
padded_length
;
}
else
{
// compile time
framework
::
VarDesc
*
x_desc
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetInputVarPtrs
(
"X"
)[
0
]);
PADDLE_ENFORCE_GE
(
x_desc
->
GetLoDLevel
(),
1
);
}
std
::
vector
<
int
>
out_dims_vec
{
out_dim_0
,
out_dim_1
};
auto
time_step_dims_vec
=
framework
::
vectorize2int
(
time_step_dims
);
out_dims_vec
.
insert
(
out_dims_vec
.
end
(),
time_step_dims_vec
.
begin
(),
time_step_dims_vec
.
end
());
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
out_dims_vec
));
}
};
class
SequencePadOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(LoDTensor, default LoDTensor<float>) Input variable which "
"should contain lod information."
);
AddInput
(
"PadValue"
,
"(LoDTensor), this Tensor holds values that will be fill into "
"padded steps. It can be a scalar or a tensor whose shape equals "
"to time steps in sequences. If it's a scalar, it will be "
"automatically broadcasted to the shape of time step."
);
AddOutput
(
"Out"
,
"(LoDTensor) The output vairable, which contains padded sequences."
);
AddAttr
<
int
>
(
"padded_length"
,
"The length of padded sequences. It can be setted to -1 or "
"any positive int. When it is -1, all sequences will be padded up to "
"the length of the longest one among them; when it a certain positive "
"value, it must be greater than the length of the longest original "
"sequence."
)
.
SetDefault
(
-
1
);
AddComment
(
R"DOC(
Sequence Pad Operator
This operator pads sequences in a same batch to a consistent length.
The length is specified by attribute 'padded_length'. New elements,
whose values are specified by input 'PadValue', will be appended to
the end of each sequence, to make their final lengths consistent.
Following are cases to better explain how this works:
Case 1:
Given a 1-level LoDTensor input(X):
X.lod = [[0, 2, 5]]
X.data = [a, b, c, d, e]
and Input(PadValue):
PadValue.data = [0]
and attribite 'padded_length' = 4,
then we get LoDTensor:
Out.data = [[a, b, 0, 0],
[c, d, e, 0]]
Case 2:
Given a 1-level LoDTensor input(X):
X.lod = [[0, 2, 5]]
X.data = [[a1, a2], [b1, b2], [c1, c2], [d1, d2], [e1, e2]]
and Input(PadValue):
PadValue.data = [0]
and attribite 'padded_length' = -1, which mean using the length
of longest input sequence(3 in this case),
then we get LoDTensor:
Out.data = [[[a1, a2], [b1, b2], [0, 0]],
[[c1, c2], [d1, d2], [e1, e2]]]
Case 3:
Given a 1-level LoDTensor input(X):
X.lod = [[0, 2, 5]]
X.data = [[a1, a2], [b1, b2], [c1, c2], [d1, d2], [e1, e2]]
and Input(PadValue):
PadValue.data = [p1, p2]
and attribite 'padded_length' = -1, which mean using the length
of longest input sequence(3 in this case),
then we get LoDTensor:
Out.data = [[[a1, a2], [b1, b2], [p1, p2]],
[[c1, c2], [d1, d2], [e1, e2]]]
)DOC"
);
}
};
class
SequencePadGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequencePadGradOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) of SequencePadGradOp should not be null."
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
sequence_pad
,
ops
::
SequencePadOp
,
ops
::
SequencePadOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
sequence_pad_grad
,
ops
::
SequencePadGradOp
);
REGISTER_OP_CPU_KERNEL
(
sequence_pad
,
ops
::
SequencePadOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SequencePadOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
SequencePadOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
SequencePadOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
sequence_pad_grad
,
ops
::
SequencePadGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SequencePadGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
SequencePadGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
SequencePadGradOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
paddle/fluid/operators/sequence_pad_op.cu
0 → 100644
浏览文件 @
7b84c580
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sequence_pad_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
sequence_pad
,
ops
::
SequencePadOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SequencePadOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
SequencePadOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
SequencePadOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
REGISTER_OP_CUDA_KERNEL
(
sequence_pad_grad
,
ops
::
SequencePadGradOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SequencePadGradOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
SequencePadGradOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
SequencePadGradOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
paddle/fluid/operators/sequence_pad_op.h
0 → 100644
浏览文件 @
7b84c580
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_padding.h"
namespace
paddle
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoD
=
framework
::
LoD
;
template
<
typename
DeviceContext
,
typename
T
>
class
SequencePadOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
x
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
LoDTensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
auto
*
pad_value
=
ctx
.
Input
<
LoDTensor
>
(
"PadValue"
);
int
padded_length
=
ctx
.
Attr
<
int
>
(
"padded_length"
);
math
::
PaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
*
x
,
out
,
*
pad_value
,
padded_length
,
0
,
false
,
math
::
kBatchLengthWidth
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
SequencePadGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
d_x
=
ctx
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
if
(
d_x
)
{
const
auto
*
d_out
=
ctx
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
padded_length
=
ctx
.
Attr
<
int
>
(
"padded_length"
);
math
::
UnpaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
*
d_out
,
d_x
,
padded_length
,
0
,
false
,
math
::
kBatchLengthWidth
);
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/warpctc_op.h
浏览文件 @
7b84c580
...
...
@@ -153,17 +153,29 @@ class WarpCTCKernel : public framework::OpKernel<T> {
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
num_sequences
),
1
});
// warpctc needs sequences data stored in transposed padding format
Tensor
warpctc_logits
;
LoD
Tensor
warpctc_logits
;
const
size_t
max_sequence_length
=
math
::
MaximumSequenceLength
(
logits_lod
,
level
);
math
::
MaximumSequenceLength
(
logits_lod
[
level
]
);
auto
warpctc_logits_dims
=
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
max_sequence_length
),
static_cast
<
int64_t
>
(
num_sequences
),
static_cast
<
int64_t
>
(
sequence_width
)});
warpctc_logits
.
mutable_data
<
T
>
(
warpctc_logits_dims
,
ctx
.
GetPlace
());
LoDTensor
cpu_pad_value
;
T
*
pad_value_data
=
cpu_pad_value
.
mutable_data
<
T
>
({
1
},
platform
::
CPUPlace
());
*
pad_value_data
=
static_cast
<
T
>
(
0
);
LoDTensor
pad_value
;
if
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()))
{
pad_value
=
cpu_pad_value
;
}
else
{
TensorCopySync
(
cpu_pad_value
,
ctx
.
GetPlace
(),
&
pad_value
);
}
math
::
PaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
*
logits
,
&
warpctc_logits
,
false
);
pad_value
,
-
1
,
0
,
false
/* norm_by_times */
,
math
::
kLengthBatchWidth
);
const
T
*
warpctc_logits_data
=
warpctc_logits
.
data
<
T
>
();
std
::
vector
<
int
>
warpctc_label_lengths
(
num_sequences
);
...
...
@@ -209,15 +221,15 @@ template <typename DeviceContext, typename T>
class
WarpCTCGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
warpctc_grad
=
ctx
.
Input
<
Tensor
>
(
"WarpCTCGrad"
);
auto
*
warpctc_grad
=
ctx
.
Input
<
LoD
Tensor
>
(
"WarpCTCGrad"
);
auto
*
logits_grad
=
ctx
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"Logits"
));
const
Tensor
*
loss_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Loss"
));
logits_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
bool
norm_by_times
=
ctx
.
Attr
<
bool
>
(
"norm_by_times"
);
math
::
UnpaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
logits
_grad
,
*
warpctc_grad
,
norm_by_times
);
ctx
.
template
device_context
<
DeviceContext
>(),
*
warpctc
_grad
,
logits_grad
,
-
1
,
0
,
norm_by_times
,
math
::
kLengthBatchWidth
);
const
T
*
loss_grad_data
=
loss_grad
->
data
<
T
>
();
math
::
ScaleLoDTensorFunctor
<
DeviceContext
,
T
>
()(
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
7b84c580
...
...
@@ -54,6 +54,7 @@ __all__ = [
'conv2d_transpose'
,
'conv3d_transpose'
,
'sequence_expand'
,
'sequence_pad'
,
'lstm_unit'
,
'reduce_sum'
,
'reduce_mean'
,
...
...
@@ -2656,6 +2657,51 @@ def sequence_expand(x, y, ref_level=-1, name=None):
return
tmp
@
templatedoc
()
def
sequence_pad
(
x
,
pad_value
,
maxlen
=
None
):
"""
${comment}
Args:
x(Variable): Input variable which should contain lod information.
pad_value(Variable): The Variable that holds values that will be fill
into padded steps. It can be a scalar or a tensor whose shape
equals to time steps in sequences. If it's a scalar, it will be
automatically broadcasted to the shape of time step.
maxlen(int, default None): The length of padded sequences. It can be
None or any positive int. When it is None, all sequences will be
padded up to the length of the longest one among them; when it a
certain positive value, it must be greater than the length of the
longest original sequence."
Returns:
Variable: The padded sequence batch. All sequences has the same length.
Examples:
.. code-block:: python
import numpy
x = fluid.layers.data(name='y', shape=[10, 5],
dtype='float32', lod_level=1)
pad_value = fluid.layers.assign(input=numpy.array([0]))
out = fluid.layers.sequence_pad(x=x, pad_value=pad_value)
"""
helper
=
LayerHelper
(
'sequence_pad'
,
input
=
x
,
**
locals
())
dtype
=
helper
.
input_dtype
()
out
=
helper
.
create_tmp_variable
(
dtype
)
if
maxlen
is
None
:
maxlen
=
-
1
helper
.
append_op
(
type
=
'sequence_pad'
,
inputs
=
{
'X'
:
x
,
'PadValue'
:
pad_value
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'padded_length'
:
maxlen
})
return
out
def
beam_search
(
pre_ids
,
pre_scores
,
ids
,
...
...
python/paddle/fluid/tests/unittests/test_sequence_pad_op.py
0 → 100644
浏览文件 @
7b84c580
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
class
TestSequencePadOp
(
OpTest
):
def
set_attr
(
self
):
self
.
x_shape
=
[
12
,
4
]
self
.
x_len_lod
=
[[
2
,
3
,
4
,
3
]]
self
.
pad_value
=
[
1.0
]
self
.
padded_length
=
-
1
self
.
dtype
=
'float32'
def
set_data
(
self
):
x_data
=
np
.
random
.
uniform
(
0.1
,
0.5
,
self
.
x_shape
).
astype
(
self
.
dtype
)
pad_value_data
=
np
.
array
(
self
.
pad_value
).
astype
(
self
.
dtype
)
self
.
inputs
=
{
'X'
:
(
x_data
,
self
.
x_len_lod
),
'PadValue'
:
pad_value_data
}
self
.
attrs
=
{
'padded_length'
:
self
.
padded_length
}
def
compute
(
self
):
# get padded length
padded_length
=
self
.
padded_length
x_len_lod_0
=
self
.
x_len_lod
[
0
]
if
padded_length
==
-
1
:
max_seq_len
=
0
for
l
in
x_len_lod_0
:
max_seq_len
=
max
(
max_seq_len
,
l
)
padded_length
=
max_seq_len
# do padding
x_data
=
self
.
inputs
[
'X'
][
0
]
pad_value_data
=
self
.
inputs
[
'PadValue'
]
if
pad_value_data
.
shape
==
(
1
,
):
pad_value_data
=
np
.
broadcast_to
(
pad_value_data
,
shape
=
x_data
.
shape
[
1
:])
padded_sequences
=
[]
start_idx
=
0
for
l
in
x_len_lod_0
:
end_idx
=
start_idx
+
l
seq
=
x_data
[
start_idx
:
end_idx
]
to_pad_len
=
padded_length
-
l
for
_
in
range
(
to_pad_len
):
seq
=
np
.
append
(
seq
,
pad_value_data
[
np
.
newaxis
,
:],
axis
=
0
)
padded_sequences
.
append
(
seq
)
start_idx
=
end_idx
out_data
=
np
.
array
(
padded_sequences
)
self
.
outputs
=
{
'Out'
:
out_data
}
def
setUp
(
self
):
self
.
op_type
=
'sequence_pad'
self
.
set_attr
()
self
.
set_data
()
self
.
compute
()
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
class
TestSequencePadOp2
(
TestSequencePadOp
):
def
set_attr
(
self
):
self
.
x_shape
=
[
12
,
4
]
self
.
x_len_lod
=
[[
2
,
3
,
4
,
3
]]
self
.
pad_value
=
[
1.0
,
2.0
,
3.0
,
4.0
]
self
.
padded_length
=
-
1
self
.
dtype
=
'float32'
class
TestSequencePadOp3
(
TestSequencePadOp
):
def
set_attr
(
self
):
self
.
x_shape
=
[
12
,
4
]
self
.
x_len_lod
=
[[
2
,
3
,
4
,
3
]]
self
.
pad_value
=
[
1.0
]
self
.
padded_length
=
7
self
.
dtype
=
'float32'
class
TestSequencePadOp4
(
TestSequencePadOp
):
def
set_attr
(
self
):
self
.
x_shape
=
[
12
,
4
]
self
.
x_len_lod
=
[[
2
,
3
,
4
,
3
]]
self
.
pad_value
=
[
1.0
,
2.0
,
3.0
,
4.0
]
self
.
padded_length
=
7
self
.
dtype
=
'float32'
class
TestSequencePadOp5
(
TestSequencePadOp
):
def
set_attr
(
self
):
self
.
x_shape
=
[
12
,
2
,
2
]
self
.
x_len_lod
=
[[
2
,
3
,
4
,
3
]]
self
.
pad_value
=
[
1.0
]
self
.
padded_length
=
-
1
self
.
dtype
=
'float32'
class
TestSequencePadOp6
(
TestSequencePadOp
):
def
set_attr
(
self
):
self
.
x_shape
=
[
12
,
2
,
2
]
self
.
x_len_lod
=
[[
2
,
3
,
4
,
3
]]
self
.
pad_value
=
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
self
.
padded_length
=
-
1
self
.
dtype
=
'float32'
class
TestSequencePadOp7
(
TestSequencePadOp
):
def
set_attr
(
self
):
self
.
x_shape
=
[
12
,
2
,
2
]
self
.
x_len_lod
=
[[
2
,
3
,
4
,
3
]]
self
.
pad_value
=
[
1.0
]
self
.
padded_length
=
7
self
.
dtype
=
'float32'
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录