提交 db694172 编写于 作者: Y Yibing Liu

Add edit distance operator

上级 8efd0876
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/ctc_edit_distance_op.h"
namespace paddle {
namespace operators {
class CTCEditDistanceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X1"), "Input(X1) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("X2"), "Input(X2) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null.");
ctx->SetOutputDim("Out", {1});
}
};
class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CTCEditDistanceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X1",
"(2-D tensor with shape [M x 1]) The indices for "
"hypothesis string");
AddInput("X2",
"(2-D tensor with shape [batch_size x 1]) The indices "
"for reference string.");
AddAttr<bool>("normalized",
"(bool, default false) Indicated whether "
"normalize. the Output(Out) by the length of reference "
"string (X2).")
.SetDefault(false);
AddOutput("Out",
"(2-D tensor with shape [1 x 1]) "
"The output distance of CTCEditDistance operator.");
AddComment(R"DOC(
CTCEditDistance operator computes the edit distance of two sequences, one named
hypothesis and another named reference.
Edit distance measures how dissimilar two strings, one is hypothesis and another
is reference, are by counting the minimum number of operations to transform
one string into anthor.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(ctc_edit_distance, ops::CTCEditDistanceOp,
ops::CTCEditDistanceOpMaker);
REGISTER_OP_CPU_KERNEL(
ctc_edit_distance,
ops::CTCEditDistanceKernel<paddle::platform::CPUPlace, int32_t>,
ops::CTCEditDistanceKernel<paddle::platform::CPUPlace, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class CTCEditDistanceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::Tensor>("Out");
auto* x1_t = ctx.Input<framework::Tensor>("X1");
auto* x2_t = ctx.Input<framework::Tensor>("X2");
out_t->mutable_data<float>(ctx.GetPlace());
auto normalized = ctx.Attr<bool>("normalized");
auto m = x1_t->numel();
auto n = x2_t->numel();
float distance = 0.0;
if (m == 0) {
distance = n;
} else if (n == 0) {
distance = m;
} else {
framework::Tensor dist_t;
dist_t.Resize({m + 1, n + 1});
dist_t.mutable_data<T>(ctx.GetPlace());
auto dist = dist_t.data<T>();
auto x1 = x1_t->data<T>();
auto x2 = x2_t->data<T>();
for (int i = 0; i < m + 1; ++i) {
dist[i * (n + 1)] = i; // dist[i][0] = i;
}
for (int j = 0; j < n + 1; ++j) {
dist[j] = j; // dist[0][j] = j;
}
for (int i = 1; i < m + 1; ++i) {
for (int j = 1; j < n + 1; ++j) {
int cost = x1[i - 1] == x2[j - 1] ? 0 : 1;
int deletions = dist[(i - 1) * (n + 1) + j] + 1;
int insertions = dist[i * (n + 1) + (j - 1)] + 1;
int substitutions = dist[(i - 1) * (n + 1) + (j - 1)] + cost;
dist[i * (n + 1) + j] =
std::min(deletions, std::min(insertions, substitutions));
}
}
distance = dist[m * (n + 1) + n];
}
if (normalized) {
distance = distance / n;
}
auto out = out_t->data<float>();
out[0] = distance;
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
def Levenshtein(hyp, ref):
""" Compute the Levenshtein distance between two strings.
:param hyp:
:type hyp: list
:param ref:
:type ref: list
"""
m = len(hyp)
n = len(ref)
if m == 0:
return n
if n == 0:
return m
dist = np.zeros((m + 1, n + 1))
for i in range(0, m + 1):
dist[i][0] = i
for j in range(0, n + 1):
dist[0][j] = j
for i in range(1, m + 1):
for j in range(1, n + 1):
cost = 0 if hyp[i - 1] == ref[j - 1] else 1
deletion = dist[i - 1][j] + 1
insertion = dist[i][j - 1] + 1
substitution = dist[i - 1][j - 1] + cost
dist[i][j] = min(deletion, insertion, substitution)
return dist[m][n]
class TestCTCEditDistanceOp(OpTest):
def setUp(self):
self.op_type = "ctc_edit_distance"
normalized = True
x1 = np.array([0, 12, 3, 5]).astype("int64")
x2 = np.array([0, 12, 4, 7, 8]).astype("int64")
distance = Levenshtein(hyp=x1, ref=x2)
if normalized is True:
distance = distance / len(x2)
print "distance = ", distance
self.attrs = {'normalized': normalized}
self.inputs = {'X1': x1, 'X2': x2}
self.outputs = {'Out': distance}
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
#x1 = ['c', 'a', 'f', 'e']
#x2 = ['c', 'o', 'f', 'f', 'e', 'e']
#print Levenshtein(x1, x2)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册