diff --git a/paddle/fluid/operators/unique_consecutive_op.cc b/paddle/fluid/operators/unique_consecutive_op.cc index 73f6918d52598fd4c0d082919a6115cb0d948731..0a36af362deb0d8f69973e766be61fe7f177573c 100644 --- a/paddle/fluid/operators/unique_consecutive_op.cc +++ b/paddle/fluid/operators/unique_consecutive_op.cc @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/unique_consecutive_op.h" - +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { @@ -118,11 +117,6 @@ namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(unique_consecutive, ops::UniqueConsecutiveOp, ops::UniqueConsecutiveOpMaker); -REGISTER_OP_CPU_KERNEL(unique_consecutive, - ops::UniqueConsecutiveKernel, - ops::UniqueConsecutiveKernel, - ops::UniqueConsecutiveKernel, - ops::UniqueConsecutiveKernel); REGISTER_OP_VERSION(unique_consecutive) .AddCheckpoint( R"ROC( diff --git a/paddle/fluid/operators/unique_consecutive_op.h b/paddle/fluid/operators/unique_consecutive_op.h deleted file mode 100644 index b0eadbd877de5453b4c69d50d1074e766285e8d2..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/unique_consecutive_op.h +++ /dev/null @@ -1,287 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include -#include -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/fluid/operators/transpose_op.h" -#include "paddle/fluid/operators/unique_op.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { -template -static void UniqueConsecutiveFlattendTensor( - const framework::ExecutionContext& context, - const framework::Tensor& in, - framework::Tensor* out, - bool return_inverse, - bool return_counts) { - const InT* in_data = in.data(); - std::vector out_vec(in.numel()); - std::vector inverse_vec(in.numel()); - std::vector counts_vec(in.numel()); - memcpy(out_vec.data(), in_data, in.numel() * sizeof(InT)); - InT* p = out_vec.data(); - int64_t last = 0; - IndexT* q = counts_vec.data(); - for (int64_t i = 0; i < in.numel(); i++) { - if (in_data[i] != *p) { - *(++p) = in_data[i]; - if (return_counts) { - *(q++) = i - last; - last = i; - } - } - if (return_inverse) { - inverse_vec[i] = p - out_vec.data(); - } - } - - int64_t output_size = p - out_vec.data() + 1; - if (return_counts) { - *q = in.numel() - last; - counts_vec.resize(output_size); - } - out_vec.resize(output_size); - - out->Resize(phi::make_ddim({output_size})); - auto* out_data = out->mutable_data(context.GetPlace()); - std::copy(out_vec.begin(), out_vec.end(), out_data); - - if (return_inverse) { - auto* inverse = context.Output("Index"); - inverse->Resize(phi::make_ddim({in.numel()})); - auto* inverse_data = inverse->mutable_data(context.GetPlace()); - std::copy(inverse_vec.begin(), inverse_vec.end(), inverse_data); - } - - if (return_counts) { - auto* count = context.Output("Counts"); - count->Resize(phi::make_ddim({out->numel()})); - auto* counts_data = count->mutable_data(context.GetPlace()); - std::copy(counts_vec.begin(), counts_vec.end(), counts_data); - } -} - -template -static ForwardIt UniqueConsecutiveDimImpl( - const framework::ExecutionContext& context, - ForwardIt first, - ForwardIt last, - const std::vector& sorted_indices_vec, - std::vector* inverse_vec, - std::vector* counts_vec) { - if (first == last) { - return last; - } - - (*inverse_vec)[sorted_indices_vec[0]] = 0; - (*counts_vec)[0] = 1; - - ForwardIt begin = first; - ForwardIt result = first; - - while (++first != last) { - int64_t idx_first = std::distance(begin, first); - int64_t idx_result = std::distance(begin, result); - if (!Equal(*result, *first)) { - if (++result != first) { - *result = std::move(*first); - } - idx_result += 1; - } - (*inverse_vec)[sorted_indices_vec[idx_first]] = idx_result; - (*counts_vec)[idx_result] += 1; - } - return ++result; -} - -template -static void UniqueConsecutiveDim(const framework::ExecutionContext& context, - const framework::Tensor& in, - framework::Tensor* out, - bool return_inverse, - bool return_counts, - int axis) { - // transpose tensor: eg. axis=1, [dim0, dim1, dim2] -> [dim1, dim0, dim2] - std::vector permute(in.dims().size()); - std::iota(permute.begin(), permute.end(), 0); - permute[axis] = 0; - permute[0] = axis; - std::vector in_trans_dims_vec(phi::vectorize(in.dims())); - in_trans_dims_vec[axis] = in.dims()[0]; - in_trans_dims_vec[0] = in.dims()[axis]; - framework::Tensor in_trans; - framework::DDim in_trans_dims = phi::make_ddim(in_trans_dims_vec); - in_trans.Resize(in_trans_dims); - in_trans.mutable_data(context.GetPlace()); - auto& dev_ctx = context.template device_context(); - TransCompute( - in.dims().size(), dev_ctx, in, &in_trans, permute); - // reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2] - framework::DDim in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1); - in_trans.Resize(in_trans_flat_dims); - - std::vector sorted_indices_vec(in_trans.dims()[0]); - std::iota(sorted_indices_vec.begin(), sorted_indices_vec.end(), 0); - int64_t col = in_trans.dims()[1]; - const InT* in_trans_data = in_trans.data(); - - // sort tensor according to indices - framework::Tensor input_sorted; - input_sorted.Resize(in_trans_dims); - input_sorted.mutable_data(context.GetPlace()); - InT* input_sorted_data = input_sorted.data(); - for (size_t i = 0; i < sorted_indices_vec.size(); ++i) { - memcpy(input_sorted_data + i * col, - in_trans_data + static_cast(sorted_indices_vec[i]) * col, - col * sizeof(InT)); - } - std::vector input_unbind = Unbind(input_sorted); - std::vector inverse_vec(sorted_indices_vec.size(), 0); - std::vector counts_vec(sorted_indices_vec.size(), 0); - auto last = - UniqueConsecutiveDimImpl::iterator, InT>( - context, - input_unbind.begin(), - input_unbind.end(), - sorted_indices_vec, - &inverse_vec, - &counts_vec); - input_unbind.erase(last, input_unbind.end()); - counts_vec.erase(counts_vec.begin() + input_unbind.size(), counts_vec.end()); - - math::ConcatFunctor concat_functor; - framework::Tensor out_trans; - std::vector out_trans_dims_vec = in_trans_dims_vec; - out_trans_dims_vec[0] = input_unbind.size(); - out_trans.Resize(phi::make_ddim(out_trans_dims_vec)); - out_trans.mutable_data(context.GetPlace()); - std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]); - out->Resize(phi::make_ddim(out_trans_dims_vec)); - out->mutable_data(context.GetPlace()); - concat_functor(dev_ctx, input_unbind, 0, &out_trans); - TransCompute( - out_trans.dims().size(), dev_ctx, out_trans, out, permute); - if (return_inverse) { - auto* inverse = context.Output("Index"); - framework::TensorFromVector(inverse_vec, context.device_context(), inverse); - } - if (return_counts) { - auto* count = context.Output("Counts"); - framework::TensorFromVector(counts_vec, context.device_context(), count); - } -} - -template -struct UniqueConsecutiveFlattendTensorFunctor { - const framework::ExecutionContext& ctx_; - const framework::Tensor& in_; - framework::Tensor* out_; - const bool return_inverse_; - const bool return_counts_; - - UniqueConsecutiveFlattendTensorFunctor( - const framework::ExecutionContext& context, - const framework::Tensor& in, - framework::Tensor* out, - bool return_inverse, - bool return_counts) - : ctx_(context), - in_(in), - out_(out), - return_inverse_(return_inverse), - return_counts_(return_counts) {} - - template - void apply() const { - UniqueConsecutiveFlattendTensor( - ctx_, in_, out_, return_inverse_, return_counts_); - } -}; - -template -struct UniqueConsecutiveDimFunctor { - const framework::ExecutionContext& ctx_; - const framework::Tensor& in_; - framework::Tensor* out_; - const int axis_; - const bool return_inverse_; - const bool return_counts_; - UniqueConsecutiveDimFunctor(const framework::ExecutionContext& context, - const framework::Tensor& in, - framework::Tensor* out, - const int axis, - bool return_inverse, - bool return_counts) - : ctx_(context), - in_(in), - out_(out), - axis_(axis), - return_inverse_(return_inverse), - return_counts_(return_counts) {} - - template - void apply() const { - UniqueConsecutiveDim( - ctx_, in_, out_, return_inverse_, return_counts_, axis_); - } -}; -template -class UniqueConsecutiveKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - auto data_type = static_cast( - context.Attr("dtype")); - if (data_type == framework::proto::VarType::INT32) { - PADDLE_ENFORCE_LE( - x->numel(), - INT_MAX, - platform::errors::InvalidArgument( - "The number of elements in Input(X) should be less than or " - "equal to INT_MAX, but received num is %d. Please set `dtype` to " - "int64.", - x->numel())); - } - std::vector axis_vec = context.Attr>("axis"); - bool return_inverse = context.Attr("return_inverse"); - bool return_counts = context.Attr("return_counts"); - - if (axis_vec.empty()) { - framework::VisitDataTypeTiny( - data_type, - UniqueConsecutiveFlattendTensorFunctor( - context, *x, out, return_inverse, return_counts)); - } else { - int axis = axis_vec[0]; - framework::VisitDataTypeTiny( - data_type, - UniqueConsecutiveDimFunctor( - context, *x, out, axis, return_inverse, return_counts)); - } - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/cpu/unique_consecutive_functor.h b/paddle/phi/kernels/cpu/unique_consecutive_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..85081e58069331ca9e14f2550afc5557707a8664 --- /dev/null +++ b/paddle/phi/kernels/cpu/unique_consecutive_functor.h @@ -0,0 +1,261 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/tensor_util.h" + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/unique_functor.h" + +namespace phi { + +template +static void UniqueConsecutiveFlattenedTensor(const Context& context, + const DenseTensor& in, + DenseTensor* out, + bool return_inverse, + bool return_counts, + DenseTensor* inverse, + DenseTensor* count) { + const InT* in_data = in.data(); + std::vector out_vec(in.numel()); + std::vector inverse_vec(in.numel()); + std::vector counts_vec(in.numel()); + memcpy(out_vec.data(), in_data, in.numel() * sizeof(InT)); + InT* p = out_vec.data(); + int64_t last = 0; + IndexT* q = counts_vec.data(); + for (int64_t i = 0; i < in.numel(); i++) { + if (in_data[i] != *p) { + *(++p) = in_data[i]; + if (return_counts) { + *(q++) = i - last; + last = i; + } + } + if (return_inverse) { + inverse_vec[i] = p - out_vec.data(); + } + } + + int64_t output_size = p - out_vec.data() + 1; + if (return_counts) { + *q = in.numel() - last; + counts_vec.resize(output_size); + } + out_vec.resize(output_size); + + out->Resize(phi::make_ddim({output_size})); + auto* out_data = context.template Alloc(out); + std::copy(out_vec.begin(), out_vec.end(), out_data); + + if (return_inverse) { + inverse->Resize(phi::make_ddim({in.numel()})); + auto* inverse_data = context.template Alloc(inverse); + std::copy(inverse_vec.begin(), inverse_vec.end(), inverse_data); + } + + if (return_counts) { + count->Resize(phi::make_ddim({out->numel()})); + auto* counts_data = context.template Alloc(count); + std::copy(counts_vec.begin(), counts_vec.end(), counts_data); + } +} + +template +struct UniqueConsecutiveFlattenedTensorFunctor { + const Context& ctx_; + const DenseTensor& in_; + DenseTensor* out_; + const bool return_inverse_; + const bool return_counts_; + DenseTensor* inverse_; + DenseTensor* count_; + + UniqueConsecutiveFlattenedTensorFunctor(const Context& context, + const DenseTensor& in, + DenseTensor* out, + bool return_inverse, + bool return_counts, + DenseTensor* inverse, + DenseTensor* count) + : ctx_(context), + in_(in), + out_(out), + return_inverse_(return_inverse), + return_counts_(return_counts), + inverse_(inverse), + count_(count) {} + + template + void apply() const { + UniqueConsecutiveFlattenedTensor( + ctx_, in_, out_, return_inverse_, return_counts_, inverse_, count_); + } +}; + +template +static ForwardIt UniqueConsecutiveDimImpl( + const Context& context, + ForwardIt first, + ForwardIt last, + const std::vector& sorted_indices_vec, + std::vector* inverse_vec, + std::vector* counts_vec) { + if (first == last) { + return last; + } + + (*inverse_vec)[sorted_indices_vec[0]] = 0; + (*counts_vec)[0] = 1; + + ForwardIt begin = first; + ForwardIt result = first; + + while (++first != last) { + int64_t idx_first = std::distance(begin, first); + int64_t idx_result = std::distance(begin, result); + if (!phi::funcs::Equal(*result, *first)) { + if (++result != first) { + *result = std::move(*first); + } + idx_result += 1; + } + (*inverse_vec)[sorted_indices_vec[idx_first]] = idx_result; + (*counts_vec)[idx_result] += 1; + } + return ++result; +} + +template +static void UniqueConsecutiveDim(const Context& context, + const DenseTensor& in, + DenseTensor* out, + bool return_inverse, + bool return_counts, + int axis, + DenseTensor* inverse, + DenseTensor* count) { + // transpose tensor: eg. axis=1, [dim0, dim1, dim2] -> [dim1, dim0, dim2] + std::vector permute(in.dims().size()); + std::iota(permute.begin(), permute.end(), 0); + permute[axis] = 0; + permute[0] = axis; + std::vector in_trans_dims_vec(phi::vectorize(in.dims())); + in_trans_dims_vec[axis] = in.dims()[0]; + in_trans_dims_vec[0] = in.dims()[axis]; + DenseTensor in_trans; + DDim in_trans_dims = phi::make_ddim(in_trans_dims_vec); + in_trans.Resize(in_trans_dims); + context.template Alloc(&in_trans); + phi::funcs::TransCompute( + in.dims().size(), context, in, &in_trans, permute); + // reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2] + DDim in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1); + in_trans.Resize(in_trans_flat_dims); + + std::vector sorted_indices_vec(in_trans.dims()[0]); + std::iota(sorted_indices_vec.begin(), sorted_indices_vec.end(), 0); + int64_t col = in_trans.dims()[1]; + const InT* in_trans_data = in_trans.data(); + + // sort tensor according to indices + DenseTensor input_sorted; + input_sorted.Resize(in_trans_dims); + context.template Alloc(&input_sorted); + InT* input_sorted_data = input_sorted.data(); + for (size_t i = 0; i < sorted_indices_vec.size(); ++i) { + memcpy(input_sorted_data + i * col, + in_trans_data + static_cast(sorted_indices_vec[i]) * col, + col * sizeof(InT)); + } + std::vector input_unbind = phi::funcs::Unbind(input_sorted); + std::vector inverse_vec(sorted_indices_vec.size(), 0); + std::vector counts_vec(sorted_indices_vec.size(), 0); + auto last = UniqueConsecutiveDimImpl::iterator, + InT>(context, + input_unbind.begin(), + input_unbind.end(), + sorted_indices_vec, + &inverse_vec, + &counts_vec); + input_unbind.erase(last, input_unbind.end()); + counts_vec.erase(counts_vec.begin() + input_unbind.size(), counts_vec.end()); + + phi::funcs::ConcatFunctor concat_functor; + DenseTensor out_trans; + std::vector out_trans_dims_vec = in_trans_dims_vec; + out_trans_dims_vec[0] = input_unbind.size(); + out_trans.Resize(phi::make_ddim(out_trans_dims_vec)); + context.template Alloc(&out_trans); + std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]); + out->Resize(phi::make_ddim(out_trans_dims_vec)); + context.template Alloc(out); + concat_functor(context, input_unbind, 0, &out_trans); + phi::funcs::TransCompute( + out_trans.dims().size(), context, out_trans, out, permute); + if (return_inverse) { + paddle::framework::TensorFromVector(inverse_vec, context, inverse); + } + if (return_counts) { + paddle::framework::TensorFromVector(counts_vec, context, count); + } +} + +template +struct UniqueConsecutiveDimFunctor { + const Context& ctx_; + const DenseTensor& in_; + DenseTensor* out_; + const int axis_; + const bool return_inverse_; + const bool return_counts_; + DenseTensor* inverse_; + DenseTensor* count_; + + UniqueConsecutiveDimFunctor(const Context& context, + const DenseTensor& in, + DenseTensor* out, + const int axis, + bool return_inverse, + bool return_counts, + DenseTensor* inverse, + DenseTensor* count) + : ctx_(context), + in_(in), + out_(out), + axis_(axis), + return_inverse_(return_inverse), + return_counts_(return_counts), + inverse_(inverse), + count_(count) {} + + template + void apply() const { + UniqueConsecutiveDim(ctx_, + in_, + out_, + return_inverse_, + return_counts_, + axis_, + inverse_, + count_); + } +}; + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/unique_consecutive_kernel.cc b/paddle/phi/kernels/cpu/unique_consecutive_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..86fe53b72c98595d251d35a21390de8a90799d0a --- /dev/null +++ b/paddle/phi/kernels/cpu/unique_consecutive_kernel.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/unique_consecutive_kernel.h" +#include "paddle/phi/kernels/cpu/unique_consecutive_functor.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/fluid/framework/data_type.h" + +namespace phi { + +template +void UniqueConsecutiveKernel(const Context& dev_ctx, + const DenseTensor& x, + bool return_inverse, + bool return_counts, + const std::vector& axis, + int dtype, + DenseTensor* out, + DenseTensor* index, + DenseTensor* counts) { + auto data_type = static_cast(dtype); + if (data_type == paddle::framework::proto::VarType::INT32) { + PADDLE_ENFORCE_LE( + x.numel(), + INT_MAX, + phi::errors::InvalidArgument( + "The number of elements in Input(X) should be less than or " + "equal to INT_MAX, but received num is %d. Please set `dtype` to " + "int64.", + x.numel())); + } + + if (axis.empty()) { + paddle::framework::VisitDataTypeTiny( + data_type, + UniqueConsecutiveFlattenedTensorFunctor( + dev_ctx, x, out, return_inverse, return_counts, index, counts)); + } else { + int valid_axis = axis[0]; + paddle::framework::VisitDataTypeTiny( + data_type, + UniqueConsecutiveDimFunctor(dev_ctx, + x, + out, + valid_axis, + return_inverse, + return_counts, + index, + counts)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(unique_consecutive, + CPU, + ALL_LAYOUT, + phi::UniqueConsecutiveKernel, + float, + double, + int32_t, + int64_t) {} diff --git a/paddle/fluid/operators/unique_consecutive_op.cu b/paddle/phi/kernels/gpu/unique_consecutive_functor.h similarity index 53% rename from paddle/fluid/operators/unique_consecutive_op.cu rename to paddle/phi/kernels/gpu/unique_consecutive_functor.h index b96499cdb20e8255ba628ace185eedde32ca8a2e..e603f695039c07d76da6b3673bd9b05a1a84a24b 100644 --- a/paddle/fluid/operators/unique_consecutive_op.cu +++ b/paddle/phi/kernels/gpu/unique_consecutive_functor.h @@ -1,16 +1,18 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once #include #include #include @@ -22,13 +24,204 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/tensor_util.h" // TensorToVector() -#include "paddle/fluid/operators/unique_consecutive_op.h" // TransComute() +#include "paddle/fluid/framework/tensor_util.h" + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/unique_functor.h" + +namespace phi { + +// The core logic of computing Unique Consecutive for a flattend Tensor +template +static void UniqueConsecutiveFlattenedCUDATensor(const Context& context, + const DenseTensor& in, + DenseTensor* out, + bool return_inverse, + bool return_counts, + equal_T equal, + not_equal_T not_equal, + int64_t num_input, + DenseTensor* inverse, + DenseTensor* counts) { + // 0. Preparation + DenseTensor in_hat; + phi::Copy(context, in, context.GetPlace(), false, &in_hat); + auto in_data_hat = context.template Alloc(&in_hat); + + DenseTensor sorted_indices; + sorted_indices.Resize(phi::make_ddim({num_input})); + auto sorted_indices_data = context.template Alloc(&sorted_indices); + thrust::sequence( + thrust::device, sorted_indices_data, sorted_indices_data + num_input); + // 1. Calculate op result: 'out' + DenseTensor range; + range.Resize(phi::make_ddim({num_input + 1})); + auto range_data_ptr = context.template Alloc(&range); + thrust::sequence( + thrust::device, range_data_ptr, range_data_ptr + num_input + 1); + phi::Copy(context, in_hat, context.GetPlace(), false, out); + int num_out; + auto out_data = context.template Alloc(out); + num_out = + thrust::unique_by_key( + thrust::device, out_data, out_data + num_input, range_data_ptr, equal) + .first - + out_data; + out->Resize(phi::make_ddim({num_out})); + + // 2. Calculate inverse index: 'inverse' + if (return_inverse) { + inverse->Resize(phi::make_ddim({num_input})); + auto inverse_data = context.template Alloc(inverse); + DenseTensor inv_loc; + inv_loc.Resize(phi::make_ddim({num_input})); + auto inv_loc_data_ptr = context.template Alloc(&inv_loc); + thrust::adjacent_difference(thrust::device, + in_data_hat, + in_data_hat + num_input, + inv_loc_data_ptr, + not_equal); + thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); + inv_loc_data_dev[0] = 0; // without device_ptr, segmentation fault + thrust::inclusive_scan(thrust::device, + inv_loc_data_ptr, + inv_loc_data_ptr + num_input, + inv_loc_data_ptr); + thrust::scatter(thrust::device, + inv_loc_data_ptr, + inv_loc_data_ptr + num_input, + sorted_indices_data, + inverse_data); + } + // 3. Calculate 'counts' + if (return_counts) { + counts->Resize(phi::make_ddim({num_out})); + auto count_data = context.template Alloc(counts); + // init 'count_data' as 0 + thrust::fill(thrust::device, count_data, count_data + num_out, 0); + thrust::device_ptr range_data_ptr_dev(range_data_ptr); + range_data_ptr_dev[num_out] = num_input; + thrust::adjacent_difference(thrust::device, + range_data_ptr + 1, + range_data_ptr + num_out + 1, + count_data); + } +} + +// functor for processing a flattend Tensor +template +struct UniqueConsecutiveFlattenedCUDAFunctor { + const Context& ctx_; + const DenseTensor& in_; + DenseTensor* out_; + const bool return_inverse_; + const bool return_counts_; + DenseTensor* inverse_; + DenseTensor* count_; + + UniqueConsecutiveFlattenedCUDAFunctor(const Context& context, + const DenseTensor& in, + DenseTensor* out, + bool return_inverse, + bool return_counts, + DenseTensor* inverse, + DenseTensor* count) + : ctx_(context), + in_(in), + out_(out), + return_inverse_(return_inverse), + return_counts_(return_counts), + inverse_(inverse), + count_(count) {} + + template + void apply() const { + UniqueConsecutiveFlattenedCUDATensor( + ctx_, + in_, + out_, + return_inverse_, + return_counts_, + thrust::equal_to(), + thrust::not_equal_to(), + in_.numel(), + inverse_, + count_); + } +}; -namespace paddle { -namespace operators { +// The logic of compute unique with axis required, it's a little different +// from above function +template +static void ComputeUniqueConsecutiveDims(const Context& context, + DenseTensor* sorted_indices, + IndexT* sorted_indices_data, + DenseTensor* out, + bool return_inverse, + bool return_counts, + equal_T equal, + not_equal_T not_equal, + int64_t row, + DenseTensor* inverse, + DenseTensor* counts) { + // 1. inverse indices: 'inverse' + inverse->Resize(phi::make_ddim({row})); + auto inverse_data = context.template Alloc(inverse); + DenseTensor inv_loc; + inv_loc.Resize(phi::make_ddim({row})); + auto inv_loc_data_ptr = context.template Alloc(&inv_loc); + thrust::adjacent_difference(thrust::device, + sorted_indices_data, + sorted_indices_data + row, + inv_loc_data_ptr, + not_equal); + thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); + inv_loc_data_dev[0] = 0; + thrust::inclusive_scan(thrust::device, + inv_loc_data_ptr, + inv_loc_data_ptr + row, + inv_loc_data_ptr); + thrust::scatter(thrust::device, + inv_loc_data_ptr, + inv_loc_data_ptr + row, + sorted_indices_data, + inverse_data); + + // 2. sorted indices + DenseTensor range; + range.Resize(phi::make_ddim({row + 1})); + auto range_data_ptr = context.template Alloc(&range); + thrust::sequence(thrust::device, range_data_ptr, range_data_ptr + row + 1); + int num_out; + num_out = thrust::unique_by_key(thrust::device, + sorted_indices_data, + sorted_indices_data + row, + range_data_ptr, + equal) + .first - + sorted_indices_data; + thrust::device_ptr range_data_ptr_dev(range_data_ptr); + range_data_ptr_dev[num_out] = row; + sorted_indices->Resize(phi::make_ddim({num_out})); -using Tensor = framework::Tensor; + // 3. counts: 'counts' + counts->Resize(phi::make_ddim({num_out})); + auto count_data = context.template Alloc(counts); + thrust::fill(thrust::device, count_data, count_data + row, 0); + thrust::adjacent_difference( + thrust::device, range_data_ptr + 1, range_data_ptr + row + 1, count_data); +} // Binary function 'equal_to' template @@ -73,11 +266,11 @@ struct BinaryNotEqual { }; // index_select() function for Tensor -template -void IndexSelect(const framework::ExecutionContext& context, - const Tensor& input, - const Tensor& index, - Tensor* output, +template +void IndexSelect(const Context& context, + const DenseTensor& input, + const DenseTensor& index, + DenseTensor* output, int dim) { auto input_dim = input.dims(); auto input_dim_size = input_dim.size(); @@ -100,17 +293,15 @@ void IndexSelect(const framework::ExecutionContext& context, std::vector input_vec; std::vector index_vec; - paddle::framework::TensorToVector( - input, context.device_context(), &input_vec); - paddle::framework::TensorToVector( - index, context.device_context(), &index_vec); + paddle::framework::TensorToVector(input, context, &input_vec); + paddle::framework::TensorToVector(index, context, &index_vec); std::vector out_vec(output->numel()); for (int i = 0; i < index_size; i++) { PADDLE_ENFORCE_GE( index_vec[i], 0, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Variable value (index) of OP(index_select) " "expected >= 0 and < %ld, but got %ld. Please check input " "value.", @@ -119,7 +310,7 @@ void IndexSelect(const framework::ExecutionContext& context, PADDLE_ENFORCE_LT( index_vec[i], input_dim[dim], - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Variable value (index) of OP(index_select) " "expected >= 0 and < %ld, but got %ld. Please check input " "value.", @@ -139,162 +330,21 @@ void IndexSelect(const framework::ExecutionContext& context, } } } - output->mutable_data(context.GetPlace()); - framework::TensorFromVector(out_vec, context.device_context(), output); + context.template Alloc(output); + paddle::framework::TensorFromVector(out_vec, context, output); output->Resize(output_dim); } -// The core logic of computing Unique Consecutive for a flattend Tensor -template -static void UniqueConsecutiveFlattendCUDATensor( - const framework::ExecutionContext& context, - const Tensor& in, - Tensor* out, - bool return_inverse, - bool return_counts, - equal_T equal, - not_equal_T not_equal, - int64_t num_input) { - // 0. Prepration - Tensor in_hat; - framework::TensorCopy(in, context.GetPlace(), &in_hat); - auto in_data_hat = in_hat.mutable_data(context.GetPlace()); - - Tensor sorted_indices; - sorted_indices.Resize(phi::make_ddim({num_input})); - auto sorted_indices_data = - sorted_indices.mutable_data(context.GetPlace()); - thrust::sequence( - thrust::device, sorted_indices_data, sorted_indices_data + num_input); - // 1. Calculate op result: 'out' - Tensor range; - range.Resize(phi::make_ddim({num_input + 1})); - auto range_data_ptr = range.mutable_data(context.GetPlace()); - thrust::sequence( - thrust::device, range_data_ptr, range_data_ptr + num_input + 1); - framework::TensorCopy(in_hat, context.GetPlace(), out); - int num_out; - auto out_data = out->mutable_data(context.GetPlace()); - num_out = - thrust::unique_by_key( - thrust::device, out_data, out_data + num_input, range_data_ptr, equal) - .first - - out_data; - out->Resize(phi::make_ddim({num_out})); - - // 2. Calculate inverse index: 'inverse' - if (return_inverse) { - Tensor* inverse = context.Output("Index"); - inverse->Resize(phi::make_ddim({num_input})); - auto inverse_data = inverse->mutable_data(context.GetPlace()); - Tensor inv_loc; - inv_loc.Resize(phi::make_ddim({num_input})); - auto inv_loc_data_ptr = inv_loc.mutable_data(context.GetPlace()); - thrust::adjacent_difference(thrust::device, - in_data_hat, - in_data_hat + num_input, - inv_loc_data_ptr, - not_equal); - thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); - inv_loc_data_dev[0] = 0; // without device_ptr, segmentation fault - thrust::inclusive_scan(thrust::device, - inv_loc_data_ptr, - inv_loc_data_ptr + num_input, - inv_loc_data_ptr); - thrust::scatter(thrust::device, - inv_loc_data_ptr, - inv_loc_data_ptr + num_input, - sorted_indices_data, - inverse_data); - } - // 3. Calculate 'counts' - if (return_counts) { - Tensor* counts = context.Output("Counts"); - counts->Resize(phi::make_ddim({num_out})); - auto count_data = counts->mutable_data(context.GetPlace()); - // init 'count_data' as 0 - thrust::fill(thrust::device, count_data, count_data + num_out, 0); - thrust::device_ptr range_data_ptr_dev(range_data_ptr); - range_data_ptr_dev[num_out] = num_input; - thrust::adjacent_difference(thrust::device, - range_data_ptr + 1, - range_data_ptr + num_out + 1, - count_data); - } -} - -// The logic of compute unique with axis required, it's a little different -// from above function -template -static void ComputeUniqueConsecutiveDims( - const framework::ExecutionContext& context, - Tensor* sorted_indices, - IndexT* sorted_indices_data, - Tensor* out, - bool return_inverse, - bool return_counts, - equal_T equal, - not_equal_T not_equal, - int64_t row) { - // 1. inverse indices: 'inverse' - Tensor* inverse = context.Output("Index"); - inverse->Resize(phi::make_ddim({row})); - auto inverse_data = inverse->mutable_data(context.GetPlace()); - Tensor inv_loc; - inv_loc.Resize(phi::make_ddim({row})); - auto inv_loc_data_ptr = inv_loc.mutable_data(context.GetPlace()); - thrust::adjacent_difference(thrust::device, - sorted_indices_data, - sorted_indices_data + row, - inv_loc_data_ptr, - not_equal); - thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); - inv_loc_data_dev[0] = 0; - thrust::inclusive_scan(thrust::device, - inv_loc_data_ptr, - inv_loc_data_ptr + row, - inv_loc_data_ptr); - thrust::scatter(thrust::device, - inv_loc_data_ptr, - inv_loc_data_ptr + row, - sorted_indices_data, - inverse_data); - - // 2. sorted indices - Tensor range; - range.Resize(phi::make_ddim({row + 1})); - auto range_data_ptr = range.mutable_data(context.GetPlace()); - thrust::sequence(thrust::device, range_data_ptr, range_data_ptr + row + 1); - int num_out; - num_out = thrust::unique_by_key(thrust::device, - sorted_indices_data, - sorted_indices_data + row, - range_data_ptr, - equal) - .first - - sorted_indices_data; - thrust::device_ptr range_data_ptr_dev(range_data_ptr); - range_data_ptr_dev[num_out] = row; - sorted_indices->Resize(phi::make_ddim({num_out})); - - // 3. counts: 'counts' - Tensor* counts = context.Output("Counts"); - counts->Resize(phi::make_ddim({num_out})); - auto count_data = counts->mutable_data(context.GetPlace()); - thrust::fill(thrust::device, count_data, count_data + row, 0); - thrust::adjacent_difference( - thrust::device, range_data_ptr + 1, range_data_ptr + row + 1, count_data); -} - // Calculate unique consecutive when 'axis' is set -template -static void UniqueConsecutiveDimsCUDATensor( - const framework::ExecutionContext& context, - const Tensor& in, - Tensor* out, - bool return_inverse, - bool return_counts, - int axis) { +template +static void UniqueConsecutiveDimsCUDATensor(const Context& context, + const DenseTensor& in, + DenseTensor* out, + bool return_inverse, + bool return_counts, + int axis, + DenseTensor* inverse, + DenseTensor* counts) { // 1. Transpose & reshape // Transpose tensor: eg. axis=1, [dim0, dim1, dim2] -> [dim1, dim0, dim2] std::vector permute(in.dims().size()); @@ -304,19 +354,18 @@ static void UniqueConsecutiveDimsCUDATensor( std::vector in_trans_dims_vec(phi::vectorize(in.dims())); in_trans_dims_vec[axis] = in.dims()[0]; in_trans_dims_vec[0] = in.dims()[axis]; - framework::Tensor in_trans; - framework::DDim in_trans_dims = phi::make_ddim(in_trans_dims_vec); + DenseTensor in_trans; + DDim in_trans_dims = phi::make_ddim(in_trans_dims_vec); in_trans.Resize(in_trans_dims); - in_trans.mutable_data(context.GetPlace()); - auto& dev_ctx = context.cuda_device_context(); - TransCompute(in.dims().size(), // num of dims - dev_ctx, // device - in, // original Tensor - &in_trans, // Tensor after reshape - permute); // index of axis + context.template Alloc(&in_trans); + phi::funcs::TransCompute(in.dims().size(), // num of dims + context, // device + in, // original Tensor + &in_trans, // Tensor after reshape + permute); // index of axis // Reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2] - framework::DDim in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1); + DDim in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1); in_trans.Resize(in_trans_flat_dims); // now 'in_trans' is 2D @@ -324,16 +373,15 @@ static void UniqueConsecutiveDimsCUDATensor( int64_t row = in_trans.dims()[0]; const InT* in_trans_data = in_trans.data(); - Tensor sorted_indices; + DenseTensor sorted_indices; sorted_indices.Resize(phi::make_ddim({row})); - auto sorted_indices_data = - sorted_indices.mutable_data(context.GetPlace()); + auto sorted_indices_data = context.template Alloc(&sorted_indices); // 2. Calculate 'inverse', 'counts' // Init index thrust::sequence( thrust::device, sorted_indices_data, sorted_indices_data + row); - ComputeUniqueConsecutiveDims( + ComputeUniqueConsecutiveDims( context, &sorted_indices, sorted_indices_data, @@ -342,143 +390,70 @@ static void UniqueConsecutiveDimsCUDATensor( return_counts, BinaryEqual(col, in_trans_data), BinaryNotEqual(col, in_trans_data), - row); + row, + inverse, + counts); // 3. Select indices and reshape back to get 'out' - Tensor out_trans; + DenseTensor out_trans; std::vector out_trans_dims_vec = in_trans_dims_vec; out_trans_dims_vec[0] = sorted_indices.numel(); out_trans.Resize(phi::make_ddim(out_trans_dims_vec)); - out_trans.mutable_data(context.GetPlace()); + context.template Alloc(&out_trans); - IndexSelect(context, in_trans, sorted_indices, &out_trans, 0); + IndexSelect( + context, in_trans, sorted_indices, &out_trans, 0); std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]); out->Resize(phi::make_ddim(out_trans_dims_vec)); - out->mutable_data(context.GetPlace()); - std::vector out_trans_unbind = Unbind(out_trans); - math::ConcatFunctor concat_functor; - concat_functor(dev_ctx, out_trans_unbind, 0, &out_trans); - TransCompute( - out_trans.dims().size(), dev_ctx, out_trans, out, permute); + context.template Alloc(out); + std::vector out_trans_unbind = phi::funcs::Unbind(out_trans); + phi::funcs::ConcatFunctor concat_functor; + concat_functor(context, out_trans_unbind, 0, &out_trans); + phi::funcs::TransCompute( + out_trans.dims().size(), context, out_trans, out, permute); } -// functor for processing a flattend Tensor -template -struct UniqueConsecutiveFlattendCUDAFunctor { - const framework::ExecutionContext& ctx_; - const Tensor& in_; - Tensor* out_; - const bool return_inverse_; - const bool return_counts_; - - UniqueConsecutiveFlattendCUDAFunctor( - const framework::ExecutionContext& context, - const Tensor& in, - Tensor* out, - bool return_inverse, - bool return_counts) - : ctx_(context), - in_(in), - out_(out), - return_inverse_(return_inverse), - return_counts_(return_counts) {} - - template - void apply() const { - UniqueConsecutiveFlattendCUDATensor( - ctx_, - in_, - out_, - return_inverse_, - return_counts_, - thrust::equal_to(), - thrust::not_equal_to(), - in_.numel()); - } -}; - // functor for processing a multi-dimentional Tensor -template +template struct UniqueConsecutiveDimsCUDAFunctor { - const framework::ExecutionContext& ctx_; - const Tensor& in_; - Tensor* out_; + const Context& ctx_; + const DenseTensor& in_; + DenseTensor* out_; const int axis_; const bool return_inverse_; const bool return_counts_; + DenseTensor* inverse_; + DenseTensor* count_; - UniqueConsecutiveDimsCUDAFunctor(const framework::ExecutionContext& context, - const Tensor& in, - Tensor* out, + UniqueConsecutiveDimsCUDAFunctor(const Context& context, + const DenseTensor& in, + DenseTensor* out, const int axis, bool return_inverse, - bool return_counts) + bool return_counts, + DenseTensor* inverse, + DenseTensor* count) : ctx_(context), in_(in), out_(out), axis_(axis), return_inverse_(return_inverse), - return_counts_(return_counts) {} + return_counts_(return_counts), + inverse_(inverse), + count_(count) {} template void apply() const { - UniqueConsecutiveDimsCUDATensor( - ctx_, in_, out_, return_inverse_, return_counts_, axis_); + UniqueConsecutiveDimsCUDATensor(ctx_, + in_, + out_, + return_inverse_, + return_counts_, + axis_, + inverse_, + count_); } }; -// Unique_Consecutive_op CUDA implementation. -template -class UniqueConsecutiveKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - auto data_type = static_cast( - context.Attr("dtype")); - if (data_type == framework::proto::VarType::INT32) { - PADDLE_ENFORCE_LE( - x->numel() + 1, - INT_MAX, - platform::errors::InvalidArgument( - "The number of elements in Input(X) should be less than or " - "equal to INT_MAX, but received num is %d. Please set `dtype` to " - "int64.", - x->numel())); - } - - std::vector axis_vec = context.Attr>("axis"); - bool return_inverse = context.Attr("return_inverse"); - bool return_counts = context.Attr("return_counts"); - - // if 'axis' is not required, flatten the Tensor. - if (axis_vec.empty()) { - framework::VisitDataTypeTiny( - data_type, - UniqueConsecutiveFlattendCUDAFunctor( - context, *x, out, return_inverse, return_counts)); - } else { - // 'axis' is required. - int axis = axis_vec[0]; - framework::VisitDataTypeTiny( - data_type, - UniqueConsecutiveDimsCUDAFunctor( - context, *x, out, axis, return_inverse, return_counts)); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL( - unique_consecutive, - ops::UniqueConsecutiveKernel, - ops::UniqueConsecutiveKernel, - ops::UniqueConsecutiveKernel, - ops::UniqueConsecutiveKernel); +} // namespace phi diff --git a/paddle/phi/kernels/gpu/unique_consecutive_kernel.cu b/paddle/phi/kernels/gpu/unique_consecutive_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..4ce91a0dd66b4bad79406e716d23e9dce460424f --- /dev/null +++ b/paddle/phi/kernels/gpu/unique_consecutive_kernel.cu @@ -0,0 +1,81 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/unique_consecutive_kernel.h" +#include "paddle/phi/kernels/gpu/unique_consecutive_functor.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/fluid/framework/data_type.h" + +namespace phi { + +template +void UniqueConsecutiveKernel(const Context& dev_ctx, + const DenseTensor& x, + bool return_inverse, + bool return_counts, + const std::vector& axis, + int dtype, + DenseTensor* out, + DenseTensor* index, + DenseTensor* counts) { + auto data_type = static_cast(dtype); + if (data_type == paddle::framework::proto::VarType::INT32) { + PADDLE_ENFORCE_LE( + x.numel() + 1, + INT_MAX, + phi::errors::InvalidArgument( + "The number of elements in Input(X) should be less than or " + "equal to INT_MAX, but received num is %d. Please set `dtype` to " + "int64.", + x.numel())); + } + + // if 'axis' is not required, flatten the Tensor. + if (axis.empty()) { + paddle::framework::VisitDataTypeTiny( + data_type, + UniqueConsecutiveFlattenedCUDAFunctor( + dev_ctx, x, out, return_inverse, return_counts, index, counts)); + } else { + // 'axis' is required. + int valid_axis = axis[0]; + paddle::framework::VisitDataTypeTiny( + data_type, + UniqueConsecutiveDimsCUDAFunctor(dev_ctx, + x, + out, + valid_axis, + return_inverse, + return_counts, + index, + counts)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(unique_consecutive, + GPU, + ALL_LAYOUT, + phi::UniqueConsecutiveKernel, + float, + double, + int32_t, + int64_t) {} diff --git a/paddle/phi/kernels/unique_consecutive_kernel.h b/paddle/phi/kernels/unique_consecutive_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ade35d4d49730e5813987cb3ddd6fba1f864d504 --- /dev/null +++ b/paddle/phi/kernels/unique_consecutive_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void UniqueConsecutiveKernel(const Context& dev_ctx, + const DenseTensor& x, + bool return_inverse, + bool return_counts, + const std::vector& axis, + int dtype, + DenseTensor* out, + DenseTensor* index, + DenseTensor* counts); + +} // namespace phi diff --git a/paddle/phi/ops/compat/unique_consecutive_sig.cc b/paddle/phi/ops/compat/unique_consecutive_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..f085858d8cb0d316791864ecdcd2bfd2d4937812 --- /dev/null +++ b/paddle/phi/ops/compat/unique_consecutive_sig.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature UniqueConsecutiveOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("unique_consecutive", + {"X"}, + {"return_inverse", "return_counts", "axis", "dtype"}, + {"Out", "Index", "Counts"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(unique_consecutive, + phi::UniqueConsecutiveOpArgumentMapping);