提交 b7a4e3d7 编写于 作者: Y Yibing Liu

rename some variables in ctc_edit_distance_op

上级 db694172
...@@ -38,11 +38,11 @@ class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -38,11 +38,11 @@ class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
"(2-D tensor with shape [M x 1]) The indices for " "(2-D tensor with shape [M x 1]) The indices for "
"hypothesis string"); "hypothesis string");
AddInput("X2", AddInput("X2",
"(2-D tensor with shape [batch_size x 1]) The indices " "(2-D tensor with shape [N x 1]) The indices "
"for reference string."); "for reference string.");
AddAttr<bool>("normalized", AddAttr<bool>("normalized",
"(bool, default false) Indicated whether " "(bool, default false) Indicated whether "
"normalize. the Output(Out) by the length of reference " "normalize the Output(Out) by the length of reference "
"string (X2).") "string (X2).")
.SetDefault(false); .SetDefault(false);
AddOutput("Out", AddOutput("Out",
......
...@@ -47,20 +47,19 @@ class CTCEditDistanceKernel : public framework::OpKernel<T> { ...@@ -47,20 +47,19 @@ class CTCEditDistanceKernel : public framework::OpKernel<T> {
auto dist = dist_t.data<T>(); auto dist = dist_t.data<T>();
auto x1 = x1_t->data<T>(); auto x1 = x1_t->data<T>();
auto x2 = x2_t->data<T>(); auto x2 = x2_t->data<T>();
for (int i = 0; i < m + 1; ++i) { for (size_t i = 0; i < m + 1; ++i) {
dist[i * (n + 1)] = i; // dist[i][0] = i; dist[i * (n + 1)] = i;
} }
for (int j = 0; j < n + 1; ++j) { for (size_t j = 0; j < n + 1; ++j) {
dist[j] = j; // dist[0][j] = j; dist[j] = j;
} }
for (int i = 1; i < m + 1; ++i) { for (size_t i = 1; i < m + 1; ++i) {
for (int j = 1; j < n + 1; ++j) { for (size_t j = 1; j < n + 1; ++j) {
int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; int cost = x1[i - 1] == x2[j - 1] ? 0 : 1;
int deletions = dist[(i - 1) * (n + 1) + j] + 1; int dels = dist[(i - 1) * (n + 1) + j] + 1;
int insertions = dist[i * (n + 1) + (j - 1)] + 1; int ins = dist[i * (n + 1) + (j - 1)] + 1;
int substitutions = dist[(i - 1) * (n + 1) + (j - 1)] + cost; int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost;
dist[i * (n + 1) + j] = dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs));
std::min(deletions, std::min(insertions, substitutions));
} }
} }
distance = dist[m * (n + 1) + n]; distance = dist[m * (n + 1) + n];
......
...@@ -6,9 +6,9 @@ from op_test import OpTest ...@@ -6,9 +6,9 @@ from op_test import OpTest
def Levenshtein(hyp, ref): def Levenshtein(hyp, ref):
""" Compute the Levenshtein distance between two strings. """ Compute the Levenshtein distance between two strings.
:param hyp: :param hyp: hypothesis string in index
:type hyp: list :type hyp: list
:param ref: :param ref: reference string in index
:type ref: list :type ref: list
""" """
m = len(hyp) m = len(hyp)
...@@ -44,7 +44,6 @@ class TestCTCEditDistanceOp(OpTest): ...@@ -44,7 +44,6 @@ class TestCTCEditDistanceOp(OpTest):
distance = Levenshtein(hyp=x1, ref=x2) distance = Levenshtein(hyp=x1, ref=x2)
if normalized is True: if normalized is True:
distance = distance / len(x2) distance = distance / len(x2)
print "distance = ", distance
self.attrs = {'normalized': normalized} self.attrs = {'normalized': normalized}
self.inputs = {'X1': x1, 'X2': x2} self.inputs = {'X1': x1, 'X2': x2}
self.outputs = {'Out': distance} self.outputs = {'Out': distance}
...@@ -54,7 +53,4 @@ class TestCTCEditDistanceOp(OpTest): ...@@ -54,7 +53,4 @@ class TestCTCEditDistanceOp(OpTest):
if __name__ == '__main__': if __name__ == '__main__':
#x1 = ['c', 'a', 'f', 'e']
#x2 = ['c', 'o', 'f', 'f', 'e', 'e']
#print Levenshtein(x1, x2)
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册