From ccab0e2a1b1a0b42d9b5314272dd0f3d4c09e485 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Thu, 29 Sep 2022 10:43:11 +0800 Subject: [PATCH] fix uniform_rand_kernel FP16 support in dygraph mode (#46212) --- paddle/phi/kernels/gpu/uniform_random_kernel.cu | 11 ++++++++--- paddle/phi/kernels/uniform_random_kernel.cc | 9 +++++++-- .../fluid/tests/unittests/test_uniform_random_op.py | 9 +++++++++ 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/gpu/uniform_random_kernel.cu b/paddle/phi/kernels/gpu/uniform_random_kernel.cu index 23232970e1..458239814b 100644 --- a/paddle/phi/kernels/gpu/uniform_random_kernel.cu +++ b/paddle/phi/kernels/gpu/uniform_random_kernel.cu @@ -74,8 +74,12 @@ void UniformRandomRawKernel(const Context& dev_ctx, funcs::distribution_and_transform(dev_ctx, out, dist, trans); } else { // Use OP seed - auto func = UniformGenerator( - min.to(), max.to(), seed, diag_num, diag_step, diag_val); + auto func = UniformGenerator(static_cast(min.to()), + static_cast(max.to()), + seed, + diag_num, + diag_step, + static_cast(diag_val)); IndexKernel>(dev_ctx, out, func); } } @@ -87,4 +91,5 @@ PD_REGISTER_KERNEL(uniform_random_raw, ALL_LAYOUT, phi::UniformRandomRawKernel, float, - double) {} + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/uniform_random_kernel.cc b/paddle/phi/kernels/uniform_random_kernel.cc index 11f61e5b4a..6669438cc3 100644 --- a/paddle/phi/kernels/uniform_random_kernel.cc +++ b/paddle/phi/kernels/uniform_random_kernel.cc @@ -51,8 +51,13 @@ PD_REGISTER_KERNEL(uniform_random, phi::dtype::bfloat16) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL( - uniform_random, GPU, ALL_LAYOUT, phi::UniformRandomKernel, float, double) {} +PD_REGISTER_KERNEL(uniform_random, + GPU, + ALL_LAYOUT, + phi::UniformRandomKernel, + float, + double, + phi::dtype::float16) {} #endif #ifdef PADDLE_WITH_XPU diff --git a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py index 50ca5d7477..4ffcd19044 100644 --- a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py @@ -585,8 +585,17 @@ class TestUniformDtype(unittest.TestCase): out = paddle.tensor.random.uniform([2, 3]) self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP64) + def test_dygraph_fp16(): + if not paddle.is_compiled_with_cuda(): + paddle.enable_static() + return + paddle.set_device('gpu') + out = paddle.uniform([2, 3], dtype=paddle.float16) + self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP16) + test_default_fp64() test_default_fp32() + test_dygraph_fp16() paddle.enable_static() -- GitLab