未验证 提交 1d37868f 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Unique support float16&bfloat16 (#52995)

* [AMP OP&Test] Unique support float16&bfloat16

* add test
上级 00efdf84
......@@ -100,24 +100,23 @@ struct BinaryNotEqual {
};
// The core logic of computing Unique for a flattend DenseTensor
template <typename Context,
typename InT,
typename IndexT,
typename equal_T,
typename not_equal_T>
static void UniqueFlattendCUDATensor(const Context& context,
const DenseTensor& in,
DenseTensor* out,
DenseTensor* indices,
DenseTensor* index,
DenseTensor* counts,
bool return_index,
bool return_inverse,
bool return_counts,
equal_T equal,
not_equal_T not_equal,
int64_t num_input) {
template <typename Context, typename InT, typename IndexT>
static typename std::enable_if<
!std::is_same<InT, phi::dtype::float16>::value &&
!std::is_same<InT, phi::dtype::bfloat16>::value>::type
UniqueFlattendCUDATensor(const Context& context,
const DenseTensor& in,
DenseTensor* out,
DenseTensor* indices,
DenseTensor* index,
DenseTensor* counts,
bool return_index,
bool return_inverse,
bool return_counts,
int64_t num_input) {
// 0. Prepration
auto equal = thrust::equal_to<InT>();
auto not_equal = thrust::not_equal_to<InT>();
DenseTensor in_hat;
phi::Copy(context, in, context.GetPlace(), false, &in_hat);
auto* in_data_hat = context.template Alloc<InT>(&in_hat);
......@@ -202,6 +201,97 @@ static void UniqueFlattendCUDATensor(const Context& context,
}
}
// The core logic of computing Unique for a flattend DenseTensor
template <typename Context, typename InT, typename IndexT>
static typename std::enable_if<
std::is_same<InT, phi::dtype::float16>::value ||
std::is_same<InT, phi::dtype::bfloat16>::value>::type
UniqueFlattendCUDATensor(const Context& context,
const DenseTensor& in,
DenseTensor* out,
DenseTensor* indices,
DenseTensor* index,
DenseTensor* counts,
bool return_index,
bool return_inverse,
bool return_counts,
int64_t num_input) {
// 1. Sort indices
DenseTensor in_resize;
in_resize.ShareDataWith(in);
in_resize.Resize(phi::make_ddim({num_input}));
const InT* in_data = in_resize.data<InT>();
auto equal = BinaryEqual<InT>(1, in_data);
auto not_equal = BinaryNotEqual<InT>(1, in_data);
indices->Resize(phi::make_ddim({num_input}));
auto* indices_data = context.template Alloc<IndexT>(indices);
thrust::sequence(thrust::device, indices_data, indices_data + num_input);
thrust::sort(thrust::device,
indices_data,
indices_data + num_input,
LessThan<InT>(1, in_data));
// 2. Calculate inverse indices: 'index'
if (return_inverse) {
index->Resize(phi::make_ddim({num_input}));
auto* inverse_data = context.template Alloc<IndexT>(index);
DenseTensor inv_loc;
inv_loc.Resize(phi::make_ddim({num_input}));
auto inv_loc_data_ptr = context.template Alloc<IndexT>(&inv_loc);
thrust::adjacent_difference(thrust::device,
indices_data,
indices_data + num_input,
inv_loc_data_ptr,
not_equal);
thrust::device_ptr<IndexT> inv_loc_data_dev(inv_loc_data_ptr);
inv_loc_data_dev[0] = 0; // without device_ptr, segmentation fault
thrust::inclusive_scan(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + num_input,
inv_loc_data_ptr);
thrust::scatter(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + num_input,
indices_data,
inverse_data);
}
// 3. Calculate op result and sorted index: 'out' & 'indices'
DenseTensor range;
range.Resize(phi::make_ddim({num_input + 1}));
auto* range_data_ptr = context.template Alloc<IndexT>(&range);
thrust::sequence(
thrust::device, range_data_ptr, range_data_ptr + num_input + 1);
int num_out;
num_out = thrust::unique_by_key(thrust::device,
indices_data,
indices_data + num_input,
range_data_ptr,
equal)
.first -
indices_data;
indices->Resize(phi::make_ddim({num_out}));
out->Resize(phi::make_ddim({num_out}));
context.template Alloc<InT>(out);
phi::IndexSelectKernel<InT, Context>(context, in_resize, *indices, 0, out);
// 4. Calculate 'counts'
if (return_counts) {
counts->Resize(phi::make_ddim({num_out}));
auto count_data = context.template Alloc<IndexT>(counts);
// init 'count_data' as 0
thrust::fill(thrust::device, count_data, count_data + num_out, 0);
thrust::device_ptr<IndexT> range_data_ptr_dev(range_data_ptr);
range_data_ptr_dev[num_out] = num_input;
thrust::adjacent_difference(thrust::device,
range_data_ptr + 1,
range_data_ptr + num_out + 1,
count_data);
}
}
// The logic of compute unique with axis required, it's a little different
// from above function
template <typename Context,
......@@ -409,8 +499,6 @@ struct UniqueFlattendCUDAFunctor {
return_index_,
return_inverse_,
return_counts_,
thrust::equal_to<InT>(),
thrust::not_equal_to<InT>(),
in_.numel());
}
};
......@@ -548,8 +636,16 @@ void UniqueKernel(const Context& context,
} // namespace phi
PD_REGISTER_KERNEL(
unique, GPU, ALL_LAYOUT, phi::UniqueKernel, float, double, int64_t, int) {
PD_REGISTER_KERNEL(unique,
GPU,
ALL_LAYOUT,
phi::UniqueKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int64_t,
int) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(3).SetDataType(phi::DataType::UNDEFINED);
......@@ -561,6 +657,8 @@ PD_REGISTER_KERNEL(unique_raw,
phi::UniqueRawKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int64_t,
int) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
......
......@@ -24,6 +24,7 @@ from paddle.fluid import core
class TestUniqueOp(OpTest):
def setUp(self):
self.op_type = "unique"
self.init_dtype()
self.init_config()
def test_check_output(self):
......@@ -31,13 +32,16 @@ class TestUniqueOp(OpTest):
check_dygraph=False
) # unique return sorted data in dygraph
def init_dtype(self):
self.dtype = np.int64
def init_config(self):
self.inputs = {
'X': np.array([2, 3, 3, 1, 5, 3], dtype='int64'),
'X': np.array([2, 3, 3, 1, 5, 3], dtype=self.dtype),
}
self.attrs = {'dtype': int(core.VarDesc.VarType.INT32)}
self.outputs = {
'Out': np.array([2, 3, 1, 5], dtype='int64'),
'Out': np.array([2, 3, 1, 5], dtype=self.dtype),
'Index': np.array([0, 1, 1, 2, 3, 1], dtype='int32'),
}
......@@ -45,25 +49,25 @@ class TestUniqueOp(OpTest):
class TestOne(TestUniqueOp):
def init_config(self):
self.inputs = {
'X': np.array([2], dtype='int64'),
'X': np.array([2], dtype=self.dtype),
}
self.attrs = {'dtype': int(core.VarDesc.VarType.INT32)}
self.outputs = {
'Out': np.array([2], dtype='int64'),
'Out': np.array([2], dtype=self.dtype),
'Index': np.array([0], dtype='int32'),
}
class TestRandom(TestUniqueOp):
def init_config(self):
self.inputs = {'X': np.random.randint(0, 100, (150,), dtype='int64')}
self.inputs = {'X': np.random.randint(0, 100, (150,), dtype=self.dtype)}
self.attrs = {'dtype': int(core.VarDesc.VarType.INT64)}
np_unique, np_index, reverse_index = np.unique(
self.inputs['X'], True, True
)
np_tuple = [(np_unique[i], np_index[i]) for i in range(len(np_unique))]
np_tuple.sort(key=lambda x: x[1])
target_out = np.array([i[0] for i in np_tuple], dtype='int64')
target_out = np.array([i[0] for i in np_tuple], dtype=self.dtype)
target_index = np.array(
[list(target_out).index(i) for i in self.inputs['X']], dtype='int64'
)
......@@ -95,11 +99,11 @@ class TestUniqueRaiseError(unittest.TestCase):
class TestOneGPU(TestUniqueOp):
def init_config(self):
self.inputs = {
'X': np.array([2], dtype='int64'),
'X': np.array([2], dtype=self.dtype),
}
self.attrs = {'dtype': int(core.VarDesc.VarType.INT32)}
self.outputs = {
'Out': np.array([2], dtype='int64'),
'Out': np.array([2], dtype=self.dtype),
'Index': np.array([0], dtype='int32'),
}
......@@ -116,14 +120,14 @@ class TestOneGPU(TestUniqueOp):
)
class TestRandomGPU(TestUniqueOp):
def init_config(self):
self.inputs = {'X': np.random.randint(0, 100, (150,), dtype='int64')}
self.inputs = {'X': np.random.randint(0, 100, (150,), dtype=self.dtype)}
self.attrs = {'dtype': int(core.VarDesc.VarType.INT64)}
np_unique, np_index, reverse_index = np.unique(
self.inputs['X'], True, True
)
np_tuple = [(np_unique[i], np_index[i]) for i in range(len(np_unique))]
np_tuple.sort(key=lambda x: x[1])
target_out = np.array([i[0] for i in np_tuple], dtype='int64')
target_out = np.array([i[0] for i in np_tuple], dtype=self.dtype)
target_index = np.array(
[list(target_out).index(i) for i in self.inputs['X']], dtype='int64'
)
......@@ -139,8 +143,11 @@ class TestRandomGPU(TestUniqueOp):
class TestSortedUniqueOp(TestUniqueOp):
def init_dtype(self):
self.dtype = np.float64
def init_config(self):
self.inputs = {'X': np.array([2, 3, 3, 1, 5, 3], dtype='int64')}
self.inputs = {'X': np.array([2, 3, 3, 1, 5, 3], dtype=self.dtype)}
unique, indices, inverse, count = np.unique(
self.inputs['X'],
return_index=True,
......@@ -164,9 +171,35 @@ class TestSortedUniqueOp(TestUniqueOp):
}
class TestSortedUniqueFP16Op(TestSortedUniqueOp):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestSortedUniqueBF16Op(TestSortedUniqueOp):
def init_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(
place, check_dygraph=False
) # unique return sorted data in dygraph
class TestUniqueOpAxisNone(TestUniqueOp):
def init_dtype(self):
self.dtype = np.float64
def init_config(self):
self.inputs = {'X': np.random.random((4, 7, 10)).astype('float64')}
self.inputs = {
'X': np.random.randint(0, 100, (4, 7, 10)).astype(self.dtype)
}
unique, indices, inverse, counts = np.unique(
self.inputs['X'],
return_index=True,
......@@ -190,9 +223,35 @@ class TestUniqueOpAxisNone(TestUniqueOp):
}
class TestUniqueOpAxisNoneFP16Op(TestUniqueOpAxisNone):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestUniqueOpAxisNoneBF16Op(TestUniqueOpAxisNone):
def init_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(
place, check_dygraph=False
) # unique return sorted data in dygraph
class TestUniqueOpAxisNeg(TestUniqueOp):
def init_dtype(self):
self.dtype = np.float64
def init_config(self):
self.inputs = {'X': np.random.random((6, 1, 8)).astype('float64')}
self.inputs = {
'X': np.random.randint(0, 100, (6, 1, 8)).astype(self.dtype)
}
unique, indices, inverse, counts = np.unique(
self.inputs['X'],
return_index=True,
......@@ -216,9 +275,35 @@ class TestUniqueOpAxisNeg(TestUniqueOp):
}
class TestUniqueOpAxisNegFP16Op(TestUniqueOpAxisNeg):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestUniqueOpAxisNegBF16Op(TestUniqueOpAxisNeg):
def init_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(
place, check_dygraph=False
) # unique return sorted data in dygraph
class TestUniqueOpAxis1(TestUniqueOp):
def init_dtype(self):
self.dtype = np.float64
def init_config(self):
self.inputs = {'X': np.random.random((3, 8, 8)).astype('float64')}
self.inputs = {
'X': np.random.randint(0, 100, (3, 8, 8)).astype(self.dtype)
}
unique, indices, inverse, counts = np.unique(
self.inputs['X'],
return_index=True,
......@@ -242,6 +327,27 @@ class TestUniqueOpAxis1(TestUniqueOp):
}
class TestUniqueOpAxis1FP16Op(TestUniqueOpAxis1):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestUniqueOpAxis1BF16Op(TestUniqueOpAxis1):
def init_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(
place, check_dygraph=False
) # unique return sorted data in dygraph
class TestUniqueAPI(unittest.TestCase):
def test_dygraph_api_out(self):
paddle.disable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册