diff --git a/paddle/fluid/operators/edit_distance_op.cc b/paddle/fluid/operators/edit_distance_op.cc index 69f3354996a1176f792594e047c6ae73e51df3f0..8197b115cddcc9831ecc15b4c35b286462d38863 100644 --- a/paddle/fluid/operators/edit_distance_op.cc +++ b/paddle/fluid/operators/edit_distance_op.cc @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/edit_distance_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { @@ -21,72 +23,6 @@ class EditDistanceOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Hyps"), "Input", "Hyps", "EditDistance"); - OP_INOUT_CHECK(ctx->HasInput("Refs"), "Input", "Refs", "EditDistance"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "EditDistance"); - OP_INOUT_CHECK( - ctx->HasOutput("SequenceNum"), "Output", "SequenceNum", "EditDistance"); - auto hyp_dims = ctx->GetInputDim("Hyps"); - auto ref_dims = ctx->GetInputDim("Refs"); - - if (ctx->HasInput("HypsLength") && ctx->HasInput("RefsLength")) { - auto hyp_length_dims = ctx->GetInputDim("HypsLength"); - auto ref_length_dims = ctx->GetInputDim("RefsLength"); - - PADDLE_ENFORCE_EQ( - hyp_dims.size() == 2 && ref_dims.size() == 2 && - hyp_dims[0] == ref_dims[0], - true, - platform::errors::InvalidArgument( - "Input(Hyps) and Input(Refs) must be 2-D Tensors with " - "identical first dimension. But received Input(Hyps): " - "input rank %u, input shape [%s]; received Input(Refs): " - "input rank %u, input shape [%s]", - hyp_dims.size(), - hyp_dims, - ref_dims.size(), - ref_dims)); - PADDLE_ENFORCE_EQ( - hyp_length_dims[0] == ref_length_dims[0] && - hyp_length_dims[0] == hyp_dims[0], - true, - platform::errors::InvalidArgument( - "Input(HypsLength), Input(RefsLength) and Input(Hyps) " - "should have identical first dimension. But received " - "Input(HypsLength): input rank %u, input shape [%s]; " - "received Input(RefsLength): input rank %u, input shape " - "[%s]; received Input(Hyps): input rank %u, input shape " - "[%s].", - hyp_length_dims.size(), - hyp_length_dims, - ref_length_dims.size(), - ref_length_dims, - hyp_dims.size(), - hyp_dims)); - } else { - PADDLE_ENFORCE_EQ( - hyp_dims.size() == 2 && hyp_dims[1] == 1, - true, - platform::errors::InvalidArgument( - "Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension " - "equal to 1. But received: input rank %u, input shape [%s].", - hyp_dims.size(), - hyp_dims)); - PADDLE_ENFORCE_EQ( - ref_dims.size() == 2 && ref_dims[1] == 1, - true, - platform::errors::InvalidArgument( - "Input(Refs) must be a 2-D LoDTensor with the 2nd dimension " - "equal to 1. But received: input rank %u, input shape [%s].", - ref_dims.size(), - ref_dims)); - } - - ctx->SetOutputDim("Out", ctx->GetInputDim("Refs")); - ctx->SetOutputDim("SequenceNum", {1}); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -153,6 +89,10 @@ will be divided by the length of reference string. } // namespace operators } // namespace paddle +DECLARE_INFER_SHAPE_FUNCTOR(edit_distance, + EditDistanceShapeFunctor, + PD_INFER_META(phi::EditDistanceInferMeta)); + namespace ops = paddle::operators; REGISTER_OPERATOR( @@ -160,6 +100,5 @@ REGISTER_OPERATOR( ops::EditDistanceOp, ops::EditDistanceOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - edit_distance, ops::EditDistanceKernel); + paddle::framework::EmptyGradOpMaker, + EditDistanceShapeFunctor); diff --git a/paddle/fluid/operators/edit_distance_op.cu b/paddle/fluid/operators/edit_distance_op.cu deleted file mode 100644 index 681f91ffa689df337c683149057fa3987ed45871..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/edit_distance_op.cu +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/mixed_vector.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/edit_distance_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_info.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using platform::PADDLE_CUDA_NUM_THREADS; - -template -__global__ void FillFirstRow(T* dist, const int N) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - if (idx < N + 1) { - dist[idx] = idx; - } -} - -template -__global__ void FillFirstColumn(T* dist, const int M, const int N) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - if (idx < M + 1) { - dist[idx * (N + 1)] = idx; - } -} - -template -__global__ void Levenshtein(T* dist, - const int64_t* x1, - const int64_t* x2, - const int M, - const int N, - const int start) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = N; - int index = start + idx * offset; - int row = index / (N + 1); - int col = index % (N + 1); - if (row > 0 && col > 0 && row < M + 1 && col < N + 1) { - int cost = x1[row - 1] == x2[col - 1] ? 0 : 1; - int dels = dist[(row - 1) * (N + 1) + col] + 1; - int ins = dist[row * (N + 1) + col - 1] + 1; - int subs = dist[(row - 1) * (N + 1) + (col - 1)] + cost; - dist[index] = min(dels, min(ins, subs)); - } -} - -template -__global__ void SetOutput( - T* out, const T* dist, const int M, const int N, bool normalized) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - if (idx == 0) { - out[0] = normalized ? dist[M * (N + 1) + N] / N : dist[M * (N + 1) + N]; - } -} - -template -class EditDistanceGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto* out_t = ctx.Output("Out"); - - auto* x1_t = ctx.Input("Hyps"); - auto* x2_t = ctx.Input("Refs"); - auto* sequence_num = ctx.Output("SequenceNum"); - sequence_num->mutable_data(ctx.GetPlace()); - auto batch_size = x1_t->dims()[0]; - - auto normalized = ctx.Attr("normalized"); - auto stream = - reinterpret_cast(ctx.device_context()).stream(); - - framework::Vector hyp_lod(batch_size + 1); - framework::Vector ref_lod(batch_size + 1); - - bool use_length = ctx.HasInput("HypsLength"); - - if (use_length) { - // build lod when using padding - auto* hyp_length = ctx.Input("HypsLength"); - auto* ref_length = ctx.Input("RefsLength"); - - framework::Tensor hyp_length_cpu; - framework::Tensor ref_length_cpu; - framework::TensorCopy(*hyp_length, platform::CPUPlace(), &hyp_length_cpu); - framework::TensorCopy(*ref_length, platform::CPUPlace(), &ref_length_cpu); - - for (auto i = 0; i < batch_size; i++) { - hyp_lod[i + 1] = hyp_lod[i] + hyp_length_cpu.data()[i]; - ref_lod[i + 1] = ref_lod[i] + ref_length_cpu.data()[i]; - } - - } else { - hyp_lod = x1_t->lod()[0]; - ref_lod = x2_t->lod()[0]; - } - - if (normalized) { - for (size_t i = 1; i < ref_lod.size(); ++i) { - PADDLE_ENFORCE_GT(ref_lod[i], - ref_lod[i - 1], - platform::errors::InvalidArgument( - "Reference string %d is empty.", i)); - } - } - - const size_t num_strs = hyp_lod.size() - 1; - phi::funcs::SetConstant set_constant; - set_constant(ctx.template device_context(), - sequence_num, - static_cast(num_strs)); - - out_t->Resize({static_cast(num_strs), 1}); - out_t->mutable_data(ctx.GetPlace()); - auto out = out_t->data(); - - T distance = 0.0; - for (size_t num = 0; num < num_strs; num++) { - auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); - auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); - if (m == 0 || n == 0) { - distance = std::max(m, n); - if (normalized) { - distance = distance / n; - } - memory::Copy(ctx.GetPlace(), - out + num, - platform::CPUPlace(), - &distance, - sizeof(T), - stream); - } else { - framework::Tensor dist_t; - dist_t.Resize({m + 1, n + 1}); - dist_t.mutable_data(ctx.GetPlace()); - auto dist = dist_t.data(); - auto hyp_offset = use_length ? num * x1_t->dims()[1] : hyp_lod[num]; - auto ref_offset = use_length ? num * x2_t->dims()[1] : ref_lod[num]; - auto x1 = x1_t->data() + hyp_offset; - auto x2 = x2_t->data() + ref_offset; - - FillFirstColumn<<<1 + m / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>(dist, m, n); - - FillFirstRow<<<1 + n / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>(dist, n); - - // Compute the elements of distance matrix in the anti-diagonal diretion - for (int64_t slice = 2; slice < m + n + 1; ++slice) { - int z_m = slice < m + 1 ? 0 : slice - m; - int z_n = slice < n + 1 ? 0 : slice - n; - int size = slice - (z_m + z_n) + 1; // number of elments in the same - // anti-diagonal line to update - // the start index at which computes from - int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1; - Levenshtein<<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>(dist, x1, x2, m, n, start); - } - SetOutput<<<1, 1, 0, stream>>>(out + num, dist, m, n, normalized); - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL( - edit_distance, - ops::EditDistanceGPUKernel); diff --git a/paddle/fluid/operators/edit_distance_op.h b/paddle/fluid/operators/edit_distance_op.h deleted file mode 100644 index 96e9c4281c491c6e1b1ea0f9e7454d01cd7a3fbc..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/edit_distance_op.h +++ /dev/null @@ -1,126 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/mixed_vector.h" -#include "paddle/fluid/framework/op_registry.h" -namespace paddle { -namespace operators { - -template -class EditDistanceKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto* out_t = ctx.Output("Out"); - - auto* x1_t = ctx.Input("Hyps"); - auto* x2_t = ctx.Input("Refs"); - auto* sequence_num = ctx.Output("SequenceNum"); - int64_t* seq_num_data = sequence_num->mutable_data(ctx.GetPlace()); - auto batch_size = x1_t->dims()[0]; - - auto normalized = ctx.Attr("normalized"); - - framework::Vector hyp_lod(batch_size + 1); - framework::Vector ref_lod(batch_size + 1); - - bool use_length = ctx.HasInput("HypsLength"); - - if (use_length) { - // build lod when using padding - auto hyp_length_ptr = - ctx.Input("HypsLength")->data(); - auto ref_length_ptr = - ctx.Input("RefsLength")->data(); - - for (auto i = 0; i < batch_size; i++) { - hyp_lod[i + 1] = hyp_lod[i] + hyp_length_ptr[i]; - ref_lod[i + 1] = ref_lod[i] + ref_length_ptr[i]; - } - - } else { - hyp_lod = x1_t->lod()[0]; - ref_lod = x2_t->lod()[0]; - } - - if (normalized) { - for (size_t i = 1; i < ref_lod.size(); ++i) { - PADDLE_ENFORCE_GT(ref_lod[i], - ref_lod[i - 1], - platform::errors::InvalidArgument( - "Reference string %d is empty.", i)); - } - } - auto num_strs = hyp_lod.size() - 1; - *seq_num_data = static_cast(num_strs); - - out_t->Resize({static_cast(num_strs), 1}); - out_t->mutable_data(ctx.GetPlace()); - auto out = out_t->data(); - - T distance = 0.0; - for (size_t num = 0; num < num_strs; ++num) { - auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); - auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); - - if (m == 0) { - distance = n; - } else if (n == 0) { - distance = m; - } else { - framework::Tensor dist_t; - dist_t.Resize({m + 1, n + 1}); - dist_t.mutable_data(ctx.GetPlace()); - auto dist = dist_t.data(); - auto hyp_offset = use_length ? num * x1_t->dims()[1] : hyp_lod[num]; - auto ref_offset = use_length ? num * x2_t->dims()[1] : ref_lod[num]; - auto x1 = x1_t->data() + hyp_offset; - auto x2 = x2_t->data() + ref_offset; - for (int64_t i = 0; i < m + 1; ++i) { - dist[i * (n + 1)] = i; - } - for (int64_t j = 0; j < n + 1; ++j) { - dist[j] = j; - } - for (int64_t i = 1; i < m + 1; ++i) { - for (int64_t j = 1; j < n + 1; ++j) { - int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; - int dels = dist[(i - 1) * (n + 1) + j] + 1; - int ins = dist[i * (n + 1) + (j - 1)] + 1; - int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost; - dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs)); - } - } - distance = dist[m * (n + 1) + n]; - } - - if (normalized) { - PADDLE_ENFORCE_GT(n, - 0UL, - platform::errors::InvalidArgument( - "The reference string (#%d) cannot be empty " - "when Attr(normalized) is enabled.", - n)); - distance = distance / n; - } - out[num] = distance; - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 51be045ce4cf8efb7f641f7e2e6b1344a80e7f86..3bd11fa8cd19828599150ab8c4002bebd8a32aa2 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -654,6 +654,16 @@ optional : seed_tensor backward : dropout_grad +- api : edit_distance + args : (Tensor hyps, Tensor refs, Tensor hypslength, Tensor refslength, bool normalized = false) + output : Tensor(sequencenum), Tensor(out) + infer_meta : + func : EditDistanceInferMeta + kernel : + func : edit_distance + data_type: DataType::FLOAT32 + optional : hypslength, refslength + # eigh - api : eigh args : (Tensor x, str uplo) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index a524506c7f07b126d205eab93432024ae4978105..f89b13abc5214d190a982419e107be63793dc2e4 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1019,6 +1019,75 @@ void DeformableConvInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void EditDistanceInferMeta(const MetaTensor& hyps, + const MetaTensor& refs, + const MetaTensor& hypslength, + const MetaTensor& refslength, + bool normalized, + MetaTensor* sequencenum, + MetaTensor* out) { + auto hyp_dims = hyps.dims(); + auto ref_dims = refs.dims(); + + if (hypslength && refslength) { + auto hyp_length_dims = hypslength.dims(); + auto ref_length_dims = refslength.dims(); + + PADDLE_ENFORCE_EQ( + hyp_dims.size() == 2 && ref_dims.size() == 2 && + hyp_dims[0] == ref_dims[0], + true, + errors::InvalidArgument( + "Input(hyps) and Input(refs) must be 2-D Tensors with " + "identical first dimension. But received Input(Hyps): " + "input rank %u, input shape [%s]; received Input(Refs): " + "input rank %u, input shape [%s]", + hyp_dims.size(), + hyp_dims, + ref_dims.size(), + ref_dims)); + PADDLE_ENFORCE_EQ( + hyp_length_dims[0] == ref_length_dims[0] && + hyp_length_dims[0] == hyp_dims[0], + true, + errors::InvalidArgument( + "Input(hypslength), Input(refslength) and Input(hyps) " + "should have identical first dimension. But received " + "Input(hypslength): input rank %u, input shape [%s]; " + "received Input(refslength): input rank %u, input shape " + "[%s]; received Input(hyps): input rank %u, input shape " + "[%s].", + hyp_length_dims.size(), + hyp_length_dims, + ref_length_dims.size(), + ref_length_dims, + hyp_dims.size(), + hyp_dims)); + } else { + PADDLE_ENFORCE_EQ( + hyp_dims.size() == 2 && hyp_dims[1] == 1, + true, + errors::InvalidArgument( + "Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension " + "equal to 1. But received: input rank %u, input shape [%s].", + hyp_dims.size(), + hyp_dims)); + PADDLE_ENFORCE_EQ( + ref_dims.size() == 2 && ref_dims[1] == 1, + true, + errors::InvalidArgument( + "Input(Refs) must be a 2-D LoDTensor with the 2nd dimension " + "equal to 1. But received: input rank %u, input shape [%s].", + ref_dims.size(), + ref_dims)); + } + + out->set_dims(refs.dims()); + out->set_dtype(DataType::FLOAT32); + sequencenum->set_dims(phi::make_ddim({1})); + sequencenum->set_dtype(DataType::FLOAT32); +} + void HierarchicalSigmoidInferMeta(const MetaTensor& x, const MetaTensor& w, const MetaTensor& label, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 60342dc58f5c91543762fe210ab9dbb7fc79604c..98008a3ebd06738cfb43363379230c923473e46a 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -212,6 +212,14 @@ void DeformableConvInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void EditDistanceInferMeta(const MetaTensor& hyps, + const MetaTensor& refs, + const MetaTensor& hypslength, + const MetaTensor& refslength, + bool normalized, + MetaTensor* sequencenum, + MetaTensor* out); + void HierarchicalSigmoidInferMeta(const MetaTensor& x, const MetaTensor& w, const MetaTensor& label, diff --git a/paddle/phi/kernels/cpu/edit_distance_kernel.cc b/paddle/phi/kernels/cpu/edit_distance_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..735086ba0edcdc63e65166e0bad74eaa7cdaa42d --- /dev/null +++ b/paddle/phi/kernels/cpu/edit_distance_kernel.cc @@ -0,0 +1,124 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/edit_distance_kernel.h" + +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void EditDistanceKernel(const Context& ctx, + const DenseTensor& hyps, + const DenseTensor& refs, + const paddle::optional& hypslength, + const paddle::optional& refslength, + bool normalized, + DenseTensor* sequencenum, + DenseTensor* out) { + int64_t* seq_num_data = ctx.template Alloc(sequencenum); + auto batch_size = hyps.dims()[0]; + + paddle::framework::Vector hyp_lod(batch_size + 1); + paddle::framework::Vector ref_lod(batch_size + 1); + + bool use_length = hypslength.get_ptr() != nullptr; + + if (use_length) { + // build lod when using padding + auto hyp_length_ptr = hypslength.get_ptr()->data(); + auto ref_length_ptr = refslength.get_ptr()->data(); + + for (auto i = 0; i < batch_size; i++) { + hyp_lod[i + 1] = hyp_lod[i] + hyp_length_ptr[i]; + ref_lod[i + 1] = ref_lod[i] + ref_length_ptr[i]; + } + + } else { + hyp_lod = hyps.lod()[0]; + ref_lod = refs.lod()[0]; + } + + if (normalized) { + for (size_t i = 1; i < ref_lod.size(); ++i) { + PADDLE_ENFORCE_GT( + ref_lod[i], + ref_lod[i - 1], + errors::InvalidArgument("Reference string %d is empty.", i)); + } + } + auto num_strs = hyp_lod.size() - 1; + *seq_num_data = static_cast(num_strs); + + out->Resize({static_cast(num_strs), 1}); + ctx.template Alloc(out); + auto outdata = out->data(); + + T distance = 0.0; + for (size_t num = 0; num < num_strs; ++num) { + auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); + auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); + + if (m == 0) { + distance = n; + } else if (n == 0) { + distance = m; + } else { + DenseTensor dist_t; + dist_t.Resize({m + 1, n + 1}); + ctx.template Alloc(&dist_t); + auto dist = dist_t.data(); + auto hyp_offset = use_length ? num * hyps.dims()[1] : hyp_lod[num]; + auto ref_offset = use_length ? num * refs.dims()[1] : ref_lod[num]; + auto x1 = hyps.data() + hyp_offset; + auto x2 = refs.data() + ref_offset; + for (int64_t i = 0; i < m + 1; ++i) { + dist[i * (n + 1)] = i; + } + for (int64_t j = 0; j < n + 1; ++j) { + dist[j] = j; + } + for (int64_t i = 1; i < m + 1; ++i) { + for (int64_t j = 1; j < n + 1; ++j) { + int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; + int dels = dist[(i - 1) * (n + 1) + j] + 1; + int ins = dist[i * (n + 1) + (j - 1)] + 1; + int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost; + dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs)); + } + } + distance = dist[m * (n + 1) + n]; + } + + if (normalized) { + PADDLE_ENFORCE_GT( + n, + 0UL, + errors::InvalidArgument("The reference string (#%d) cannot be empty " + "when Attr(normalized) is enabled.", + n)); + distance = distance / n; + } + outdata[num] = distance; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + edit_distance, CPU, ALL_LAYOUT, phi::EditDistanceKernel, float) {} diff --git a/paddle/phi/kernels/edit_distance_kernel.h b/paddle/phi/kernels/edit_distance_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..40b3e5f0aa025dc7ed8fea3ecfe49cc30b54be6f --- /dev/null +++ b/paddle/phi/kernels/edit_distance_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void EditDistanceKernel(const Context& ctx, + const DenseTensor& hyps, + const DenseTensor& refs, + const paddle::optional& hypslength, + const paddle::optional& refslength, + bool normalized, + DenseTensor* sequencenum, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/edit_distance_kernel.cu b/paddle/phi/kernels/gpu/edit_distance_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..993b4771cc9584cff50e7f1e1a008e73d2012ccf --- /dev/null +++ b/paddle/phi/kernels/gpu/edit_distance_kernel.cu @@ -0,0 +1,186 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/edit_distance_kernel.h" + +#include +#include + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +using paddle::platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void FillFirstRow(T* dist, const int N) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < N + 1) { + dist[idx] = idx; + } +} + +template +__global__ void FillFirstColumn(T* dist, const int M, const int N) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < M + 1) { + dist[idx * (N + 1)] = idx; + } +} + +template +__global__ void Levenshtein(T* dist, + const int64_t* x1, + const int64_t* x2, + const int M, + const int N, + const int start) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = N; + int index = start + idx * offset; + int row = index / (N + 1); + int col = index % (N + 1); + if (row > 0 && col > 0 && row < M + 1 && col < N + 1) { + int cost = x1[row - 1] == x2[col - 1] ? 0 : 1; + int dels = dist[(row - 1) * (N + 1) + col] + 1; + int ins = dist[row * (N + 1) + col - 1] + 1; + int subs = dist[(row - 1) * (N + 1) + (col - 1)] + cost; + dist[index] = min(dels, min(ins, subs)); + } +} + +template +__global__ void SetOutput( + T* out, const T* dist, const int M, const int N, bool normalized) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx == 0) { + out[0] = normalized ? dist[M * (N + 1) + N] / N : dist[M * (N + 1) + N]; + } +} + +template +void EditDistanceKernel(const Context& ctx, + const DenseTensor& hyps, + const DenseTensor& refs, + const paddle::optional& hypslength, + const paddle::optional& refslength, + bool normalized, + DenseTensor* sequencenum, + DenseTensor* out) { + ctx.template Alloc(sequencenum); + auto batch_size = hyps.dims()[0]; + + auto stream = reinterpret_cast(ctx).stream(); + + paddle::framework::Vector hyp_lod(batch_size + 1); + paddle::framework::Vector ref_lod(batch_size + 1); + + bool use_length = hypslength.get_ptr() != nullptr; + + if (use_length) { + DenseTensor hyp_length_cpu; + DenseTensor ref_length_cpu; + phi::Copy( + ctx, *(hypslength.get_ptr()), phi::CPUPlace(), false, &hyp_length_cpu); + phi::Copy( + ctx, *(refslength.get_ptr()), phi::CPUPlace(), false, &ref_length_cpu); + + for (auto i = 0; i < batch_size; i++) { + hyp_lod[i + 1] = hyp_lod[i] + hyp_length_cpu.data()[i]; + ref_lod[i + 1] = ref_lod[i] + ref_length_cpu.data()[i]; + } + + } else { + hyp_lod = hyps.lod()[0]; + ref_lod = refs.lod()[0]; + } + + if (normalized) { + for (size_t i = 1; i < ref_lod.size(); ++i) { + PADDLE_ENFORCE_GT( + ref_lod[i], + ref_lod[i - 1], + errors::InvalidArgument("Reference string %d is empty.", i)); + } + } + + const size_t num_strs = hyp_lod.size() - 1; + phi::funcs::SetConstant set_constant; + set_constant(ctx, sequencenum, static_cast(num_strs)); + + out->Resize({static_cast(num_strs), 1}); + ctx.template Alloc(out); + auto out_data = out->data(); + + T distance = 0.0; + for (size_t num = 0; num < num_strs; num++) { + auto m = static_cast(hyp_lod[num + 1] - hyp_lod[num]); + auto n = static_cast(ref_lod[num + 1] - ref_lod[num]); + if (m == 0 || n == 0) { + distance = std::max(m, n); + if (normalized) { + distance = distance / n; + } + paddle::memory::Copy(ctx.GetPlace(), + out_data + num, + CPUPlace(), + &distance, + sizeof(T), + stream); + } else { + DenseTensor dist_t; + dist_t.Resize({m + 1, n + 1}); + ctx.template Alloc(&dist_t); + auto dist = dist_t.data(); + auto hyp_offset = use_length ? num * hyps.dims()[1] : hyp_lod[num]; + auto ref_offset = use_length ? num * refs.dims()[1] : ref_lod[num]; + auto x1 = hyps.data() + hyp_offset; + auto x2 = refs.data() + ref_offset; + + FillFirstColumn<<<1 + m / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>(dist, m, n); + + FillFirstRow<<<1 + n / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>(dist, n); + + // Compute the elements of distance matrix in the anti-diagonal diretion + for (int64_t slice = 2; slice < m + n + 1; ++slice) { + int z_m = slice < m + 1 ? 0 : slice - m; + int z_n = slice < n + 1 ? 0 : slice - n; + int size = slice - (z_m + z_n) + 1; // number of elments in the same + // anti-diagonal line to update + // the start index at which computes from + int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1; + Levenshtein<<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>(dist, x1, x2, m, n, start); + } + SetOutput<<<1, 1, 0, stream>>>(out_data + num, dist, m, n, normalized); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + edit_distance, GPU, ALL_LAYOUT, phi::EditDistanceKernel, float) {} diff --git a/paddle/phi/ops/compat/edit_distance_sig.cc b/paddle/phi/ops/compat/edit_distance_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..e558b354e1c2b6ec9cd639d2f7f9b43de0e81ece --- /dev/null +++ b/paddle/phi/ops/compat/edit_distance_sig.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature EditDistanceOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("edit_distance", + {"Hyps", "Refs", "HypsLength", "RefsLength"}, + {"normalized"}, + {"SequenceNum", "Out"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(edit_distance, phi::EditDistanceOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_edit_distance_op.py b/python/paddle/fluid/tests/unittests/test_edit_distance_op.py index 561a379b6fa62d3dccbe4985bd09e7379ad3f37b..54759b41cd7bda5a903e1dc0a9d9425d334a6a82 100644 --- a/python/paddle/fluid/tests/unittests/test_edit_distance_op.py +++ b/python/paddle/fluid/tests/unittests/test_edit_distance_op.py @@ -17,6 +17,22 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle + + +def python_edit_distance(input, + label, + input_length=None, + label_length=None, + normalized=True, + ignored_tokens=None): + return paddle.nn.functional.loss.edit_distance( + input, + label, + normalized=normalized, + ignored_tokens=ignored_tokens, + input_length=input_length, + label_length=label_length) def Levenshtein(hyp, ref): @@ -54,6 +70,7 @@ class TestEditDistanceOp(OpTest): def setUp(self): self.op_type = "edit_distance" + self.python_api = python_edit_distance normalized = False x1 = np.array([[12, 3, 5, 8, 2]]).astype("int64") x2 = np.array([[12, 4, 7, 8]]).astype("int64") @@ -83,7 +100,7 @@ class TestEditDistanceOp(OpTest): self.outputs = {'Out': distance, 'SequenceNum': sequence_num} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestEditDistanceOpNormalizedCase0(OpTest): @@ -96,6 +113,7 @@ class TestEditDistanceOpNormalizedCase0(OpTest): def setUp(self): self.op_type = "edit_distance" + self.python_api = python_edit_distance normalized = True self.x1 = np.array([[10, 3, 6, 5, 8, 2]]).astype("int64") self.x2 = np.array([[10, 4, 6, 7, 8]]).astype("int64") @@ -132,7 +150,7 @@ class TestEditDistanceOpNormalizedCase0(OpTest): self.post_config() def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestEditDistanceOpNormalizedCase1(TestEditDistanceOpNormalizedCase0): @@ -159,6 +177,7 @@ class TestEditDistanceOpNormalizedTensor(OpTest): def setUp(self): self.op_type = "edit_distance" + self.python_api = python_edit_distance normalized = True self.reset_config() @@ -184,8 +203,9 @@ class TestEditDistanceOpNormalizedTensor(OpTest): self.outputs = {'Out': distance, 'SequenceNum': sequence_num} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 9ebc5c03ef00bbd3d6ada9021d9a33d212a788e4..4e568a571edace6f9496832156bf331c295bc548 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -532,6 +532,10 @@ def edit_distance(input, attrs={"tokens": ignored_tokens}) label = erased_label + if in_dygraph_mode(): + return _C_ops.final_state_edit_distance(input, label, input_length, + label_length, normalized) + this_inputs = {"Hyps": [input], "Refs": [label]} if input_length is not None and label_length is not None: this_inputs['HypsLength'] = [input_length]