// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/unique_kernel.h" #include #include #include #include #include #include #include #include #include #include #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/unique_functor.h" #include "paddle/phi/kernels/index_select_kernel.h" namespace phi { // Binary function 'less than' template struct LessThan { int col; const InT* in_trans_data; LessThan(int64_t _col, const InT* _in_trans_data) : col(_col), in_trans_data(_in_trans_data) {} __device__ bool operator()(int64_t a, int64_t b) const { for (int i = 0; i < col; ++i) { InT lhs = in_trans_data[i + a * col]; InT rhs = in_trans_data[i + b * col]; if (lhs < rhs) { return true; } else if (lhs > rhs) { return false; } } return false; } }; // Binary function 'equal_to' template struct BinaryEqual { int64_t col; const InT* in_trans_data; BinaryEqual(int64_t _col, const InT* _in_trans_data) : col(_col), in_trans_data(_in_trans_data) {} __device__ bool operator()(int64_t a, int64_t b) const { for (int64_t i = 0; i < col; ++i) { InT lhs = in_trans_data[i + a * col]; InT rhs = in_trans_data[i + b * col]; if (lhs != rhs) { return false; } } return true; } }; // Binary function 'not_equal_to' template struct BinaryNotEqual { int64_t col; const InT* in_trans_data; BinaryNotEqual(int64_t _col, const InT* _in_trans_data) : col(_col), in_trans_data(_in_trans_data) {} __device__ bool operator()(int64_t a, int64_t b) const { for (int64_t i = 0; i < col; ++i) { InT lhs = in_trans_data[i + a * col]; InT rhs = in_trans_data[i + b * col]; if (lhs != rhs) { return true; } } return false; } }; // The core logic of computing Unique for a flattend DenseTensor template static typename std::enable_if< !std::is_same::value && !std::is_same::value>::type UniqueFlattendCUDATensor(const Context& context, const DenseTensor& in, DenseTensor* out, DenseTensor* indices, DenseTensor* index, DenseTensor* counts, bool return_index, bool return_inverse, bool return_counts, int64_t num_input) { // 0. Prepration auto equal = thrust::equal_to(); auto not_equal = thrust::not_equal_to(); DenseTensor in_hat; phi::Copy(context, in, context.GetPlace(), false, &in_hat); auto* in_data_hat = context.template Alloc(&in_hat); indices->Resize(phi::make_ddim({num_input})); auto* indices_data = context.template Alloc(indices); thrust::sequence(thrust::device, indices_data, indices_data + num_input); thrust::sort_by_key( thrust::device, in_data_hat, in_data_hat + num_input, indices_data); // 1. Calculate op result: 'out' DenseTensor range; range.Resize(phi::make_ddim({num_input + 1})); auto* range_data_ptr = context.template Alloc(&range); thrust::sequence( thrust::device, range_data_ptr, range_data_ptr + num_input + 1); phi::Copy(context, in_hat, context.GetPlace(), false, out); int num_out; auto out_data = context.template Alloc(out); num_out = thrust::unique_by_key( thrust::device, out_data, out_data + num_input, range_data_ptr, equal) .first - out_data; out->Resize(phi::make_ddim({num_out})); // 3. Calculate inverse index: 'inverse' if (return_inverse) { index->Resize(phi::make_ddim({num_input})); auto* inverse_data = context.template Alloc(index); DenseTensor inv_loc; inv_loc.Resize(phi::make_ddim({num_input})); auto inv_loc_data_ptr = context.template Alloc(&inv_loc); thrust::adjacent_difference(thrust::device, in_data_hat, in_data_hat + num_input, inv_loc_data_ptr, not_equal); thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); inv_loc_data_dev[0] = 0; // without device_ptr, segmentation fault thrust::inclusive_scan(thrust::device, inv_loc_data_ptr, inv_loc_data_ptr + num_input, inv_loc_data_ptr); thrust::scatter(thrust::device, inv_loc_data_ptr, inv_loc_data_ptr + num_input, indices_data, inverse_data); } // 2. Calculate sorted index: 'indices' if (return_index) { DenseTensor tmp_indices; tmp_indices.Resize(phi::make_ddim({num_input})); auto* tmp_indices_data_ptr = context.template Alloc(&tmp_indices); thrust::copy(thrust::device, in_data_hat, in_data_hat + num_input, tmp_indices_data_ptr); thrust::unique_by_key(thrust::device, tmp_indices_data_ptr, tmp_indices_data_ptr + num_input, indices_data, equal); indices->Resize(phi::make_ddim({num_out})); } // 4. Calculate 'counts' if (return_counts) { counts->Resize(phi::make_ddim({num_out})); auto count_data = context.template Alloc(counts); // init 'count_data' as 0 thrust::fill(thrust::device, count_data, count_data + num_out, 0); thrust::device_ptr range_data_ptr_dev(range_data_ptr); range_data_ptr_dev[num_out] = num_input; thrust::adjacent_difference(thrust::device, range_data_ptr + 1, range_data_ptr + num_out + 1, count_data); } } // The core logic of computing Unique for a flattend DenseTensor template static typename std::enable_if< std::is_same::value || std::is_same::value>::type UniqueFlattendCUDATensor(const Context& context, const DenseTensor& in, DenseTensor* out, DenseTensor* indices, DenseTensor* index, DenseTensor* counts, bool return_index, bool return_inverse, bool return_counts, int64_t num_input) { // 1. Sort indices DenseTensor in_resize; in_resize.ShareDataWith(in); in_resize.Resize(phi::make_ddim({num_input})); const InT* in_data = in_resize.data(); auto equal = BinaryEqual(1, in_data); auto not_equal = BinaryNotEqual(1, in_data); indices->Resize(phi::make_ddim({num_input})); auto* indices_data = context.template Alloc(indices); thrust::sequence(thrust::device, indices_data, indices_data + num_input); thrust::sort(thrust::device, indices_data, indices_data + num_input, LessThan(1, in_data)); // 2. Calculate inverse indices: 'index' if (return_inverse) { index->Resize(phi::make_ddim({num_input})); auto* inverse_data = context.template Alloc(index); DenseTensor inv_loc; inv_loc.Resize(phi::make_ddim({num_input})); auto inv_loc_data_ptr = context.template Alloc(&inv_loc); thrust::adjacent_difference(thrust::device, indices_data, indices_data + num_input, inv_loc_data_ptr, not_equal); thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); inv_loc_data_dev[0] = 0; // without device_ptr, segmentation fault thrust::inclusive_scan(thrust::device, inv_loc_data_ptr, inv_loc_data_ptr + num_input, inv_loc_data_ptr); thrust::scatter(thrust::device, inv_loc_data_ptr, inv_loc_data_ptr + num_input, indices_data, inverse_data); } // 3. Calculate op result and sorted index: 'out' & 'indices' DenseTensor range; range.Resize(phi::make_ddim({num_input + 1})); auto* range_data_ptr = context.template Alloc(&range); thrust::sequence( thrust::device, range_data_ptr, range_data_ptr + num_input + 1); int num_out; num_out = thrust::unique_by_key(thrust::device, indices_data, indices_data + num_input, range_data_ptr, equal) .first - indices_data; indices->Resize(phi::make_ddim({num_out})); out->Resize(phi::make_ddim({num_out})); context.template Alloc(out); phi::IndexSelectKernel(context, in_resize, *indices, 0, out); // 4. Calculate 'counts' if (return_counts) { counts->Resize(phi::make_ddim({num_out})); auto count_data = context.template Alloc(counts); // init 'count_data' as 0 thrust::fill(thrust::device, count_data, count_data + num_out, 0); thrust::device_ptr range_data_ptr_dev(range_data_ptr); range_data_ptr_dev[num_out] = num_input; thrust::adjacent_difference(thrust::device, range_data_ptr + 1, range_data_ptr + num_out + 1, count_data); } } // The logic of compute unique with axis required, it's a little different // from above function template static void ComputeUniqueDims(const Context& context, DenseTensor* sorted_indices, IndexT* sorted_indices_data, DenseTensor* out, DenseTensor* inverse, DenseTensor* counts, bool return_index, bool return_inverse, bool return_counts, equal_T equal, not_equal_T not_equal, int64_t row) { // 1. inverse indices: 'inverse' inverse->Resize(phi::make_ddim({row})); auto* inverse_data = context.template Alloc(inverse); DenseTensor inv_loc; inv_loc.Resize(phi::make_ddim({row})); auto inv_loc_data_ptr = context.template Alloc(&inv_loc); thrust::adjacent_difference(thrust::device, sorted_indices_data, sorted_indices_data + row, inv_loc_data_ptr, not_equal); thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); inv_loc_data_dev[0] = 0; thrust::inclusive_scan(thrust::device, inv_loc_data_ptr, inv_loc_data_ptr + row, inv_loc_data_ptr); thrust::scatter(thrust::device, inv_loc_data_ptr, inv_loc_data_ptr + row, sorted_indices_data, inverse_data); // 2. sorted indices DenseTensor range; range.Resize(phi::make_ddim({row + 1})); auto range_data_ptr = context.template Alloc(&range); thrust::sequence(thrust::device, range_data_ptr, range_data_ptr + row + 1); int num_out; num_out = thrust::unique_by_key(thrust::device, sorted_indices_data, sorted_indices_data + row, range_data_ptr, equal) .first - sorted_indices_data; thrust::device_ptr range_data_ptr_dev(range_data_ptr); range_data_ptr_dev[num_out] = row; sorted_indices->Resize(phi::make_ddim({num_out})); // 3. counts: 'counts' counts->Resize(phi::make_ddim({num_out})); auto* count_data = context.template Alloc(counts); thrust::fill(thrust::device, count_data, count_data + num_out, 0); thrust::adjacent_difference(thrust::device, range_data_ptr + 1, range_data_ptr + num_out + 1, count_data); } // Calculate unique when 'axis' is set template static void UniqueDimsCUDATensor(const Context& context, const DenseTensor& in, DenseTensor* out, DenseTensor* indices, DenseTensor* index, DenseTensor* counts, bool return_index, bool return_inverse, bool return_counts, int axis) { // 1. Transpose & reshape // Transpose tensor: eg. axis=1, [dim0, dim1, dim2] -> [dim1, dim0, dim2] DenseTensor in_trans; std::vector in_trans_dims_vec(phi::vectorize(in.dims())); auto in_trans_dims = phi::make_ddim(in_trans_dims_vec); std::vector permute(in.dims().size()); bool is_transpose = axis != 0; if (is_transpose) { std::iota(permute.begin(), permute.end(), 0); permute[axis] = 0; permute[0] = axis; in_trans_dims_vec[axis] = in.dims()[0]; in_trans_dims_vec[0] = in.dims()[axis]; in_trans_dims = phi::make_ddim(in_trans_dims_vec); in_trans.Resize(in_trans_dims); context.template Alloc(&in_trans); phi::funcs::TransCompute( in.dims().size(), // num of dims context, // device in, // original DenseTensor &in_trans, // DenseTensor after reshape permute); // index of axis } else { in_trans.ShareDataWith(in); } // Reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2] auto in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1); in_trans.Resize(in_trans_flat_dims); // now 'in_trans' is 2D int64_t col = in_trans.dims()[1]; int64_t row = in_trans.dims()[0]; const InT* in_trans_data = in_trans.data(); indices->Resize(phi::make_ddim({row})); auto* sorted_indices_data = context.template Alloc(indices); // 2. Calculate 'indices', 'inverse', 'counts' // Init index and sort thrust::sequence( thrust::device, sorted_indices_data, sorted_indices_data + row); thrust::sort(thrust::device, sorted_indices_data, sorted_indices_data + row, LessThan(col, in_trans_data)); ComputeUniqueDims( context, indices, sorted_indices_data, out, index, counts, return_index, return_inverse, return_counts, BinaryEqual(col, in_trans_data), BinaryNotEqual(col, in_trans_data), row); // 3. Select indices and reshape back to get 'out' std::vector out_trans_dims_vec = in_trans_dims_vec; out_trans_dims_vec[0] = indices->numel(); if (is_transpose) { DenseTensor out_trans; out_trans.Resize(phi::make_ddim(out_trans_dims_vec)); context.template Alloc(&out_trans); phi::IndexSelectKernel( context, in_trans, *indices, 0, &out_trans); std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]); out->Resize(phi::make_ddim(out_trans_dims_vec)); context.template Alloc(out); phi::funcs::TransCompute( out_trans.dims().size(), context, out_trans, out, permute); } else { out->Resize(phi::make_ddim(out_trans_dims_vec)); context.template Alloc(out); phi::IndexSelectKernel(context, in_trans, *indices, 0, out); } } // functor for processing a flattend DenseTensor template struct UniqueFlattendCUDAFunctor { const Context& ctx_; const DenseTensor& in_; DenseTensor* out_; DenseTensor* indices_; DenseTensor* index_; DenseTensor* counts_; const bool return_index_; const bool return_inverse_; const bool return_counts_; UniqueFlattendCUDAFunctor(const Context& context, const DenseTensor& in, DenseTensor* out, DenseTensor* indices, DenseTensor* index, DenseTensor* counts, bool return_index, bool return_inverse, bool return_counts) : ctx_(context), in_(in), out_(out), indices_(indices), index_(index), counts_(counts), return_index_(return_index), return_inverse_(return_inverse), return_counts_(return_counts) {} template void apply() const { UniqueFlattendCUDATensor(ctx_, in_, out_, indices_, index_, counts_, return_index_, return_inverse_, return_counts_, in_.numel()); } }; // functor for processing a multi-dimentional DenseTensor template struct UniqueDimsCUDAFunctor { const Context& ctx_; const DenseTensor& in_; DenseTensor* out_; DenseTensor* indices_; DenseTensor* index_; DenseTensor* counts_; const int axis_; const bool return_index_; const bool return_inverse_; const bool return_counts_; UniqueDimsCUDAFunctor(const Context& context, const DenseTensor& in, DenseTensor* out, DenseTensor* indices, DenseTensor* index, DenseTensor* counts, const int axis, bool return_index, bool return_inverse, bool return_counts) : ctx_(context), in_(in), out_(out), indices_(indices), index_(index), counts_(counts), axis_(axis), return_index_(return_index), return_inverse_(return_inverse), return_counts_(return_counts) {} template void apply() const { UniqueDimsCUDATensor(ctx_, in_, out_, indices_, index_, counts_, return_index_, return_inverse_, return_counts_, axis_); } }; template void UniqueRawKernel(const Context& context, const DenseTensor& x, bool return_index, bool return_inverse, bool return_counts, const std::vector& axis, DataType dtype, bool is_sorted, DenseTensor* out, DenseTensor* indices, DenseTensor* index, DenseTensor* counts) { if (dtype == phi::DataType::INT32) { PADDLE_ENFORCE_LE( x.numel() + 1, INT_MAX, phi::errors::InvalidArgument( "The number of elements in Input(X) should be less than or " "equal to INT_MAX, but received num is %d. Please set `dtype` to " "int64.", x.numel())); } // if 'axis' is not required, flatten the DenseTensor. if (axis.empty()) { phi::VisitDataTypeTiny( dtype, UniqueFlattendCUDAFunctor(context, x, out, indices, index, counts, return_index, return_inverse, return_counts)); } else { // 'axis' is required. int axis_value = axis[0]; axis_value = (axis_value == -1) ? (x.dims().size() - 1) : axis_value; phi::VisitDataTypeTiny(dtype, UniqueDimsCUDAFunctor(context, x, out, indices, index, counts, axis_value, return_index, return_inverse, return_counts)); } } template void UniqueKernel(const Context& context, const DenseTensor& x, bool return_index, bool return_inverse, bool return_counts, const std::vector& axis, DataType dtype, DenseTensor* out, DenseTensor* indices, DenseTensor* index, DenseTensor* counts) { bool is_sorted = true; UniqueRawKernel(context, x, return_index, return_inverse, return_counts, axis, dtype, is_sorted, out, indices, index, counts); } } // namespace phi PD_REGISTER_KERNEL(unique, GPU, ALL_LAYOUT, phi::UniqueKernel, float, double, phi::dtype::float16, phi::dtype::bfloat16, int64_t, int) { kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(3).SetDataType(phi::DataType::UNDEFINED); } PD_REGISTER_KERNEL(unique_raw, GPU, ALL_LAYOUT, phi::UniqueRawKernel, float, double, phi::dtype::float16, phi::dtype::bfloat16, int64_t, int) { kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(3).SetDataType(phi::DataType::UNDEFINED); }