提交 2c1adb06 编写于 作者: Y Yibing Liu

Rename ctc_edit_distance_op to edit_distance_op

上级 36ec3e90
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/ctc_edit_distance_op.h" #include "paddle/operators/edit_distance_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class CTCEditDistanceOp : public framework::OperatorWithKernel { class EditDistanceOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -29,17 +29,16 @@ class CTCEditDistanceOp : public framework::OperatorWithKernel { ...@@ -29,17 +29,16 @@ class CTCEditDistanceOp : public framework::OperatorWithKernel {
} }
protected: protected:
framework::OpKernelType GetKernelType( framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::DataType::FP32, return framework::OpKernelType(framework::proto::DataType::FP32,
ctx.device_context()); ctx.device_context());
} }
}; };
class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CTCEditDistanceOpMaker(framework::OpProto *proto, EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X1", AddInput("X1",
"(2-D tensor with shape [M x 1]) The indices for " "(2-D tensor with shape [M x 1]) The indices for "
...@@ -54,10 +53,10 @@ class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -54,10 +53,10 @@ class CTCEditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(false); .SetDefault(false);
AddOutput("Out", AddOutput("Out",
"(2-D tensor with shape [1 x 1]) " "(2-D tensor with shape [1 x 1]) "
"The output distance of CTCEditDistance operator."); "The output distance of EditDistance operator.");
AddComment(R"DOC( AddComment(R"DOC(
CTCEditDistance operator computes the edit distance of two sequences, one named EditDistance operator computes the edit distance of two sequences, one named
hypothesis with length M and another named reference with length N. hypothesis with length M and another named reference with length N.
Edit distance, also called Levenshtein distance, measures how dissimilar two strings Edit distance, also called Levenshtein distance, measures how dissimilar two strings
...@@ -80,8 +79,7 @@ reference string N. ...@@ -80,8 +79,7 @@ reference string N.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(ctc_edit_distance, ops::CTCEditDistanceOp, REGISTER_OPERATOR(edit_distance, ops::EditDistanceOp, ops::EditDistanceOpMaker,
ops::CTCEditDistanceOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
ctc_edit_distance, edit_distance, ops::EditDistanceKernel<paddle::platform::CPUPlace, float>);
ops::CTCEditDistanceKernel<paddle::platform::CPUPlace, float>);
...@@ -65,7 +65,7 @@ __global__ void SetOutput(T* out, const T* dist, const int M, const int N, ...@@ -65,7 +65,7 @@ __global__ void SetOutput(T* out, const T* dist, const int M, const int N,
} }
template <typename Place, typename T> template <typename Place, typename T>
class CTCEditDistanceGPUKernel : public framework::OpKernel<T> { class EditDistanceGPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::Tensor>("Out"); auto* out_t = ctx.Output<framework::Tensor>("Out");
...@@ -110,8 +110,8 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel<T> { ...@@ -110,8 +110,8 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel<T> {
int z_n = slice < n + 1 ? 0 : slice - n; int z_n = slice < n + 1 ? 0 : slice - n;
int size = slice - (z_m + z_n) + 1; // number of elments in the same int size = slice - (z_m + z_n) + 1; // number of elments in the same
// anti-diagonal line to update // anti-diagonal line to update
int start = slice < n + 1 ? slice : z_n * (n + 1) - 1; // start index // the start index at which computes from
int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1;
Levenshtein<T><<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, Levenshtein<T><<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, m, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, m,
n, start); n, start);
...@@ -126,6 +126,6 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel<T> { ...@@ -126,6 +126,6 @@ class CTCEditDistanceGPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_CUDA_KERNEL(
ctc_edit_distance, edit_distance,
ops::CTCEditDistanceGPUKernel<paddle::platform::GPUPlace, float>); ops::EditDistanceGPUKernel<paddle::platform::CUDAPlace, float>);
...@@ -21,7 +21,7 @@ namespace paddle { ...@@ -21,7 +21,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class CTCEditDistanceKernel : public framework::OpKernel<T> { class EditDistanceKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::Tensor>("Out"); auto* out_t = ctx.Output<framework::Tensor>("Out");
......
...@@ -36,7 +36,7 @@ def Levenshtein(hyp, ref): ...@@ -36,7 +36,7 @@ def Levenshtein(hyp, ref):
class TestCTCEditDistanceOp(OpTest): class TestCTCEditDistanceOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "ctc_edit_distance" self.op_type = "edit_distance"
normalized = True normalized = True
x1 = np.array([0, 12, 3, 5]).astype("int32") x1 = np.array([0, 12, 3, 5]).astype("int32")
x2 = np.array([0, 12, 4, 7, 8]).astype("int32") x2 = np.array([0, 12, 4, 7, 8]).astype("int32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册