Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6884dc80
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6884dc80
编写于
9月 25, 2019
作者:
L
Liufang Sang
提交者:
whs
9月 25, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine ctc align op with padding (#19926)
* refine ctc align op with padding * refine api sample code
上级
65a02fc1
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
182 addition
and
41 deletion
+182
-41
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-1
paddle/fluid/operators/ctc_align_op.cc
paddle/fluid/operators/ctc_align_op.cc
+25
-6
paddle/fluid/operators/ctc_align_op.cu
paddle/fluid/operators/ctc_align_op.cu
+13
-9
paddle/fluid/operators/ctc_align_op.h
paddle/fluid/operators/ctc_align_op.h
+8
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+81
-15
python/paddle/fluid/tests/unittests/test_ctc_align.py
python/paddle/fluid/tests/unittests/test_ctc_align.py
+54
-9
未找到文件。
paddle/fluid/API.spec
浏览文件 @
6884dc80
...
@@ -161,7 +161,7 @@ paddle.fluid.layers.sequence_last_step (ArgSpec(args=['input'], varargs=None, ke
...
@@ -161,7 +161,7 @@ paddle.fluid.layers.sequence_last_step (ArgSpec(args=['input'], varargs=None, ke
paddle.fluid.layers.sequence_slice (ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '39fbc5437be389f6c0c769f82fc1fba2'))
paddle.fluid.layers.sequence_slice (ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '39fbc5437be389f6c0c769f82fc1fba2'))
paddle.fluid.layers.dropout (ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name', 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer')), ('document', '558d13133596209190df9a624264f28f'))
paddle.fluid.layers.dropout (ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name', 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer')), ('document', '558d13133596209190df9a624264f28f'))
paddle.fluid.layers.split (ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None)), ('document', '78cf3a7323d1a7697658242e13f63759'))
paddle.fluid.layers.split (ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None)), ('document', '78cf3a7323d1a7697658242e13f63759'))
paddle.fluid.layers.ctc_greedy_decoder (ArgSpec(args=['input', 'blank', '
name'], varargs=None, keywords=None, defaults=(None,)), ('document', '2bc3a59efa9d52b628a6255422d9f0e8
'))
paddle.fluid.layers.ctc_greedy_decoder (ArgSpec(args=['input', 'blank', '
input_length', 'padding_value', 'name'], varargs=None, keywords=None, defaults=(None, 0, None)), ('document', '9abb7bb8d267e017620a39a146dc47ea
'))
paddle.fluid.layers.edit_distance (ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens', 'input_length', 'label_length'], varargs=None, keywords=None, defaults=(True, None, None, None)), ('document', '77cbfb28cd2fc589f589c7013c5086cd'))
paddle.fluid.layers.edit_distance (ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens', 'input_length', 'label_length'], varargs=None, keywords=None, defaults=(True, None, None, None)), ('document', '77cbfb28cd2fc589f589c7013c5086cd'))
paddle.fluid.layers.l2_normalize (ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None)), ('document', 'c1df110ea65998984f564c5c10abc54a'))
paddle.fluid.layers.l2_normalize (ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None)), ('document', 'c1df110ea65998984f564c5c10abc54a'))
paddle.fluid.layers.matmul (ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y', 'alpha', 'name'], varargs=None, keywords=None, defaults=(False, False, 1.0, None)), ('document', '3720b4a386585094435993deb028b592'))
paddle.fluid.layers.matmul (ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y', 'alpha', 'name'], varargs=None, keywords=None, defaults=(False, False, 1.0, None)), ('document', '3720b4a386585094435993deb028b592'))
...
...
paddle/fluid/operators/ctc_align_op.cc
浏览文件 @
6884dc80
...
@@ -22,15 +22,18 @@ class CTCAlignOp : public framework::OperatorWithKernel {
...
@@ -22,15 +22,18 @@ class CTCAlignOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
)
,
PADDLE_ENFORCE
_EQ
(
ctx
->
HasInput
(
"Input"
),
true
,
"Input of CTCAlignOp should not be null."
);
"Input of CTCAlignOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Output"
)
,
PADDLE_ENFORCE
_EQ
(
ctx
->
HasOutput
(
"Output"
),
true
,
"Output of CTCAlignOp should not be null."
);
"Output of CTCAlignOp should not be null."
);
auto
input_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
input_dims
=
ctx
->
GetInputDim
(
"Input"
);
// TODO(wanghaoshuang): it is tricky to set the wrong dimension here.
// TODO(wanghaoshuang): it is tricky to set the wrong dimension here.
ctx
->
SetOutputDim
(
"Output"
,
input_dims
);
ctx
->
SetOutputDim
(
"Output"
,
input_dims
);
if
(
ctx
->
HasInput
(
"InputLength"
))
{
ctx
->
SetOutputDim
(
"OutputLength"
,
{
input_dims
[
0
],
1
});
}
}
}
protected:
protected:
...
@@ -47,7 +50,17 @@ class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -47,7 +50,17 @@ class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"Input"
,
AddInput
(
"Input"
,
"2-D Tensor or LodTensor with shape "
"2-D Tensor or LodTensor with shape "
"[Lp, 1], where Lp is the sum of all input sequences' length."
);
"[Lp, 1], where Lp is the sum of all input sequences' length."
);
AddInput
(
"InputLength"
,
"2-D Tensor with shape [batch_size, 1], "
" When Input is padding mode, InputLength is length of every "
"sequence in Input."
)
.
AsDispensable
();
AddOutput
(
"Output"
,
"(Tensor, default: Tensor<int>), The align result."
);
AddOutput
(
"Output"
,
"(Tensor, default: Tensor<int>), The align result."
);
AddOutput
(
"OutputLength"
,
"2-D Tensor with shape [batch_size, 1], "
"When Input is padding mode, OutputLength is length of every "
"sequence in Output."
)
.
AsDispensable
();
AddAttr
<
int
>
(
"blank"
,
AddAttr
<
int
>
(
"blank"
,
"(int, default: 0), the blank label setted in Connectionist "
"(int, default: 0), the blank label setted in Connectionist "
"Temporal Classification (CTC) op."
)
"Temporal Classification (CTC) op."
)
...
@@ -84,6 +97,9 @@ or Given:
...
@@ -84,6 +97,9 @@ or Given:
Input.data = [[0, 1, 2, 2, 0, 4],
Input.data = [[0, 1, 2, 2, 0, 4],
[0, 4, 5, 0, 6, 0],
[0, 4, 5, 0, 6, 0],
[0, 7, 7, 7, 0, 0]]
[0, 7, 7, 7, 0, 0]]
InputLength.data = [[6],
[5],
[4]],
Input.dims = {3, 6},
Input.dims = {3, 6},
Input.Lod = []
Input.Lod = []
And:
And:
...
@@ -94,7 +110,10 @@ And:
...
@@ -94,7 +110,10 @@ And:
Then:
Then:
Output.data = [[1, 2, 4, 0, 0, 0],
Output.data = [[1, 2, 4, 0, 0, 0],
[4, 5, 6, 0, 0, 0],
[4, 5, 6, 0, 0, 0],
[7, 0, 0, 0, 0, 0]]
[7, 0, 0, 0, 0, 0]],
OutputLength.data = [[3],
[3],
[1]],
Output.dims = {3, 6},
Output.dims = {3, 6},
Output.Lod = []
Output.Lod = []
)DOC"
);
)DOC"
);
...
...
paddle/fluid/operators/ctc_align_op.cu
浏览文件 @
6884dc80
...
@@ -43,17 +43,15 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens,
...
@@ -43,17 +43,15 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens,
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
PaddingMergeAndDelCudaKernel
(
const
int64_t
num_token
,
__global__
void
PaddingMergeAndDelCudaKernel
(
const
T
*
tokens
,
const
int
blank
,
const
int64_t
num_token
,
const
T
*
tokens
,
const
T
*
tokens_length
,
const
int
merge_repeated
,
const
int
blank
,
const
int
merge_repeated
,
const
int
padding_value
,
const
int
padding_value
,
const
int64_t
batch_size
,
T
*
output
,
T
*
output_length
)
{
const
int64_t
batch_size
,
T
*
output
)
{
int
ind
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
ind
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
ind
>=
batch_size
)
return
;
if
(
ind
>=
batch_size
)
return
;
int
output_idx
=
ind
*
num_token
;
int
output_idx
=
ind
*
num_token
;
T
prev_token
=
-
1
;
T
prev_token
=
-
1
;
for
(
int
i
=
ind
*
num_token
;
i
<
ind
*
num_token
+
num_token
;
i
++
)
{
for
(
int
i
=
ind
*
num_token
;
i
<
ind
*
num_token
+
tokens_length
[
ind
]
;
i
++
)
{
if
((
unsigned
)
tokens
[
i
]
!=
blank
&&
if
((
unsigned
)
tokens
[
i
]
!=
blank
&&
!
(
merge_repeated
&&
tokens
[
i
]
==
prev_token
))
{
!
(
merge_repeated
&&
tokens
[
i
]
==
prev_token
))
{
output
[
output_idx
]
=
tokens
[
i
];
output
[
output_idx
]
=
tokens
[
i
];
...
@@ -61,6 +59,7 @@ __global__ void PaddingMergeAndDelCudaKernel(const int64_t num_token,
...
@@ -61,6 +59,7 @@ __global__ void PaddingMergeAndDelCudaKernel(const int64_t num_token,
}
}
prev_token
=
tokens
[
i
];
prev_token
=
tokens
[
i
];
}
}
output_length
[
ind
]
=
output_idx
-
ind
*
num_token
;
for
(
int
i
=
output_idx
;
i
<
ind
*
num_token
+
num_token
;
i
++
)
{
for
(
int
i
=
output_idx
;
i
<
ind
*
num_token
+
num_token
;
i
++
)
{
output
[
i
]
=
padding_value
;
output
[
i
]
=
padding_value
;
}
}
...
@@ -86,10 +85,15 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -86,10 +85,15 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
auto
input_dims
=
input
->
dims
();
auto
input_dims
=
input
->
dims
();
T
*
output_data
=
output
->
mutable_data
<
T
>
({
input_dims
[
0
],
input_dims
[
1
]},
T
*
output_data
=
output
->
mutable_data
<
T
>
({
input_dims
[
0
],
input_dims
[
1
]},
ctx
.
GetPlace
());
ctx
.
GetPlace
());
auto
*
input_length
=
ctx
.
Input
<
LoDTensor
>
(
"InputLength"
);
const
T
*
input_length_data
=
input_length
->
data
<
T
>
();
auto
*
output_length
=
ctx
.
Output
<
LoDTensor
>
(
"OutputLength"
);
T
*
output_length_data
=
output_length
->
mutable_data
<
T
>
({
input_dims
[
0
],
1
},
ctx
.
GetPlace
());
PaddingMergeAndDelCudaKernel
<
PaddingMergeAndDelCudaKernel
<
T
><<<
32
,
(
input_dims
[
0
]
+
32
-
1
)
/
32
,
0
,
stream
>>>
(
T
><<<
32
,
(
input_dims
[
0
]
+
32
-
1
)
/
32
,
0
,
stream
>>>
(
input_dims
[
1
],
tokens
,
blank
,
merge_repeated
,
padding_value
,
input_dims
[
1
],
tokens
,
input_length_data
,
blank
,
merge_repeated
,
input_dims
[
0
],
output
_data
);
padding_value
,
input_dims
[
0
],
output_data
,
output_length
_data
);
}
else
{
}
else
{
const
size_t
level
=
0
;
const
size_t
level
=
0
;
auto
input_lod
=
framework
::
ToAbsOffset
(
input
->
lod
());
auto
input_lod
=
framework
::
ToAbsOffset
(
input
->
lod
());
...
...
paddle/fluid/operators/ctc_align_op.h
浏览文件 @
6884dc80
...
@@ -41,11 +41,17 @@ class CTCAlignKernel : public framework::OpKernel<T> {
...
@@ -41,11 +41,17 @@ class CTCAlignKernel : public framework::OpKernel<T> {
if
(
input
->
lod
().
empty
())
{
if
(
input
->
lod
().
empty
())
{
size_t
padding_value
=
size_t
padding_value
=
static_cast
<
size_t
>
(
ctx
.
Attr
<
int
>
(
"padding_value"
));
static_cast
<
size_t
>
(
ctx
.
Attr
<
int
>
(
"padding_value"
));
auto
*
input_length
=
ctx
.
Input
<
LoDTensor
>
(
"InputLength"
);
const
T
*
input_length_data
=
input_length
->
data
<
T
>
();
auto
*
output_length
=
ctx
.
Output
<
LoDTensor
>
(
"OutputLength"
);
T
*
output_length_data
=
output_length
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
for
(
size_t
batch_id
=
0
;
batch_id
<
(
unsigned
)
input_dims
[
0
];
for
(
size_t
batch_id
=
0
;
batch_id
<
(
unsigned
)
input_dims
[
0
];
batch_id
++
)
{
batch_id
++
)
{
T
prev_token
=
-
1
;
T
prev_token
=
-
1
;
size_t
output_idx
=
0
;
size_t
output_idx
=
0
;
for
(
size_t
i
=
0
;
i
<
(
unsigned
)
input_
dims
[
1
];
i
++
)
{
for
(
size_t
i
=
0
;
i
<
(
unsigned
)
input_
length_data
[
batch_id
];
i
++
)
{
size_t
input_ind
=
batch_id
*
input_dims
[
1
]
+
i
;
size_t
input_ind
=
batch_id
*
input_dims
[
1
]
+
i
;
if
((
unsigned
)
input_data
[
input_ind
]
!=
blank
&&
if
((
unsigned
)
input_data
[
input_ind
]
!=
blank
&&
!
(
merge_repeated
&&
input_data
[
input_ind
]
==
prev_token
))
{
!
(
merge_repeated
&&
input_data
[
input_ind
]
==
prev_token
))
{
...
@@ -55,6 +61,7 @@ class CTCAlignKernel : public framework::OpKernel<T> {
...
@@ -55,6 +61,7 @@ class CTCAlignKernel : public framework::OpKernel<T> {
}
}
prev_token
=
input_data
[
input_ind
];
prev_token
=
input_data
[
input_ind
];
}
}
output_length_data
[
batch_id
]
=
output_idx
;
for
(
size_t
j
=
output_idx
;
j
<
(
unsigned
)
input_dims
[
1
];
j
++
)
for
(
size_t
j
=
output_idx
;
j
<
(
unsigned
)
input_dims
[
1
];
j
++
)
output_data
[
batch_id
*
input_dims
[
1
]
+
j
]
=
padding_value
;
output_data
[
batch_id
*
input_dims
[
1
]
+
j
]
=
padding_value
;
}
}
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
6884dc80
...
@@ -5851,7 +5851,11 @@ def edit_distance(input,
...
@@ -5851,7 +5851,11 @@ def edit_distance(input,
return edit_distance_out, sequence_num
return edit_distance_out, sequence_num
def ctc_greedy_decoder(input, blank, name=None):
def ctc_greedy_decoder(input,
blank,
input_length=None,
padding_value=0,
name=None):
"""
"""
This op is used to decode sequences by greedy policy by below steps:
This op is used to decode sequences by greedy policy by below steps:
...
@@ -5865,6 +5869,7 @@ def ctc_greedy_decoder(input, blank, name=None):
...
@@ -5865,6 +5869,7 @@ def ctc_greedy_decoder(input, blank, name=None):
.. code-block:: text
.. code-block:: text
Given:
Given:
for lod mode:
input.data = [[0.6, 0.1, 0.3, 0.1],
input.data = [[0.6, 0.1, 0.3, 0.1],
[0.3, 0.2, 0.4, 0.1],
[0.3, 0.2, 0.4, 0.1],
...
@@ -5893,38 +5898,83 @@ def ctc_greedy_decoder(input, blank, name=None):
...
@@ -5893,38 +5898,83 @@ def ctc_greedy_decoder(input, blank, name=None):
output.lod = [[2, 1]]
output.lod = [[2, 1]]
for padding mode:
input.data = [[[0.6, 0.1, 0.3, 0.1],
[0.3, 0.2, 0.4, 0.1],
[0.1, 0.5, 0.1, 0.3],
[0.5, 0.1, 0.3, 0.1]],
[[0.5, 0.1, 0.3, 0.1],
[0.2, 0.2, 0.2, 0.4],
[0.2, 0.2, 0.1, 0.5],
[0.5, 0.1, 0.3, 0.1]]]
input_length.data = [[4], [4]]
input.shape = [2, 4, 4]
step1: Apply argmax to first input sequence which is input.data[0:4]. Then we get:
[[0], [2], [1], [0]], for input.data[4:8] is [[0], [3], [3], [0]], shape is [2,4,1]
step2: Change the argmax result to use padding mode, then argmax result is
[[0, 2, 1, 0], [0, 3, 3, 0]], shape is [2, 4], lod is [], input_length is [[4], [4]]
step3: Apply ctc_align to padding argmax result, padding_value is 0
Finally:
output.data = [[2, 1, 0, 0],
[3, 0, 0, 0]]
output_length.data = [[2], [1]]
Args:
Args:
input(Variable): (LoDTensor<float>), the probabilities of
input(Variable): (LoDTensor<float>), the probabilities of
variable-length sequences
, which
is a 2-D Tensor with
variable-length sequences
. When in lod mode, it
is a 2-D Tensor with
LoD information. It's shape is [Lp, num_classes + 1]
,
LoD information. It's shape is [Lp, num_classes + 1]
where Lp is the sum of all input sequences' length and
where Lp is the sum of all input sequences' length and
num_classes is the true number of classes. (not
num_classes is the true number of classes. When in padding mode,
including the blank label).
it is a 3-D Tensor with padding, It's shape is [batch_size, N, num_classes + 1].
(not including the blank label).
blank(int): the blank label index of Connectionist Temporal
blank(int): the blank label index of Connectionist Temporal
Classification (CTC) loss, which is in thehalf-opened
Classification (CTC) loss, which is in thehalf-opened
interval [0, num_classes + 1).
interval [0, num_classes + 1).
name (str): The name of this layer. It is optional.
input_length(Variable, optional): (LoDTensor<int>), shape is [batch_size, 1], when in lod mode, input_length
is None.
padding_value(int): padding value.
name (str, optional): The name of this layer. It is optional.
Returns:
Returns:
Variable:
CTC greedy decode result which is a 2-D tensor with shape [Lp, 1]. \
output(Variable): For lod mode,
CTC greedy decode result which is a 2-D tensor with shape [Lp, 1]. \
'Lp' is the sum if all output sequences' length. If all the sequences \
'Lp' is the sum if all output sequences' length. If all the sequences \
in result were empty, the result LoDTensor will be [-1] with \
in result were empty, the result LoDTensor will be [-1] with \
LoD [[]] and dims [1, 1].
LoD [[]] and dims [1, 1]. For padding mode, CTC greedy decode result is a 2-D tensor \
with shape [batch_size, N], output length's shape is [batch_size, 1] which is length \
of every sequence in output.
output_length(Variable, optional): length of each sequence of output for padding mode.
Examples:
Examples:
.. code-block:: python
.. code-block:: python
# for lod mode
import paddle.fluid as fluid
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[8], dtype='float32')
x = fluid.layers.data(name='x', shape=[8], dtype='float32')
cost = fluid.layers.ctc_greedy_decoder(input=x, blank=0)
cost = fluid.layers.ctc_greedy_decoder(input=x, blank=0)
# for padding mode
x_pad = fluid.layers.data(name='x_pad', shape=[4,8], dtype='float32')
x_pad_len = fluid.layers.data(name='x_pad_len', shape=[1], dtype='int64')
out, out_len = fluid.layers.ctc_greedy_decoder(input=x_pad, blank=0,
input_length=x_pad_len)
"""
"""
helper = LayerHelper("ctc_greedy_decoder", **locals())
helper = LayerHelper("ctc_greedy_decoder", **locals())
_, topk_indices = topk(input, k=1)
_, topk_indices = topk(input, k=1)
# ctc align op
# ctc align op
ctc_out = helper.create_variable_for_type_inference(dtype="int64")
ctc_out = helper.create_variable_for_type_inference(dtype="int64")
if input_length is None:
helper.append_op(
helper.append_op(
type="ctc_align",
type="ctc_align",
inputs={"Input": [topk_indices]},
inputs={"Input": [topk_indices]},
...
@@ -5932,6 +5982,22 @@ def ctc_greedy_decoder(input, blank, name=None):
...
@@ -5932,6 +5982,22 @@ def ctc_greedy_decoder(input, blank, name=None):
attrs={"merge_repeated": True,
attrs={"merge_repeated": True,
"blank": blank})
"blank": blank})
return ctc_out
return ctc_out
else:
ctc_out_len = helper.create_variable_for_type_inference(dtype="int64")
ctc_input = squeeze(topk_indices, [2])
helper.append_op(
type="ctc_align",
inputs={"Input": [ctc_input],
"InputLength": [input_length]},
outputs={"Output": [ctc_out],
"OutputLength": [ctc_out_len]},
attrs={
"merge_repeated": True,
"blank": blank,
"padding_value": padding_value
})
return ctc_out, ctc_out_len
def warpctc(input,
def warpctc(input,
...
...
python/paddle/fluid/tests/unittests/test_ctc_align.py
浏览文件 @
6884dc80
...
@@ -19,10 +19,11 @@ import unittest
...
@@ -19,10 +19,11 @@ import unittest
import
numpy
as
np
import
numpy
as
np
from
op_test
import
OpTest
from
op_test
import
OpTest
from
test_softmax_op
import
stable_softmax
from
test_softmax_op
import
stable_softmax
import
paddle.fluid
as
fluid
def
CTCAlign
(
input
,
lod
,
blank
,
merge_repeated
,
padding
=
0
):
def
CTCAlign
(
input
,
lod
,
blank
,
merge_repeated
,
padding
=
0
,
input_length
=
None
):
if
lod
is
not
None
and
len
(
lod
)
>
0
:
if
input_length
is
None
:
lod0
=
lod
[
0
]
lod0
=
lod
[
0
]
result
=
[]
result
=
[]
cur_offset
=
0
cur_offset
=
0
...
@@ -38,23 +39,28 @@ def CTCAlign(input, lod, blank, merge_repeated, padding=0):
...
@@ -38,23 +39,28 @@ def CTCAlign(input, lod, blank, merge_repeated, padding=0):
result
=
np
.
array
(
result
).
reshape
([
len
(
result
),
1
]).
astype
(
"int32"
)
result
=
np
.
array
(
result
).
reshape
([
len
(
result
),
1
]).
astype
(
"int32"
)
if
len
(
result
)
==
0
:
if
len
(
result
)
==
0
:
result
=
np
.
array
([
-
1
])
result
=
np
.
array
([
-
1
])
return
result
else
:
else
:
result
=
[[]
for
i
in
range
(
len
(
input
))]
result
=
[[]
for
i
in
range
(
len
(
input
))]
output_length
=
[]
for
i
in
range
(
len
(
input
)):
for
i
in
range
(
len
(
input
)):
prev_token
=
-
1
prev_token
=
-
1
for
j
in
range
(
len
(
input
[
i
])
):
for
j
in
range
(
input_length
[
i
][
0
]
):
token
=
input
[
i
][
j
]
token
=
input
[
i
][
j
]
if
(
token
!=
blank
)
and
not
(
merge_repeated
and
if
(
token
!=
blank
)
and
not
(
merge_repeated
and
token
==
prev_token
):
token
==
prev_token
):
result
[
i
].
append
(
token
)
result
[
i
].
append
(
token
)
prev_token
=
token
prev_token
=
token
start
=
len
(
result
[
i
])
start
=
len
(
result
[
i
])
output_length
.
append
([
start
])
for
j
in
range
(
start
,
len
(
input
[
i
])):
for
j
in
range
(
start
,
len
(
input
[
i
])):
result
[
i
].
append
(
padding
)
result
[
i
].
append
(
padding
)
result
=
np
.
array
(
result
).
reshape
(
result
=
np
.
array
(
result
).
reshape
(
[
len
(
input
),
len
(
input
[
0
])]).
astype
(
"int32"
)
[
len
(
input
),
len
(
input
[
0
])]).
astype
(
"int32"
)
output_length
=
np
.
array
(
output_length
).
reshape
(
[
len
(
input
),
1
]).
astype
(
"int32"
)
return
result
return
result
,
output_length
class
TestCTCAlignOp
(
OpTest
):
class
TestCTCAlignOp
(
OpTest
):
...
@@ -114,13 +120,18 @@ class TestCTCAlignPaddingOp(OpTest):
...
@@ -114,13 +120,18 @@ class TestCTCAlignPaddingOp(OpTest):
self
.
input
=
np
.
array
([[
0
,
2
,
4
,
4
,
0
,
6
,
3
,
6
,
6
,
0
,
0
],
self
.
input
=
np
.
array
([[
0
,
2
,
4
,
4
,
0
,
6
,
3
,
6
,
6
,
0
,
0
],
[
1
,
1
,
3
,
0
,
0
,
4
,
5
,
6
,
0
,
0
,
0
]]).
reshape
(
[
1
,
1
,
3
,
0
,
0
,
4
,
5
,
6
,
0
,
0
,
0
]]).
reshape
(
[
2
,
11
]).
astype
(
"int32"
)
[
2
,
11
]).
astype
(
"int32"
)
self
.
input_length
=
np
.
array
([[
9
],
[
8
]]).
reshape
([
2
,
1
]).
astype
(
"int32"
)
def
setUp
(
self
):
def
setUp
(
self
):
self
.
config
()
self
.
config
()
output
=
CTCAlign
(
self
.
input
,
self
.
input_lod
,
self
.
blank
,
output
,
output_length
=
CTCAlign
(
self
.
input
,
self
.
input_lod
,
self
.
blank
,
self
.
merge_repeated
,
self
.
padding_value
)
self
.
merge_repeated
,
self
.
inputs
=
{
"Input"
:
(
self
.
input
,
self
.
input_lod
),
}
self
.
padding_value
,
self
.
input_length
)
self
.
outputs
=
{
"Output"
:
output
}
self
.
inputs
=
{
"Input"
:
(
self
.
input
,
self
.
input_lod
),
"InputLength"
:
self
.
input_length
}
self
.
outputs
=
{
"Output"
:
output
,
"OutputLength"
:
output_length
}
self
.
attrs
=
{
self
.
attrs
=
{
"blank"
:
self
.
blank
,
"blank"
:
self
.
blank
,
"merge_repeated"
:
self
.
merge_repeated
,
"merge_repeated"
:
self
.
merge_repeated
,
...
@@ -129,7 +140,6 @@ class TestCTCAlignPaddingOp(OpTest):
...
@@ -129,7 +140,6 @@ class TestCTCAlignPaddingOp(OpTest):
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
pass
class
TestCTCAlignOpCase3
(
TestCTCAlignPaddingOp
):
class
TestCTCAlignOpCase3
(
TestCTCAlignPaddingOp
):
...
@@ -142,6 +152,8 @@ class TestCTCAlignOpCase3(TestCTCAlignPaddingOp):
...
@@ -142,6 +152,8 @@ class TestCTCAlignOpCase3(TestCTCAlignPaddingOp):
self
.
input
=
np
.
array
([[
0
,
1
,
2
,
2
,
0
,
4
],
[
0
,
4
,
5
,
0
,
6
,
0
],
self
.
input
=
np
.
array
([[
0
,
1
,
2
,
2
,
0
,
4
],
[
0
,
4
,
5
,
0
,
6
,
0
],
[
0
,
7
,
7
,
7
,
0
,
0
]]).
reshape
(
[
0
,
7
,
7
,
7
,
0
,
0
]]).
reshape
(
[
3
,
6
]).
astype
(
"int32"
)
[
3
,
6
]).
astype
(
"int32"
)
self
.
input_length
=
np
.
array
([[
6
],
[
5
],
[
4
]]).
reshape
([
3
,
1
]).
astype
(
"int32"
)
class
TestCTCAlignOpCase4
(
TestCTCAlignPaddingOp
):
class
TestCTCAlignOpCase4
(
TestCTCAlignPaddingOp
):
...
@@ -158,6 +170,8 @@ class TestCTCAlignOpCase4(TestCTCAlignPaddingOp):
...
@@ -158,6 +170,8 @@ class TestCTCAlignOpCase4(TestCTCAlignPaddingOp):
self
.
input
=
np
.
array
([[
0
,
1
,
2
,
2
,
0
,
4
],
[
0
,
4
,
5
,
0
,
6
,
0
],
self
.
input
=
np
.
array
([[
0
,
1
,
2
,
2
,
0
,
4
],
[
0
,
4
,
5
,
0
,
6
,
0
],
[
0
,
7
,
7
,
7
,
0
,
0
]]).
reshape
(
[
0
,
7
,
7
,
7
,
0
,
0
]]).
reshape
(
[
3
,
6
]).
astype
(
"int32"
)
[
3
,
6
]).
astype
(
"int32"
)
self
.
input_length
=
np
.
array
([[
6
],
[
5
],
[
4
]]).
reshape
([
3
,
1
]).
astype
(
"int32"
)
class
TestCTCAlignOpCase5
(
TestCTCAlignPaddingOp
):
class
TestCTCAlignOpCase5
(
TestCTCAlignPaddingOp
):
...
@@ -170,6 +184,37 @@ class TestCTCAlignOpCase5(TestCTCAlignPaddingOp):
...
@@ -170,6 +184,37 @@ class TestCTCAlignOpCase5(TestCTCAlignPaddingOp):
self
.
input
=
np
.
array
([[
0
,
1
,
2
,
2
,
0
,
4
],
[
0
,
4
,
5
,
0
,
6
,
0
],
self
.
input
=
np
.
array
([[
0
,
1
,
2
,
2
,
0
,
4
],
[
0
,
4
,
5
,
0
,
6
,
0
],
[
0
,
7
,
1
,
7
,
0
,
0
]]).
reshape
(
[
0
,
7
,
1
,
7
,
0
,
0
]]).
reshape
(
[
3
,
6
]).
astype
(
"int32"
)
[
3
,
6
]).
astype
(
"int32"
)
self
.
input_length
=
np
.
array
([[
6
],
[
5
],
[
4
]]).
reshape
([
3
,
1
]).
astype
(
"int32"
)
class
TestCTCAlignOpApi
(
unittest
.
TestCase
):
def
test_api
(
self
):
x
=
fluid
.
layers
.
data
(
'x'
,
shape
=
[
4
],
dtype
=
'float32'
)
y
=
fluid
.
layers
.
ctc_greedy_decoder
(
x
,
blank
=
0
)
x_pad
=
fluid
.
layers
.
data
(
'x_pad'
,
shape
=
[
4
,
4
],
dtype
=
'float32'
)
x_pad_len
=
fluid
.
layers
.
data
(
'x_pad_len'
,
shape
=
[
1
],
dtype
=
'int64'
)
y_pad
,
y_pad_len
=
fluid
.
layers
.
ctc_greedy_decoder
(
x_pad
,
blank
=
0
,
input_length
=
x_pad_len
)
place
=
fluid
.
CPUPlace
()
x_tensor
=
fluid
.
create_lod_tensor
(
np
.
random
.
rand
(
8
,
4
).
astype
(
"float32"
),
[[
4
,
4
]],
place
)
x_pad_tensor
=
np
.
random
.
rand
(
2
,
4
,
4
).
astype
(
"float32"
)
x_pad_len_tensor
=
np
.
array
([[
4
],
[
4
]]).
reshape
([
2
,
1
]).
astype
(
"int64"
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
ret
=
exe
.
run
(
feed
=
{
'x'
:
x_tensor
,
'x_pad'
:
x_pad_tensor
,
'x_pad_len'
:
x_pad_len_tensor
},
fetch_list
=
[
y
,
y_pad
,
y_pad_len
],
return_numpy
=
False
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录