From 5cdd9f2cbdf57246e7912ec8ff99929980466044 Mon Sep 17 00:00:00 2001 From: Wang Xinyu Date: Thu, 30 Mar 2023 11:31:54 +0800 Subject: [PATCH] [AMP OP&Test] Strided slice fp16 and bf16 unitest (#52220) * stride slice fp16 and bf16 unitest * fix code style * add self.dtype --- .../kernels/cpu/strided_slice_grad_kernel.cc | 2 + .../phi/kernels/cpu/strided_slice_kernel.cc | 2 + .../kernels/gpu/strided_slice_grad_kernel.cu | 2 + .../phi/kernels/gpu/strided_slice_kernel.cu | 2 + .../phi/kernels/strided_slice_grad_kernel.cc | 2 + paddle/phi/kernels/strided_slice_kernel.cc | 2 + .../tests/unittests/test_strided_slice_op.py | 76 ++++++++++++++++++- 7 files changed, 86 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc b/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc index e6c812cf6bd..21ef18dbd90 100644 --- a/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc @@ -28,6 +28,7 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad, int64_t, float, double, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -40,5 +41,6 @@ PD_REGISTER_KERNEL(strided_slice_array_grad, int64_t, float, double, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/strided_slice_kernel.cc b/paddle/phi/kernels/cpu/strided_slice_kernel.cc index d0aa7b2f4ce..e9a1671bcc4 100644 --- a/paddle/phi/kernels/cpu/strided_slice_kernel.cc +++ b/paddle/phi/kernels/cpu/strided_slice_kernel.cc @@ -28,6 +28,7 @@ PD_REGISTER_KERNEL(strided_slice_raw, int64_t, float, double, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -40,5 +41,6 @@ PD_REGISTER_KERNEL(strided_slice_array, int64_t, float, double, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu b/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu index 6c5a99e9095..08ac3da93bb 100644 --- a/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu @@ -29,6 +29,7 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad, float, double, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -42,5 +43,6 @@ PD_REGISTER_KERNEL(strided_slice_array_grad, float, double, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/strided_slice_kernel.cu b/paddle/phi/kernels/gpu/strided_slice_kernel.cu index 786ccb287c2..9b88322e20a 100644 --- a/paddle/phi/kernels/gpu/strided_slice_kernel.cu +++ b/paddle/phi/kernels/gpu/strided_slice_kernel.cu @@ -29,6 +29,7 @@ PD_REGISTER_KERNEL(strided_slice_raw, float, double, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -42,5 +43,6 @@ PD_REGISTER_KERNEL(strided_slice_array, float, double, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/strided_slice_grad_kernel.cc b/paddle/phi/kernels/strided_slice_grad_kernel.cc index af8994cd8c5..7582f751bf1 100644 --- a/paddle/phi/kernels/strided_slice_grad_kernel.cc +++ b/paddle/phi/kernels/strided_slice_grad_kernel.cc @@ -52,6 +52,7 @@ PD_REGISTER_KERNEL(strided_slice_grad, int64_t, float, double, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -65,6 +66,7 @@ PD_REGISTER_KERNEL(strided_slice_grad, float, double, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} #endif diff --git a/paddle/phi/kernels/strided_slice_kernel.cc b/paddle/phi/kernels/strided_slice_kernel.cc index 3ceb8057235..68377dbe846 100644 --- a/paddle/phi/kernels/strided_slice_kernel.cc +++ b/paddle/phi/kernels/strided_slice_kernel.cc @@ -43,6 +43,7 @@ PD_REGISTER_KERNEL(strided_slice, int64_t, float, double, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -56,6 +57,7 @@ PD_REGISTER_KERNEL(strided_slice, float, double, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} #endif diff --git a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py index dc5397fb4f4..de5a1bcb19a 100644 --- a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid @@ -600,7 +600,7 @@ class TestStridedSliceAPI(unittest.TestCase): feed={ "x": input, 'starts': np.array([-3, 0, 2]).astype("int32"), - 'ends': np.array([3, 2147483648, -1]).astype("int64"), + 'ends': np.array([3, 2147483647, -1]).astype("int32"), 'strides': np.array([1, 1, 1]).astype("int32"), }, fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7], @@ -1011,5 +1011,77 @@ class TestStridedSliceFloat16(unittest.TestCase): np.testing.assert_allclose(x_grad_np_fp16, x_grad_np_fp32) +class TestStrideSliceFP16Op(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'strided_slice' + self.dtype = np.float16 + self.python_api = paddle.strided_slice + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides + ) + + self.inputs = {'Input': self.input.astype(self.dtype)} + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts, + 'ends': self.ends, + 'strides': self.strides, + 'infer_flags': self.infer_flags, + } + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad({'Input'}, 'Out', check_eager=True) + + def initTestCase(self): + self.input = np.random.rand(100) + self.axes = [0] + self.starts = [-4] + self.ends = [-3] + self.strides = [1] + self.infer_flags = [1] + + +class TestStrideSliceBF16Op(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'strided_slice' + self.dtype = np.uint16 + self.python_api = paddle.strided_slice + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides + ) + + self.inputs = { + 'Input': convert_float_to_uint16(self.input.astype(np.float32)) + } + self.outputs = {'Out': convert_float_to_uint16(self.output)} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts, + 'ends': self.ends, + 'strides': self.strides, + 'infer_flags': self.infer_flags, + } + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad({'Input'}, 'Out', check_eager=True) + + def initTestCase(self): + self.input = np.random.rand(100) + self.axes = [0] + self.starts = [-4] + self.ends = [-3] + self.strides = [1] + self.infer_flags = [1] + + if __name__ == "__main__": unittest.main() -- GitLab