Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
116687a8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
116687a8
编写于
11月 29, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean up code in ctc_edit_distance_op
上级
6bc6ccd1
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
49 addition
and
44 deletion
+49
-44
paddle/operators/ctc_edit_distance_op.cc
paddle/operators/ctc_edit_distance_op.cc
+8
-2
paddle/operators/ctc_edit_distance_op.cu
paddle/operators/ctc_edit_distance_op.cu
+30
-29
paddle/operators/ctc_edit_distance_op.h
paddle/operators/ctc_edit_distance_op.h
+8
-8
python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py
python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py
+3
-5
未找到文件。
paddle/operators/ctc_edit_distance_op.cc
浏览文件 @
116687a8
...
@@ -27,6 +27,13 @@ class CTCEditDistanceOp : public framework::OperatorWithKernel {
...
@@ -27,6 +27,13 @@ class CTCEditDistanceOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) shouldn't be null."
);
ctx
->
SetOutputDim
(
"Out"
,
{
1
});
ctx
->
SetOutputDim
(
"Out"
,
{
1
});
}
}
protected:
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
DataType
::
FP32
,
ctx
.
device_context
());
}
};
};
class
CTCEditDistanceOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
CTCEditDistanceOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
@@ -70,5 +77,4 @@ REGISTER_OP_WITHOUT_GRADIENT(ctc_edit_distance, ops::CTCEditDistanceOp,
...
@@ -70,5 +77,4 @@ REGISTER_OP_WITHOUT_GRADIENT(ctc_edit_distance, ops::CTCEditDistanceOp,
ops
::
CTCEditDistanceOpMaker
);
ops
::
CTCEditDistanceOpMaker
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
ctc_edit_distance
,
ctc_edit_distance
,
ops
::
CTCEditDistanceKernel
<
paddle
::
platform
::
CPUPlace
,
int32_t
>
,
ops
::
CTCEditDistanceKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
ops
::
CTCEditDistanceKernel
<
paddle
::
platform
::
CPUPlace
,
int64_t
>
);
paddle/operators/ctc_edit_distance_op.cu
浏览文件 @
116687a8
...
@@ -39,7 +39,7 @@ __global__ void FillFirstColumn(T* dist, const int M, const int N) {
...
@@ -39,7 +39,7 @@ __global__ void FillFirstColumn(T* dist, const int M, const int N) {
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
Levenshtein
(
T
*
dist
,
const
T
*
x1
,
const
T
*
x2
,
const
int
M
,
__global__
void
Levenshtein
(
T
*
dist
,
const
int
*
x1
,
const
int
*
x2
,
const
int
M
,
const
int
N
,
const
int
start
)
{
const
int
N
,
const
int
start
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
offset
=
N
;
int
offset
=
N
;
...
@@ -55,6 +55,15 @@ __global__ void Levenshtein(T* dist, const T* x1, const T* x2, const int M,
...
@@ -55,6 +55,15 @@ __global__ void Levenshtein(T* dist, const T* x1, const T* x2, const int M,
}
}
}
}
template
<
typename
T
>
__global__
void
SetOutput
(
T
*
out
,
const
T
*
dist
,
const
int
M
,
const
int
N
,
bool
normalized
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
idx
==
0
)
{
out
[
0
]
=
normalized
?
dist
[
M
*
(
N
+
1
)
+
N
]
/
N
:
dist
[
M
*
(
N
+
1
)
+
N
];
}
}
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
CTCEditDistanceGPUKernel
:
public
framework
::
OpKernel
<
T
>
{
class
CTCEditDistanceGPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -64,7 +73,8 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel<T> {
...
@@ -64,7 +73,8 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel<T> {
auto
*
x1_t
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X1"
);
auto
*
x1_t
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X1"
);
auto
*
x2_t
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X2"
);
auto
*
x2_t
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X2"
);
out_t
->
mutable_data
<
float
>
(
ctx
.
GetPlace
());
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
out
=
out_t
->
data
<
T
>
();
auto
normalized
=
ctx
.
Attr
<
bool
>
(
"normalized"
);
auto
normalized
=
ctx
.
Attr
<
bool
>
(
"normalized"
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
...
@@ -73,49 +83,41 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel<T> {
...
@@ -73,49 +83,41 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel<T> {
auto
m
=
x1_t
->
numel
();
auto
m
=
x1_t
->
numel
();
auto
n
=
x2_t
->
numel
();
auto
n
=
x2_t
->
numel
();
T
distance
=
0
;
T
distance
=
0.0
;
if
(
m
==
0
)
{
if
(
m
==
0
||
n
==
0
)
{
distance
=
n
;
distance
=
std
::
max
(
m
,
n
);
}
else
if
(
n
==
0
)
{
if
(
normalized
)
{
distance
=
m
;
distance
=
distance
/
n
;
}
memory
::
Copy
(
boost
::
get
<
Place
>
(
ctx
.
GetPlace
()),
out
,
platform
::
CPUPlace
(),
&
distance
,
sizeof
(
T
),
stream
);
}
else
{
}
else
{
framework
::
Tensor
dist_t
;
framework
::
Tensor
dist_t
;
dist_t
.
Resize
({
m
+
1
,
n
+
1
});
dist_t
.
Resize
({
m
+
1
,
n
+
1
});
dist_t
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
dist_t
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dist
=
dist_t
.
data
<
T
>
();
auto
dist
=
dist_t
.
data
<
T
>
();
auto
x1
=
x1_t
->
data
<
T
>
();
auto
x1
=
x1_t
->
data
<
int
>
();
auto
x2
=
x2_t
->
data
<
T
>
();
auto
x2
=
x2_t
->
data
<
int
>
();
FillFirstColumn
<
T
><<<
1
+
m
/
PADDLE_CUDA_NUM_THREADS
,
FillFirstColumn
<
T
><<<
1
+
m
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
m
,
n
);
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
m
,
n
);
FillFirstRow
<
T
><<<
1
+
n
/
PADDLE_CUDA_NUM_THREADS
,
FillFirstRow
<
T
><<<
1
+
n
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
n
);
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
n
);
//
c
ompute the elements of distance matrix in the anti-diagonal diretion
//
C
ompute the elements of distance matrix in the anti-diagonal diretion
for
(
size
_t
slice
=
2
;
slice
<
m
+
n
+
1
;
++
slice
)
{
for
(
int64
_t
slice
=
2
;
slice
<
m
+
n
+
1
;
++
slice
)
{
int
z_m
=
slice
<
m
+
1
?
0
:
slice
-
m
;
int
z_m
=
slice
<
m
+
1
?
0
:
slice
-
m
;
int
z_n
=
slice
<
n
+
1
?
0
:
slice
-
n
;
int
z_n
=
slice
<
n
+
1
?
0
:
slice
-
n
;
// number of elments in the same anti-diagonal line
int
size
=
slice
-
(
z_m
+
z_n
)
+
1
;
// number of elments in the same
int
size
=
slice
-
(
z_m
+
z_n
)
+
1
;
// anti-diagonal line to update
int
start
=
slice
<
n
+
1
?
slice
:
z_n
*
(
n
+
1
)
-
1
;
int
start
=
slice
<
n
+
1
?
slice
:
z_n
*
(
n
+
1
)
-
1
;
// start index
Levenshtein
<
T
><<<
1
+
(
size
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
Levenshtein
<
T
><<<
1
+
(
size
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
x1
,
x2
,
m
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
x1
,
x2
,
m
,
n
,
start
);
n
,
start
);
}
}
SetOutput
<
T
><<<
1
,
1
,
0
,
stream
>>>
(
out
,
dist
,
m
,
n
,
normalized
);
Place
gpu_place
=
boost
::
get
<
Place
>
(
ctx
.
GetPlace
());
memory
::
Copy
(
platform
::
CPUPlace
(),
&
distance
,
gpu_place
,
dist
+
m
*
(
n
+
1
)
+
n
,
sizeof
(
T
),
stream
);
}
if
(
normalized
)
{
distance
=
distance
/
n
;
}
}
auto
out
=
out_t
->
data
<
float
>
();
Place
gpu_place
=
boost
::
get
<
Place
>
(
ctx
.
GetPlace
());
float
dist_f
=
distance
;
memory
::
Copy
(
gpu_place
,
out
,
platform
::
CPUPlace
(),
&
dist_f
,
sizeof
(
float
),
stream
);
}
}
};
};
...
@@ -126,5 +128,4 @@ namespace ops = paddle::operators;
...
@@ -126,5 +128,4 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL
(
REGISTER_OP_GPU_KERNEL
(
ctc_edit_distance
,
ctc_edit_distance
,
ops
::
CTCEditDistanceGPUKernel
<
paddle
::
platform
::
GPUPlace
,
int
>
,
ops
::
CTCEditDistanceGPUKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
ops
::
CTCEditDistanceGPUKernel
<
paddle
::
platform
::
GPUPlace
,
int64_t
>
);
paddle/operators/ctc_edit_distance_op.h
浏览文件 @
116687a8
...
@@ -35,7 +35,7 @@ class CTCEditDistanceKernel : public framework::OpKernel<T> {
...
@@ -35,7 +35,7 @@ class CTCEditDistanceKernel : public framework::OpKernel<T> {
auto
m
=
x1_t
->
numel
();
auto
m
=
x1_t
->
numel
();
auto
n
=
x2_t
->
numel
();
auto
n
=
x2_t
->
numel
();
float
distance
=
0.0
;
T
distance
=
0.0
;
if
(
m
==
0
)
{
if
(
m
==
0
)
{
distance
=
n
;
distance
=
n
;
}
else
if
(
n
==
0
)
{
}
else
if
(
n
==
0
)
{
...
@@ -45,16 +45,16 @@ class CTCEditDistanceKernel : public framework::OpKernel<T> {
...
@@ -45,16 +45,16 @@ class CTCEditDistanceKernel : public framework::OpKernel<T> {
dist_t
.
Resize
({
m
+
1
,
n
+
1
});
dist_t
.
Resize
({
m
+
1
,
n
+
1
});
dist_t
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
dist_t
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dist
=
dist_t
.
data
<
T
>
();
auto
dist
=
dist_t
.
data
<
T
>
();
auto
x1
=
x1_t
->
data
<
T
>
();
auto
x1
=
x1_t
->
data
<
int
>
();
auto
x2
=
x2_t
->
data
<
T
>
();
auto
x2
=
x2_t
->
data
<
int
>
();
for
(
size
_t
i
=
0
;
i
<
m
+
1
;
++
i
)
{
for
(
int64
_t
i
=
0
;
i
<
m
+
1
;
++
i
)
{
dist
[
i
*
(
n
+
1
)]
=
i
;
dist
[
i
*
(
n
+
1
)]
=
i
;
}
}
for
(
size
_t
j
=
0
;
j
<
n
+
1
;
++
j
)
{
for
(
int64
_t
j
=
0
;
j
<
n
+
1
;
++
j
)
{
dist
[
j
]
=
j
;
dist
[
j
]
=
j
;
}
}
for
(
size
_t
i
=
1
;
i
<
m
+
1
;
++
i
)
{
for
(
int64
_t
i
=
1
;
i
<
m
+
1
;
++
i
)
{
for
(
size
_t
j
=
1
;
j
<
n
+
1
;
++
j
)
{
for
(
int64
_t
j
=
1
;
j
<
n
+
1
;
++
j
)
{
int
cost
=
x1
[
i
-
1
]
==
x2
[
j
-
1
]
?
0
:
1
;
int
cost
=
x1
[
i
-
1
]
==
x2
[
j
-
1
]
?
0
:
1
;
int
dels
=
dist
[(
i
-
1
)
*
(
n
+
1
)
+
j
]
+
1
;
int
dels
=
dist
[(
i
-
1
)
*
(
n
+
1
)
+
j
]
+
1
;
int
ins
=
dist
[
i
*
(
n
+
1
)
+
(
j
-
1
)]
+
1
;
int
ins
=
dist
[
i
*
(
n
+
1
)
+
(
j
-
1
)]
+
1
;
...
@@ -68,7 +68,7 @@ class CTCEditDistanceKernel : public framework::OpKernel<T> {
...
@@ -68,7 +68,7 @@ class CTCEditDistanceKernel : public framework::OpKernel<T> {
if
(
normalized
)
{
if
(
normalized
)
{
distance
=
distance
/
n
;
distance
=
distance
/
n
;
}
}
auto
out
=
out_t
->
data
<
float
>
();
auto
out
=
out_t
->
data
<
T
>
();
out
[
0
]
=
distance
;
out
[
0
]
=
distance
;
}
}
};
};
...
...
python/paddle/v2/fluid/tests/test_ctc_edit_distance_op.py
浏览文件 @
116687a8
...
@@ -37,11 +37,9 @@ def Levenshtein(hyp, ref):
...
@@ -37,11 +37,9 @@ def Levenshtein(hyp, ref):
class
TestCTCEditDistanceOp
(
OpTest
):
class
TestCTCEditDistanceOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"ctc_edit_distance"
self
.
op_type
=
"ctc_edit_distance"
normalized
=
False
normalized
=
True
#x1 = np.array([0, 12, 3, 5]).astype("int64")
x1
=
np
.
array
([
0
,
12
,
3
,
5
]).
astype
(
"int32"
)
#x2 = np.array([0, 12, 4, 7, 8]).astype("int64")
x2
=
np
.
array
([
0
,
12
,
4
,
7
,
8
]).
astype
(
"int32"
)
x1
=
np
.
array
([
0
,
12
,
5
]).
astype
(
"int64"
)
x2
=
np
.
array
([
0
,
12
,
4
]).
astype
(
"int64"
)
distance
=
Levenshtein
(
hyp
=
x1
,
ref
=
x2
)
distance
=
Levenshtein
(
hyp
=
x1
,
ref
=
x2
)
if
normalized
is
True
:
if
normalized
is
True
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录