未验证 提交 0bed2203 编写于 作者: C co63oc 提交者: GitHub

Add segment_pool tests (#53785)

上级 117e951b
...@@ -445,6 +445,57 @@ CUDA_ATOMIC_WRAPPER(Max, phi::dtype::float16) { ...@@ -445,6 +445,57 @@ CUDA_ATOMIC_WRAPPER(Max, phi::dtype::float16) {
} }
#endif #endif
inline static __device__ uint32_t bf16_max_to_low_half(uint32_t val, float x) {
phi::dtype::bfloat16 low_half;
// The bfloat16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half =
static_cast<phi::dtype::bfloat16>(max(static_cast<float>(low_half), x));
return (val & 0xFFFF0000u) | low_half.x;
}
inline static __device__ uint32_t bf16_max_to_high_half(uint32_t val, float x) {
phi::dtype::bfloat16 high_half;
// The bfloat16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half =
static_cast<phi::dtype::bfloat16>(max(static_cast<float>(high_half), x));
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}
CUDA_ATOMIC_WRAPPER(Max, phi::dtype::bfloat16) {
if (*address >= val) {
return *address;
}
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// The bfloat16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(
address_as_ui, assumed, bf16_max_to_low_half(assumed, val_f));
} while (old != assumed);
phi::dtype::bfloat16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// The bfloat16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(
address_as_ui, assumed, bf16_max_to_high_half(assumed, val_f));
} while (old != assumed);
phi::dtype::bfloat16 ret;
ret.x = old >> 16;
return ret;
}
}
// For atomicMin // For atomicMin
USE_CUDA_ATOMIC(Min, int); USE_CUDA_ATOMIC(Min, int);
USE_CUDA_ATOMIC(Min, unsigned int); USE_CUDA_ATOMIC(Min, unsigned int);
...@@ -580,6 +631,57 @@ CUDA_ATOMIC_WRAPPER(Min, phi::dtype::float16) { ...@@ -580,6 +631,57 @@ CUDA_ATOMIC_WRAPPER(Min, phi::dtype::float16) {
} }
#endif #endif
inline static __device__ uint32_t bf16_min_to_low_half(uint32_t val, float x) {
phi::dtype::bfloat16 low_half;
// The bfloat16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half =
static_cast<phi::dtype::bfloat16>(min(static_cast<float>(low_half), x));
return (val & 0xFFFF0000u) | low_half.x;
}
inline static __device__ uint32_t bf16_min_to_high_half(uint32_t val, float x) {
phi::dtype::bfloat16 high_half;
// The bfloat16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half =
static_cast<phi::dtype::bfloat16>(min(static_cast<float>(high_half), x));
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}
CUDA_ATOMIC_WRAPPER(Min, phi::dtype::bfloat16) {
if (*address <= val) {
return *address;
}
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// The bfloat16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(
address_as_ui, assumed, bf16_min_to_low_half(assumed, val_f));
} while (old != assumed);
phi::dtype::bfloat16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// The bfloat16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(
address_as_ui, assumed, bf16_min_to_high_half(assumed, val_f));
} while (old != assumed);
phi::dtype::bfloat16 ret;
ret.x = old >> 16;
return ret;
}
}
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
/* /*
* One thead block deals with elementwise atomicAdd for vector of len. * One thead block deals with elementwise atomicAdd for vector of len.
......
...@@ -451,6 +451,8 @@ template class SegmentPoolFunctor<GPU, int64_t, int>; ...@@ -451,6 +451,8 @@ template class SegmentPoolFunctor<GPU, int64_t, int>;
template class SegmentPoolFunctor<GPU, int64_t, int64_t>; template class SegmentPoolFunctor<GPU, int64_t, int64_t>;
template class SegmentPoolFunctor<GPU, float16, int>; template class SegmentPoolFunctor<GPU, float16, int>;
template class SegmentPoolFunctor<GPU, float16, int64_t>; template class SegmentPoolFunctor<GPU, float16, int64_t>;
template class SegmentPoolFunctor<GPU, phi::dtype::bfloat16, int>;
template class SegmentPoolFunctor<GPU, phi::dtype::bfloat16, int64_t>;
template class SegmentPoolGradFunctor<GPU, float, int>; template class SegmentPoolGradFunctor<GPU, float, int>;
template class SegmentPoolGradFunctor<GPU, float, int64_t>; template class SegmentPoolGradFunctor<GPU, float, int64_t>;
...@@ -462,6 +464,8 @@ template class SegmentPoolGradFunctor<GPU, int64_t, int>; ...@@ -462,6 +464,8 @@ template class SegmentPoolGradFunctor<GPU, int64_t, int>;
template class SegmentPoolGradFunctor<GPU, int64_t, int64_t>; template class SegmentPoolGradFunctor<GPU, int64_t, int64_t>;
template class SegmentPoolGradFunctor<GPU, float16, int>; template class SegmentPoolGradFunctor<GPU, float16, int>;
template class SegmentPoolGradFunctor<GPU, float16, int64_t>; template class SegmentPoolGradFunctor<GPU, float16, int64_t>;
template class SegmentPoolGradFunctor<GPU, phi::dtype::bfloat16, int>;
template class SegmentPoolGradFunctor<GPU, phi::dtype::bfloat16, int64_t>;
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(segment_pool_grad, ...@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(segment_pool_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(segment_pool, ...@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(segment_pool,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
...@@ -84,7 +84,10 @@ def segment_pool_split(X, SegmentIds, pooltype): ...@@ -84,7 +84,10 @@ def segment_pool_split(X, SegmentIds, pooltype):
class TestSegmentOps(OpTest): class TestSegmentOps(OpTest):
def set_data(self): def set_data(self):
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) if self.dtype == np.uint16:
x = np.random.uniform(-1, 1, self.shape).astype(self.np_dtype)
else:
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
segment_ids = self.set_segment(len(x), len(x) // 5 + 1) segment_ids = self.set_segment(len(x), len(x) // 5 + 1)
return x, segment_ids return x, segment_ids
...@@ -110,10 +113,14 @@ class TestSegmentOps(OpTest): ...@@ -110,10 +113,14 @@ class TestSegmentOps(OpTest):
x, segment_ids = self.set_data() x, segment_ids = self.set_data()
result = self.compute(x, segment_ids) result = self.compute(x, segment_ids)
self.inputs = { self.inputs = {
'X': x.astype(self.dtype), 'X': x,
'SegmentIds': segment_ids.astype(np.int64), 'SegmentIds': segment_ids.astype(np.int64),
} }
self.outputs = {'Out': result.astype(self.dtype)} if self.dtype == np.uint16:
self.outputs = {'Out': result.astype(self.np_dtype)}
else:
self.outputs = {'Out': result.astype(self.dtype)}
self.convert_bf16()
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -121,6 +128,12 @@ class TestSegmentOps(OpTest): ...@@ -121,6 +128,12 @@ class TestSegmentOps(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
def convert_bf16(self):
if self.dtype == np.uint16:
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
class TestSegmentSum2(TestSegmentOps): class TestSegmentSum2(TestSegmentOps):
def prepare(self): def prepare(self):
...@@ -141,23 +154,16 @@ class TestSegmentSum2(TestSegmentOps): ...@@ -141,23 +154,16 @@ class TestSegmentSum2(TestSegmentOps):
class TestSegmentMax(TestSegmentOps): class TestSegmentMax(TestSegmentOps):
def compute(self, x, segment_ids): def compute(self, x, segment_ids):
return compute_segment_min_max(x, segment_ids, pooltype="MAX") result, self.gradient = compute_segment_min_max(
x, segment_ids, pooltype="MAX"
)
return result
def prepare(self): def prepare(self):
super().prepare() super().prepare()
self.shape = [40, 20] self.shape = [40, 20]
self.attrs = {'pooltype': "MAX"} self.attrs = {'pooltype': "MAX"}
def setUp(self):
self.prepare()
x, segment_ids = self.set_data()
result, self.gradient = self.compute(x, segment_ids)
self.inputs = {
'X': x.astype(self.dtype),
'SegmentIds': segment_ids.astype(np.int32),
}
self.outputs = {'Out': result.astype(self.dtype)}
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X"], "Out", user_defined_grads=[self.gradient]) self.check_grad(["X"], "Out", user_defined_grads=[self.gradient])
...@@ -170,7 +176,10 @@ class TestSegmentMax2(TestSegmentMax): ...@@ -170,7 +176,10 @@ class TestSegmentMax2(TestSegmentMax):
class TestSegmentMin(TestSegmentMax): class TestSegmentMin(TestSegmentMax):
def compute(self, x, segment_ids): def compute(self, x, segment_ids):
return compute_segment_min_max(x, segment_ids, pooltype="MIN") result, self.gradient = compute_segment_min_max(
x, segment_ids, pooltype="MIN"
)
return result
def prepare(self): def prepare(self):
super().prepare() super().prepare()
...@@ -197,12 +206,17 @@ class TestSegmentMean(TestSegmentOps): ...@@ -197,12 +206,17 @@ class TestSegmentMean(TestSegmentOps):
x, segment_ids = self.set_data() x, segment_ids = self.set_data()
result = self.compute(x, segment_ids) result = self.compute(x, segment_ids)
self.inputs = {'X': x, 'SegmentIds': segment_ids} self.inputs = {'X': x, 'SegmentIds': segment_ids}
if self.dtype == np.uint16:
astype = self.np_dtype
else:
astype = self.dtype
self.outputs = { self.outputs = {
'Out': result, 'Out': result,
'SummedIds': compute_segment_sum( 'SummedIds': compute_segment_sum(
np.ones([len(x), 1]).astype(self.dtype), segment_ids np.ones([len(x), 1]).astype(astype), segment_ids
), ),
} }
self.convert_bf16()
class TestSegmentMean2(TestSegmentMean): class TestSegmentMean2(TestSegmentMean):
...@@ -213,6 +227,106 @@ class TestSegmentMean2(TestSegmentMean): ...@@ -213,6 +227,106 @@ class TestSegmentMean2(TestSegmentMean):
self.attrs = {'pooltype': "MEAN"} self.attrs = {'pooltype': "MEAN"}
class TestSegmentSumFP16Op(TestSegmentOps):
def prepare(self):
super().prepare()
self.dtype = np.float16
class TestSegmentMaxFP16Op(TestSegmentMax):
def prepare(self):
super().prepare()
self.dtype = np.float16
class TestSegmentMinFP16Op(TestSegmentMin):
def prepare(self):
super().prepare()
self.dtype = np.float16
class TestSegmentMeanFP16Op(TestSegmentMean):
def prepare(self):
super().prepare()
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestSegmentSumBF16Op(TestSegmentOps):
def prepare(self):
super().prepare()
self.dtype = np.uint16
self.np_dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ["X"], "Out")
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestSegmentMaxBF16Op(TestSegmentMax):
def prepare(self):
super().prepare()
self.dtype = np.uint16
self.np_dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(
self.place, ["X"], "Out", user_defined_grads=[self.gradient]
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestSegmentMinBF16Op(TestSegmentMin):
def prepare(self):
super().prepare()
self.dtype = np.uint16
self.np_dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(
self.place, ["X"], "Out", user_defined_grads=[self.gradient]
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestSegmentMeanBF16Op(TestSegmentMean):
def prepare(self):
super().prepare()
self.dtype = np.uint16
self.np_dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ["X"], "Out")
class API_SegmentOpsTest(unittest.TestCase): class API_SegmentOpsTest(unittest.TestCase):
def test_static(self): def test_static(self):
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
......
...@@ -56,7 +56,7 @@ def segment_sum(data, segment_ids, name=None): ...@@ -56,7 +56,7 @@ def segment_sum(data, segment_ids, name=None):
check_variable_and_dtype( check_variable_and_dtype(
data, data,
"X", "X",
("float32", "float64", "int32", "int64", "float16"), ("float32", "float64", "int32", "int64", "float16", "uint16"),
"segment_pool", "segment_pool",
) )
check_variable_and_dtype( check_variable_and_dtype(
...@@ -114,7 +114,7 @@ def segment_mean(data, segment_ids, name=None): ...@@ -114,7 +114,7 @@ def segment_mean(data, segment_ids, name=None):
check_variable_and_dtype( check_variable_and_dtype(
data, data,
"X", "X",
("float32", "float64", "int32", "int64", "float16"), ("float32", "float64", "int32", "int64", "float16", "uint16"),
"segment_pool", "segment_pool",
) )
check_variable_and_dtype( check_variable_and_dtype(
...@@ -170,7 +170,7 @@ def segment_min(data, segment_ids, name=None): ...@@ -170,7 +170,7 @@ def segment_min(data, segment_ids, name=None):
check_variable_and_dtype( check_variable_and_dtype(
data, data,
"X", "X",
("float32", "float64", "int32", "int64", "float16"), ("float32", "float64", "int32", "int64", "float16", "uint16"),
"segment_pool", "segment_pool",
) )
check_variable_and_dtype( check_variable_and_dtype(
...@@ -226,7 +226,7 @@ def segment_max(data, segment_ids, name=None): ...@@ -226,7 +226,7 @@ def segment_max(data, segment_ids, name=None):
check_variable_and_dtype( check_variable_and_dtype(
data, data,
"X", "X",
("float32", "float64", "int32", "int64", "float16"), ("float32", "float64", "int32", "int64", "float16", "uint16"),
"segment_pool", "segment_pool",
) )
check_variable_and_dtype( check_variable_and_dtype(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册