diff --git a/paddle/phi/kernels/gpu/unique_kernel.cu b/paddle/phi/kernels/gpu/unique_kernel.cu index c073708ed8556989122490f21cc361fc9d962349..10cf1ea8df5343049fd742d79e308221e2384ec3 100644 --- a/paddle/phi/kernels/gpu/unique_kernel.cu +++ b/paddle/phi/kernels/gpu/unique_kernel.cu @@ -30,6 +30,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/unique_functor.h" +#include "paddle/phi/kernels/index_select_kernel.h" namespace phi { @@ -98,76 +99,6 @@ struct BinaryNotEqual { } }; -// index_select() function for DenseTensor -template -void IndexSelect(const Context& context, - const DenseTensor& input, - const DenseTensor& index, - DenseTensor* output, - int dim) { - auto input_dim = input.dims(); - auto input_dim_size = input_dim.size(); - auto output_dim = output->dims(); - - auto slice_size = 1; - for (auto i = dim + 1; i < input_dim_size; i++) { - slice_size *= input_dim[i]; - } - - auto input_width = slice_size * input_dim[dim]; - auto output_width = slice_size * output_dim[dim]; - - auto outer_nums = 1; - for (auto i = 0; i < dim; i++) { - outer_nums *= input_dim[i]; - } - - auto index_size = index.dims()[0]; - - std::vector input_vec; - std::vector index_vec; - phi::TensorToVector(input, context, &input_vec); - phi::TensorToVector(index, context, &index_vec); - std::vector out_vec(output->numel()); - - for (int i = 0; i < index_size; i++) { - PADDLE_ENFORCE_GE( - index_vec[i], - 0, - phi::errors::InvalidArgument( - "Variable value (index) of OP(index_select) " - "expected >= 0 and < %ld, but got %ld. Please check input " - "value.", - input_dim[dim], - index_vec[i])); - PADDLE_ENFORCE_LT( - index_vec[i], - input_dim[dim], - phi::errors::InvalidArgument( - "Variable value (index) of OP(index_select) " - "expected >= 0 and < %ld, but got %ld. Please check input " - "value.", - input_dim[dim], - index_vec[i])); - } - - for (auto i = 0; i < outer_nums; i++) { - auto input_start_offset = i * input_width; - auto output_start_offset = i * output_width; - - for (auto j = 0; j < index_size; j++) { - IndexT index_value = index_vec[j]; - for (auto k = 0; k < slice_size; k++) { - out_vec[output_start_offset + j * slice_size + k] = - input_vec[input_start_offset + index_value * slice_size + k]; - } - } - } - context.template Alloc(output); - phi::TensorFromVector(out_vec, context, output); - output->Resize(output_dim); -} - // The core logic of computing Unique for a flattend DenseTensor template [dim1, dim0, dim2] - std::vector permute(in.dims().size()); - std::iota(permute.begin(), permute.end(), 0); - permute[axis] = 0; - permute[0] = axis; - std::vector in_trans_dims_vec(phi::vectorize(in.dims())); - in_trans_dims_vec[axis] = in.dims()[0]; - in_trans_dims_vec[0] = in.dims()[axis]; DenseTensor in_trans; + std::vector in_trans_dims_vec(phi::vectorize(in.dims())); auto in_trans_dims = phi::make_ddim(in_trans_dims_vec); - in_trans.Resize(in_trans_dims); - context.template Alloc(&in_trans); - phi::funcs::TransCompute( - in.dims().size(), // num of dims - context, // device - in, // original DenseTensor - &in_trans, // DenseTensor after reshape - permute); // index of axis - + std::vector permute(in.dims().size()); + bool is_transpose = axis != 0; + if (is_transpose) { + std::iota(permute.begin(), permute.end(), 0); + permute[axis] = 0; + permute[0] = axis; + in_trans_dims_vec[axis] = in.dims()[0]; + in_trans_dims_vec[0] = in.dims()[axis]; + in_trans_dims = phi::make_ddim(in_trans_dims_vec); + in_trans.Resize(in_trans_dims); + context.template Alloc(&in_trans); + phi::funcs::TransCompute( + in.dims().size(), // num of dims + context, // device + in, // original DenseTensor + &in_trans, // DenseTensor after reshape + permute); // index of axis + } else { + in_trans.ShareDataWith(in); + } // Reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2] auto in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1); in_trans.Resize(in_trans_flat_dims); @@ -407,22 +343,27 @@ static void UniqueDimsCUDATensor(const Context& context, row); // 3. Select indices and reshape back to get 'out' - DenseTensor out_trans; std::vector out_trans_dims_vec = in_trans_dims_vec; out_trans_dims_vec[0] = indices->numel(); - out_trans.Resize(phi::make_ddim(out_trans_dims_vec)); - context.template Alloc(&out_trans); - - IndexSelect(context, in_trans, *indices, &out_trans, 0); - - std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]); - out->Resize(phi::make_ddim(out_trans_dims_vec)); - context.template Alloc(out); - std::vector out_trans_unbind = phi::funcs::Unbind(out_trans); - phi::funcs::ConcatFunctor concat_functor; - concat_functor(context, out_trans_unbind, 0, &out_trans); - phi::funcs::TransCompute( - out_trans.dims().size(), context, out_trans, out, permute); + if (is_transpose) { + DenseTensor out_trans; + out_trans.Resize(phi::make_ddim(out_trans_dims_vec)); + context.template Alloc(&out_trans); + + phi::IndexSelectKernel( + context, in_trans, *indices, 0, &out_trans); + + std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]); + out->Resize(phi::make_ddim(out_trans_dims_vec)); + context.template Alloc(out); + phi::funcs::TransCompute( + out_trans.dims().size(), context, out_trans, out, permute); + } else { + out->Resize(phi::make_ddim(out_trans_dims_vec)); + context.template Alloc(out); + + phi::IndexSelectKernel(context, in_trans, *indices, 0, out); + } } // functor for processing a flattend DenseTensor