edit_distance_op.cc 5.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yibing Liu 已提交
2

Y
Yibing Liu 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yibing Liu 已提交
6

Y
Yibing Liu 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yibing Liu 已提交
8

Y
Yibing Liu 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yibing Liu 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/edit_distance_op.h"
Y
Yibing Liu 已提交
16 17 18 19

namespace paddle {
namespace operators {

20
class EditDistanceOp : public framework::OperatorWithKernel {
Y
Yibing Liu 已提交
21 22 23 24
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
25 26
    PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null.");
    PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null.");
Y
Yibing Liu 已提交
27
    PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null.");
28 29
    PADDLE_ENFORCE(ctx->HasOutput("SequenceNum"),
                   "Output(SequenceNum) shouldn't be null.");
30 31
    auto hyp_dims = ctx->GetInputDim("Hyps");
    auto ref_dims = ctx->GetInputDim("Refs");
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

    if (ctx->HasInput("HypsLength") && ctx->HasInput("RefsLength")) {
      auto hyp_length_dims = ctx->GetInputDim("HypsLength");
      auto ref_length_dims = ctx->GetInputDim("RefsLength");

      PADDLE_ENFORCE(hyp_dims.size() == 2 && ref_dims.size() == 2 &&
                         hyp_dims[0] == ref_dims[0],
                     "Input(Hyps) and Input(Refs) must be 2-D Tensors with "
                     "identical first dimension");
      PADDLE_ENFORCE(hyp_length_dims[0] == ref_length_dims[0] &&
                         hyp_length_dims[0] == hyp_dims[0],
                     "Input(HypsLength), Input(RefsLength) and Input(Hyps) "
                     "should have identical first dimension");
    } else {
      PADDLE_ENFORCE(
          hyp_dims.size() == 2 && hyp_dims[1] == 1,
          "Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension "
          "equal to 1.");
      PADDLE_ENFORCE(
          ref_dims.size() == 2 && ref_dims[1] == 1,
          "Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
          "equal to 1.");
    }

56
    ctx->SetOutputDim("Out", ctx->GetInputDim("Refs"));
57
    ctx->SetOutputDim("SequenceNum", {1});
Y
Yibing Liu 已提交
58
  }
Y
Yibing Liu 已提交
59 60

 protected:
Y
Yibing Liu 已提交
61
  framework::OpKernelType GetExpectedKernelType(
Y
Yibing Liu 已提交
62
      const framework::ExecutionContext &ctx) const override {
63
    return framework::OpKernelType(framework::proto::VarType::FP32,
Y
Yibing Liu 已提交
64 65
                                   ctx.device_context());
  }
Y
Yibing Liu 已提交
66 67
};

68
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
Y
Yibing Liu 已提交
69
 public:
Y
Yu Yang 已提交
70
  void Make() override {
71
    AddInput("Hyps",
72 73
             "2-D Tensor<int64_t>, or 2-D LoDTensor<int64_t> with last "
             "dimension being 1. "
74 75
             "The indices for hypothesis strings.");
    AddInput("Refs",
76 77
             "2-D Tensor<int64_t>, or 2-D LoDTensor<int64_t> with last "
             "dimension being 1. "
78
             "The indices for reference strings.");
79 80 81 82 83 84 85 86
    AddInput("HypsLength",
             "1-D Tensor<int64_t>. "
             "Sequence length for hyps when hyps is a tensor")
        .AsDispensable();
    AddInput("RefsLength",
             "1-D Tensor<int64_t>. "
             "Sequence length for refs when refs is a tensor")
        .AsDispensable();
87
    AddOutput("SequenceNum", "The sequence count of current batch");
Y
Yibing Liu 已提交
88
    AddAttr<bool>("normalized",
89 90
                  "(bool, default false) Indicated whether to normalize "
                  "the edit distance by the length of reference string.")
Y
Yibing Liu 已提交
91 92
        .SetDefault(false);
    AddOutput("Out",
93 94
              "(2-D Tensor with shape [`batch_size` x 1]) "
              "The output edit distances of EditDistance operator.");
Y
Yibing Liu 已提交
95 96
    AddComment(R"DOC(

97 98
EditDistance operator computes the edit distances between a batch of hypothesis
strings and their references.
Y
Yibing Liu 已提交
99

100
Edit distance, also called Levenshtein distance, measures how dissimilar two strings
R
ruri 已提交
101 102 103 104 105
are by counting the minimum number of operations to transform one string into another.
The operations include insertion, deletion, and substitution. 

For example, given hypothesis string A = "kitten" and reference B = "sitting",
A will be transformed into B at least after two substitutions and one
106
insertion:
107

108
   "kitten" -> "sitten" -> "sittin" -> "sitting"
Y
Yibing Liu 已提交
109

R
ruri 已提交
110 111
So the edit distance between A and B is 3.

112
Input(Hyps) is a 2-D Tensor or a 2-D LoDTensor consisting of all the hypothesis strings.
113
And the `batch_size` reference strings are arranged in order in the same way in the
114
Input(Refs).
115

116
Output(Out) contains the `batch_size` results and each stands for the edit distance
117
for a pair of strings respectively. If Attr(normalized) is true, the edit distance
118
will be divided by the length of reference string.
Y
Yibing Liu 已提交
119 120 121 122 123 124 125 126 127
)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

H
hong 已提交
128 129 130 131
REGISTER_OPERATOR(
    edit_distance, ops::EditDistanceOp, ops::EditDistanceOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
Y
Yibing Liu 已提交
132
REGISTER_OP_CPU_KERNEL(
133
    edit_distance, ops::EditDistanceKernel<paddle::platform::CPUPlace, float>);