Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9db40b8e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9db40b8e
编写于
8月 20, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Complete sequence_padding GPU kernel
上级
c3de69a9
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
113 addition
and
104 deletion
+113
-104
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+1
-0
paddle/fluid/operators/math/sequence_padding.cc
paddle/fluid/operators/math/sequence_padding.cc
+13
-13
paddle/fluid/operators/math/sequence_padding.cu
paddle/fluid/operators/math/sequence_padding.cu
+69
-82
paddle/fluid/operators/math/sequence_padding.h
paddle/fluid/operators/math/sequence_padding.h
+4
-2
paddle/fluid/operators/math/sequence_padding_test.cc
paddle/fluid/operators/math/sequence_padding_test.cc
+12
-1
paddle/fluid/operators/sequence_pad_op.h
paddle/fluid/operators/sequence_pad_op.h
+1
-4
paddle/fluid/operators/warpctc_op.h
paddle/fluid/operators/warpctc_op.h
+13
-2
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
9db40b8e
...
@@ -282,6 +282,7 @@ op_library(unsqueeze_op DEPS reshape_op)
...
@@ -282,6 +282,7 @@ op_library(unsqueeze_op DEPS reshape_op)
op_library
(
squeeze_op DEPS reshape_op
)
op_library
(
squeeze_op DEPS reshape_op
)
op_library
(
extract_rows_op DEPS memory
)
op_library
(
extract_rows_op DEPS memory
)
op_library
(
flatten_op DEPS reshape_op
)
op_library
(
flatten_op DEPS reshape_op
)
op_library
(
sequence_pad_op DEPS sequence_padding
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
op_library
(
conv_op DEPS vol2col depthwise_conv im2col
)
op_library
(
conv_op DEPS vol2col depthwise_conv im2col
)
...
...
paddle/fluid/operators/math/sequence_padding.cc
浏览文件 @
9db40b8e
...
@@ -18,8 +18,6 @@ namespace paddle {
...
@@ -18,8 +18,6 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
enum
CopyType
{
kSeqToPad
,
kPadToSeq
};
template
<
typename
T
>
template
<
typename
T
>
void
CopyValidData
(
framework
::
Tensor
*
dst_tensor
,
void
CopyValidData
(
framework
::
Tensor
*
dst_tensor
,
const
framework
::
Tensor
*
src_tensor
,
const
framework
::
Tensor
*
src_tensor
,
...
@@ -67,7 +65,7 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
...
@@ -67,7 +65,7 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
LoDTensor
&
seq_tensor
,
const
framework
::
LoDTensor
&
seq_tensor
,
framework
::
LoDTensor
*
pad_tensor
,
framework
::
LoDTensor
*
pad_tensor
,
std
::
vector
<
T
>
pad_value
=
{
0
}
,
int
pad_seq_len
=
-
1
,
const
framework
::
LoDTensor
&
pad_value
,
int
pad_seq_len
=
-
1
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
const
PadLayout
layout
=
kBatchLengthWidth
)
{
const
PadLayout
layout
=
kBatchLengthWidth
)
{
auto
seq_lod
=
seq_tensor
.
lod
();
auto
seq_lod
=
seq_tensor
.
lod
();
...
@@ -81,19 +79,21 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
...
@@ -81,19 +79,21 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
CheckDims
(
seq_tensor_dims
,
pad_tensor_dims
,
seq_offsets
,
pad_seq_len
,
CheckDims
(
seq_tensor_dims
,
pad_tensor_dims
,
seq_offsets
,
pad_seq_len
,
step_width
,
layout
);
step_width
,
layout
);
PADDLE_ENFORCE
(
pad_value
.
size
()
==
1
||
PADDLE_ENFORCE
(
pad_value
.
numel
()
==
1
||
pad_value
.
numel
()
==
step_width
,
static_cast
<
int
>
(
pad_value
.
size
())
==
step_width
,
"The numel of 'pad_value' can only be 1 or be equal to the "
"The size of 'pad_value' can only be 1 or be equal to the "
"'step_width'."
);
"'step_width'."
);
if
(
pad_value
.
size
()
==
1
)
{
pad_value
=
std
::
vector
<
T
>
(
step_width
,
pad_value
[
0
]);
}
// fill padding value
// fill padding value
T
*
pad_data
=
pad_tensor
->
data
<
T
>
();
T
*
pad_data
=
pad_tensor
->
data
<
T
>
();
for
(
int
i
=
0
;
i
<
pad_tensor
->
numel
();
i
+=
step_width
)
{
const
T
*
pad_value_data
=
pad_value
.
data
<
T
>
();
memcpy
(
pad_data
+
i
,
pad_value
.
data
(),
step_width
*
sizeof
(
T
));
if
(
pad_value
.
numel
()
==
1
)
{
for
(
int
i
=
0
;
i
<
pad_tensor
->
numel
();
++
i
)
{
pad_data
[
i
]
=
*
pad_value_data
;
}
}
else
{
for
(
int
i
=
0
;
i
<
pad_tensor
->
numel
();
i
+=
step_width
)
{
memcpy
(
pad_data
+
i
,
pad_value_data
,
step_width
*
sizeof
(
T
));
}
}
}
CopyValidData
<
T
>
(
pad_tensor
,
&
seq_tensor
,
seq_offsets
,
pad_seq_len
,
CopyValidData
<
T
>
(
pad_tensor
,
&
seq_tensor
,
seq_offsets
,
pad_seq_len
,
...
@@ -117,7 +117,7 @@ class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
...
@@ -117,7 +117,7 @@ class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
const
framework
::
LoDTensor
&
pad_tensor
,
const
framework
::
LoDTensor
&
pad_tensor
,
framework
::
LoDTensor
*
seq_tensor
,
int
pad_seq_len
=
-
1
,
framework
::
LoDTensor
*
seq_tensor
,
int
pad_seq_len
=
-
1
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
const
PadLayout
&
layout
=
kBatchLengthWidth
)
{
const
PadLayout
layout
=
kBatchLengthWidth
)
{
auto
seq_offsets
=
framework
::
ToAbsOffset
(
seq_tensor
->
lod
())[
lod_level
];
auto
seq_offsets
=
framework
::
ToAbsOffset
(
seq_tensor
->
lod
())[
lod_level
];
const
auto
&
seq_tensor_dims
=
seq_tensor
->
dims
();
const
auto
&
seq_tensor_dims
=
seq_tensor
->
dims
();
const
auto
&
pad_tensor_dims
=
pad_tensor
.
dims
();
const
auto
&
pad_tensor_dims
=
pad_tensor
.
dims
();
...
...
paddle/fluid/operators/math/sequence_padding.cu
浏览文件 @
9db40b8e
...
@@ -19,46 +19,32 @@ namespace paddle {
...
@@ -19,46 +19,32 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
template
<
typename
T
,
bool
Padding
>
template
<
typename
T
,
CopyType
Type
>
__global__
void
SequencePaddingKernel
(
__global__
void
SequencePaddingKernel
(
T
*
pad_data
,
T
*
seq_data
,
const
size_t
*
seq_offset
,
const
size_t
&
seq_num
,
T
*
dst
,
const
T
*
src
,
const
T
*
pad_value
,
bool
is_constant_pad
,
const
size_t
&
max_seq_len
,
const
size_t
&
seq_width
,
bool
norm_by_times
,
const
size_t
*
seq_offsets
,
const
size_t
&
seq_num
,
const
size_t
&
pad_seq_len
,
const
T
&
pad_value
,
const
OutputLayout
&
output_
layout
)
{
const
size_t
&
step_width
,
bool
norm_by_len
,
const
PadLayout
&
layout
)
{
size_t
seq_idx
=
blockIdx
.
y
;
size_t
seq_idx
=
blockIdx
.
y
;
size_t
seq_start
=
seq_offset
[
seq_idx
];
size_t
seq_len
=
seq_offsets
[
seq_idx
+
1
]
-
seq_offsets
[
seq_idx
];
size_t
seq_len
=
seq_offset
[
seq_idx
+
1
]
-
seq_start
;
size_t
step_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
size_t
seq_step_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
size_t
seq_data_offset
=
(
seq_offsets
[
seq_idx
]
+
step_idx
)
*
step_width
;
size_t
pad_data_offset
=
layout
==
kBatchLengthWidth
size_t
seq_data_offset
=
(
seq_start
+
seq_step_idx
)
*
seq_width
;
?
(
seq_idx
*
pad_seq_len
+
step_idx
)
*
step_width
:
(
step_idx
*
seq_num
+
seq_idx
)
*
step_width
;
size_t
pad_data_offset
=
0
;
T
*
dst_data
=
dst
+
(
Type
==
kSeqToPad
?
pad_data_offset
:
seq_data_offset
);
if
(
output_layout
==
kLengthBatchWidth
)
{
const
T
*
src_data
=
pad_data_offset
=
(
seq_step_idx
*
seq_num
+
seq_idx
)
*
seq_width
;
src
+
(
Type
==
kSeqToPad
?
seq_data_offset
:
pad_data_offset
);
}
else
{
pad_data_offset
=
(
seq_idx
*
max_seq_len
+
seq_step_idx
)
*
seq_width
;
if
(
step_idx
<
seq_len
)
{
}
float
scale
=
norm_by_len
?
(
1.0
f
/
static_cast
<
float
>
(
seq_len
))
:
1.0
f
;
for
(
size_t
i
=
threadIdx
.
x
;
i
<
step_width
;
i
+=
blockDim
.
x
)
{
if
(
seq_step_idx
<
seq_len
)
{
dst_data
[
i
]
=
scale
*
src_data
[
i
];
T
scale
=
norm_by_times
?
(
1.0
f
/
static_cast
<
T
>
(
seq_len
))
:
1.0
f
;
if
(
Padding
)
{
/* seq -> pad */
for
(
size_t
i
=
threadIdx
.
x
;
i
<
seq_width
;
i
+=
blockDim
.
x
)
{
pad_data
[
pad_data_offset
+
i
]
=
scale
*
seq_data
[
seq_data_offset
+
i
];
}
}
else
{
/* pad -> seq */
for
(
size_t
i
=
threadIdx
.
x
;
i
<
seq_width
;
i
+=
blockDim
.
x
)
{
seq_data
[
seq_data_offset
+
i
]
=
scale
*
pad_data
[
pad_data_offset
+
i
];
}
}
}
}
else
if
(
seq_step_idx
<
max_seq_len
)
{
}
else
if
(
step_idx
<
pad_seq_len
&&
Type
==
kSeqToPad
)
{
if
(
Padding
)
{
for
(
size_t
i
=
threadIdx
.
x
;
i
<
seq_width
;
i
+=
blockDim
.
x
)
{
/* seq -> pad */
dst_data
[
i
]
=
is_constant_pad
?
pad_value
[
0
]
:
pad_value
[
i
];
for
(
size_t
i
=
threadIdx
.
x
;
i
<
seq_width
;
i
+=
blockDim
.
x
)
{
pad_data
[
pad_data_offset
+
i
]
=
pad_value
;
}
}
}
}
}
}
}
...
@@ -69,24 +55,26 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
...
@@ -69,24 +55,26 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
LoDTensor
&
seq_tensor
,
const
framework
::
LoDTensor
&
seq_tensor
,
framework
::
Tensor
*
pad_tensor
,
framework
::
Tensor
*
pad_tensor
,
T
pad_value
=
static_cast
<
T
>
(
0
),
bool
norm_by_times
=
false
,
const
framework
::
LoDTensor
&
pad_value
,
int
pad_seq_len
=
-
1
,
size_t
lod_level
=
0
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
OutputLayout
output_layout
=
kBatchLengthWidth
)
{
const
PadLayout
layout
=
kBatchLengthWidth
)
{
CheckLoD
(
seq_tensor
,
lod_level
);
auto
seq_lod
=
seq_tensor
.
lod
();
const
auto
seq_offsets
=
framework
::
ToAbsOffset
(
seq_lod
)[
lod_level
];
auto
&
lod
=
seq_tensor
.
lod
();
const
auto
&
seq_tensor_dims
=
seq_tensor
.
dims
();
auto
&
seq_offset
=
framework
::
ToAbsOffset
(
lod
)[
lod_level
];
const
auto
&
pad_tensor_dims
=
pad_tensor
->
dims
();
if
(
pad_seq_len
==
-
1
)
{
auto
seq_tensor_dims
=
seq_tensor
.
dims
();
pad_seq_len
=
MaximumSequenceLength
(
seq_offsets
);
auto
pad_tensor_dims
=
pad_tensor
->
dims
();
}
int64_t
max_seq_len
=
MaximumSequenceLength
(
seq_offset
);
int
step_width
=
seq_tensor
.
numel
()
/
seq_tensor_dims
[
0
];
int64_t
seq_num
=
seq_offset
.
size
()
-
1
;
int
seq_num
=
seq_offset
.
size
()
-
1
;
int64_t
seq_width
=
seq_tensor
.
numel
()
/
seq_tensor_dims
[
0
];
CheckDims
(
seq_tensor_dims
,
seq_offset
.
back
(),
pad_tensor_dims
,
max_seq_len
,
CheckDims
(
seq_tensor_dims
,
pad_tensor_dims
,
seq_offsets
,
pad_seq_len
,
seq_num
,
seq_width
,
output_layout
);
step_width
,
layout
);
PADDLE_ENFORCE
(
pad_value
.
numel
()
==
1
||
pad_value
.
numel
()
==
step_width
,
"The numel of 'pad_value' can only be 1 or be equal to the "
"'step_width'."
);
if
(
!
norm_by_times
&&
seq_num
==
1UL
)
{
if
(
!
norm_by_times
&&
seq_num
==
1UL
&&
pad_seq_len
==
-
1
)
{
TensorCopy
(
seq_tensor
,
context
.
GetPlace
(),
context
,
pad_tensor
);
TensorCopy
(
seq_tensor
,
context
.
GetPlace
(),
context
,
pad_tensor
);
pad_tensor
->
Resize
(
pad_tensor_dims
);
pad_tensor
->
Resize
(
pad_tensor_dims
);
return
;
return
;
...
@@ -98,21 +86,22 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
...
@@ -98,21 +86,22 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
* and at least 8 elements for each thread.
* and at least 8 elements for each thread.
*/
*/
size_t
block_dim_x
=
size_t
block_dim_x
=
std
::
min
(((((
s
eq
_width
+
7
)
>>
3
)
+
31
)
>>
5
)
<<
5
,
kBlockSize
);
std
::
min
(((((
s
tep
_width
+
7
)
>>
3
)
+
31
)
>>
5
)
<<
5
,
kBlockSize
);
size_t
block_dim_y
=
kBlockSize
/
block_dim_x
;
size_t
block_dim_y
=
kBlockSize
/
block_dim_x
;
dim3
threads
(
block_dim_x
,
block_dim_y
);
dim3
threads
(
block_dim_x
,
block_dim_y
);
size_t
grid_dim_x
=
(
max
_seq_len
+
block_dim_y
-
1
)
/
block_dim_y
;
size_t
grid_dim_x
=
(
pad
_seq_len
+
block_dim_y
-
1
)
/
block_dim_y
;
size_t
grid_dim_y
=
seq_num
;
size_t
grid_dim_y
=
seq_num
;
dim3
grid
(
grid_dim_x
,
grid_dim_y
);
dim3
grid
(
grid_dim_x
,
grid_dim_y
);
const
T
*
seq_data
=
seq_tensor
.
data
<
T
>
();
const
T
*
seq_data
=
seq_tensor
.
data
<
T
>
();
T
*
pad_data
=
pad_tensor
->
data
<
T
>
();
T
*
pad_data
=
pad_tensor
->
data
<
T
>
();
const
T
*
pad_value_data
=
pad_value
.
data
<
T
>
();
SequencePaddingKernel
<
T
,
1
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
SequencePaddingKernel
<
T
,
kSeqToPad
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
pad_data
,
const_cast
<
T
*>
(
seq_data
)
,
pad_data
,
seq_data
,
pad_value_data
,
pad_value
.
numel
()
==
1
,
seq_offset
.
CUDAData
(
context
.
GetPlace
()),
seq_num
,
max
_seq_len
,
seq_offset
.
CUDAData
(
context
.
GetPlace
()),
seq_num
,
pad
_seq_len
,
s
eq_width
,
norm_by_times
,
pad_value
,
output_
layout
);
s
tep_width
,
norm_by_times
,
layout
);
}
}
};
};
...
@@ -120,25 +109,23 @@ template <typename T>
...
@@ -120,25 +109,23 @@ template <typename T>
class
UnpaddingLoDTensorFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
class
UnpaddingLoDTensorFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
framework
::
LoDTensor
*
seq_tensor
,
const
framework
::
LoDTensor
&
pad_tensor
,
const
framework
::
Tensor
&
pad_tensor
,
framework
::
LoDTensor
*
seq_tensor
,
int
pad_seq_len
=
-
1
,
bool
norm_by_times
=
false
,
size_t
lod_level
=
0
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
OutputLayout
output_layout
=
kBatchLengthWidth
)
{
const
PadLayout
layout
=
kBatchLengthWidth
)
{
CheckLoD
(
*
seq_tensor
,
lod_level
);
auto
seq_offsets
=
framework
::
ToAbsOffset
(
seq_tensor
->
lod
())[
lod_level
];
const
auto
&
seq_tensor_dims
=
seq_tensor
->
dims
();
auto
&
lod
=
seq_tensor
->
lod
();
const
auto
&
pad_tensor_dims
=
pad_tensor
.
dims
();
auto
&
seq_offset
=
framework
::
ToAbsOffset
(
lod
)[
lod_level
];
if
(
pad_seq_len
==
-
1
)
{
pad_seq_len
=
MaximumSequenceLength
(
seq_offsets
);
auto
seq_tensor_dims
=
seq_tensor
->
dims
();
}
auto
pad_tensor_dims
=
pad_tensor
.
dims
();
int
step_width
=
seq_tensor
->
numel
()
/
seq_tensor_dims
[
0
];
int64_t
max_seq_len
=
MaximumSequenceLength
(
seq_offset
);
int
seq_num
=
seq_offset
.
size
()
-
1
;
int64_t
seq_num
=
seq_offset
.
size
()
-
1
;
int64_t
seq_width
=
seq_tensor
->
numel
()
/
seq_tensor_dims
[
0
];
CheckDims
(
seq_tensor_dims
,
seq_offset
.
back
(),
pad_tensor_dims
,
max
_seq_len
,
CheckDims
(
seq_tensor_dims
,
pad_tensor_dims
,
seq_offsets
,
pad
_seq_len
,
s
eq_num
,
seq_width
,
output_
layout
);
s
tep_width
,
layout
);
if
(
!
norm_by_times
&&
seq_num
==
1UL
)
{
if
(
!
norm_by_times
&&
seq_num
==
1UL
&&
pad_seq_len
==
-
1
)
{
TensorCopy
(
pad_tensor
,
context
.
GetPlace
(),
context
,
seq_tensor
);
TensorCopy
(
pad_tensor
,
context
.
GetPlace
(),
context
,
seq_tensor
);
seq_tensor
->
Resize
(
seq_tensor_dims
);
seq_tensor
->
Resize
(
seq_tensor_dims
);
return
;
return
;
...
@@ -150,21 +137,21 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
...
@@ -150,21 +137,21 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
* and at least 8 elements for each thread.
* and at least 8 elements for each thread.
*/
*/
size_t
block_dim_x
=
size_t
block_dim_x
=
std
::
min
(((((
s
eq
_width
+
7
)
>>
3
)
+
31
)
>>
5
)
<<
5
,
kBlockSize
);
std
::
min
(((((
s
tep
_width
+
7
)
>>
3
)
+
31
)
>>
5
)
<<
5
,
kBlockSize
);
size_t
block_dim_y
=
kBlockSize
/
block_dim_x
;
size_t
block_dim_y
=
kBlockSize
/
block_dim_x
;
dim3
threads
(
block_dim_x
,
block_dim_y
);
dim3
threads
(
block_dim_x
,
block_dim_y
);
size_t
grid_dim_x
=
(
max
_seq_len
+
block_dim_y
-
1
)
/
block_dim_y
;
size_t
grid_dim_x
=
(
pad
_seq_len
+
block_dim_y
-
1
)
/
block_dim_y
;
size_t
grid_dim_y
=
seq_num
;
size_t
grid_dim_y
=
seq_num
;
dim3
grid
(
grid_dim_x
,
grid_dim_y
);
dim3
grid
(
grid_dim_x
,
grid_dim_y
);
const
T
*
pad_data
=
pad_tensor
.
data
<
T
>
();
const
T
*
pad_data
=
pad_tensor
.
data
<
T
>
();
T
*
seq_data
=
seq_tensor
->
data
<
T
>
();
T
*
seq_data
=
seq_tensor
->
data
<
T
>
();
SequencePaddingKernel
<
T
,
0
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
SequencePaddingKernel
<
T
,
kPadToSeq
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
const_cast
<
T
*>
(
pad_data
),
seq_data
,
seq_data
,
pad_data
,
nullptr
,
false
,
seq_offset
.
CUDAData
(
context
.
GetPlace
()),
seq_num
,
max
_seq_len
,
seq_offset
.
CUDAData
(
context
.
GetPlace
()),
seq_num
,
pad
_seq_len
,
s
eq_width
,
norm_by_times
,
static_cast
<
T
>
(
0
),
output_
layout
);
s
tep_width
,
norm_by_times
,
layout
);
}
}
};
};
...
...
paddle/fluid/operators/math/sequence_padding.h
浏览文件 @
9db40b8e
...
@@ -25,6 +25,8 @@ namespace math {
...
@@ -25,6 +25,8 @@ namespace math {
enum
PadLayout
{
kBatchLengthWidth
=
0
,
kLengthBatchWidth
};
enum
PadLayout
{
kBatchLengthWidth
=
0
,
kLengthBatchWidth
};
enum
CopyType
{
kSeqToPad
,
kPadToSeq
};
inline
static
size_t
MaximumSequenceLength
(
inline
static
size_t
MaximumSequenceLength
(
const
framework
::
Vector
<
size_t
>&
seq_offset
)
{
const
framework
::
Vector
<
size_t
>&
seq_offset
)
{
size_t
seq_num
=
seq_offset
.
size
()
-
1
;
size_t
seq_num
=
seq_offset
.
size
()
-
1
;
...
@@ -82,7 +84,7 @@ class PaddingLoDTensorFunctor {
...
@@ -82,7 +84,7 @@ class PaddingLoDTensorFunctor {
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
LoDTensor
&
seq_tensor
,
const
framework
::
LoDTensor
&
seq_tensor
,
framework
::
LoDTensor
*
pad_tensor
,
framework
::
LoDTensor
*
pad_tensor
,
std
::
vector
<
T
>
pad_value
=
{
0
}
,
int
pad_seq_len
=
-
1
,
const
framework
::
LoDTensor
&
pad_value
,
int
pad_seq_len
=
-
1
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
const
PadLayout
layout
=
kBatchLengthWidth
);
const
PadLayout
layout
=
kBatchLengthWidth
);
};
};
...
@@ -94,7 +96,7 @@ class UnpaddingLoDTensorFunctor {
...
@@ -94,7 +96,7 @@ class UnpaddingLoDTensorFunctor {
const
framework
::
LoDTensor
&
pad_tensor
,
const
framework
::
LoDTensor
&
pad_tensor
,
framework
::
LoDTensor
*
seq_tensor
,
int
pad_seq_len
=
-
1
,
framework
::
LoDTensor
*
seq_tensor
,
int
pad_seq_len
=
-
1
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
int
lod_level
=
0
,
bool
norm_by_times
=
false
,
const
PadLayout
&
layout
=
kBatchLengthWidth
);
const
PadLayout
layout
=
kBatchLengthWidth
);
};
};
}
// namespace math
}
// namespace math
...
...
paddle/fluid/operators/math/sequence_padding_test.cc
浏览文件 @
9db40b8e
...
@@ -24,6 +24,8 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
...
@@ -24,6 +24,8 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
paddle
::
framework
::
LoDTensor
seq
;
paddle
::
framework
::
LoDTensor
seq
;
paddle
::
framework
::
LoDTensor
seq_back
;
paddle
::
framework
::
LoDTensor
seq_back
;
paddle
::
framework
::
LoDTensor
padding
;
paddle
::
framework
::
LoDTensor
padding
;
paddle
::
framework
::
LoDTensor
cpu_pad_value
;
paddle
::
framework
::
LoDTensor
pad_value
;
const
size_t
level
=
lod
.
size
()
-
1
;
const
size_t
level
=
lod
.
size
()
-
1
;
auto
seq_dims
=
auto
seq_dims
=
...
@@ -55,8 +57,17 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
...
@@ -55,8 +57,17 @@ void TestSequencePadding(const paddle::framework::LoD& lod,
padding
.
mutable_data
<
T
>
(
padding_dims
,
*
place
);
padding
.
mutable_data
<
T
>
(
padding_dims
,
*
place
);
T
*
pad_value_data
=
cpu_pad_value
.
mutable_data
<
T
>
({
1
},
paddle
::
platform
::
CPUPlace
());
*
pad_value_data
=
static_cast
<
T
>
(
0
);
if
(
paddle
::
platform
::
is_cpu_place
(
*
place
))
{
pad_value
=
cpu_pad_value
;
}
else
{
TensorCopySync
(
cpu_pad_value
,
*
place
,
&
pad_value
);
}
paddle
::
operators
::
math
::
PaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
paddle
::
operators
::
math
::
PaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
*
context
,
seq
,
&
padding
,
{
0
}
,
-
1
,
0
,
false
,
*
context
,
seq
,
&
padding
,
pad_value
,
-
1
,
0
,
false
,
paddle
::
operators
::
math
::
kLengthBatchWidth
);
paddle
::
operators
::
math
::
kLengthBatchWidth
);
seq_back
.
set_lod
(
lod
);
seq_back
.
set_lod
(
lod
);
...
...
paddle/fluid/operators/sequence_pad_op.h
浏览文件 @
9db40b8e
...
@@ -35,14 +35,11 @@ class SequencePadOpKernel : public framework::OpKernel<T> {
...
@@ -35,14 +35,11 @@ class SequencePadOpKernel : public framework::OpKernel<T> {
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
auto
*
pad_value
=
ctx
.
Input
<
LoDTensor
>
(
"PadValue"
);
const
auto
*
pad_value
=
ctx
.
Input
<
LoDTensor
>
(
"PadValue"
);
const
T
*
pad_value_data
=
pad_value
->
data
<
T
>
();
std
::
vector
<
T
>
pad_value_vec
(
pad_value_data
,
pad_value_data
+
pad_value
->
numel
());
int
padded_length
=
ctx
.
Attr
<
int
>
(
"padded_length"
);
int
padded_length
=
ctx
.
Attr
<
int
>
(
"padded_length"
);
math
::
PaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
math
::
PaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
*
x
,
out
,
pad_value_vec
,
ctx
.
template
device_context
<
DeviceContext
>(),
*
x
,
out
,
*
pad_value
,
padded_length
,
0
,
false
,
math
::
kBatchLengthWidth
);
padded_length
,
0
,
false
,
math
::
kBatchLengthWidth
);
}
}
};
};
...
...
paddle/fluid/operators/warpctc_op.h
浏览文件 @
9db40b8e
...
@@ -161,10 +161,21 @@ class WarpCTCKernel : public framework::OpKernel<T> {
...
@@ -161,10 +161,21 @@ class WarpCTCKernel : public framework::OpKernel<T> {
static_cast
<
int64_t
>
(
num_sequences
),
static_cast
<
int64_t
>
(
num_sequences
),
static_cast
<
int64_t
>
(
sequence_width
)});
static_cast
<
int64_t
>
(
sequence_width
)});
warpctc_logits
.
mutable_data
<
T
>
(
warpctc_logits_dims
,
ctx
.
GetPlace
());
warpctc_logits
.
mutable_data
<
T
>
(
warpctc_logits_dims
,
ctx
.
GetPlace
());
LoDTensor
cpu_pad_value
;
T
*
pad_value_data
=
cpu_pad_value
.
mutable_data
<
T
>
({
1
},
platform
::
CPUPlace
());
*
pad_value_data
=
static_cast
<
T
>
(
0
);
LoDTensor
pad_value
;
if
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()))
{
pad_value
=
cpu_pad_value
;
}
else
{
TensorCopySync
(
cpu_pad_value
,
ctx
.
GetPlace
(),
&
pad_value
);
}
math
::
PaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
math
::
PaddingLoDTensorFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
*
logits
,
&
warpctc_logits
,
ctx
.
template
device_context
<
DeviceContext
>(),
*
logits
,
&
warpctc_logits
,
{
static_cast
<
T
>
(
0
)},
-
1
,
0
,
false
/* norm_by_times */
,
pad_value
,
-
1
,
0
,
false
/* norm_by_times */
,
math
::
kLengthBatchWidth
);
math
::
kLengthBatchWidth
);
const
T
*
warpctc_logits_data
=
warpctc_logits
.
data
<
T
>
();
const
T
*
warpctc_logits_data
=
warpctc_logits
.
data
<
T
>
();
std
::
vector
<
int
>
warpctc_label_lengths
(
num_sequences
);
std
::
vector
<
int
>
warpctc_label_lengths
(
num_sequences
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录