From 608a3f2842fd79ae455ef94a2c98b8f44c797b9d Mon Sep 17 00:00:00 2001 From: LoneRanger <836253168@qq.com> Date: Mon, 31 Jul 2023 10:38:13 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PaddlePaddle=20Hackathon=204=E3=80=91N?= =?UTF-8?q?o.56=20:=20add=20fp16=20test=20and=20bf16=20for=20poisson=20(#5?= =?UTF-8?q?1662)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add fp16 and bf16 support for poisson * add fp16 and bf16 support for searchsorted * fix bug * Update test_searchsorted_op.py fix function name * Update test_poisson_op.py fix function name * fix bug * remove the searchorted * Update test_poisson_op.py * fix bug of TestPoissonBF16Op * Update test_poisson_op.py * Update test_poisson_op.py * Update test_poisson_op.py * fix bug of import * fix bug --- paddle/phi/kernels/gpu/poisson_grad_kernel.cu | 10 ++- paddle/phi/kernels/gpu/poisson_kernel.cu | 10 ++- test/legacy_test/test_poisson_op.py | 63 ++++++++++++++++++- 3 files changed, 77 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/gpu/poisson_grad_kernel.cu b/paddle/phi/kernels/gpu/poisson_grad_kernel.cu index 8c16bc51fff..be7d28a6630 100644 --- a/paddle/phi/kernels/gpu/poisson_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/poisson_grad_kernel.cu @@ -15,5 +15,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/poisson_grad_kernel_impl.h" -PD_REGISTER_KERNEL( - poisson_grad, GPU, ALL_LAYOUT, phi::PoissonGradKernel, float, double) {} +PD_REGISTER_KERNEL(poisson_grad, + GPU, + ALL_LAYOUT, + phi::PoissonGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/poisson_kernel.cu b/paddle/phi/kernels/gpu/poisson_kernel.cu index 302a9fe5ce5..1d1968b30ae 100644 --- a/paddle/phi/kernels/gpu/poisson_kernel.cu +++ b/paddle/phi/kernels/gpu/poisson_kernel.cu @@ -64,5 +64,11 @@ void PoissonKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { } // namespace phi -PD_REGISTER_KERNEL( - poisson, GPU, ALL_LAYOUT, phi::PoissonKernel, float, double) {} +PD_REGISTER_KERNEL(poisson, + GPU, + ALL_LAYOUT, + phi::PoissonKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/test/legacy_test/test_poisson_op.py b/test/legacy_test/test_poisson_op.py index ee66d578014..84edf6a3221 100644 --- a/test/legacy_test/test_poisson_op.py +++ b/test/legacy_test/test_poisson_op.py @@ -16,9 +16,14 @@ import math import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import ( + OpTest, + convert_float_to_uint16, + convert_uint16_to_float, +) import paddle +from paddle.fluid import core paddle.enable_static() paddle.seed(100) @@ -42,17 +47,20 @@ class TestPoissonOp1(OpTest): def setUp(self): self.op_type = "poisson" self.python_api = paddle.tensor.poisson + self.init_dtype() self.config() self.attrs = {} self.inputs = {'X': np.full([2048, 1024], self.lam, dtype=self.dtype)} self.outputs = {'Out': np.ones([2048, 1024], dtype=self.dtype)} + def init_dtype(self): + self.dtype = "float64" + def config(self): self.lam = 10 self.a = 5 self.b = 15 - self.dtype = "float64" def verify_output(self, outs): hist, prob = output_hist(np.array(outs[0]), self.lam, self.a, self.b) @@ -368,5 +376,56 @@ class TestPoissonAPI(unittest.TestCase): paddle.enable_static() +class TestPoissonFP16OP(TestPoissonOp1): + def init_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestPoissonBF16Op(OpTest): + def setUp(self): + self.op_type = "poisson" + self.python_api = paddle.tensor.poisson + self.__class__.op_type = self.op_type + self.config() + x = np.full([2048, 1024], self.lam, dtype="float32") + out = np.ones([2048, 1024], dtype="float32") + self.attrs = {} + self.inputs = {'X': convert_float_to_uint16(x)} + self.outputs = {'Out': convert_float_to_uint16(out)} + + def config(self): + self.lam = 10 + self.a = 5 + self.b = 15 + self.dtype = np.uint16 + + def verify_output(self, outs): + hist, prob = output_hist( + convert_uint16_to_float(np.array(outs[0])), self.lam, self.a, self.b + ) + np.testing.assert_allclose(hist, prob, rtol=0.01) + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place_customized(self.verify_output, place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + ['X'], + 'Out', + user_defined_grads=[np.zeros([2048, 1024], dtype="float32")], + user_defined_grad_outputs=[ + np.random.rand(2048, 1024).astype("float32") + ], + ) + + if __name__ == "__main__": unittest.main() -- GitLab