edit_distance_op.cc 4.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yibing Liu 已提交
2

Y
Yibing Liu 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yibing Liu 已提交
6

Y
Yibing Liu 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yibing Liu 已提交
8

Y
Yibing Liu 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yibing Liu 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/edit_distance_op.h"
Y
Yibing Liu 已提交
16 17 18 19

namespace paddle {
namespace operators {

20
class EditDistanceOp : public framework::OperatorWithKernel {
Y
Yibing Liu 已提交
21 22 23 24
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
25 26
    PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null.");
    PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null.");
Y
Yibing Liu 已提交
27
    PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null.");
28 29
    PADDLE_ENFORCE(ctx->HasOutput("SequenceNum"),
                   "Output(SequenceNum) shouldn't be null.");
30 31 32 33 34 35 36 37 38
    auto hyp_dims = ctx->GetInputDim("Hyps");
    auto ref_dims = ctx->GetInputDim("Refs");
    PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1,
                   "Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension "
                   "equal to 1.");
    PADDLE_ENFORCE(ref_dims.size() == 2 && ref_dims[1] == 1,
                   "Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
                   "equal to 1.");
    ctx->SetOutputDim("Out", ctx->GetInputDim("Refs"));
39
    ctx->SetOutputDim("SequenceNum", {1});
Y
Yibing Liu 已提交
40
  }
Y
Yibing Liu 已提交
41 42

 protected:
Y
Yibing Liu 已提交
43
  framework::OpKernelType GetExpectedKernelType(
Y
Yibing Liu 已提交
44
      const framework::ExecutionContext &ctx) const override {
45
    return framework::OpKernelType(framework::proto::VarType::FP32,
Y
Yibing Liu 已提交
46 47
                                   ctx.device_context());
  }
Y
Yibing Liu 已提交
48 49
};

50
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
Y
Yibing Liu 已提交
51
 public:
52
  EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
Y
Yibing Liu 已提交
53
      : OpProtoAndCheckerMaker(proto, op_checker) {
54
    AddInput("Hyps",
55
             "(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
56 57
             "The indices for hypothesis strings.");
    AddInput("Refs",
58
             "(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
59
             "The indices for reference strings.");
60
    AddOutput("SequenceNum", "The sequence count of current batch");
Y
Yibing Liu 已提交
61
    AddAttr<bool>("normalized",
62 63
                  "(bool, default false) Indicated whether to normalize "
                  "the edit distance by the length of reference string.")
Y
Yibing Liu 已提交
64 65
        .SetDefault(false);
    AddOutput("Out",
66 67
              "(2-D Tensor with shape [`batch_size` x 1]) "
              "The output edit distances of EditDistance operator.");
Y
Yibing Liu 已提交
68 69
    AddComment(R"DOC(

70 71
EditDistance operator computes the edit distances between a batch of hypothesis
strings and their references.
Y
Yibing Liu 已提交
72

73 74 75 76 77
Edit distance, also called Levenshtein distance, measures how dissimilar two strings
are by counting the minimum number of operations to transform one string into anthor.
Here the operations include insertion, deletion, and substitution. For example,
given hypothesis string A = "kitten" and reference B = "sitting", the edit distance
is 3 for A will be transformed into B at least after two substitutions and one
78
insertion:
79

80
   "kitten" -> "sitten" -> "sittin" -> "sitting"
Y
Yibing Liu 已提交
81

82 83 84
Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total
number denoted by `batch_size`, and the separation is specified by the LoD information.
And the `batch_size` reference strings are arranged in order in the same way in the
85 86
LoDTensor Input(Refs).

87 88
Output(Out) contains the `batch_size` results and each stands for the edit stance
for a pair of strings respectively. If Attr(normalized) is true, the edit distance
89
will be divided by the length of reference string.
Y
Yibing Liu 已提交
90 91 92 93 94 95 96 97 98
)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

99 100
REGISTER_OPERATOR(edit_distance, ops::EditDistanceOp, ops::EditDistanceOpMaker,
                  paddle::framework::EmptyGradOpMaker);
Y
Yibing Liu 已提交
101
REGISTER_OP_CPU_KERNEL(
102
    edit_distance, ops::EditDistanceKernel<paddle::platform::CPUPlace, float>);