From 2b0fffc2ae0a8b9ee808b5f69ea27d45ac6df17d Mon Sep 17 00:00:00 2001 From: Difer <707065510@qq.com> Date: Mon, 10 Apr 2023 17:26:19 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=20No57=E3=80=91=20add=20fp16?= =?UTF-8?q?=20&=20bf16=20for=20flip,=20fp16=20for=20gaussian=20(#52380)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add_fp_bf_for_flip_gaussian_random * forget convert uint * fix some error * fix some error --- paddle/phi/kernels/gpu/flip_kernel.cu | 1 + .../paddle/fluid/tests/unittests/test_flip.py | 95 ++++++++++++++++++- .../unittests/test_gaussian_random_op.py | 43 +++++++++ 3 files changed, 135 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/gpu/flip_kernel.cu b/paddle/phi/kernels/gpu/flip_kernel.cu index d48ecc4dfac..812d68df92d 100644 --- a/paddle/phi/kernels/gpu/flip_kernel.cu +++ b/paddle/phi/kernels/gpu/flip_kernel.cu @@ -145,6 +145,7 @@ PD_REGISTER_KERNEL(flip, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int, int64_t, bool, diff --git a/python/paddle/fluid/tests/unittests/test_flip.py b/python/paddle/fluid/tests/unittests/test_flip.py index 197766a5563..a06ef10ca06 100644 --- a/python/paddle/fluid/tests/unittests/test_flip.py +++ b/python/paddle/fluid/tests/unittests/test_flip.py @@ -17,7 +17,7 @@ import unittest import gradient_checker import numpy as np from decorator_helper import prog_scope -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid @@ -74,9 +74,27 @@ class TestFlipOp(OpTest): self.op_type = 'flip' self.python_api = paddle.tensor.flip self.init_test_case() - self.inputs = {'X': np.random.random(self.in_shape).astype('float64')} self.init_attrs() - self.outputs = {'Out': self.calc_ref_res()} + self.init_dtype() + + if self.is_bfloat16_op(): + self.input = np.random.random(self.in_shape).astype(np.float32) + else: + self.input = np.random.random(self.in_shape).astype(self.dtype) + + output = self.calc_ref_res() + + if self.is_bfloat16_op(): + output = output.astype(np.float32) + self.inputs = {'X': convert_float_to_uint16(self.input)} + self.outputs = {'Out': convert_float_to_uint16(output)} + else: + self.inputs = {'X': self.input.astype(self.dtype)} + output = output.astype(self.dtype) + self.outputs = {'Out': output} + + def init_dtype(self): + self.dtype = np.float64 def init_attrs(self): self.attrs = {"axis": self.axis} @@ -92,7 +110,7 @@ class TestFlipOp(OpTest): self.axis = [0, 1] def calc_ref_res(self): - res = self.inputs['X'] + res = self.input if isinstance(self.axis, int): return np.flip(res, self.axis) for axis in self.axis: @@ -136,6 +154,75 @@ class TestFlipOpNegAxis(TestFlipOp): self.axis = [-1] +# ----------------flip_fp16---------------- +def create_test_fp16_class(parent): + @unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" + ) + class TestFlipFP16(parent): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place(place, ["X"], "Out") + + cls_name = "{}_{}".format(parent.__name__, "FP16OP") + TestFlipFP16.__name__ = cls_name + globals()[cls_name] = TestFlipFP16 + + +create_test_fp16_class(TestFlipOp) +create_test_fp16_class(TestFlipOpAxis1) +create_test_fp16_class(TestFlipOpAxis2) +create_test_fp16_class(TestFlipOpAxis3) +create_test_fp16_class(TestFlipOpAxis4) +create_test_fp16_class(TestFlipOpEmptyAxis) +create_test_fp16_class(TestFlipOpNegAxis) + + +# ----------------flip_bf16---------------- +def create_test_bf16_class(parent): + @unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", + ) + class TestFlipBF16(parent): + def init_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + if core.is_bfloat16_supported(place): + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + if core.is_bfloat16_supported(place): + self.check_grad_with_place(place, ["X"], "Out") + + cls_name = "{}_{}".format(parent.__name__, "BF16OP") + TestFlipBF16.__name__ = cls_name + globals()[cls_name] = TestFlipBF16 + + +create_test_bf16_class(TestFlipOp) +create_test_bf16_class(TestFlipOpAxis1) +create_test_bf16_class(TestFlipOpAxis2) +create_test_bf16_class(TestFlipOpAxis3) +create_test_bf16_class(TestFlipOpAxis4) +create_test_bf16_class(TestFlipOpEmptyAxis) +create_test_bf16_class(TestFlipOpNegAxis) + + class TestFlipDoubleGradCheck(unittest.TestCase): def flip_wrapper(self, x): return paddle.flip(x[0], [0, 1]) diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py index 3b8bf23e600..f735835cb58 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py @@ -64,6 +64,49 @@ class TestGaussianRandomOp(OpTest): np.testing.assert_allclose(hist, hist2, rtol=0, atol=0.01) +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestGaussianRandomFP16Op(OpTest): + def setUp(self): + self.op_type = "gaussian_random" + self.python_api = paddle.normal + self.set_attrs() + self.inputs = {} + self.use_mkldnn = False + self.attrs = { + "shape": [123, 92], + "mean": self.mean, + "std": self.std, + "seed": 10, + "dtype": paddle.fluid.core.VarDesc.VarType.FP16, + "use_mkldnn": self.use_mkldnn, + } + paddle.seed(10) + + self.outputs = {'Out': np.zeros((123, 92), dtype='float16')} + + def set_attrs(self): + self.mean = 1.0 + self.std = 2.0 + + def test_check_output(self): + self.check_output_with_place_customized( + self.verify_output, place=core.CUDAPlace(0) + ) + + def verify_output(self, outs): + self.assertEqual(outs[0].shape, (123, 92)) + hist, _ = np.histogram(outs[0], range=(-3, 5)) + hist = hist.astype("float16") + hist /= float(outs[0].size) + data = np.random.normal(size=(123, 92), loc=1, scale=2) + hist2, _ = np.histogram(data, range=(-3, 5)) + hist2 = hist2.astype("float16") + hist2 /= float(outs[0].size) + np.testing.assert_allclose(hist, hist2, rtol=0, atol=0.015) + + @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA" ) -- GitLab