Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0250e54c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
0250e54c
编写于
1月 03, 2018
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Enable batch input in edit_distance_op
上级
2e49faca
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
189 addition
and
101 deletion
+189
-101
paddle/operators/edit_distance_op.cc
paddle/operators/edit_distance_op.cc
+31
-18
paddle/operators/edit_distance_op.cu
paddle/operators/edit_distance_op.cu
+58
-40
paddle/operators/edit_distance_op.h
paddle/operators/edit_distance_op.h
+55
-36
python/paddle/v2/fluid/tests/test_edit_distance_op.py
python/paddle/v2/fluid/tests/test_edit_distance_op.py
+45
-7
未找到文件。
paddle/operators/edit_distance_op.cc
浏览文件 @
0250e54c
...
...
@@ -22,10 +22,18 @@ class EditDistanceOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Hyp
"
),
"Input(Hyp
) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ref
"
),
"Input(Ref
) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Hyp
s"
),
"Input(Hyps
) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ref
s"
),
"Input(Refs
) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) shouldn't be null."
);
ctx
->
SetOutputDim
(
"Out"
,
{
1
});
auto
hyp_dims
=
ctx
->
GetInputDim
(
"Hyps"
);
auto
ref_dims
=
ctx
->
GetInputDim
(
"Refs"
);
PADDLE_ENFORCE
(
hyp_dims
.
size
()
==
2
&&
hyp_dims
[
1
]
==
1
,
"Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1."
);
PADDLE_ENFORCE
(
ref_dims
.
size
()
==
2
&&
ref_dims
[
1
]
==
1
,
"Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
"equal to 1."
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"Refs"
));
}
protected:
...
...
@@ -40,24 +48,23 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
EditDistanceOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"Hyp"
,
"(2-D
tensor with shape [M x 1]) The indices for
"
"
hypothesis string
"
);
AddInput
(
"Ref"
,
"(2-D
tensor with shape [N x 1]) The indices
"
"
for reference string
."
);
AddInput
(
"Hyp
s
"
,
"(2-D
LoDTensor, 2nd dim. equal to 1)
"
"
The indices for hypothesis strings.
"
);
AddInput
(
"Ref
s
"
,
"(2-D
LoDTensor, 2nd dim. equal to 1)
"
"
The indices for reference strings
."
);
AddAttr
<
bool
>
(
"normalized"
,
"(bool, default false) Indicated whether "
"normalize the Output(Out) by the length of reference "
"string (Ref)."
)
"(bool, default false) Indicated whether to normalize "
"the edit distance by the length of reference string."
)
.
SetDefault
(
false
);
AddOutput
(
"Out"
,
"(2-D
tensor with shape [1
x 1]) "
"The output
distance
of EditDistance operator."
);
"(2-D
Tensor with shape [`batch_size`
x 1]) "
"The output
edit distances
of EditDistance operator."
);
AddComment
(
R"DOC(
EditDistance operator computes the edit distance
of two sequences, one named
hypothesis with length M and another named reference with length N
.
EditDistance operator computes the edit distance
s between a batch of hypothesis
strings and their references
.
Edit distance, also called Levenshtein distance, measures how dissimilar two strings
are by counting the minimum number of operations to transform one string into anthor.
...
...
@@ -68,8 +75,14 @@ insertion:
"kitten" -> "sitten" -> "sittin" -> "sitting"
If Attr(normalized) is true, the edit distance will be divided by the length of
reference string N.
Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total
number denoted by `batch_size`, and the separation is specified by the LoD information.
And the `batch_size` reference strings are arranged in order in the same way in the
LoDTensor Input(Refs).
Output(Out) contains the `batch_size` results and each stands for the edit stance
for a pair of strings respectively. If Attr(normalized) is true, the edit distance
will be divided by the length of reference string.
)DOC"
);
}
};
...
...
paddle/operators/edit_distance_op.cu
浏览文件 @
0250e54c
...
...
@@ -70,53 +70,71 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
out_t
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
x1_t
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Hyp"
);
auto
*
x2_t
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Ref"
);
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
out
=
out_t
->
data
<
T
>
();
auto
*
x1_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Hyps"
);
auto
*
x2_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Refs"
);
auto
normalized
=
ctx
.
Attr
<
bool
>
(
"normalized"
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
();
auto
m
=
x1_t
->
numel
();
auto
n
=
x2_t
->
numel
();
T
distance
=
0.0
;
if
(
m
==
0
||
n
==
0
)
{
distance
=
std
::
max
(
m
,
n
);
if
(
normalized
)
{
distance
=
distance
/
n
;
}
memory
::
Copy
(
boost
::
get
<
Place
>
(
ctx
.
GetPlace
()),
out
,
platform
::
CPUPlace
(),
&
distance
,
sizeof
(
T
),
stream
);
}
else
{
framework
::
Tensor
dist_t
;
dist_t
.
Resize
({
m
+
1
,
n
+
1
});
dist_t
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dist
=
dist_t
.
data
<
T
>
();
auto
x1
=
x1_t
->
data
<
int
>
();
auto
x2
=
x2_t
->
data
<
int
>
();
FillFirstColumn
<
T
><<<
1
+
m
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
m
,
n
);
FillFirstRow
<
T
><<<
1
+
n
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
n
);
// Compute the elements of distance matrix in the anti-diagonal diretion
for
(
int64_t
slice
=
2
;
slice
<
m
+
n
+
1
;
++
slice
)
{
int
z_m
=
slice
<
m
+
1
?
0
:
slice
-
m
;
int
z_n
=
slice
<
n
+
1
?
0
:
slice
-
n
;
int
size
=
slice
-
(
z_m
+
z_n
)
+
1
;
// number of elments in the same
// anti-diagonal line to update
// the start index at which computes from
int
start
=
slice
<
n
+
1
?
slice
:
(
z_n
+
1
)
*
(
n
+
1
)
-
1
;
Levenshtein
<
T
><<<
1
+
(
size
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
x1
,
x2
,
m
,
n
,
start
);
auto
hyp_lod
=
x1_t
->
lod
()[
0
];
auto
ref_lod
=
x2_t
->
lod
()[
0
];
PADDLE_ENFORCE
(
hyp_lod
.
size
()
==
ref_lod
.
size
(),
"Input(Hyps) and Input(Refs) must have the same batch size."
);
for
(
size_t
i
=
1
;
i
<
ref_lod
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
ref_lod
[
i
]
>
ref_lod
[
i
-
1
],
"Reference string %d is empty."
,
i
);
}
auto
num_strs
=
hyp_lod
.
size
()
-
1
;
out_t
->
Resize
({
static_cast
<
int64_t
>
(
num_strs
),
1
});
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
out
=
out_t
->
data
<
T
>
();
std
::
vector
<
T
>
distance
(
num_strs
,
0.0
);
for
(
size_t
num
=
0
;
num
<
num_strs
;
num
++
)
{
auto
m
=
static_cast
<
int64_t
>
(
hyp_lod
[
num
+
1
]
-
hyp_lod
[
num
]);
auto
n
=
static_cast
<
int64_t
>
(
ref_lod
[
num
+
1
]
-
ref_lod
[
num
]);
if
(
m
==
0
||
n
==
0
)
{
distance
[
num
]
=
std
::
max
(
m
,
n
);
if
(
normalized
)
{
PADDLE_ENFORCE
(
n
>
0
,
"The reference string (#%d) cannot be empty "
"when Attr(normalized) is enabled."
,
n
);
distance
[
num
]
=
distance
[
num
]
/
n
;
}
memory
::
Copy
(
boost
::
get
<
Place
>
(
ctx
.
GetPlace
()),
out
+
num
,
platform
::
CPUPlace
(),
&
distance
[
num
],
sizeof
(
T
),
stream
);
}
else
{
framework
::
Tensor
dist_t
;
dist_t
.
Resize
({
m
+
1
,
n
+
1
});
dist_t
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dist
=
dist_t
.
data
<
T
>
();
auto
x1
=
x1_t
->
data
<
int
>
()
+
hyp_lod
[
num
];
auto
x2
=
x2_t
->
data
<
int
>
()
+
ref_lod
[
num
];
FillFirstColumn
<
T
><<<
1
+
m
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
m
,
n
);
FillFirstRow
<
T
><<<
1
+
n
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
n
);
// Compute the elements of distance matrix in the anti-diagonal diretion
for
(
int64_t
slice
=
2
;
slice
<
m
+
n
+
1
;
++
slice
)
{
int
z_m
=
slice
<
m
+
1
?
0
:
slice
-
m
;
int
z_n
=
slice
<
n
+
1
?
0
:
slice
-
n
;
int
size
=
slice
-
(
z_m
+
z_n
)
+
1
;
// number of elments in the same
// anti-diagonal line to update
// the start index at which computes from
int
start
=
slice
<
n
+
1
?
slice
:
(
z_n
+
1
)
*
(
n
+
1
)
-
1
;
Levenshtein
<
T
><<<
1
+
(
size
-
1
)
/
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dist
,
x1
,
x2
,
m
,
n
,
start
);
}
SetOutput
<
T
><<<
1
,
1
,
0
,
stream
>>>
(
out
+
num
,
dist
,
m
,
n
,
normalized
);
}
SetOutput
<
T
><<<
1
,
1
,
0
,
stream
>>>
(
out
,
dist
,
m
,
n
,
normalized
);
}
}
};
...
...
paddle/operators/edit_distance_op.h
浏览文件 @
0250e54c
...
...
@@ -26,50 +26,69 @@ class EditDistanceKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
out_t
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
x1_t
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Hyp
"
);
auto
*
x2_t
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Ref
"
);
auto
*
x1_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Hyps
"
);
auto
*
x2_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Refs
"
);
auto
normalized
=
ctx
.
Attr
<
bool
>
(
"normalized"
);
auto
hyp_lod
=
x1_t
->
lod
()[
0
];
auto
ref_lod
=
x2_t
->
lod
()[
0
];
PADDLE_ENFORCE
(
hyp_lod
.
size
()
==
ref_lod
.
size
(),
"Input(Hyps) and Input(Refs) must have the same batch size."
);
for
(
size_t
i
=
1
;
i
<
ref_lod
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
ref_lod
[
i
]
>
ref_lod
[
i
-
1
],
"Reference string %d is empty."
,
i
);
}
auto
num_strs
=
hyp_lod
.
size
()
-
1
;
out_t
->
Resize
({
static_cast
<
int64_t
>
(
num_strs
),
1
});
out_t
->
mutable_data
<
float
>
(
ctx
.
GetPlace
());
auto
out
=
out_t
->
data
<
T
>
();
auto
normalized
=
ctx
.
Attr
<
bool
>
(
"normalized"
);
std
::
vector
<
T
>
distance
(
num_strs
,
0.0
);
for
(
size_t
num
=
0
;
num
<
num_strs
;
++
num
)
{
auto
m
=
static_cast
<
int64_t
>
(
hyp_lod
[
num
+
1
]
-
hyp_lod
[
num
]);
auto
n
=
static_cast
<
int64_t
>
(
ref_lod
[
num
+
1
]
-
ref_lod
[
num
]);
auto
m
=
x1_t
->
numel
();
auto
n
=
x2_t
->
numel
();
T
distance
=
0.0
;
if
(
m
==
0
)
{
distance
=
n
;
}
else
if
(
n
==
0
)
{
distance
=
m
;
}
else
{
framework
::
Tensor
dist_t
;
dist_t
.
Resize
({
m
+
1
,
n
+
1
});
dist_t
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dist
=
dist_t
.
data
<
T
>
();
auto
x1
=
x1_t
->
data
<
int
>
();
auto
x2
=
x2_t
->
data
<
int
>
();
for
(
int64_t
i
=
0
;
i
<
m
+
1
;
++
i
)
{
dist
[
i
*
(
n
+
1
)]
=
i
;
}
for
(
int64_t
j
=
0
;
j
<
n
+
1
;
++
j
)
{
dist
[
j
]
=
j
;
}
for
(
int64_t
i
=
1
;
i
<
m
+
1
;
++
i
)
{
for
(
int64_t
j
=
1
;
j
<
n
+
1
;
++
j
)
{
int
cost
=
x1
[
i
-
1
]
==
x2
[
j
-
1
]
?
0
:
1
;
int
dels
=
dist
[(
i
-
1
)
*
(
n
+
1
)
+
j
]
+
1
;
int
ins
=
dist
[
i
*
(
n
+
1
)
+
(
j
-
1
)]
+
1
;
int
subs
=
dist
[(
i
-
1
)
*
(
n
+
1
)
+
(
j
-
1
)]
+
cost
;
dist
[
i
*
(
n
+
1
)
+
j
]
=
std
::
min
(
dels
,
std
::
min
(
ins
,
subs
));
if
(
m
==
0
)
{
distance
[
num
]
=
n
;
}
else
if
(
n
==
0
)
{
distance
[
num
]
=
m
;
}
else
{
framework
::
Tensor
dist_t
;
dist_t
.
Resize
({
m
+
1
,
n
+
1
});
dist_t
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dist
=
dist_t
.
data
<
T
>
();
auto
x1
=
x1_t
->
data
<
int
>
()
+
hyp_lod
[
num
];
auto
x2
=
x2_t
->
data
<
int
>
()
+
ref_lod
[
num
];
for
(
int64_t
i
=
0
;
i
<
m
+
1
;
++
i
)
{
dist
[
i
*
(
n
+
1
)]
=
i
;
}
for
(
int64_t
j
=
0
;
j
<
n
+
1
;
++
j
)
{
dist
[
j
]
=
j
;
}
for
(
int64_t
i
=
1
;
i
<
m
+
1
;
++
i
)
{
for
(
int64_t
j
=
1
;
j
<
n
+
1
;
++
j
)
{
int
cost
=
x1
[
i
-
1
]
==
x2
[
j
-
1
]
?
0
:
1
;
int
dels
=
dist
[(
i
-
1
)
*
(
n
+
1
)
+
j
]
+
1
;
int
ins
=
dist
[
i
*
(
n
+
1
)
+
(
j
-
1
)]
+
1
;
int
subs
=
dist
[(
i
-
1
)
*
(
n
+
1
)
+
(
j
-
1
)]
+
cost
;
dist
[
i
*
(
n
+
1
)
+
j
]
=
std
::
min
(
dels
,
std
::
min
(
ins
,
subs
));
}
}
distance
[
num
]
=
dist
[
m
*
(
n
+
1
)
+
n
];
}
distance
=
dist
[
m
*
(
n
+
1
)
+
n
];
}
if
(
normalized
)
{
distance
=
distance
/
n
;
if
(
normalized
)
{
PADDLE_ENFORCE
(
n
>
0
,
"The reference string (#%d) cannot be empty "
"when Attr(normalized) is enabled."
,
n
);
distance
[
num
]
=
distance
[
num
]
/
n
;
}
out
[
num
]
=
distance
[
num
];
}
auto
out
=
out_t
->
data
<
T
>
();
out
[
0
]
=
distance
;
}
};
...
...
python/paddle/v2/fluid/tests/test_edit_distance_op.py
浏览文件 @
0250e54c
...
...
@@ -18,7 +18,7 @@ def Levenshtein(hyp, ref):
if
n
==
0
:
return
m
dist
=
np
.
zeros
((
m
+
1
,
n
+
1
))
dist
=
np
.
zeros
((
m
+
1
,
n
+
1
))
.
astype
(
"float32"
)
for
i
in
range
(
0
,
m
+
1
):
dist
[
i
][
0
]
=
i
for
j
in
range
(
0
,
n
+
1
):
...
...
@@ -35,17 +35,55 @@ def Levenshtein(hyp, ref):
class
TestCTCEditDistanceOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"edit_distance"
normalized
=
False
x1
=
np
.
array
([[
0
,
12
,
3
,
5
,
8
,
2
]]).
astype
(
"int32"
)
x2
=
np
.
array
([[
0
,
12
,
4
,
7
,
8
]]).
astype
(
"int32"
)
x1
=
np
.
transpose
(
x1
)
x2
=
np
.
transpose
(
x2
)
x1_lod
=
[
0
,
1
,
5
]
x2_lod
=
[
0
,
3
,
4
]
num_strs
=
len
(
x1_lod
)
-
1
distance
=
np
.
zeros
((
num_strs
,
1
)).
astype
(
"float32"
)
for
i
in
range
(
0
,
num_strs
):
distance
[
i
]
=
Levenshtein
(
hyp
=
x1
[
x1_lod
[
i
]:
x1_lod
[
i
+
1
]],
ref
=
x2
[
x2_lod
[
i
]:
x2_lod
[
i
+
1
]])
if
normalized
is
True
:
len_ref
=
x2_lod
[
i
+
1
]
-
x2_lod
[
i
]
distance
[
i
]
=
distance
[
i
]
/
len_ref
self
.
attrs
=
{
'normalized'
:
normalized
}
self
.
inputs
=
{
'Hyps'
:
(
x1
,
[
x1_lod
]),
'Refs'
:
(
x2
,
[
x2_lod
])}
self
.
outputs
=
{
'Out'
:
distance
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestCTCEditDistanceOpNormalized
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"edit_distance"
normalized
=
True
x1
=
np
.
array
([
0
,
12
,
3
,
5
]).
astype
(
"int32"
)
x2
=
np
.
array
([
0
,
12
,
4
,
7
,
8
]).
astype
(
"int32"
)
x1
=
np
.
array
([[
0
,
10
,
3
,
6
,
5
,
8
,
2
]]).
astype
(
"int32"
)
x2
=
np
.
array
([[
0
,
10
,
4
,
6
,
7
,
8
]]).
astype
(
"int32"
)
x1
=
np
.
transpose
(
x1
)
x2
=
np
.
transpose
(
x2
)
x1_lod
=
[
0
,
1
,
3
,
6
]
x2_lod
=
[
0
,
2
,
3
,
5
]
distance
=
Levenshtein
(
hyp
=
x1
,
ref
=
x2
)
if
normalized
is
True
:
distance
=
distance
/
len
(
x2
)
num_strs
=
len
(
x1_lod
)
-
1
distance
=
np
.
zeros
((
num_strs
,
1
)).
astype
(
"float32"
)
for
i
in
range
(
0
,
num_strs
):
distance
[
i
]
=
Levenshtein
(
hyp
=
x1
[
x1_lod
[
i
]:
x1_lod
[
i
+
1
]],
ref
=
x2
[
x2_lod
[
i
]:
x2_lod
[
i
+
1
]])
if
normalized
is
True
:
len_ref
=
x2_lod
[
i
+
1
]
-
x2_lod
[
i
]
distance
[
i
]
=
distance
[
i
]
/
len_ref
self
.
attrs
=
{
'normalized'
:
normalized
}
self
.
inputs
=
{
'Hyp
'
:
x1
,
'Ref'
:
x2
}
self
.
inputs
=
{
'Hyp
s'
:
(
x1
,
[
x1_lod
]),
'Refs'
:
(
x2
,
[
x2_lod
])
}
self
.
outputs
=
{
'Out'
:
distance
}
def
test_check_output
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录