edit_distance_op.cc 5.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yibing Liu 已提交
2

Y
Yibing Liu 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yibing Liu 已提交
6

Y
Yibing Liu 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yibing Liu 已提交
8

Y
Yibing Liu 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yibing Liu 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/edit_distance_op.h"
Y
Yibing Liu 已提交
16 17 18 19

namespace paddle {
namespace operators {

20
class EditDistanceOp : public framework::OperatorWithKernel {
Y
Yibing Liu 已提交
21 22 23 24
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
25 26
    PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null.");
    PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null.");
Y
Yibing Liu 已提交
27
    PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null.");
28 29
    PADDLE_ENFORCE(ctx->HasOutput("SequenceNum"),
                   "Output(SequenceNum) shouldn't be null.");
30 31
    auto hyp_dims = ctx->GetInputDim("Hyps");
    auto ref_dims = ctx->GetInputDim("Refs");
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

    if (ctx->HasInput("HypsLength") && ctx->HasInput("RefsLength")) {
      auto hyp_length_dims = ctx->GetInputDim("HypsLength");
      auto ref_length_dims = ctx->GetInputDim("RefsLength");

      PADDLE_ENFORCE(hyp_dims.size() == 2 && ref_dims.size() == 2 &&
                         hyp_dims[0] == ref_dims[0],
                     "Input(Hyps) and Input(Refs) must be 2-D Tensors with "
                     "identical first dimension");
      PADDLE_ENFORCE(hyp_length_dims[0] == ref_length_dims[0] &&
                         hyp_length_dims[0] == hyp_dims[0],
                     "Input(HypsLength), Input(RefsLength) and Input(Hyps) "
                     "should have identical first dimension");
    } else {
      PADDLE_ENFORCE(
          hyp_dims.size() == 2 && hyp_dims[1] == 1,
          "Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension "
          "equal to 1.");
      PADDLE_ENFORCE(
          ref_dims.size() == 2 && ref_dims[1] == 1,
          "Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
          "equal to 1.");
    }

56
    ctx->SetOutputDim("Out", ctx->GetInputDim("Refs"));
57
    ctx->SetOutputDim("SequenceNum", {1});
Y
Yibing Liu 已提交
58
  }
Y
Yibing Liu 已提交
59 60

 protected:
Y
Yibing Liu 已提交
61
  framework::OpKernelType GetExpectedKernelType(
Y
Yibing Liu 已提交
62
      const framework::ExecutionContext &ctx) const override {
63
    return framework::OpKernelType(framework::proto::VarType::FP32,
Y
Yibing Liu 已提交
64 65
                                   ctx.device_context());
  }
Y
Yibing Liu 已提交
66 67
};

68
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
Y
Yibing Liu 已提交
69
 public:
Y
Yu Yang 已提交
70
  void Make() override {
71
    AddInput("Hyps",
72 73
             "2-D Tensor<int64_t>, or 2-D LoDTensor<int64_t> with last "
             "dimension being 1. "
74 75
             "The indices for hypothesis strings.");
    AddInput("Refs",
76 77
             "2-D Tensor<int64_t>, or 2-D LoDTensor<int64_t> with last "
             "dimension being 1. "
78
             "The indices for reference strings.");
79 80 81 82 83 84 85 86
    AddInput("HypsLength",
             "1-D Tensor<int64_t>. "
             "Sequence length for hyps when hyps is a tensor")
        .AsDispensable();
    AddInput("RefsLength",
             "1-D Tensor<int64_t>. "
             "Sequence length for refs when refs is a tensor")
        .AsDispensable();
87
    AddOutput("SequenceNum", "The sequence count of current batch");
Y
Yibing Liu 已提交
88
    AddAttr<bool>("normalized",
89 90
                  "(bool, default false) Indicated whether to normalize "
                  "the edit distance by the length of reference string.")
Y
Yibing Liu 已提交
91 92
        .SetDefault(false);
    AddOutput("Out",
93 94
              "(2-D Tensor with shape [`batch_size` x 1]) "
              "The output edit distances of EditDistance operator.");
Y
Yibing Liu 已提交
95 96
    AddComment(R"DOC(

97 98
EditDistance operator computes the edit distances between a batch of hypothesis
strings and their references.
Y
Yibing Liu 已提交
99

100 101 102 103 104
Edit distance, also called Levenshtein distance, measures how dissimilar two strings
are by counting the minimum number of operations to transform one string into anthor.
Here the operations include insertion, deletion, and substitution. For example,
given hypothesis string A = "kitten" and reference B = "sitting", the edit distance
is 3 for A will be transformed into B at least after two substitutions and one
105
insertion:
106

107
   "kitten" -> "sitten" -> "sittin" -> "sitting"
Y
Yibing Liu 已提交
108

109
Input(Hyps) is a 2-D Tensor or a 2-D LoDTensor consisting of all the hypothesis strings.
110
And the `batch_size` reference strings are arranged in order in the same way in the
111
Input(Refs).
112

113
Output(Out) contains the `batch_size` results and each stands for the edit distance
114
for a pair of strings respectively. If Attr(normalized) is true, the edit distance
115
will be divided by the length of reference string.
Y
Yibing Liu 已提交
116 117 118 119 120 121 122 123 124
)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

125 126
REGISTER_OPERATOR(edit_distance, ops::EditDistanceOp, ops::EditDistanceOpMaker,
                  paddle::framework::EmptyGradOpMaker);
Y
Yibing Liu 已提交
127
REGISTER_OP_CPU_KERNEL(
128
    edit_distance, ops::EditDistanceKernel<paddle::platform::CPUPlace, float>);