未验证 提交 7581ef9e 编写于 作者: H houj04 提交者: GitHub

[XPU] add fp16 support for compare ops. (#51846)

* [XPU] add fp16 support for compare ops.

* fix ci.
上级 21a72c55
......@@ -241,6 +241,7 @@ XPUOpMap& get_kl2_ops() {
{"equal",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"exp_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"exp", XPUKernelSet({phi::DataType::FLOAT32})},
......@@ -371,10 +372,12 @@ XPUOpMap& get_kl2_ops() {
{"greater_equal",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"greater_than",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"grid_sampler_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"grid_sampler", XPUKernelSet({phi::DataType::FLOAT32})},
......@@ -419,10 +422,12 @@ XPUOpMap& get_kl2_ops() {
{"less_equal",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"less_than",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"load", XPUKernelSet({phi::DataType::FLOAT32})},
{"load_combine",
......@@ -489,6 +494,7 @@ XPUOpMap& get_kl2_ops() {
{"not_equal",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"one_hot", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"one_hot_v2",
......
......@@ -89,8 +89,14 @@ DEFINE_XPU_COMPARE_KERNEL(GreaterEqual, xpu::broadcast_greater_equal<XPUType>)
} // namespace phi
PD_REGISTER_KERNEL(
less_than, XPU, ALL_LAYOUT, phi::LessThanKernel, int, int64_t, float) {
PD_REGISTER_KERNEL(less_than,
XPU,
ALL_LAYOUT,
phi::LessThanKernel,
int,
int64_t,
float,
phi::dtype::float16) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
......@@ -100,13 +106,20 @@ PD_REGISTER_KERNEL(less_than_raw,
phi::LessThanRawKernel,
int,
int64_t,
float) {
float,
phi::dtype::float16) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL( \
name, XPU, ALL_LAYOUT, phi::func##Kernel, int, int64_t, float) { \
PD_REGISTER_KERNEL(name, \
XPU, \
ALL_LAYOUT, \
phi::func##Kernel, \
int, \
int64_t, \
float, \
phi::dtype::float16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
} \
PD_REGISTER_KERNEL(name##_raw, \
......@@ -115,7 +128,8 @@ PD_REGISTER_KERNEL(less_than_raw,
phi::func##RawKernel, \
int, \
int64_t, \
float) { \
float, \
phi::dtype::float16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
......
......@@ -66,34 +66,6 @@ void CumsumKernel(const Context& dev_ctx,
}
}
// special for fp16
if (std::is_same<T, dtype::float16>::value) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
float* cast_input_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(x.numel());
float* temp_result_fp32 = RAII_GUARD.alloc_l3_or_gm<float>(x.numel());
// cast to fp32
int r =
xpu::cast<XPUType, float>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
cast_input_fp32,
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
// cumsum in fp32
r = xpu::cumsum<float>(dev_ctx.x_context(),
cast_input_fp32,
temp_result_fp32,
x_shape,
reverse,
exclusive,
axis_as_int);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cumsum");
// cast back to fp16
r = xpu::cast<float, XPUType>(dev_ctx.x_context(),
temp_result_fp32,
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
} else {
// template<typename T> DLL_EXPORT int cumsum(Context* ctx, const T* x, T*
// y, const std::vector<int>& xshape, bool reverse, bool exclusive, int
// axis);
......@@ -105,7 +77,6 @@ void CumsumKernel(const Context& dev_ctx,
exclusive,
axis_as_int);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cumsum");
}
}
} // namespace phi
......
......@@ -265,9 +265,8 @@ class TestGaussianRandomAPI(unittest.TestCase):
def test_default_fp16():
paddle.framework.set_default_dtype('float16')
paddle.tensor.random.gaussian([2, 3])
self.assertRaises(TypeError, test_default_fp16)
out = paddle.tensor.random.gaussian([2, 3])
self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP16)
def test_default_fp32():
paddle.framework.set_default_dtype('float32')
......@@ -281,6 +280,7 @@ class TestGaussianRandomAPI(unittest.TestCase):
test_default_fp64()
test_default_fp32()
test_default_fp16()
paddle.enable_static()
......@@ -291,9 +291,8 @@ class TestStandardNormalDtype(unittest.TestCase):
def test_default_fp16():
paddle.framework.set_default_dtype('float16')
paddle.tensor.random.standard_normal([2, 3])
self.assertRaises(TypeError, test_default_fp16)
out = paddle.tensor.random.standard_normal([2, 3])
self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP16)
def test_default_fp32():
paddle.framework.set_default_dtype('float32')
......@@ -307,6 +306,7 @@ class TestStandardNormalDtype(unittest.TestCase):
test_default_fp64()
test_default_fp32()
test_default_fp16()
paddle.enable_static()
......
......@@ -1259,13 +1259,13 @@ class XPUTestSetValueOp(XPUOpTestWrapper):
# test stop_gradient
value.stop_gradient = True
x.stop_gradient = False
start = paddle.tensor.layers.fill_constant(
start = paddle.tensor.fill_constant(
[1], "int32", 5, force_cpu=True
)
end = paddle.tensor.layers.fill_constant(
end = paddle.tensor.fill_constant(
[1], "int32", 0, force_cpu=True
)
step = paddle.tensor.layers.fill_constant(
step = paddle.tensor.fill_constant(
[1], "int32", -2, force_cpu=True
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册