edit_distance_op.cc 6.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yibing Liu 已提交
2

Y
Yibing Liu 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yibing Liu 已提交
6

Y
Yibing Liu 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yibing Liu 已提交
8

Y
Yibing Liu 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yibing Liu 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/edit_distance_op.h"
Y
Yibing Liu 已提交
16 17 18 19

namespace paddle {
namespace operators {

20
class EditDistanceOp : public framework::OperatorWithKernel {
Y
Yibing Liu 已提交
21 22 23 24
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
25 26 27 28 29
    OP_INOUT_CHECK(ctx->HasInput("Hyps"), "Input", "Hyps", "EditDistance");
    OP_INOUT_CHECK(ctx->HasInput("Refs"), "Input", "Refs", "EditDistance");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "EditDistance");
    OP_INOUT_CHECK(ctx->HasOutput("SequenceNum"), "Output", "SequenceNum",
                   "EditDistance");
30 31
    auto hyp_dims = ctx->GetInputDim("Hyps");
    auto ref_dims = ctx->GetInputDim("Refs");
32 33 34 35 36

    if (ctx->HasInput("HypsLength") && ctx->HasInput("RefsLength")) {
      auto hyp_length_dims = ctx->GetInputDim("HypsLength");
      auto ref_length_dims = ctx->GetInputDim("RefsLength");

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
      PADDLE_ENFORCE_EQ(
          hyp_dims.size() == 2 && ref_dims.size() == 2 &&
              hyp_dims[0] == ref_dims[0],
          true, platform::errors::InvalidArgument(
                    "Input(Hyps) and Input(Refs) must be 2-D Tensors with "
                    "identical first dimension. But received Input(Hyps): "
                    "input rank %u, input shape [%s]; received Input(Refs): "
                    "input rank %u, input shape [%s]",
                    hyp_dims.size(), hyp_dims, ref_dims.size(), ref_dims));
      PADDLE_ENFORCE_EQ(
          hyp_length_dims[0] == ref_length_dims[0] &&
              hyp_length_dims[0] == hyp_dims[0],
          true,
          platform::errors::InvalidArgument(
              "Input(HypsLength), Input(RefsLength) and Input(Hyps) "
              "should have identical first dimension. But received "
              "Input(HypsLength): input rank %u, input shape [%s]; "
              "received Input(RefsLength): input rank %u, input shape "
              "[%s]; received Input(Hyps): input rank %u, input shape "
              "[%s].",
              hyp_length_dims.size(), hyp_length_dims, ref_length_dims.size(),
              ref_length_dims, hyp_dims.size(), hyp_dims));
59
    } else {
60 61 62 63 64 65 66 67 68 69 70 71
      PADDLE_ENFORCE_EQ(
          hyp_dims.size() == 2 && hyp_dims[1] == 1, true,
          platform::errors::InvalidArgument(
              "Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension "
              "equal to 1. But received: input rank %u, input shape [%s].",
              hyp_dims.size(), hyp_dims));
      PADDLE_ENFORCE_EQ(
          ref_dims.size() == 2 && ref_dims[1] == 1, true,
          platform::errors::InvalidArgument(
              "Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
              "equal to 1. But received: input rank %u, input shape [%s].",
              ref_dims.size(), ref_dims));
72 73
    }

74
    ctx->SetOutputDim("Out", ctx->GetInputDim("Refs"));
75
    ctx->SetOutputDim("SequenceNum", {1});
Y
Yibing Liu 已提交
76
  }
Y
Yibing Liu 已提交
77 78

 protected:
Y
Yibing Liu 已提交
79
  framework::OpKernelType GetExpectedKernelType(
Y
Yibing Liu 已提交
80
      const framework::ExecutionContext &ctx) const override {
81
    return framework::OpKernelType(framework::proto::VarType::FP32,
Y
Yibing Liu 已提交
82 83
                                   ctx.device_context());
  }
Y
Yibing Liu 已提交
84 85
};

86
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
Y
Yibing Liu 已提交
87
 public:
Y
Yu Yang 已提交
88
  void Make() override {
89
    AddInput("Hyps",
90 91
             "2-D Tensor<int64_t>, or 2-D LoDTensor<int64_t> with last "
             "dimension being 1. "
92 93
             "The indices for hypothesis strings.");
    AddInput("Refs",
94 95
             "2-D Tensor<int64_t>, or 2-D LoDTensor<int64_t> with last "
             "dimension being 1. "
96
             "The indices for reference strings.");
97 98 99 100 101 102 103 104
    AddInput("HypsLength",
             "1-D Tensor<int64_t>. "
             "Sequence length for hyps when hyps is a tensor")
        .AsDispensable();
    AddInput("RefsLength",
             "1-D Tensor<int64_t>. "
             "Sequence length for refs when refs is a tensor")
        .AsDispensable();
105
    AddOutput("SequenceNum", "The sequence count of current batch");
Y
Yibing Liu 已提交
106
    AddAttr<bool>("normalized",
107 108
                  "(bool, default false) Indicated whether to normalize "
                  "the edit distance by the length of reference string.")
Y
Yibing Liu 已提交
109 110
        .SetDefault(false);
    AddOutput("Out",
111 112
              "(2-D Tensor with shape [`batch_size` x 1]) "
              "The output edit distances of EditDistance operator.");
Y
Yibing Liu 已提交
113 114
    AddComment(R"DOC(

115 116
EditDistance operator computes the edit distances between a batch of hypothesis
strings and their references.
Y
Yibing Liu 已提交
117

118
Edit distance, also called Levenshtein distance, measures how dissimilar two strings
R
ruri 已提交
119 120 121 122 123
are by counting the minimum number of operations to transform one string into another.
The operations include insertion, deletion, and substitution. 

For example, given hypothesis string A = "kitten" and reference B = "sitting",
A will be transformed into B at least after two substitutions and one
124
insertion:
125

126
   "kitten" -> "sitten" -> "sittin" -> "sitting"
Y
Yibing Liu 已提交
127

R
ruri 已提交
128 129
So the edit distance between A and B is 3.

130
Input(Hyps) is a 2-D Tensor or a 2-D LoDTensor consisting of all the hypothesis strings.
131
And the `batch_size` reference strings are arranged in order in the same way in the
132
Input(Refs).
133

134
Output(Out) contains the `batch_size` results and each stands for the edit distance
135
for a pair of strings respectively. If Attr(normalized) is true, the edit distance
136
will be divided by the length of reference string.
Y
Yibing Liu 已提交
137 138 139 140 141 142 143 144 145
)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

H
hong 已提交
146 147 148 149
REGISTER_OPERATOR(
    edit_distance, ops::EditDistanceOp, ops::EditDistanceOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
Y
Yibing Liu 已提交
150
REGISTER_OP_CPU_KERNEL(
151
    edit_distance, ops::EditDistanceKernel<paddle::platform::CPUPlace, float>);