From 587120ecbe279a7ff2477fe1a6840c506b9c32e8 Mon Sep 17 00:00:00 2001 From: Ainavo <57820731+Ainavo@users.noreply.github.com> Date: Mon, 27 Feb 2023 14:56:43 +0800 Subject: [PATCH] [fp16] fix fp16 support for nn.PairwiseDistance (#50849) --- .../tests/unittests/test_pairwise_distance.py | 91 +++++++++++++++++++ python/paddle/nn/functional/distance.py | 8 +- python/paddle/nn/layer/distance.py | 2 +- 3 files changed, 96 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_pairwise_distance.py b/python/paddle/fluid/tests/unittests/test_pairwise_distance.py index a764612cd9d..fe705adc9b7 100644 --- a/python/paddle/fluid/tests/unittests/test_pairwise_distance.py +++ b/python/paddle/fluid/tests/unittests/test_pairwise_distance.py @@ -286,6 +286,97 @@ class TestPairwiseDistance(unittest.TestCase): dygraph_functional_ret, excepted_value, rtol=1e-05 ) + def test_pairwise_distance_fp16(self): + epsilon = 1e-6 + all_shape = [[5], [100, 100]] + dtypes = ['float16'] + p_list = [-1, 0, 1, 2, np.inf, -np.inf] + places = [paddle.CPUPlace()] + if paddle.device.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + keeps = [False, True] + for place in places: + for shape in all_shape: + for dtype in dtypes: + for p in p_list: + for keepdim in keeps: + x_np = np.random.random(shape).astype(dtype) + y_np = np.random.random(shape).astype(dtype) + # Currently, the CPU does not support float16 + if dtype == "float16" and isinstance( + place, paddle.CPUPlace + ): + continue + static_ret = test_static( + place, + x_np, + y_np, + p, + epsilon=epsilon, + keepdim=keepdim, + ) + dygraph_ret = test_dygraph( + place, + x_np, + y_np, + p, + epsilon=epsilon, + keepdim=keepdim, + ) + excepted_value = np_pairwise_distance( + x_np, y_np, p, epsilon=epsilon, keepdim=keepdim + ) + + self.assertEqual( + static_ret.shape, excepted_value.shape + ) + self.assertEqual( + dygraph_ret.shape, excepted_value.shape + ) + + np.testing.assert_allclose( + static_ret, excepted_value, atol=1e-03 + ) + np.testing.assert_allclose( + dygraph_ret, excepted_value, atol=1e-03 + ) + static_functional_ret = test_static( + place, + x_np, + y_np, + p, + epsilon=epsilon, + keepdim=keepdim, + ) + dygraph_functional_ret = test_dygraph( + place, + x_np, + y_np, + p, + epsilon=epsilon, + keepdim=keepdim, + ) + + self.assertEqual( + static_functional_ret.shape, + excepted_value.shape, + ) + self.assertEqual( + dygraph_functional_ret.shape, + excepted_value.shape, + ) + + np.testing.assert_allclose( + static_functional_ret, + excepted_value, + atol=1e-03, + ) + np.testing.assert_allclose( + dygraph_functional_ret, + excepted_value, + atol=1e-03, + ) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/distance.py b/python/paddle/nn/functional/distance.py index 0bdba06c8b4..e8e209be18a 100644 --- a/python/paddle/nn/functional/distance.py +++ b/python/paddle/nn/functional/distance.py @@ -35,10 +35,10 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None): Parameters: x (Tensor): Tensor, shape is :math:`[N, D]` or :math:`[D]`, where :math:`N` is batch size, :math:`D` is the dimension of vector. Available dtype is - float32, float64. + float16, float32, float64. y (Tensor): Tensor, shape is :math:`[N, D]` or :math:`[D]`, where :math:`N` is batch size, :math:`D` is the dimension of vector. Available dtype is - float32, float64. + float16, float32, float64. p (float, optional): The order of norm. Default: :math:`2.0`. epsilon (float, optional): Add small value to avoid division by zero. Default: :math:`1e-6`. @@ -84,10 +84,10 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None): check_type(keepdim, 'keepdim', (bool), 'PairwiseDistance') check_variable_and_dtype( - x, 'x', ['float32', 'float64'], 'PairwiseDistance' + x, 'x', ['float16', 'float32', 'float64'], 'PairwiseDistance' ) check_variable_and_dtype( - y, 'y', ['float32', 'float64'], 'PairwiseDistance' + y, 'y', ['float16', 'float32', 'float64'], 'PairwiseDistance' ) sub = paddle.subtract(x, y) if epsilon != 0.0: diff --git a/python/paddle/nn/layer/distance.py b/python/paddle/nn/layer/distance.py index f63ce53c4e2..f68cb5d95ee 100644 --- a/python/paddle/nn/layer/distance.py +++ b/python/paddle/nn/layer/distance.py @@ -40,7 +40,7 @@ class PairwiseDistance(Layer): Shape: - x: :math:`[N, D]` or :math:`[D]`, where :math:`N` is batch size, :math:`D` - is the dimension of the data. Available data type is float32, float64. + is the dimension of the data. Available data type is float16, float32, float64. - y: :math:`[N, D]` or :math:`[D]`, y have the same dtype as x. - output: The same dtype as input tensor. - If :attr:`keepdim` is True, the output shape is :math:`[N, 1]` or :math:`[1]`, -- GitLab