Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
579f6846
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
579f6846
编写于
1月 15, 2018
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add ctc_greedy_decode_op
上级
8d253e49
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
364 addition
and
0 deletion
+364
-0
paddle/operators/ctc_greedy_decode_op.cc
paddle/operators/ctc_greedy_decode_op.cc
+88
-0
paddle/operators/ctc_greedy_decode_op.cu
paddle/operators/ctc_greedy_decode_op.cu
+138
-0
paddle/operators/ctc_greedy_decode_op.h
paddle/operators/ctc_greedy_decode_op.h
+82
-0
python/paddle/v2/fluid/tests/test_ctc_greedy_decode.py
python/paddle/v2/fluid/tests/test_ctc_greedy_decode.py
+56
-0
未找到文件。
paddle/operators/ctc_greedy_decode_op.cc
0 → 100644
浏览文件 @
579f6846
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/ctc_greedy_decode_op.h"
namespace
paddle
{
namespace
operators
{
class
CTCGreedyDecodeOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input of CTCGreedyDecodeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Output"
),
"Output of CTCGreedyDecodeOp should not be null."
);
auto
input_dims
=
ctx
->
GetInputDim
(
"Input"
);
int
sequence_width
=
static_cast
<
int
>
(
framework
::
product
(
input_dims
)
/
input_dims
[
0
]);
int
blank
=
ctx
->
Attrs
().
Get
<
int
>
(
"blank"
);
PADDLE_ENFORCE
((
blank
>=
0
)
&&
(
blank
<
sequence_width
),
"The value of Attr(blank) should be in interval [0, %d)."
,
sequence_width
);
// TODO(wanghaoshuang): it is tricky to set the wrong dimension here.
ctx
->
SetOutputDim
(
"Output"
,
{
input_dims
[
0
],
1
});
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
device_context
());
}
};
class
CTCGreedyDecodeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
CTCGreedyDecodeOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"Input"
,
"(LodTensor, default: LoDTensor<float>), the unscaled "
"probabilities of variable-length sequences, which is a 2-D "
"Tensor with LoD information. It's shape is "
"[Lp, num_classes + 1], where Lp is the sum of all input "
"sequences' length and num_classes is the true number of classes "
"(not including the blank label)."
);
AddOutput
(
"Output"
,
"(Tensor, default: Tensor<int>), the decode result "
);
AddAttr
<
int
>
(
"blank"
,
"(int, default: 0), the blank label setted in Connectionist "
"Temporal Classification (CTC) op, and it is in the "
"half-opened interval [0, num_classes + 1)."
)
.
SetDefault
(
0
);
AddAttr
<
bool
>
(
"merge_repeated"
,
"(bool, default: true), whether to "
"merge repeated elements between two blanks. "
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
CTCGreedyDecoder is an implementation of the simple best path decoding
algorithm, selecting at each timestep the most likely class at each timestep.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
ctc_greedy_decode
,
ops
::
CTCGreedyDecodeOp
,
ops
::
CTCGreedyDecodeOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
ctc_greedy_decode
,
ops
::
CTCGreedyDecodeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
paddle/operators/ctc_greedy_decode_op.cu
0 → 100644
浏览文件 @
579f6846
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/operators/ctc_greedy_decode_op.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/gpu_info.h"
namespace
paddle
{
namespace
operators
{
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
__device__
static
float
atomicMaxF
(
float
*
address
,
float
val
)
{
int
*
address_as_i
=
(
int
*
)
address
;
int
old
=
*
address_as_i
,
assumed
;
do
{
assumed
=
old
;
old
=
::
atomicCAS
(
address_as_i
,
assumed
,
__float_as_int
(
::
fmaxf
(
val
,
__int_as_float
(
assumed
))));
}
while
(
assumed
!=
old
);
return
__int_as_float
(
old
);
}
template
<
typename
T
,
int
BlockSize
>
__global__
void
ArgmaxCudaKernel
(
const
size_t
seq_width
,
const
T
*
logits
,
int
*
output
)
{
T
local_max_value
=
0
;
int
local_max_index
=
0
;
__shared__
T
max_value
;
if
(
threadIdx
.
x
==
0
)
{
max_value
=
0
;
}
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
seq_width
;
i
+=
BlockSize
)
{
T
value
=
logits
[
blockIdx
.
x
*
seq_width
+
i
];
if
(
value
>
local_max_value
)
{
local_max_value
=
value
;
local_max_index
=
i
;
}
}
atomicMaxF
(
&
max_value
,
local_max_value
);
__syncthreads
();
if
(
local_max_value
==
max_value
)
{
output
[
blockIdx
.
x
]
=
local_max_index
;
}
}
template
<
typename
T
>
__global__
void
MergeAndDelCudaKernel
(
const
int64_t
num_token
,
int
*
tokens
,
const
size_t
num_seq
,
size_t
*
lod0
,
const
int
blank
,
const
int
merge_repeated
,
size_t
*
out_lod0
,
int
*
output
)
{
int
ouput_idx
=
0
;
out_lod0
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_seq
;
++
i
)
{
int
pre_token
=
-
1
;
for
(
int
j
=
lod0
[
i
];
j
<
lod0
[
i
+
1
];
++
j
)
{
if
(
tokens
[
j
]
!=
blank
&&
!
(
merge_repeated
&&
tokens
[
j
]
==
pre_token
))
{
output
[
ouput_idx
]
=
tokens
[
j
];
++
ouput_idx
;
}
pre_token
=
tokens
[
j
];
}
out_lod0
[
i
+
1
]
=
ouput_idx
;
}
}
template
<
typename
T
>
class
CTCGreedyDecodeOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use CUDAPlace."
);
auto
*
input
=
ctx
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
output
=
ctx
.
Output
<
LoDTensor
>
(
"Output"
);
const
int64_t
num_tokens
=
input
->
dims
()[
0
];
const
size_t
seq_width
=
input
->
numel
()
/
num_tokens
;
const
T
*
logits
=
input
->
data
<
T
>
();
Tensor
tmp
;
int
*
tokens
=
tmp
.
mutable_data
<
int
>
({
num_tokens
,
1
},
ctx
.
GetPlace
());
// get argmax
// platform::GpuMemsetAsync(args, 0, sizeof(float), stream);
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
ArgmaxCudaKernel
<
T
,
PADDLE_CUDA_NUM_THREADS
><<<
num_tokens
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
seq_width
,
logits
,
tokens
);
const
size_t
level
=
0
;
auto
input_lod
=
framework
::
ToAbsOffset
(
input
->
lod
());
const
size_t
num_seq
=
input_lod
[
level
].
size
()
-
1
;
const
int
blank
=
ctx
.
Attr
<
int
>
(
"blank"
);
const
int
merge_repeated
=
static_cast
<
int
>
(
ctx
.
Attr
<
bool
>
(
"merge_repeated"
));
thrust
::
device_vector
<
size_t
>
dev_out_lod0
(
input_lod
[
level
].
size
());
size_t
*
dev_out_lod0_ptr
=
thrust
::
raw_pointer_cast
(
dev_out_lod0
.
data
());
int
*
output_data
=
output
->
mutable_data
<
int
>
({
num_tokens
,
1
},
ctx
.
GetPlace
());
MergeAndDelCudaKernel
<
T
><<<
1
,
1
,
0
,
stream
>>>
(
num_tokens
,
tokens
,
num_seq
,
input_lod
[
level
].
data
(),
blank
,
merge_repeated
,
dev_out_lod0_ptr
,
output_data
);
thrust
::
host_vector
<
size_t
>
host_out_lod0
(
dev_out_lod0
.
begin
(),
dev_out_lod0
.
end
());
framework
::
LoD
out_lod
;
out_lod
.
push_back
(
host_out_lod0
);
output
->
set_lod
(
out_lod
);
output
->
Resize
({
static_cast
<
int64_t
>
(
host_out_lod0
.
back
()),
1
});
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
ctc_greedy_decode
,
paddle
::
operators
::
CTCGreedyDecodeOpCUDAKernel
<
float
>
);
paddle/operators/ctc_greedy_decode_op.h
0 → 100644
浏览文件 @
579f6846
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string.h>
#include "paddle/framework/op_registry.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
DeviceContext
,
typename
T
>
class
CTCGreedyDecodeKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
LoDTensor
>
(
"Input"
);
auto
*
output
=
ctx
.
Output
<
LoDTensor
>
(
"Output"
);
const
size_t
level
=
0
;
auto
input_lod
=
framework
::
ToAbsOffset
(
input
->
lod
());
auto
input_dims
=
input
->
dims
();
PADDLE_ENFORCE_EQ
(
input_dims
[
0
],
static_cast
<
int64_t
>
(
input_lod
[
level
].
back
()),
"The first dimension of Input(Input) should be equal to "
"the sum of all sequences' lengths."
);
const
size_t
num_sequences
=
input_lod
[
level
].
size
()
-
1
;
const
size_t
sequence_width
=
input
->
numel
()
/
input_dims
[
0
];
size_t
blank
=
static_cast
<
size_t
>
(
ctx
.
Attr
<
int
>
(
"blank"
));
bool
merge_repeated
=
ctx
.
Attr
<
bool
>
(
"merge_repeated"
);
std
::
vector
<
std
::
vector
<
int
>>
pathes
(
num_sequences
);
std
::
vector
<
size_t
>
output_lod0
(
1
,
0
);
const
T
*
input_data
=
input
->
data
<
T
>
();
Eigen
::
Map
<
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
input_mat
(
const_cast
<
T
*>
(
input_data
),
input
->
numel
()
/
sequence_width
,
sequence_width
);
size_t
max_class_idx
;
size_t
prev_class_idx
=
-
1
;
for
(
size_t
seq_idx
=
0
;
seq_idx
<
num_sequences
;
++
seq_idx
)
{
for
(
size_t
i
=
input_lod
[
level
][
seq_idx
];
i
<
input_lod
[
level
][
seq_idx
+
1
];
++
i
)
{
input_mat
.
row
(
i
).
maxCoeff
(
&
max_class_idx
);
if
(
max_class_idx
!=
blank
&&
!
(
merge_repeated
&&
max_class_idx
==
prev_class_idx
))
{
pathes
[
seq_idx
].
push_back
(
max_class_idx
);
}
prev_class_idx
=
max_class_idx
;
}
output_lod0
.
push_back
(
output_lod0
.
back
()
+
pathes
[
seq_idx
].
size
());
}
framework
::
LoD
output_lod
;
output_lod
.
push_back
(
output_lod0
);
output
->
set_lod
(
output_lod
);
int64_t
num_step
=
static_cast
<
int64_t
>
(
output_lod0
.
back
());
int
*
output_data
=
output
->
mutable_data
<
int
>
({
num_step
,
1
},
ctx
.
GetPlace
());
for
(
int
i
=
0
;
i
<
num_sequences
;
++
i
)
{
memcpy
(
output_data
+
output_lod0
[
i
],
pathes
[
i
].
data
(),
sizeof
(
int
)
*
pathes
[
i
].
size
());
}
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/v2/fluid/tests/test_ctc_greedy_decode.py
0 → 100644
浏览文件 @
579f6846
import
sys
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
from
test_softmax_op
import
stable_softmax
def
CTCGreedyDecode
(
softmax
,
blank
,
merge_repeated
):
prev_token
=
-
1
result
=
[]
for
token
in
np
.
argmax
(
softmax
,
axis
=
1
):
if
(
token
!=
blank
)
and
not
(
merge_repeated
and
token
==
prev_token
):
result
.
append
(
token
)
return
np
.
array
(
result
).
reshape
([
len
(
result
),
1
])
class
TestCTCGreedyDecodeOp
(
OpTest
):
def
config
(
self
):
self
.
op_type
=
"ctc_greedy_decode"
self
.
batch_size
=
4
self
.
num_classes
=
8
self
.
input_lod
=
[[
0
,
4
,
5
,
8
,
11
]]
self
.
blank
=
7
self
.
merge_repeated
=
True
def
setUp
(
self
):
self
.
config
()
input
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
self
.
input_lod
[
0
][
-
1
],
self
.
num_classes
]).
astype
(
"float32"
)
softmax
=
np
.
apply_along_axis
(
stable_softmax
,
1
,
input
)
output
=
CTCGreedyDecode
(
softmax
,
self
.
blank
,
self
.
merge_repeated
)
self
.
inputs
=
{
"Input"
:
(
softmax
,
self
.
input_lod
),
}
self
.
outputs
=
{
"Output"
:
output
}
self
.
attrs
=
{
"blank"
:
self
.
blank
,
"merge_repeated"
:
self
.
merge_repeated
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestCTCGreedyDecodeOpCase1
(
TestCTCGreedyDecodeOp
):
def
config
(
self
):
self
.
op_type
=
"ctc_greedy_decode"
self
.
batch_size
=
4
self
.
num_classes
=
1025
self
.
input_lod
=
[[
0
,
4
,
5
,
8
,
11
]]
self
.
blank
=
0
self
.
merge_repeated
=
True
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录