未验证 提交 396fe483 编写于 作者: H huangjiyi 提交者: GitHub

Support static graph code generation for op edit_distance (#53297)

* update

* fix bug

* support parsing fixed kernel data_type

* update op_compat

* update
上级 005fee12
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class EditDistanceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return phi::KernelKey(framework::proto::VarType::FP32,
ctx.device_context().GetPlace());
}
};
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Hyps",
"2-D Tensor<int64_t>, or 2-D phi::DenseTensor<int64_t> with last "
"dimension being 1. "
"The indices for hypothesis strings.");
AddInput("Refs",
"2-D Tensor<int64_t>, or 2-D phi::DenseTensor<int64_t> with last "
"dimension being 1. "
"The indices for reference strings.");
AddInput("HypsLength",
"1-D Tensor<int64_t>. "
"Sequence length for hyps when hyps is a tensor")
.AsDispensable();
AddInput("RefsLength",
"1-D Tensor<int64_t>. "
"Sequence length for refs when refs is a tensor")
.AsDispensable();
AddOutput("SequenceNum", "The sequence count of current batch");
AddAttr<bool>("normalized",
"(bool, default false) Indicated whether to normalize "
"the edit distance by the length of reference string.")
.SetDefault(false);
AddOutput("Out",
"(2-D Tensor with shape [`batch_size` x 1]) "
"The output edit distances of EditDistance operator.");
AddComment(R"DOC(
EditDistance operator computes the edit distances between a batch of hypothesis
strings and their references.
Edit distance, also called Levenshtein distance, measures how dissimilar two strings
are by counting the minimum number of operations to transform one string into another.
The operations include insertion, deletion, and substitution.
For example, given hypothesis string A = "kitten" and reference B = "sitting",
A will be transformed into B at least after two substitutions and one
insertion:
"kitten" -> "sitten" -> "sittin" -> "sitting"
So the edit distance between A and B is 3.
Input(Hyps) is a 2-D Tensor or a 2-D phi::DenseTensor consisting of all the hypothesis strings.
And the `batch_size` reference strings are arranged in order in the same way in the
Input(Refs).
Output(Out) contains the `batch_size` results and each stands for the edit distance
for a pair of strings respectively. If Attr(normalized) is true, the edit distance
will be divided by the length of reference string.
)DOC");
}
};
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(edit_distance,
EditDistanceShapeFunctor,
PD_INFER_META(phi::EditDistanceInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(
edit_distance,
ops::EditDistanceOp,
ops::EditDistanceOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
EditDistanceShapeFunctor);
......@@ -318,13 +318,16 @@ phi::KernelKey GetExpectedKernelType(
{% if kernel["data_type"]["candidates"] | length == 1 %}
{% set data_type_arg = kernel["data_type"]["candidates"][0] %}
{% set inputs = op["inputs"] | map(attribute="fluid_name") | list %}
{% set attrs = op["attrs"] | map(attribute="fluid_name") | list %}
{% if data_type_arg in inputs %}
auto data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_arg | to_opmaker_name}});
{% if kernel["data_type"]["to_complex_flag"][0] %}
data_type = framework::ToComplexType(data_type);
{% endif %}
{% else %}{# it is an attribute and probably named dtype#}
{% elif data_type_arg in attrs %}{# it is an attribute and probably named dtype#}
auto data_type = framework::proto::VarType::Type(ctx.Attr<int>("{{data_type_arg}}"));
{% else %}
auto data_type = framework::TransToProtoVarType(phi::{{data_type_arg}});
{% endif %}
{% elif kernel["data_type"]["candidates"] | length == 2 %}
{% set data_type_args = kernel["data_type"]["candidates"] %}
......
......@@ -312,16 +312,6 @@
optional : seed_tensor
backward : dropout_grad
- op : edit_distance
args : (Tensor hyps, Tensor refs, Tensor hypslength, Tensor refslength, bool normalized = false)
output : Tensor(sequencenum), Tensor(out)
infer_meta :
func : EditDistanceInferMeta
kernel :
func : edit_distance
data_type: DataType::FLOAT32
optional : hypslength, refslength
- op : einsum
args : (Tensor[] x, str equation)
output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
......
......@@ -683,6 +683,16 @@
extra :
attrs : [bool fix_seed = false, int seed = 0]
- op : edit_distance
inputs :
hyps : Hyps
refs : Refs
hypslength : HypsLength
refslength : RefsLength
outputs :
sequencenum : SequenceNum
out : Out
- op : eig
inputs :
x : X
......
......@@ -565,6 +565,16 @@
data_type : x
backward : dot_grad
- op : edit_distance
args : (Tensor hyps, Tensor refs, Tensor hypslength, Tensor refslength, bool normalized = false)
output : Tensor(sequencenum), Tensor(out)
infer_meta :
func : EditDistanceInferMeta
kernel :
func : edit_distance
data_type : DataType::FLOAT32
optional : hypslength, refslength
- op : eig
args: (Tensor x)
output: Tensor(out_w), Tensor(out_v)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature EditDistanceOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("edit_distance",
{"Hyps", "Refs", "HypsLength", "RefsLength"},
{"normalized"},
{"SequenceNum", "Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(edit_distance, phi::EditDistanceOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册