edit_distance_op.cc 4.1 KB
Newer Older
Y
Yibing Liu 已提交
1 2
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Y
Yibing Liu 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yibing Liu 已提交
6

Y
Yibing Liu 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yibing Liu 已提交
8

Y
Yibing Liu 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yibing Liu 已提交
14

15
#include "paddle/operators/edit_distance_op.h"
Y
Yibing Liu 已提交
16 17 18 19

namespace paddle {
namespace operators {

20
class EditDistanceOp : public framework::OperatorWithKernel {
Y
Yibing Liu 已提交
21 22 23 24
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
25 26
    PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null.");
    PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null.");
Y
Yibing Liu 已提交
27
    PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null.");
28 29 30 31 32 33 34 35 36
    auto hyp_dims = ctx->GetInputDim("Hyps");
    auto ref_dims = ctx->GetInputDim("Refs");
    PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1,
                   "Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension "
                   "equal to 1.");
    PADDLE_ENFORCE(ref_dims.size() == 2 && ref_dims[1] == 1,
                   "Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
                   "equal to 1.");
    ctx->SetOutputDim("Out", ctx->GetInputDim("Refs"));
Y
Yibing Liu 已提交
37
  }
Y
Yibing Liu 已提交
38 39

 protected:
Y
Yibing Liu 已提交
40
  framework::OpKernelType GetExpectedKernelType(
Y
Yibing Liu 已提交
41
      const framework::ExecutionContext &ctx) const override {
42
    return framework::OpKernelType(framework::proto::DataType::FP32,
Y
Yibing Liu 已提交
43 44
                                   ctx.device_context());
  }
Y
Yibing Liu 已提交
45 46
};

47
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
Y
Yibing Liu 已提交
48
 public:
49
  EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
Y
Yibing Liu 已提交
50
      : OpProtoAndCheckerMaker(proto, op_checker) {
51
    AddInput("Hyps",
52
             "(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
53 54
             "The indices for hypothesis strings.");
    AddInput("Refs",
55
             "(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
56
             "The indices for reference strings.");
Y
Yibing Liu 已提交
57
    AddAttr<bool>("normalized",
58 59
                  "(bool, default false) Indicated whether to normalize "
                  "the edit distance by the length of reference string.")
Y
Yibing Liu 已提交
60 61
        .SetDefault(false);
    AddOutput("Out",
62 63
              "(2-D Tensor with shape [`batch_size` x 1]) "
              "The output edit distances of EditDistance operator.");
Y
Yibing Liu 已提交
64 65
    AddComment(R"DOC(

66 67
EditDistance operator computes the edit distances between a batch of hypothesis
strings and their references.
Y
Yibing Liu 已提交
68

69 70 71 72 73
Edit distance, also called Levenshtein distance, measures how dissimilar two strings
are by counting the minimum number of operations to transform one string into anthor.
Here the operations include insertion, deletion, and substitution. For example,
given hypothesis string A = "kitten" and reference B = "sitting", the edit distance
is 3 for A will be transformed into B at least after two substitutions and one
74
insertion:
75

76
   "kitten" -> "sitten" -> "sittin" -> "sitting"
Y
Yibing Liu 已提交
77

78 79 80
Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total
number denoted by `batch_size`, and the separation is specified by the LoD information.
And the `batch_size` reference strings are arranged in order in the same way in the
81 82
LoDTensor Input(Refs).

83 84
Output(Out) contains the `batch_size` results and each stands for the edit stance
for a pair of strings respectively. If Attr(normalized) is true, the edit distance
85
will be divided by the length of reference string.
Y
Yibing Liu 已提交
86 87 88 89 90 91 92 93 94
)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

95 96
REGISTER_OPERATOR(edit_distance, ops::EditDistanceOp, ops::EditDistanceOpMaker,
                  paddle::framework::EmptyGradOpMaker);
Y
Yibing Liu 已提交
97
REGISTER_OP_CPU_KERNEL(
98
    edit_distance, ops::EditDistanceKernel<paddle::platform::CPUPlace, float>);