diff --git a/paddle/phi/kernels/gpu/unique_kernel.cu b/paddle/phi/kernels/gpu/unique_kernel.cu index 5fbc7da7d4c136b75b5fdc9bbac700bd91b198c9..5d4399e42e1abb7ef4f10e455c372a353bc7220c 100644 --- a/paddle/phi/kernels/gpu/unique_kernel.cu +++ b/paddle/phi/kernels/gpu/unique_kernel.cu @@ -100,24 +100,23 @@ struct BinaryNotEqual { }; // The core logic of computing Unique for a flattend DenseTensor -template -static void UniqueFlattendCUDATensor(const Context& context, - const DenseTensor& in, - DenseTensor* out, - DenseTensor* indices, - DenseTensor* index, - DenseTensor* counts, - bool return_index, - bool return_inverse, - bool return_counts, - equal_T equal, - not_equal_T not_equal, - int64_t num_input) { +template +static typename std::enable_if< + !std::is_same::value && + !std::is_same::value>::type +UniqueFlattendCUDATensor(const Context& context, + const DenseTensor& in, + DenseTensor* out, + DenseTensor* indices, + DenseTensor* index, + DenseTensor* counts, + bool return_index, + bool return_inverse, + bool return_counts, + int64_t num_input) { // 0. Prepration + auto equal = thrust::equal_to(); + auto not_equal = thrust::not_equal_to(); DenseTensor in_hat; phi::Copy(context, in, context.GetPlace(), false, &in_hat); auto* in_data_hat = context.template Alloc(&in_hat); @@ -202,6 +201,97 @@ static void UniqueFlattendCUDATensor(const Context& context, } } +// The core logic of computing Unique for a flattend DenseTensor +template +static typename std::enable_if< + std::is_same::value || + std::is_same::value>::type +UniqueFlattendCUDATensor(const Context& context, + const DenseTensor& in, + DenseTensor* out, + DenseTensor* indices, + DenseTensor* index, + DenseTensor* counts, + bool return_index, + bool return_inverse, + bool return_counts, + int64_t num_input) { + // 1. Sort indices + DenseTensor in_resize; + in_resize.ShareDataWith(in); + in_resize.Resize(phi::make_ddim({num_input})); + const InT* in_data = in_resize.data(); + auto equal = BinaryEqual(1, in_data); + auto not_equal = BinaryNotEqual(1, in_data); + + indices->Resize(phi::make_ddim({num_input})); + auto* indices_data = context.template Alloc(indices); + + thrust::sequence(thrust::device, indices_data, indices_data + num_input); + thrust::sort(thrust::device, + indices_data, + indices_data + num_input, + LessThan(1, in_data)); + + // 2. Calculate inverse indices: 'index' + if (return_inverse) { + index->Resize(phi::make_ddim({num_input})); + auto* inverse_data = context.template Alloc(index); + DenseTensor inv_loc; + inv_loc.Resize(phi::make_ddim({num_input})); + auto inv_loc_data_ptr = context.template Alloc(&inv_loc); + thrust::adjacent_difference(thrust::device, + indices_data, + indices_data + num_input, + inv_loc_data_ptr, + not_equal); + thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); + inv_loc_data_dev[0] = 0; // without device_ptr, segmentation fault + thrust::inclusive_scan(thrust::device, + inv_loc_data_ptr, + inv_loc_data_ptr + num_input, + inv_loc_data_ptr); + thrust::scatter(thrust::device, + inv_loc_data_ptr, + inv_loc_data_ptr + num_input, + indices_data, + inverse_data); + } + + // 3. Calculate op result and sorted index: 'out' & 'indices' + DenseTensor range; + range.Resize(phi::make_ddim({num_input + 1})); + auto* range_data_ptr = context.template Alloc(&range); + thrust::sequence( + thrust::device, range_data_ptr, range_data_ptr + num_input + 1); + int num_out; + num_out = thrust::unique_by_key(thrust::device, + indices_data, + indices_data + num_input, + range_data_ptr, + equal) + .first - + indices_data; + indices->Resize(phi::make_ddim({num_out})); + out->Resize(phi::make_ddim({num_out})); + context.template Alloc(out); + phi::IndexSelectKernel(context, in_resize, *indices, 0, out); + + // 4. Calculate 'counts' + if (return_counts) { + counts->Resize(phi::make_ddim({num_out})); + auto count_data = context.template Alloc(counts); + // init 'count_data' as 0 + thrust::fill(thrust::device, count_data, count_data + num_out, 0); + thrust::device_ptr range_data_ptr_dev(range_data_ptr); + range_data_ptr_dev[num_out] = num_input; + thrust::adjacent_difference(thrust::device, + range_data_ptr + 1, + range_data_ptr + num_out + 1, + count_data); + } +} + // The logic of compute unique with axis required, it's a little different // from above function template (), - thrust::not_equal_to(), in_.numel()); } }; @@ -548,8 +636,16 @@ void UniqueKernel(const Context& context, } // namespace phi -PD_REGISTER_KERNEL( - unique, GPU, ALL_LAYOUT, phi::UniqueKernel, float, double, int64_t, int) { +PD_REGISTER_KERNEL(unique, + GPU, + ALL_LAYOUT, + phi::UniqueKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16, + int64_t, + int) { kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(3).SetDataType(phi::DataType::UNDEFINED); @@ -561,6 +657,8 @@ PD_REGISTER_KERNEL(unique_raw, phi::UniqueRawKernel, float, double, + phi::dtype::float16, + phi::dtype::bfloat16, int64_t, int) { kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); diff --git a/python/paddle/fluid/tests/unittests/test_unique.py b/python/paddle/fluid/tests/unittests/test_unique.py index 2f88ecb849910438289422a4a36a0e4da764b609..8e31f377bba5e7142d7ed2c5faed5d44c637f724 100644 --- a/python/paddle/fluid/tests/unittests/test_unique.py +++ b/python/paddle/fluid/tests/unittests/test_unique.py @@ -24,6 +24,7 @@ from paddle.fluid import core class TestUniqueOp(OpTest): def setUp(self): self.op_type = "unique" + self.init_dtype() self.init_config() def test_check_output(self): @@ -31,13 +32,16 @@ class TestUniqueOp(OpTest): check_dygraph=False ) # unique return sorted data in dygraph + def init_dtype(self): + self.dtype = np.int64 + def init_config(self): self.inputs = { - 'X': np.array([2, 3, 3, 1, 5, 3], dtype='int64'), + 'X': np.array([2, 3, 3, 1, 5, 3], dtype=self.dtype), } self.attrs = {'dtype': int(core.VarDesc.VarType.INT32)} self.outputs = { - 'Out': np.array([2, 3, 1, 5], dtype='int64'), + 'Out': np.array([2, 3, 1, 5], dtype=self.dtype), 'Index': np.array([0, 1, 1, 2, 3, 1], dtype='int32'), } @@ -45,25 +49,25 @@ class TestUniqueOp(OpTest): class TestOne(TestUniqueOp): def init_config(self): self.inputs = { - 'X': np.array([2], dtype='int64'), + 'X': np.array([2], dtype=self.dtype), } self.attrs = {'dtype': int(core.VarDesc.VarType.INT32)} self.outputs = { - 'Out': np.array([2], dtype='int64'), + 'Out': np.array([2], dtype=self.dtype), 'Index': np.array([0], dtype='int32'), } class TestRandom(TestUniqueOp): def init_config(self): - self.inputs = {'X': np.random.randint(0, 100, (150,), dtype='int64')} + self.inputs = {'X': np.random.randint(0, 100, (150,), dtype=self.dtype)} self.attrs = {'dtype': int(core.VarDesc.VarType.INT64)} np_unique, np_index, reverse_index = np.unique( self.inputs['X'], True, True ) np_tuple = [(np_unique[i], np_index[i]) for i in range(len(np_unique))] np_tuple.sort(key=lambda x: x[1]) - target_out = np.array([i[0] for i in np_tuple], dtype='int64') + target_out = np.array([i[0] for i in np_tuple], dtype=self.dtype) target_index = np.array( [list(target_out).index(i) for i in self.inputs['X']], dtype='int64' ) @@ -95,11 +99,11 @@ class TestUniqueRaiseError(unittest.TestCase): class TestOneGPU(TestUniqueOp): def init_config(self): self.inputs = { - 'X': np.array([2], dtype='int64'), + 'X': np.array([2], dtype=self.dtype), } self.attrs = {'dtype': int(core.VarDesc.VarType.INT32)} self.outputs = { - 'Out': np.array([2], dtype='int64'), + 'Out': np.array([2], dtype=self.dtype), 'Index': np.array([0], dtype='int32'), } @@ -116,14 +120,14 @@ class TestOneGPU(TestUniqueOp): ) class TestRandomGPU(TestUniqueOp): def init_config(self): - self.inputs = {'X': np.random.randint(0, 100, (150,), dtype='int64')} + self.inputs = {'X': np.random.randint(0, 100, (150,), dtype=self.dtype)} self.attrs = {'dtype': int(core.VarDesc.VarType.INT64)} np_unique, np_index, reverse_index = np.unique( self.inputs['X'], True, True ) np_tuple = [(np_unique[i], np_index[i]) for i in range(len(np_unique))] np_tuple.sort(key=lambda x: x[1]) - target_out = np.array([i[0] for i in np_tuple], dtype='int64') + target_out = np.array([i[0] for i in np_tuple], dtype=self.dtype) target_index = np.array( [list(target_out).index(i) for i in self.inputs['X']], dtype='int64' ) @@ -139,8 +143,11 @@ class TestRandomGPU(TestUniqueOp): class TestSortedUniqueOp(TestUniqueOp): + def init_dtype(self): + self.dtype = np.float64 + def init_config(self): - self.inputs = {'X': np.array([2, 3, 3, 1, 5, 3], dtype='int64')} + self.inputs = {'X': np.array([2, 3, 3, 1, 5, 3], dtype=self.dtype)} unique, indices, inverse, count = np.unique( self.inputs['X'], return_index=True, @@ -164,9 +171,35 @@ class TestSortedUniqueOp(TestUniqueOp): } +class TestSortedUniqueFP16Op(TestSortedUniqueOp): + def init_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestSortedUniqueBF16Op(TestSortedUniqueOp): + def init_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place( + place, check_dygraph=False + ) # unique return sorted data in dygraph + + class TestUniqueOpAxisNone(TestUniqueOp): + def init_dtype(self): + self.dtype = np.float64 + def init_config(self): - self.inputs = {'X': np.random.random((4, 7, 10)).astype('float64')} + self.inputs = { + 'X': np.random.randint(0, 100, (4, 7, 10)).astype(self.dtype) + } unique, indices, inverse, counts = np.unique( self.inputs['X'], return_index=True, @@ -190,9 +223,35 @@ class TestUniqueOpAxisNone(TestUniqueOp): } +class TestUniqueOpAxisNoneFP16Op(TestUniqueOpAxisNone): + def init_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestUniqueOpAxisNoneBF16Op(TestUniqueOpAxisNone): + def init_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place( + place, check_dygraph=False + ) # unique return sorted data in dygraph + + class TestUniqueOpAxisNeg(TestUniqueOp): + def init_dtype(self): + self.dtype = np.float64 + def init_config(self): - self.inputs = {'X': np.random.random((6, 1, 8)).astype('float64')} + self.inputs = { + 'X': np.random.randint(0, 100, (6, 1, 8)).astype(self.dtype) + } unique, indices, inverse, counts = np.unique( self.inputs['X'], return_index=True, @@ -216,9 +275,35 @@ class TestUniqueOpAxisNeg(TestUniqueOp): } +class TestUniqueOpAxisNegFP16Op(TestUniqueOpAxisNeg): + def init_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestUniqueOpAxisNegBF16Op(TestUniqueOpAxisNeg): + def init_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place( + place, check_dygraph=False + ) # unique return sorted data in dygraph + + class TestUniqueOpAxis1(TestUniqueOp): + def init_dtype(self): + self.dtype = np.float64 + def init_config(self): - self.inputs = {'X': np.random.random((3, 8, 8)).astype('float64')} + self.inputs = { + 'X': np.random.randint(0, 100, (3, 8, 8)).astype(self.dtype) + } unique, indices, inverse, counts = np.unique( self.inputs['X'], return_index=True, @@ -242,6 +327,27 @@ class TestUniqueOpAxis1(TestUniqueOp): } +class TestUniqueOpAxis1FP16Op(TestUniqueOpAxis1): + def init_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestUniqueOpAxis1BF16Op(TestUniqueOpAxis1): + def init_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place( + place, check_dygraph=False + ) # unique return sorted data in dygraph + + class TestUniqueAPI(unittest.TestCase): def test_dygraph_api_out(self): paddle.disable_static()