Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
8f37c3c2
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8f37c3c2
编写于
1月 15, 2018
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix sequence scale functor cuda kernel
1. Fix kernel 2. Add more test case
上级
45cf2341
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
47 addition
and
24 deletion
+47
-24
paddle/operators/math/sequence_scale.cu
paddle/operators/math/sequence_scale.cu
+10
-5
python/paddle/v2/fluid/tests/test_warpctc_op.py
python/paddle/v2/fluid/tests/test_warpctc_op.py
+37
-19
未找到文件。
paddle/operators/math/sequence_scale.cu
浏览文件 @
8f37c3c2
...
...
@@ -13,16 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence_scale.h"
#include "paddle/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
template
<
typename
T
,
int
BlockSize
>
__global__
void
SequenceScaleKernel
(
T
*
seq
,
size_t
*
lod
,
const
T
*
scales
,
const
size_t
seq_width
)
{
if
(
threadIdx
.
x
<
(
lod
[
blockIdx
.
x
+
1
]
-
lod
[
blockIdx
.
x
])
*
seq_width
)
{
int
idx
=
lod
[
blockIdx
.
x
]
*
seq_width
+
threadIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
(
lod
[
blockIdx
.
x
+
1
]
-
lod
[
blockIdx
.
x
])
*
seq_width
;
i
+=
BlockSize
)
{
int
idx
=
lod
[
blockIdx
.
x
]
*
seq_width
+
i
;
seq
[
idx
]
*=
scales
[
blockIdx
.
x
];
}
}
...
...
@@ -39,8 +44,8 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> {
framework
::
LoD
abs_offset_lod
=
framework
::
ToAbsOffset
(
lod
);
T
*
seq_data
=
seq
.
mutable_data
<
T
>
(
context
.
GetPlace
());
int
threads
=
1024
;
SequenceScaleKernel
<
T
><<<
num_seq
,
threads
,
0
,
context
.
stream
()
>>>
(
SequenceScaleKernel
<
T
,
PADDLE_CUDA_NUM_THREADS
><<<
num_seq
,
PADDLE_CUDA_NUM_THREADS
,
0
,
context
.
stream
()
>>>
(
seq_data
,
abs_offset_lod
[
level
].
data
(),
scales
,
seq_width
);
}
};
...
...
python/paddle/v2/fluid/tests/test_warpctc_op.py
浏览文件 @
8f37c3c2
...
...
@@ -4,6 +4,8 @@ import numpy as np
from
op_test
import
OpTest
from
test_softmax_op
import
stable_softmax
CUDA_BLOCK_SIZE
=
512
class
CTCForward
(
object
):
def
__init__
(
self
,
softmax
,
softmax_lod
,
labels
,
labels_lod
,
blank
,
...
...
@@ -154,39 +156,45 @@ class CTCForward(object):
class
TestWarpCTCOp
(
OpTest
):
def
config
(
self
):
self
.
batch_size
=
4
self
.
num_classes
=
8
self
.
logits_lod
=
[[
0
,
4
,
5
,
8
,
11
]]
self
.
labels_lod
=
[[
0
,
3
,
4
,
8
,
12
]]
self
.
blank
=
self
.
num_classes
-
1
self
.
norm_by_times
=
False
def
setUp
(
self
):
self
.
op_type
=
"warpctc"
self
.
config
()
batch_size
=
4
num_classes
=
8
logits_lod
=
[[
0
,
4
,
5
,
8
,
11
]]
logits
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
11
,
num_classes
]).
astype
(
"float32"
)
logits
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
self
.
logits_lod
[
0
][
-
1
],
self
.
num_classes
]).
astype
(
"float32"
)
softmax
=
np
.
apply_along_axis
(
stable_softmax
,
1
,
logits
)
labels_lod
=
[[
0
,
3
,
4
,
8
,
12
]]
# labels should not be blank
labels
=
np
.
random
.
randint
(
0
,
num_classes
-
1
,
[
12
,
1
],
dtype
=
"int32"
)
blank
=
num_classes
-
1
norm_by_times
=
False
labels
=
np
.
random
.
randint
(
0
,
self
.
num_classes
-
1
,
[
self
.
labels_lod
[
0
][
-
1
],
1
],
dtype
=
"int32"
)
ctc
=
CTCForward
(
softmax
,
logits_lod
,
labels
,
labels_lod
,
blank
,
norm_by_times
)
ctc
=
CTCForward
(
softmax
,
self
.
logits_lod
,
labels
,
self
.
labels_lod
,
self
.
blank
,
self
.
norm_by_times
)
loss
=
ctc
.
forward
()
max_sequence_length
=
0
for
i
in
range
(
batch_size
):
max_sequence_length
=
max
(
max_sequence_length
,
logits_lod
[
0
][
i
+
1
]
-
logits_lod
[
0
][
i
])
for
i
in
range
(
self
.
batch_size
):
max_sequence_length
=
max
(
max_sequence_length
,
self
.
logits_lod
[
0
][
i
+
1
]
-
self
.
logits_lod
[
0
][
i
])
self
.
gradient
=
np
.
zeros
(
[
max_sequence_length
,
batch_size
,
num_classes
],
dtype
=
"float32"
)
[
max_sequence_length
,
self
.
batch_size
,
self
.
num_classes
],
dtype
=
"float32"
)
self
.
inputs
=
{
"Logits"
:
(
logits
,
logits_lod
),
"Label"
:
(
labels
,
labels_lod
)
"Logits"
:
(
logits
,
self
.
logits_lod
),
"Label"
:
(
labels
,
self
.
labels_lod
)
}
self
.
outputs
=
{
"Loss"
:
loss
}
self
.
attrs
=
{
"blank"
:
blank
,
"norm_by_times"
:
norm_by_times
}
self
.
attrs
=
{
"blank"
:
self
.
blank
,
"norm_by_times"
:
self
.
norm_by_times
}
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -196,5 +204,15 @@ class TestWarpCTCOp(OpTest):
self
.
check_grad
([
"Logits"
],
"Loss"
,
max_relative_error
=
0.007
)
class
TestWarpCTCOpCase1
(
TestWarpCTCOp
):
def
config
(
self
):
self
.
batch_size
=
4
self
.
num_classes
=
CUDA_BLOCK_SIZE
+
2
self
.
logits_lod
=
[[
0
,
4
,
5
,
8
,
11
]]
self
.
labels_lod
=
[[
0
,
3
,
4
,
8
,
12
]]
self
.
blank
=
0
self
.
norm_by_times
=
False
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录