提交 b7a4e3d7 编写于 作者: Y Yibing Liu

rename some variables in ctc_edit_distance_op

上级 db694172
......@@ -38,11 +38,11 @@ class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
"(2-D tensor with shape [M x 1]) The indices for "
"hypothesis string");
AddInput("X2",
"(2-D tensor with shape [batch_size x 1]) The indices "
"(2-D tensor with shape [N x 1]) The indices "
"for reference string.");
AddAttr<bool>("normalized",
"(bool, default false) Indicated whether "
"normalize. the Output(Out) by the length of reference "
"normalize the Output(Out) by the length of reference "
"string (X2).")
.SetDefault(false);
AddOutput("Out",
......
......@@ -47,20 +47,19 @@ class CTCEditDistanceKernel : public framework::OpKernel<T> {
auto dist = dist_t.data<T>();
auto x1 = x1_t->data<T>();
auto x2 = x2_t->data<T>();
for (int i = 0; i < m + 1; ++i) {
dist[i * (n + 1)] = i; // dist[i][0] = i;
for (size_t i = 0; i < m + 1; ++i) {
dist[i * (n + 1)] = i;
}
for (int j = 0; j < n + 1; ++j) {
dist[j] = j; // dist[0][j] = j;
for (size_t j = 0; j < n + 1; ++j) {
dist[j] = j;
}
for (int i = 1; i < m + 1; ++i) {
for (int j = 1; j < n + 1; ++j) {
for (size_t i = 1; i < m + 1; ++i) {
for (size_t j = 1; j < n + 1; ++j) {
int cost = x1[i - 1] == x2[j - 1] ? 0 : 1;
int deletions = dist[(i - 1) * (n + 1) + j] + 1;
int insertions = dist[i * (n + 1) + (j - 1)] + 1;
int substitutions = dist[(i - 1) * (n + 1) + (j - 1)] + cost;
dist[i * (n + 1) + j] =
std::min(deletions, std::min(insertions, substitutions));
int dels = dist[(i - 1) * (n + 1) + j] + 1;
int ins = dist[i * (n + 1) + (j - 1)] + 1;
int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost;
dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs));
}
}
distance = dist[m * (n + 1) + n];
......
......@@ -6,9 +6,9 @@ from op_test import OpTest
def Levenshtein(hyp, ref):
""" Compute the Levenshtein distance between two strings.
:param hyp:
:param hyp: hypothesis string in index
:type hyp: list
:param ref:
:param ref: reference string in index
:type ref: list
"""
m = len(hyp)
......@@ -44,7 +44,6 @@ class TestCTCEditDistanceOp(OpTest):
distance = Levenshtein(hyp=x1, ref=x2)
if normalized is True:
distance = distance / len(x2)
print "distance = ", distance
self.attrs = {'normalized': normalized}
self.inputs = {'X1': x1, 'X2': x2}
self.outputs = {'Out': distance}
......@@ -54,7 +53,4 @@ class TestCTCEditDistanceOp(OpTest):
if __name__ == '__main__':
#x1 = ['c', 'a', 'f', 'e']
#x2 = ['c', 'o', 'f', 'f', 'e', 'e']
#print Levenshtein(x1, x2)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部