From 51f4291ca1528345ecf891ce83c708730381ab90 Mon Sep 17 00:00:00 2001 From: ming1753 <61511741+ming1753@users.noreply.github.com> Date: Tue, 30 Aug 2022 10:30:57 +0800 Subject: [PATCH] strided_slice grad add fp16 support (#45504) --- .../kernels/gpu/strided_slice_grad_kernel.cu | 2 + .../tests/unittests/test_strided_slice_op.py | 38 +++++++++++++++++++ python/paddle/tensor/manipulation.py | 8 ++-- 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu b/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu index 90d9f1d986..6c5a99e909 100644 --- a/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu @@ -28,6 +28,7 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad, int64_t, float, double, + phi::dtype::float16, phi::dtype::complex, phi::dtype::complex) {} @@ -40,5 +41,6 @@ PD_REGISTER_KERNEL(strided_slice_array_grad, int64_t, float, double, + phi::dtype::float16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py index b68fbd9468..05f4e4a994 100644 --- a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py @@ -990,5 +990,43 @@ class TestStridedSliceTensorArray(unittest.TestCase): self.create_case(Net27(input_size=112, array_size=13)) +@unittest.skipIf(not fluid.core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestStridedSliceFloat16(unittest.TestCase): + + def init_test_case(self): + self.op_type = 'strided_slice' + self.input_shape = [3, 3, 3, 6, 7, 8] + self.axes = [0, 1, 2, 3, 4, 5] + self.starts = [1, 0, 0, 0, 1, 2] + self.ends = [2, 2, 3, 1, 2, 8] + self.strides = [1, 1, 1, 1, 1, 2] + self.infer_flags = [1, 1, 1, 1, 1] + + def check_main(self, x_np, dtype): + paddle.disable_static() + x_np = x_np.astype(dtype) + x = paddle.to_tensor(x_np) + x.stop_gradient = False + output = strided_slice_native_forward(x, self.axes, self.starts, + self.ends, self.strides) + x_grad = paddle.grad(output, x) + output_np = output[0].numpy().astype('float32') + x_grad_np = x_grad[0].numpy().astype('float32') + paddle.enable_static() + return output_np, x_grad_np + + def test_check(self): + self.init_test_case() + x_np = np.random.random(self.input_shape).astype("float16") + + output_np_fp16, x_grad_np_fp16 = self.check_main(x_np, 'float16') + output_np_fp32, x_grad_np_fp32 = self.check_main(x_np, 'float32') + + np.testing.assert_allclose(output_np_fp16, output_np_fp32) + + np.testing.assert_allclose(x_grad_np_fp16, x_grad_np_fp32) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 7b90eaa920..555ce0f427 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -3580,7 +3580,7 @@ def strided_slice(x, axes, starts, ends, strides, name=None): result = [ [2], ] Args: - x (Tensor): An N-D ``Tensor``. The data type is ``bool``, ``float32``, ``float64``, ``int32`` or ``int64``. + x (Tensor): An N-D ``Tensor``. The data type is ``bool``, ``float16``, ``float32``, ``float64``, ``int32`` or ``int64``. axes (list|tuple): The data type is ``int32`` . Axes that `starts` and `ends` apply to. It's optional. If it is not provides, it will be treated as :math:`[0,1,...,len(starts)-1]`. starts (list|tuple|Tensor): The data type is ``int32`` . If ``starts`` is a list or tuple, the elements of it should be integers or Tensors with shape [1]. If ``starts`` is an Tensor, it should be an 1-D Tensor. It represents starting indices of corresponding axis in ``axes``. @@ -3619,9 +3619,9 @@ def strided_slice(x, axes, starts, ends, strides, name=None): helper = LayerHelper('strided_slice', **locals()) - check_variable_and_dtype(x, 'x', - ['bool', 'float32', 'float64', 'int32', 'int64'], - 'strided_slice') + check_variable_and_dtype( + x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'strided_slice') check_type(axes, 'axes', (list, tuple), 'strided_slice') check_type(starts, 'starts', (list, tuple, Variable), 'strided_slice') check_type(ends, 'ends', (list, tuple, Variable), 'strided_slice') -- GitLab