diff --git a/paddle/fluid/operators/unique_consecutive_op.cc b/paddle/fluid/operators/unique_consecutive_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..464660d80be0194a1f19c9e7de885e65acfd6f0b --- /dev/null +++ b/paddle/fluid/operators/unique_consecutive_op.cc @@ -0,0 +1,142 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/unique_consecutive_op.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace operators { + +class UniqueConsecutiveOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "unique_consecutive"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", + "unique_consecutive"); + + auto in_dims = ctx->GetInputDim("X"); + bool return_inverse = ctx->Attrs().Get("return_inverse"); + bool return_counts = ctx->Attrs().Get("return_counts"); + auto axis_vec = ctx->Attrs().Get>("axis"); + if (return_inverse) { + OP_INOUT_CHECK(ctx->HasOutput("Index"), "Output", "Index", + "unique_consecutive"); + } + if (return_counts) { + OP_INOUT_CHECK(ctx->HasOutput("Counts"), "Output", "Counts", + "unique_consecutive"); + } + + if (axis_vec.empty()) { + ctx->SetOutputDim("Out", {-1}); + if (return_inverse) { + ctx->SetOutputDim("Index", {framework::product(in_dims)}); + } + } else { + int axis = axis_vec[0]; + if (axis < 0) { + axis += in_dims.size(); + } + PADDLE_ENFORCE_LT( + axis, in_dims.size(), + platform::errors::InvalidArgument("The axis(%d) should be less than " + "the dimension size(%d) of x.", + axis, in_dims.size())); + auto out_dims = in_dims; + out_dims[axis] = -1; + ctx->SetOutputDim("Out", out_dims); + if (return_inverse) { + ctx->SetOutputDim("Index", {in_dims[axis]}); + } + } + if (return_counts) { + ctx->SetOutputDim("Counts", {-1}); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class UniqueConsecutiveOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor of unique_consecutive op."); + AddAttr("dtype", + "(int, default 5(FP32)) " + "data type for output index") + .SetDefault(framework::proto::VarType::FP32); + + AddOutput("Out", "A unique consecutive subsequence for input tensor."); + AddOutput("Index", + "The indices for where elements in the original input ended up " + "in the returned unique tensor.") + .AsDispensable(); + AddOutput("Counts", "The counts for each unique element.").AsDispensable(); + AddAttr( + "return_inverse", + "If True, also return the indices for where elements" + " in the original input ended up in the returned unique tensor.") + .SetDefault(false); + AddAttr("return_counts", + "If True, also return the counts for each unique element.") + .SetDefault(false); + AddAttr>( + "axis", + "The axis to apply unique. If None, the input will be flattened.") + .SetDefault({}); + AddComment(R"DOC( + This function is different from paddle.unique() in the sense that this + function only eliminates consecutive duplicate values. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(unique_consecutive, ops::UniqueConsecutiveOp, + ops::UniqueConsecutiveOpMaker); +REGISTER_OP_CPU_KERNEL( + unique_consecutive, + ops::UniqueConsecutiveKernel, + ops::UniqueConsecutiveKernel, + ops::UniqueConsecutiveKernel, + ops::UniqueConsecutiveKernel); +REGISTER_OP_VERSION(unique_consecutive) + .AddCheckpoint( + R"ROC( + Upgrade unique_consecutive, add 2 outputs [Indices, Counts] and 3 attribute + [return_inverse, return_counts, axis]. + )ROC", + paddle::framework::compatible::OpVersionDesc() + .NewOutput("Counts", "The counts for each unique element.") + .NewAttr("return_inverse", + "If True, also return the indices for where elements" + " in the original input ended up in the returned unique " + "tensor.", + false) + .NewAttr("return_counts", + "If True, also return the counts for each unique element.", + false) + .NewAttr("axis", + "The axis to apply unique. If None, the input will be " + "flattened.", + std::vector{})); diff --git a/paddle/fluid/operators/unique_consecutive_op.cu b/paddle/fluid/operators/unique_consecutive_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..1f0023c467c01cdd3456153ec895ebde0f8c8728 --- /dev/null +++ b/paddle/fluid/operators/unique_consecutive_op.cu @@ -0,0 +1,424 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/tensor_util.h" // TensorToVector() +#include "paddle/fluid/operators/unique_consecutive_op.h" // TransComute() + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +// Binary function 'equal_to' +template +struct BinaryEqual { + int64_t col; + const InT* in_trans_data; + + BinaryEqual(int64_t _col, const InT* _in_trans_data) + : col(_col), in_trans_data(_in_trans_data) {} + + __device__ bool operator()(int64_t a, int64_t b) const { + for (int64_t i = 0; i < col; ++i) { + InT lhs = in_trans_data[i + a * col]; + InT rhs = in_trans_data[i + b * col]; + if (lhs != rhs) { + return false; + } + } + return true; + } +}; + +// Binary function 'not_equal_to' +template +struct BinaryNotEqual { + int64_t col; + const InT* in_trans_data; + + BinaryNotEqual(int64_t _col, const InT* _in_trans_data) + : col(_col), in_trans_data(_in_trans_data) {} + + __device__ bool operator()(int64_t a, int64_t b) const { + for (int64_t i = 0; i < col; ++i) { + InT lhs = in_trans_data[i + a * col]; + InT rhs = in_trans_data[i + b * col]; + if (lhs != rhs) { + return true; + } + } + return false; + } +}; + +// index_select() function for Tensor +template +void IndexSelect(const framework::ExecutionContext& context, + const Tensor& input, const Tensor& index, Tensor* output, + int dim) { + auto input_dim = input.dims(); + auto input_dim_size = input_dim.size(); + auto output_dim = output->dims(); + + auto slice_size = 1; + for (auto i = dim + 1; i < input_dim_size; i++) { + slice_size *= input_dim[i]; + } + + auto input_width = slice_size * input_dim[dim]; + auto output_width = slice_size * output_dim[dim]; + + auto outer_nums = 1; + for (auto i = 0; i < dim; i++) { + outer_nums *= input_dim[i]; + } + + auto index_size = index.dims()[0]; + + std::vector input_vec; + std::vector index_vec; + TensorToVector(input, context.device_context(), &input_vec); + TensorToVector(index, context.device_context(), &index_vec); + std::vector out_vec(output->numel()); + + for (int i = 0; i < index_size; i++) { + PADDLE_ENFORCE_GE( + index_vec[i], 0, + platform::errors::InvalidArgument( + "Variable value (index) of OP(index_select) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + input_dim[dim], index_vec[i])); + PADDLE_ENFORCE_LT( + index_vec[i], input_dim[dim], + platform::errors::InvalidArgument( + "Variable value (index) of OP(index_select) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + input_dim[dim], index_vec[i])); + } + + for (auto i = 0; i < outer_nums; i++) { + auto input_start_offset = i * input_width; + auto output_start_offset = i * output_width; + + for (auto j = 0; j < index_size; j++) { + IndexT index_value = index_vec[j]; + for (auto k = 0; k < slice_size; k++) { + out_vec[output_start_offset + j * slice_size + k] = + input_vec[input_start_offset + index_value * slice_size + k]; + } + } + } + output->mutable_data(context.GetPlace()); + framework::TensorFromVector(out_vec, context.device_context(), output); + output->Resize(output_dim); +} + +// The core logic of computing Unique Consecutive for a flattend Tensor +template +static void UniqueConsecutiveFlattendCUDATensor( + const framework::ExecutionContext& context, const Tensor& in, Tensor* out, + bool return_inverse, bool return_counts, equal_T equal, + not_equal_T not_equal, int64_t num_input) { + // 0. Prepration + Tensor in_hat; + framework::TensorCopy(in, context.GetPlace(), &in_hat); + auto in_data_hat = in_hat.mutable_data(context.GetPlace()); + + Tensor sorted_indices; + sorted_indices.Resize(framework::make_ddim({num_input})); + auto sorted_indices_data = + sorted_indices.mutable_data(context.GetPlace()); + thrust::sequence(thrust::device, sorted_indices_data, + sorted_indices_data + num_input); + // 1. Calculate op result: 'out' + Tensor range; + range.Resize(framework::make_ddim({num_input + 1})); + auto range_data_ptr = range.mutable_data(context.GetPlace()); + thrust::sequence(thrust::device, range_data_ptr, + range_data_ptr + num_input + 1); + framework::TensorCopy(in_hat, context.GetPlace(), out); + int num_out; + auto out_data = out->mutable_data(context.GetPlace()); + num_out = thrust::unique_by_key(thrust::device, out_data, + out_data + num_input, range_data_ptr, equal) + .first - + out_data; + out->Resize(framework::make_ddim({num_out})); + + // 2. Calculate inverse index: 'inverse' + if (return_inverse) { + Tensor* inverse = context.Output("Index"); + inverse->Resize(framework::make_ddim({num_input})); + auto inverse_data = inverse->mutable_data(context.GetPlace()); + Tensor inv_loc; + inv_loc.Resize(framework::make_ddim({num_input})); + auto inv_loc_data_ptr = inv_loc.mutable_data(context.GetPlace()); + thrust::adjacent_difference(thrust::device, in_data_hat, + in_data_hat + num_input, inv_loc_data_ptr, + not_equal); + thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); + inv_loc_data_dev[0] = 0; // without device_ptr, segmentation fault + thrust::inclusive_scan(thrust::device, inv_loc_data_ptr, + inv_loc_data_ptr + num_input, inv_loc_data_ptr); + thrust::scatter(thrust::device, inv_loc_data_ptr, + inv_loc_data_ptr + num_input, sorted_indices_data, + inverse_data); + } + // 3. Calculate 'counts' + if (return_counts) { + Tensor* counts = context.Output("Counts"); + counts->Resize(framework::make_ddim({num_out})); + auto count_data = counts->mutable_data(context.GetPlace()); + // init 'count_data' as 0 + thrust::fill(thrust::device, count_data, count_data + num_out, 0); + thrust::device_ptr range_data_ptr_dev(range_data_ptr); + range_data_ptr_dev[num_out] = num_input; + thrust::adjacent_difference(thrust::device, range_data_ptr + 1, + range_data_ptr + num_out + 1, count_data); + } +} + +// The logic of compute unique with axis required, it's a little different +// from above function +template +static void ComputeUniqueConsecutiveDims( + const framework::ExecutionContext& context, Tensor* sorted_indices, + IndexT* sorted_indices_data, Tensor* out, bool return_inverse, + bool return_counts, equal_T equal, not_equal_T not_equal, int64_t row) { + // 1. inverse indices: 'inverse' + Tensor* inverse = context.Output("Index"); + inverse->Resize(framework::make_ddim({row})); + auto inverse_data = inverse->mutable_data(context.GetPlace()); + Tensor inv_loc; + inv_loc.Resize(framework::make_ddim({row})); + auto inv_loc_data_ptr = inv_loc.mutable_data(context.GetPlace()); + thrust::adjacent_difference(thrust::device, sorted_indices_data, + sorted_indices_data + row, inv_loc_data_ptr, + not_equal); + thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); + inv_loc_data_dev[0] = 0; + thrust::inclusive_scan(thrust::device, inv_loc_data_ptr, + inv_loc_data_ptr + row, inv_loc_data_ptr); + thrust::scatter(thrust::device, inv_loc_data_ptr, inv_loc_data_ptr + row, + sorted_indices_data, inverse_data); + + // 2. sorted indices + Tensor range; + range.Resize(framework::make_ddim({row + 1})); + auto range_data_ptr = range.mutable_data(context.GetPlace()); + thrust::sequence(thrust::device, range_data_ptr, range_data_ptr + row + 1); + int num_out; + num_out = + thrust::unique_by_key(thrust::device, sorted_indices_data, + sorted_indices_data + row, range_data_ptr, equal) + .first - + sorted_indices_data; + thrust::device_ptr range_data_ptr_dev(range_data_ptr); + range_data_ptr_dev[num_out] = row; + sorted_indices->Resize(framework::make_ddim({num_out})); + + // 3. counts: 'counts' + Tensor* counts = context.Output("Counts"); + counts->Resize(framework::make_ddim({num_out})); + auto count_data = counts->mutable_data(context.GetPlace()); + thrust::fill(thrust::device, count_data, count_data + row, 0); + thrust::adjacent_difference(thrust::device, range_data_ptr + 1, + range_data_ptr + row + 1, count_data); +} + +// Calculate unique consecutive when 'axis' is set +template +static void UniqueConsecutiveDimsCUDATensor( + const framework::ExecutionContext& context, const Tensor& in, Tensor* out, + bool return_inverse, bool return_counts, int axis) { + // 1. Transpose & reshape + // Transpose tensor: eg. axis=1, [dim0, dim1, dim2] -> [dim1, dim0, dim2] + std::vector permute(in.dims().size()); + std::iota(permute.begin(), permute.end(), 0); + permute[axis] = 0; + permute[0] = axis; + std::vector in_trans_dims_vec(framework::vectorize(in.dims())); + in_trans_dims_vec[axis] = in.dims()[0]; + in_trans_dims_vec[0] = in.dims()[axis]; + framework::Tensor in_trans; + framework::DDim in_trans_dims = framework::make_ddim(in_trans_dims_vec); + in_trans.Resize(in_trans_dims); + in_trans.mutable_data(context.GetPlace()); + auto& dev_ctx = context.cuda_device_context(); + TransCompute(in.dims().size(), // num of dims + dev_ctx, // device + in, // original Tensor + &in_trans, // Tensor after reshape + permute); // index of axis + + // Reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2] + framework::DDim in_trans_flat_dims = + framework::flatten_to_2d(in_trans_dims, 1); + in_trans.Resize(in_trans_flat_dims); + + // now 'in_trans' is 2D + int64_t col = in_trans.dims()[1]; + int64_t row = in_trans.dims()[0]; + const InT* in_trans_data = in_trans.data(); + + Tensor sorted_indices; + sorted_indices.Resize(framework::make_ddim({row})); + auto sorted_indices_data = + sorted_indices.mutable_data(context.GetPlace()); + + // 2. Calculate 'inverse', 'counts' + // Init index + thrust::sequence(thrust::device, sorted_indices_data, + sorted_indices_data + row); + ComputeUniqueConsecutiveDims( + context, &sorted_indices, sorted_indices_data, out, return_inverse, + return_counts, BinaryEqual(col, in_trans_data), + BinaryNotEqual(col, in_trans_data), row); + + // 3. Select indices and reshape back to get 'out' + Tensor out_trans; + std::vector out_trans_dims_vec = in_trans_dims_vec; + out_trans_dims_vec[0] = sorted_indices.numel(); + out_trans.Resize(framework::make_ddim(out_trans_dims_vec)); + out_trans.mutable_data(context.GetPlace()); + + IndexSelect(context, in_trans, sorted_indices, &out_trans, 0); + + std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]); + out->Resize(framework::make_ddim(out_trans_dims_vec)); + out->mutable_data(context.GetPlace()); + std::vector out_trans_unbind = Unbind(out_trans); + math::ConcatFunctor concat_functor; + concat_functor(dev_ctx, out_trans_unbind, 0, &out_trans); + TransCompute(out_trans.dims().size(), dev_ctx, out_trans, + out, permute); +} + +// functor for processing a flattend Tensor +template +struct UniqueConsecutiveFlattendCUDAFunctor { + const framework::ExecutionContext& ctx_; + const Tensor& in_; + Tensor* out_; + const bool return_inverse_; + const bool return_counts_; + + UniqueConsecutiveFlattendCUDAFunctor( + const framework::ExecutionContext& context, const Tensor& in, Tensor* out, + bool return_inverse, bool return_counts) + : ctx_(context), + in_(in), + out_(out), + return_inverse_(return_inverse), + return_counts_(return_counts) {} + + template + void apply() const { + UniqueConsecutiveFlattendCUDATensor( + ctx_, in_, out_, return_inverse_, return_counts_, + thrust::equal_to(), thrust::not_equal_to(), in_.numel()); + } +}; + +// functor for processing a multi-dimentional Tensor +template +struct UniqueConsecutiveDimsCUDAFunctor { + const framework::ExecutionContext& ctx_; + const Tensor& in_; + Tensor* out_; + const int axis_; + const bool return_inverse_; + const bool return_counts_; + + UniqueConsecutiveDimsCUDAFunctor(const framework::ExecutionContext& context, + const Tensor& in, Tensor* out, + const int axis, bool return_inverse, + bool return_counts) + : ctx_(context), + in_(in), + out_(out), + axis_(axis), + return_inverse_(return_inverse), + return_counts_(return_counts) {} + + template + void apply() const { + UniqueConsecutiveDimsCUDATensor( + ctx_, in_, out_, return_inverse_, return_counts_, axis_); + } +}; + +// Unique_Consecutive_op CUDA implementation. +template +class UniqueConsecutiveKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + auto data_type = static_cast( + context.Attr("dtype")); + if (data_type == framework::proto::VarType::INT32) { + PADDLE_ENFORCE_LE( + x->numel() + 1, INT_MAX, + platform::errors::InvalidArgument( + "The number of elements in Input(X) should be less than or " + "equal to INT_MAX, but received num is %d. Please set `dtype` to " + "int64.", + x->numel())); + } + + std::vector axis_vec = context.Attr>("axis"); + bool return_inverse = context.Attr("return_inverse"); + bool return_counts = context.Attr("return_counts"); + + // if 'axis' is not required, flatten the Tensor. + if (axis_vec.empty()) { + framework::VisitDataTypeTiny( + data_type, + UniqueConsecutiveFlattendCUDAFunctor( + context, *x, out, return_inverse, return_counts)); + } else { + // 'axis' is required. + int axis = axis_vec[0]; + framework::VisitDataTypeTiny( + data_type, + UniqueConsecutiveDimsCUDAFunctor( + context, *x, out, axis, return_inverse, return_counts)); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + unique_consecutive, + ops::UniqueConsecutiveKernel, + ops::UniqueConsecutiveKernel, + ops::UniqueConsecutiveKernel, + ops::UniqueConsecutiveKernel); diff --git a/paddle/fluid/operators/unique_consecutive_op.h b/paddle/fluid/operators/unique_consecutive_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e6cb5dafe343308d58142f8e67fa3c42318fca48 --- /dev/null +++ b/paddle/fluid/operators/unique_consecutive_op.h @@ -0,0 +1,268 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/concat_and_split.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/transpose_op.h" +#include "paddle/fluid/operators/unique_op.h" + +namespace paddle { +namespace operators { +template +static void UniqueConsecutiveFlattendTensor( + const framework::ExecutionContext& context, const framework::Tensor& in, + framework::Tensor* out, bool return_inverse, bool return_counts) { + const InT* in_data = in.data(); + std::vector out_vec(in.numel()); + std::vector inverse_vec(in.numel()); + std::vector counts_vec(in.numel()); + memcpy(out_vec.data(), in_data, in.numel() * sizeof(InT)); + InT* p = out_vec.data(); + int64_t last = 0; + IndexT* q = counts_vec.data(); + for (int64_t i = 0; i < in.numel(); i++) { + if (in_data[i] != *p) { + *(++p) = in_data[i]; + if (return_counts) { + *(q++) = i - last; + last = i; + } + } + if (return_inverse) { + inverse_vec[i] = p - out_vec.data(); + } + } + + int64_t output_size = p - out_vec.data() + 1; + if (return_counts) { + *q = in.numel() - last; + counts_vec.resize(output_size); + } + out_vec.resize(output_size); + + out->Resize(framework::make_ddim({output_size})); + auto* out_data = out->mutable_data(context.GetPlace()); + std::copy(out_vec.begin(), out_vec.end(), out_data); + + if (return_inverse) { + auto* inverse = context.Output("Index"); + inverse->Resize(framework::make_ddim({in.numel()})); + auto* inverse_data = inverse->mutable_data(context.GetPlace()); + std::copy(inverse_vec.begin(), inverse_vec.end(), inverse_data); + } + + if (return_counts) { + auto* count = context.Output("Counts"); + count->Resize(framework::make_ddim({out->numel()})); + auto* counts_data = count->mutable_data(context.GetPlace()); + std::copy(counts_vec.begin(), counts_vec.end(), counts_data); + } +} + +template +static ForwardIt UniqueConsecutiveDimImpl( + const framework::ExecutionContext& context, ForwardIt first, ForwardIt last, + const std::vector& sorted_indices_vec, + std::vector* inverse_vec, std::vector* counts_vec) { + if (first == last) { + return last; + } + + (*inverse_vec)[sorted_indices_vec[0]] = 0; + (*counts_vec)[0] = 1; + + ForwardIt begin = first; + ForwardIt result = first; + + while (++first != last) { + int64_t idx_first = std::distance(begin, first); + int64_t idx_result = std::distance(begin, result); + if (!Equal(*result, *first)) { + if (++result != first) { + *result = std::move(*first); + } + idx_result += 1; + } + (*inverse_vec)[sorted_indices_vec[idx_first]] = idx_result; + (*counts_vec)[idx_result] += 1; + } + return ++result; +} + +template +static void UniqueConsecutiveDim(const framework::ExecutionContext& context, + const framework::Tensor& in, + framework::Tensor* out, bool return_inverse, + bool return_counts, int axis) { + // transpose tensor: eg. axis=1, [dim0, dim1, dim2] -> [dim1, dim0, dim2] + std::vector permute(in.dims().size()); + std::iota(permute.begin(), permute.end(), 0); + permute[axis] = 0; + permute[0] = axis; + std::vector in_trans_dims_vec(framework::vectorize(in.dims())); + in_trans_dims_vec[axis] = in.dims()[0]; + in_trans_dims_vec[0] = in.dims()[axis]; + framework::Tensor in_trans; + framework::DDim in_trans_dims = framework::make_ddim(in_trans_dims_vec); + in_trans.Resize(in_trans_dims); + in_trans.mutable_data(context.GetPlace()); + auto& dev_ctx = context.template device_context(); + TransCompute(in.dims().size(), dev_ctx, in, &in_trans, + permute); + // reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2] + framework::DDim in_trans_flat_dims = + framework::flatten_to_2d(in_trans_dims, 1); + in_trans.Resize(in_trans_flat_dims); + + std::vector sorted_indices_vec(in_trans.dims()[0]); + std::iota(sorted_indices_vec.begin(), sorted_indices_vec.end(), 0); + int64_t col = in_trans.dims()[1]; + const InT* in_trans_data = in_trans.data(); + + // sort tensor according to indices + framework::Tensor input_sorted; + input_sorted.Resize(in_trans_dims); + input_sorted.mutable_data(context.GetPlace()); + InT* input_sorted_data = input_sorted.data(); + for (size_t i = 0; i < sorted_indices_vec.size(); ++i) { + memcpy(input_sorted_data + i * col, + in_trans_data + static_cast(sorted_indices_vec[i]) * col, + col * sizeof(InT)); + } + std::vector input_unbind = Unbind(input_sorted); + std::vector inverse_vec(sorted_indices_vec.size(), 0); + std::vector counts_vec(sorted_indices_vec.size(), 0); + auto last = + UniqueConsecutiveDimImpl::iterator, InT>( + context, input_unbind.begin(), input_unbind.end(), sorted_indices_vec, + &inverse_vec, &counts_vec); + input_unbind.erase(last, input_unbind.end()); + counts_vec.erase(counts_vec.begin() + input_unbind.size(), counts_vec.end()); + + math::ConcatFunctor concat_functor; + framework::Tensor out_trans; + std::vector out_trans_dims_vec = in_trans_dims_vec; + out_trans_dims_vec[0] = input_unbind.size(); + out_trans.Resize(framework::make_ddim(out_trans_dims_vec)); + out_trans.mutable_data(context.GetPlace()); + std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]); + out->Resize(framework::make_ddim(out_trans_dims_vec)); + out->mutable_data(context.GetPlace()); + concat_functor(dev_ctx, input_unbind, 0, &out_trans); + TransCompute(out_trans.dims().size(), dev_ctx, out_trans, + out, permute); + if (return_inverse) { + auto* inverse = context.Output("Index"); + framework::TensorFromVector(inverse_vec, context.device_context(), inverse); + } + if (return_counts) { + auto* count = context.Output("Counts"); + framework::TensorFromVector(counts_vec, context.device_context(), count); + } +} + +template +struct UniqueConsecutiveFlattendTensorFunctor { + const framework::ExecutionContext& ctx_; + const framework::Tensor& in_; + framework::Tensor* out_; + const bool return_inverse_; + const bool return_counts_; + + UniqueConsecutiveFlattendTensorFunctor( + const framework::ExecutionContext& context, const framework::Tensor& in, + framework::Tensor* out, bool return_inverse, bool return_counts) + : ctx_(context), + in_(in), + out_(out), + return_inverse_(return_inverse), + return_counts_(return_counts) {} + + template + void apply() const { + UniqueConsecutiveFlattendTensor( + ctx_, in_, out_, return_inverse_, return_counts_); + } +}; + +template +struct UniqueConsecutiveDimFunctor { + const framework::ExecutionContext& ctx_; + const framework::Tensor& in_; + framework::Tensor* out_; + const int axis_; + const bool return_inverse_; + const bool return_counts_; + UniqueConsecutiveDimFunctor(const framework::ExecutionContext& context, + const framework::Tensor& in, + framework::Tensor* out, const int axis, + bool return_inverse, bool return_counts) + : ctx_(context), + in_(in), + out_(out), + axis_(axis), + return_inverse_(return_inverse), + return_counts_(return_counts) {} + + template + void apply() const { + UniqueConsecutiveDim( + ctx_, in_, out_, return_inverse_, return_counts_, axis_); + } +}; +template +class UniqueConsecutiveKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + auto data_type = static_cast( + context.Attr("dtype")); + if (data_type == framework::proto::VarType::INT32) { + PADDLE_ENFORCE_LE( + x->numel(), INT_MAX, + platform::errors::InvalidArgument( + "The number of elements in Input(X) should be less than or " + "equal to INT_MAX, but received num is %d. Please set `dtype` to " + "int64.", + x->numel())); + } + std::vector axis_vec = context.Attr>("axis"); + bool return_inverse = context.Attr("return_inverse"); + bool return_counts = context.Attr("return_counts"); + + if (axis_vec.empty()) { + framework::VisitDataTypeTiny( + data_type, UniqueConsecutiveFlattendTensorFunctor( + context, *x, out, return_inverse, return_counts)); + } else { + int axis = axis_vec[0]; + framework::VisitDataTypeTiny( + data_type, + UniqueConsecutiveDimFunctor( + context, *x, out, axis, return_inverse, return_counts)); + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 07a3fc8a8df331eba66d9c1c1c987638830458aa..dc27befd26cda829409e5e119c02bd8f89189106 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -86,6 +86,7 @@ std::map> op_outs_map = { {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, {"unique", {"Out", "Index", "Indices", "Counts"}}, + {"unique_consecutive", {"Out", "Index", "Counts"}}, {"generate_proposals", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}}, {"collect_fpn_proposals", {"FpnRois", "RoisNum"}}, {"matrix_nms", {"Out", "Index", "RoisNum"}}, diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index f72fb6c1806b10e089399ad29427f3990f4bd80e..907a667cb6ba789694fbebc4b185e73715009a6e 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -142,6 +142,7 @@ from .tensor.manipulation import squeeze_ # noqa: F401 from .tensor.manipulation import stack # noqa: F401 from .tensor.manipulation import strided_slice # noqa: F401 from .tensor.manipulation import unique # noqa: F401 +from .tensor.manipulation import unique_consecutive # noqa: F401 from .tensor.manipulation import unsqueeze # noqa: F401 from .tensor.manipulation import unsqueeze_ # noqa: F401 from .tensor.manipulation import unstack # noqa: F401 @@ -470,6 +471,7 @@ __all__ = [ # noqa 'randn', 'strided_slice', 'unique', + 'unique_consecutive', 'set_cuda_rng_state', 'set_printoptions', 'std', diff --git a/python/paddle/fluid/tests/unittests/test_unique_consecutive_op.py b/python/paddle/fluid/tests/unittests/test_unique_consecutive_op.py new file mode 100644 index 0000000000000000000000000000000000000000..a12f1aaff45969d8d37a604017be03c9c71d6a19 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unique_consecutive_op.py @@ -0,0 +1,238 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core + +import paddle +import paddle.fluid as fluid +import paddle.fluid.framework as framework + + +def reference_unique_consecutive(X, return_inverse=False, return_counts=False): + """ + Reference unique_consecutive implementation using python. + Args: + x(Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + return_inverse(bool, optional): If True, also return the indices for where elements in + the original input ended up in the returned unique consecutive tensor. Default is False. + return_counts(bool, optional): If True, also return the counts for each unique consecutive element. + """ + X = list(X) + counts_vec = [1] * len(X) + i = 0 + counts = 1 + last = 0 + inverse_vec = [0] * len(X) + inverse_vec[last] = i + cnt = 0 + while i < len(X) - 1: + if X[i] == X[i + 1]: + if return_counts: + counts_vec[cnt] += 1 + del X[i] + else: + i += 1 + cnt += 1 + if return_inverse: + last += 1 + inverse_vec[last] = i + if return_counts: + counts_vec = counts_vec[:len(X)] + if return_inverse and return_counts: + return X, np.array(inverse_vec), np.array(counts_vec) + elif return_counts: + return X, np.array(counts_vec) + elif return_inverse: + return X, np.array(inverse_vec) + else: + return X + + +class TestUniqueConsecutiveOp(OpTest): + """case 1""" + + def config(self): + self.x_size = 100 + self.x_range = 20 + self.return_inverse = False + self.return_counts = False + + def init_kernel_type(self): + self.dtype = "float32" if core.is_compiled_with_rocm() else "float64" + + def setUp(self): + self.init_kernel_type() + self.config() + self.op_type = "unique_consecutive" + x = np.random.randint(self.x_range, size=self.x_size).astype(self.dtype) + result = reference_unique_consecutive(x, self.return_inverse, + self.return_counts) + out = reference_unique_consecutive(x) + out = np.array(out).astype(self.dtype) + self.inputs = {'X': x, } + self.attrs = {'dtype': int(core.VarDesc.VarType.INT32)} + self.outputs = {'Out': out, } + + def test_check_output(self): + self.check_output() + + +class TestUniqueConsecutiveOp2(TestUniqueConsecutiveOp): + """case 2""" + + def config(self): + self.x_size = 100 + self.x_range = 20 + self.return_inverse = True + self.return_counts = False + + def setUp(self): + self.init_kernel_type() + self.config() + self.op_type = "unique_consecutive" + x = np.random.randint(self.x_range, size=self.x_size).astype(self.dtype) + result, inverse = reference_unique_consecutive(x, self.return_inverse, + self.return_counts) + result = np.array(result).astype(self.dtype) + inverse = inverse.astype(self.dtype) + self.inputs = {'X': x, } + self.attrs = { + 'return_inverse': self.return_inverse, + 'dtype': int(core.VarDesc.VarType.INT32) + } + self.outputs = {'Out': result, 'Index': inverse} + + +class TestUniqueConsecutiveOp3(TestUniqueConsecutiveOp): + """case 3""" + + def config(self): + self.x_size = 100 + self.x_range = 20 + self.return_inverse = False + self.return_counts = True + + def setUp(self): + self.init_kernel_type() + self.config() + self.op_type = "unique_consecutive" + x = np.random.randint(self.x_range, size=self.x_size).astype(self.dtype) + result, counts = reference_unique_consecutive(x, self.return_inverse, + self.return_counts) + result = np.array(result).astype(self.dtype) + counts = counts.astype(self.dtype) + self.inputs = {'X': x, } + self.attrs = { + 'return_counts': self.return_counts, + 'dtype': int(core.VarDesc.VarType.INT32) + } + self.outputs = {'Out': result, 'Counts': counts} + + +class TestUniqueConsecutiveOp4(TestUniqueConsecutiveOp): + """case 4""" + + def config(self): + self.x_size = 100 + self.x_range = 20 + self.return_inverse = True + self.return_counts = True + + def setUp(self): + self.init_kernel_type() + self.config() + self.op_type = "unique_consecutive" + x = np.random.randint(self.x_range, size=self.x_size).astype(self.dtype) + result, inverse, counts = reference_unique_consecutive( + x, self.return_inverse, self.return_counts) + result = np.array(result).astype(self.dtype) + inverse = inverse.astype(self.dtype) + counts = counts.astype(self.dtype) + self.inputs = {'X': x, } + self.attrs = { + 'return_inverse': self.return_inverse, + 'return_counts': self.return_counts, + 'dtype': int(core.VarDesc.VarType.INT32) + } + self.outputs = {'Out': result, 'Index': inverse, 'Counts': counts} + + +class TestUniqueConsecutiveAPI(unittest.TestCase): + def setUp(self): + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + paddle.enable_static() + input_x = fluid.data(name="input_x", shape=[100, ], dtype="float32") + result = paddle.unique_consecutive(input_x) + x_np = np.random.randint(20, size=100).astype("float32") + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input_x": x_np}, + fetch_list=[result]) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input_x = np.random.randint(20, size=100).astype("float64") + x = paddle.to_tensor(input_x) + result = paddle.unique_consecutive(x) + + +class TestUniqueConsecutiveCase2API(unittest.TestCase): + def setUp(self): + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + paddle.enable_static() + input_x = fluid.data(name="input_x", shape=[100, ], dtype="float32") + result, inverse, counts = paddle.unique_consecutive( + input_x, return_inverse=True, return_counts=True) + x_np = np.random.randint(20, size=100).astype("float32") + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input_x": x_np}, + fetch_list=[result]) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input_x = np.random.randint(20, size=100).astype("float64") + x = paddle.to_tensor(input_x) + result, inverse, counts = paddle.unique_consecutive( + x, return_inverse=True, return_counts=True) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 1c6996bcad6e5c88cf593b029fa9445a1089d622..bcb508d11922fc9613953a90ef03133b133cb689 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -88,6 +88,7 @@ from .manipulation import squeeze_ # noqa: F401 from .manipulation import stack # noqa: F401 from .manipulation import strided_slice # noqa: F401 from .manipulation import unique # noqa: F401 +from .manipulation import unique_consecutive # noqa: F401 from .manipulation import unsqueeze # noqa: F401 from .manipulation import unsqueeze_ # noqa: F401 from .manipulation import unstack # noqa: F401 @@ -333,6 +334,7 @@ tensor_method_func = [ #noqa 'strided_slice', 'transpose', 'unique', + 'unique_consecutive', 'unsqueeze', 'unsqueeze_', 'unstack', diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 434069fe74bce673ecfc292be66c1ad378f0a2ba..4b84401aa094583c9448b5cb7a711c01eee4a850 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -717,6 +717,112 @@ def squeeze_(x, axis=None, name=None): return out +def unique_consecutive(x, + return_inverse=False, + return_counts=False, + axis=None, + dtype="int64", + name=None): + r""" + Eliminates all but the first element from every consecutive group of equivalent elements. + + .. note:: This function is different from :func:`paddle.unique` in the sense that this function + only eliminates consecutive duplicate values. This semantics is similar to `std::unique` in C++. + + Args: + x(Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + return_inverse(bool, optional): If True, also return the indices for where elements in + the original input ended up in the returned unique consecutive tensor. Default is False. + return_counts(bool, optional): If True, also return the counts for each unique consecutive element. + Default is False. + axis(int, optional): The axis to apply unique consecutive. If None, the input will be flattened. + Default is None. + dtype(np.dtype|str, optional): The data type `inverse` tensor: int32 or int64. + Default: int64. + name(str, optional): Name for the operation. For more information, please refer to + :ref:`api_guide_Name`. Default is None. + + Returns: + tuple: (out, inverse, counts). `out` is the unique consecutive tensor for `x`. `inverse` is provided only if `return_inverse` is True. `counts` is provided only if `return_counts` is True. + + Example: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([1, 1, 2, 2, 3, 1, 1, 2]) + output = paddle.unique_consecutive(x) # + np_output = output.numpy() # [1 2 3 1 2] + _, inverse, counts = paddle.unique_consecutive(x, return_inverse=True, return_counts=True) + np_inverse = inverse.numpy() # [0 0 1 1 2 3 3 4] + np_counts = inverse.numpy() # [2 2 1 2 1] + + x = paddle.to_tensor([[2, 1, 3], [3, 0, 1], [2, 1, 3], [2, 1, 3]]) + output = paddle.unique_consecutive(x, axis=0) # + np_output = output.numpy() # [2 1 3 0 1 2 1 3 2 1 3] + + x = paddle.to_tensor([[2, 1, 3], [3, 0, 1], [2, 1, 3], [2, 1, 3]]) + output = paddle.unique_consecutive(x, axis=0) # + np_output = output.numpy() + # [[2 1 3] + # [3 0 1] + # [2 1 3]] + """ + + if axis is None: + axis = [] + else: + axis = [axis] + attr_dtype = convert_np_dtype_to_dtype_(dtype) + if in_dygraph_mode(): + out, inverse, counts = core.ops.unique_consecutive( + x, 'dtype', attr_dtype, 'return_inverse', return_inverse, + 'return_counts', return_counts, 'axis', axis) + outs = [out] + if return_inverse: + outs.append(inverse) + if return_counts: + outs.append(counts) + if len(outs) == 1: + return outs[0] + return tuple(outs) + check_variable_and_dtype(x, "input", + ['float32', 'float64', 'int32', 'int64'], + 'unique_consecutive') + check_type(return_inverse, 'return_inverse', bool, 'unique_consecutive') + check_type(return_counts, 'return_counts', bool, 'unique_consecutive') + check_dtype(dtype, 'dtype', ['int32', 'int64'], 'unique_consecutive') + if len(axis) != 0: + check_type(axis[0], 'axis', int, 'unique_consecutive') + helper = LayerHelper('unique_consecutive', **locals()) + attrs = { + 'dtype': attr_dtype, + "return_inverse": return_inverse, + "return_counts": return_counts, + "axis": axis, + } + out = helper.create_variable_for_type_inference( + dtype=x.dtype, stop_gradient=True) + inverse = helper.create_variable_for_type_inference( + dtype=attr_dtype, stop_gradient=True) + counts = helper.create_variable_for_type_inference( + dtype=attr_dtype, stop_gradient=True) + outputs = {"Out": out, "Index": inverse, "Counts": counts} + outs = [out] + if return_inverse: + outs.append(inverse) + if return_counts: + outs.append(counts) + helper.append_op( + type="unique_consecutive", + inputs={"X": x}, + attrs=attrs, + outputs=outputs) + if len(outs) == 1: + return outs[0] + return tuple(outs) + + def unique(x, return_index=False, return_inverse=False,