diff --git a/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc b/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc index e6c812cf6bd5aa3b4d5119b380986ecc2802e073..21ef18dbd90cfac3c54884a936a5271c989aae5b 100644 --- a/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc @@ -28,6 +28,7 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad, int64_t, float, double, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -40,5 +41,6 @@ PD_REGISTER_KERNEL(strided_slice_array_grad, int64_t, float, double, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/strided_slice_kernel.cc b/paddle/phi/kernels/cpu/strided_slice_kernel.cc index d0aa7b2f4cee62e2611f7a509053142353d746a9..e9a1671bcc4c913bca94783a07676ade716fc4d7 100644 --- a/paddle/phi/kernels/cpu/strided_slice_kernel.cc +++ b/paddle/phi/kernels/cpu/strided_slice_kernel.cc @@ -28,6 +28,7 @@ PD_REGISTER_KERNEL(strided_slice_raw, int64_t, float, double, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -40,5 +41,6 @@ PD_REGISTER_KERNEL(strided_slice_array, int64_t, float, double, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu b/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu index 6c5a99e9095b63e09dd208506afb038b74901618..08ac3da93bb49d36b2409ac306f3f775137546c7 100644 --- a/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu @@ -29,6 +29,7 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad, float, double, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -42,5 +43,6 @@ PD_REGISTER_KERNEL(strided_slice_array_grad, float, double, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/strided_slice_kernel.cu b/paddle/phi/kernels/gpu/strided_slice_kernel.cu index 786ccb287c2719c8b6da433b626ca271c7f06710..9b88322e20a06ecba3af390a7d970d84539acde7 100644 --- a/paddle/phi/kernels/gpu/strided_slice_kernel.cu +++ b/paddle/phi/kernels/gpu/strided_slice_kernel.cu @@ -29,6 +29,7 @@ PD_REGISTER_KERNEL(strided_slice_raw, float, double, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -42,5 +43,6 @@ PD_REGISTER_KERNEL(strided_slice_array, float, double, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/strided_slice_grad_kernel.cc b/paddle/phi/kernels/strided_slice_grad_kernel.cc index af8994cd8c51571e039a51c09d87b725987e05d0..7582f751bf16a55d1ae0a00e646fa15fcc2818c6 100644 --- a/paddle/phi/kernels/strided_slice_grad_kernel.cc +++ b/paddle/phi/kernels/strided_slice_grad_kernel.cc @@ -52,6 +52,7 @@ PD_REGISTER_KERNEL(strided_slice_grad, int64_t, float, double, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -65,6 +66,7 @@ PD_REGISTER_KERNEL(strided_slice_grad, float, double, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} #endif diff --git a/paddle/phi/kernels/strided_slice_kernel.cc b/paddle/phi/kernels/strided_slice_kernel.cc index 3ceb8057235df5162d4f9af1d2383fb16533ca33..68377dbe8468ed0234e2e3a750a7c7712a58e7ec 100644 --- a/paddle/phi/kernels/strided_slice_kernel.cc +++ b/paddle/phi/kernels/strided_slice_kernel.cc @@ -43,6 +43,7 @@ PD_REGISTER_KERNEL(strided_slice, int64_t, float, double, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -56,6 +57,7 @@ PD_REGISTER_KERNEL(strided_slice, float, double, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} #endif diff --git a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py index dc5397fb4f4891cfc9f5e9440ac59d777e98ed67..de5a1bcb19a0f99474d2847aacf4a24b9533ecab 100644 --- a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid @@ -600,7 +600,7 @@ class TestStridedSliceAPI(unittest.TestCase): feed={ "x": input, 'starts': np.array([-3, 0, 2]).astype("int32"), - 'ends': np.array([3, 2147483648, -1]).astype("int64"), + 'ends': np.array([3, 2147483647, -1]).astype("int32"), 'strides': np.array([1, 1, 1]).astype("int32"), }, fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7], @@ -1011,5 +1011,77 @@ class TestStridedSliceFloat16(unittest.TestCase): np.testing.assert_allclose(x_grad_np_fp16, x_grad_np_fp32) +class TestStrideSliceFP16Op(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'strided_slice' + self.dtype = np.float16 + self.python_api = paddle.strided_slice + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides + ) + + self.inputs = {'Input': self.input.astype(self.dtype)} + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts, + 'ends': self.ends, + 'strides': self.strides, + 'infer_flags': self.infer_flags, + } + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad({'Input'}, 'Out', check_eager=True) + + def initTestCase(self): + self.input = np.random.rand(100) + self.axes = [0] + self.starts = [-4] + self.ends = [-3] + self.strides = [1] + self.infer_flags = [1] + + +class TestStrideSliceBF16Op(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'strided_slice' + self.dtype = np.uint16 + self.python_api = paddle.strided_slice + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides + ) + + self.inputs = { + 'Input': convert_float_to_uint16(self.input.astype(np.float32)) + } + self.outputs = {'Out': convert_float_to_uint16(self.output)} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts, + 'ends': self.ends, + 'strides': self.strides, + 'infer_flags': self.infer_flags, + } + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad({'Input'}, 'Out', check_eager=True) + + def initTestCase(self): + self.input = np.random.rand(100) + self.axes = [0] + self.starts = [-4] + self.ends = [-3] + self.strides = [1] + self.infer_flags = [1] + + if __name__ == "__main__": unittest.main()