未验证 提交 7581ef9e 编写于 作者: H houj04 提交者: GitHub

[XPU] add fp16 support for compare ops. (#51846)

* [XPU] add fp16 support for compare ops.

* fix ci.
上级 21a72c55
...@@ -241,6 +241,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -241,6 +241,7 @@ XPUOpMap& get_kl2_ops() {
{"equal", {"equal",
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"exp_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"exp_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"exp", XPUKernelSet({phi::DataType::FLOAT32})}, {"exp", XPUKernelSet({phi::DataType::FLOAT32})},
...@@ -371,10 +372,12 @@ XPUOpMap& get_kl2_ops() { ...@@ -371,10 +372,12 @@ XPUOpMap& get_kl2_ops() {
{"greater_equal", {"greater_equal",
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"greater_than", {"greater_than",
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"grid_sampler_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"grid_sampler_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"grid_sampler", XPUKernelSet({phi::DataType::FLOAT32})}, {"grid_sampler", XPUKernelSet({phi::DataType::FLOAT32})},
...@@ -419,10 +422,12 @@ XPUOpMap& get_kl2_ops() { ...@@ -419,10 +422,12 @@ XPUOpMap& get_kl2_ops() {
{"less_equal", {"less_equal",
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"less_than", {"less_than",
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"load", XPUKernelSet({phi::DataType::FLOAT32})}, {"load", XPUKernelSet({phi::DataType::FLOAT32})},
{"load_combine", {"load_combine",
...@@ -489,6 +494,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -489,6 +494,7 @@ XPUOpMap& get_kl2_ops() {
{"not_equal", {"not_equal",
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"one_hot", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})}, {"one_hot", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"one_hot_v2", {"one_hot_v2",
......
...@@ -89,8 +89,14 @@ DEFINE_XPU_COMPARE_KERNEL(GreaterEqual, xpu::broadcast_greater_equal<XPUType>) ...@@ -89,8 +89,14 @@ DEFINE_XPU_COMPARE_KERNEL(GreaterEqual, xpu::broadcast_greater_equal<XPUType>)
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(less_than,
less_than, XPU, ALL_LAYOUT, phi::LessThanKernel, int, int64_t, float) { XPU,
ALL_LAYOUT,
phi::LessThanKernel,
int,
int64_t,
float,
phi::dtype::float16) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
...@@ -100,23 +106,31 @@ PD_REGISTER_KERNEL(less_than_raw, ...@@ -100,23 +106,31 @@ PD_REGISTER_KERNEL(less_than_raw,
phi::LessThanRawKernel, phi::LessThanRawKernel,
int, int,
int64_t, int64_t,
float) { float,
phi::dtype::float16) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
#define PD_REGISTER_COMPARE_KERNEL(name, func) \ #define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL( \ PD_REGISTER_KERNEL(name, \
name, XPU, ALL_LAYOUT, phi::func##Kernel, int, int64_t, float) { \ XPU, \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ ALL_LAYOUT, \
} \ phi::func##Kernel, \
PD_REGISTER_KERNEL(name##_raw, \ int, \
XPU, \ int64_t, \
ALL_LAYOUT, \ float, \
phi::func##RawKernel, \ phi::dtype::float16) { \
int, \ kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
int64_t, \ } \
float) { \ PD_REGISTER_KERNEL(name##_raw, \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ XPU, \
ALL_LAYOUT, \
phi::func##RawKernel, \
int, \
int64_t, \
float, \
phi::dtype::float16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
} }
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual) PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
......
...@@ -66,46 +66,17 @@ void CumsumKernel(const Context& dev_ctx, ...@@ -66,46 +66,17 @@ void CumsumKernel(const Context& dev_ctx,
} }
} }
// special for fp16 // template<typename T> DLL_EXPORT int cumsum(Context* ctx, const T* x, T*
if (std::is_same<T, dtype::float16>::value) { // y, const std::vector<int>& xshape, bool reverse, bool exclusive, int
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); // axis);
float* cast_input_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(x.numel()); int r = xpu::cumsum<XPUType>(dev_ctx.x_context(),
float* temp_result_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(x.numel()); reinterpret_cast<const XPUType*>(x.data<T>()),
// cast to fp32 reinterpret_cast<XPUType*>(out->data<T>()),
int r = x_shape,
xpu::cast<XPUType, float>(dev_ctx.x_context(), reverse,
reinterpret_cast<const XPUType*>(x.data<T>()), exclusive,
cast_input_fp32, axis_as_int);
x.numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "cumsum");
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
// cumsum in fp32
r = xpu::cumsum<float>(dev_ctx.x_context(),
cast_input_fp32,
temp_result_fp32,
x_shape,
reverse,
exclusive,
axis_as_int);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cumsum");
// cast back to fp16
r = xpu::cast<float, XPUType>(dev_ctx.x_context(),
temp_result_fp32,
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
} else {
// template<typename T> DLL_EXPORT int cumsum(Context* ctx, const T* x, T*
// y, const std::vector<int>& xshape, bool reverse, bool exclusive, int
// axis);
int r = xpu::cumsum<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x_shape,
reverse,
exclusive,
axis_as_int);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cumsum");
}
} }
} // namespace phi } // namespace phi
......
...@@ -265,9 +265,8 @@ class TestGaussianRandomAPI(unittest.TestCase): ...@@ -265,9 +265,8 @@ class TestGaussianRandomAPI(unittest.TestCase):
def test_default_fp16(): def test_default_fp16():
paddle.framework.set_default_dtype('float16') paddle.framework.set_default_dtype('float16')
paddle.tensor.random.gaussian([2, 3]) out = paddle.tensor.random.gaussian([2, 3])
self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP16)
self.assertRaises(TypeError, test_default_fp16)
def test_default_fp32(): def test_default_fp32():
paddle.framework.set_default_dtype('float32') paddle.framework.set_default_dtype('float32')
...@@ -281,6 +280,7 @@ class TestGaussianRandomAPI(unittest.TestCase): ...@@ -281,6 +280,7 @@ class TestGaussianRandomAPI(unittest.TestCase):
test_default_fp64() test_default_fp64()
test_default_fp32() test_default_fp32()
test_default_fp16()
paddle.enable_static() paddle.enable_static()
...@@ -291,9 +291,8 @@ class TestStandardNormalDtype(unittest.TestCase): ...@@ -291,9 +291,8 @@ class TestStandardNormalDtype(unittest.TestCase):
def test_default_fp16(): def test_default_fp16():
paddle.framework.set_default_dtype('float16') paddle.framework.set_default_dtype('float16')
paddle.tensor.random.standard_normal([2, 3]) out = paddle.tensor.random.standard_normal([2, 3])
self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP16)
self.assertRaises(TypeError, test_default_fp16)
def test_default_fp32(): def test_default_fp32():
paddle.framework.set_default_dtype('float32') paddle.framework.set_default_dtype('float32')
...@@ -307,6 +306,7 @@ class TestStandardNormalDtype(unittest.TestCase): ...@@ -307,6 +306,7 @@ class TestStandardNormalDtype(unittest.TestCase):
test_default_fp64() test_default_fp64()
test_default_fp32() test_default_fp32()
test_default_fp16()
paddle.enable_static() paddle.enable_static()
......
...@@ -1259,13 +1259,13 @@ class XPUTestSetValueOp(XPUOpTestWrapper): ...@@ -1259,13 +1259,13 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
# test stop_gradient # test stop_gradient
value.stop_gradient = True value.stop_gradient = True
x.stop_gradient = False x.stop_gradient = False
start = paddle.tensor.layers.fill_constant( start = paddle.tensor.fill_constant(
[1], "int32", 5, force_cpu=True [1], "int32", 5, force_cpu=True
) )
end = paddle.tensor.layers.fill_constant( end = paddle.tensor.fill_constant(
[1], "int32", 0, force_cpu=True [1], "int32", 0, force_cpu=True
) )
step = paddle.tensor.layers.fill_constant( step = paddle.tensor.fill_constant(
[1], "int32", -2, force_cpu=True [1], "int32", -2, force_cpu=True
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册