diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index ff5f4c865079583d1f0f2e9c8010cbbe3aecf7d2..719849500376c75ac1ed3685bc939d653350ffee 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -126,6 +126,7 @@ from .tensor.linalg import norm # noqa: F401 from .tensor.linalg import transpose # noqa: F401 from .tensor.linalg import dist # noqa: F401 from .tensor.linalg import t # noqa: F401 +from .tensor.linalg import cdist # noqa: F401 from .tensor.linalg import cross # noqa: F401 from .tensor.linalg import cholesky # noqa: F401 from .tensor.linalg import bmm # noqa: F401 @@ -536,6 +537,7 @@ __all__ = [ # noqa 'triu', 'sin', 'dist', + 'cdist', 'unbind', 'meshgrid', 'arange', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index ba13034fc56edf96baffbcba2b23a8ba2c17a6de..4d083a5febb353dc14eb21137a41ba04bbd0f8f9 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -70,6 +70,7 @@ from .linalg import solve # noqa: F401 from .linalg import cholesky_solve # noqa: F401 from .linalg import lu # noqa: F401 from .linalg import lu_unpack # noqa: F401 +from .linalg import cdist # noqa: F401 from .logic import equal # noqa: F401 from .logic import greater_equal # noqa: F401 from .logic import greater_than # noqa: F401 @@ -518,6 +519,7 @@ tensor_method_func = [ # noqa 'acosh', 'lu', 'lu_unpack', + 'cdist', 'as_complex', 'as_real', 'rad2deg', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index da77bee601824aab3d074d20ed20dea980e8ae73..67548a8120fd219af65ee27c7f21b50a3a6ac1f4 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -3326,3 +3326,119 @@ def corrcoef(x, rowvar=True, name=None): c = paddle.clip(c, -1, 1) return c + + +def cdist( + x, y, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary", name=None +): + r""" + + Compute the p-norm distance between each pair of the two collections of inputs. + + This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)` + if :math:`p \in (0, \infty)`. When :math:`p = 0` it is equivalent to `scipy.spatial.distance.cdist(input, 'hamming') * M`. + When :math:`p = \infty`, the closest scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`. + + Args: + x (Tensor): A tensor with shape :math:`B \times P \times M`. + y (Tensor): A tensor with shape :math:`B \times R \times M`. + p (float, optional): The value for the p-norm distance to calculate between each vector pair. Default: :math:`2.0`. + compute_mode (str, optional): The mode for compute distance. + + - ``use_mm_for_euclid_dist_if_necessary`` , for p = 2.0 and (P > 25 or R > 25), it will use matrix multiplication to calculate euclid distance if possible. + - ``use_mm_for_euclid_dist`` , for p = 2.0, it will use matrix multiplication to calculate euclid distance. + - ``donot_use_mm_for_euclid_dist`` , it will not use matrix multiplication to calculate euclid distance. + + Default: ``use_mm_for_euclid_dist_if_necessary``. + name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + Tensor, the dtype is same as input tensor. + + If x has shape :math:`B \times P \times M` and y has shape :math:`B \times R \times M` then + the output will have shape :math:`B \times P \times R`. + + Examples: + .. code-block:: python + + import paddle + x = paddle.to_tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]], dtype=paddle.float32) + y = paddle.to_tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]], dtype=paddle.float32) + distance = paddle.cdist(x, y) + print(distance) + # Tensor(shape=[3, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[3.1193, 2.0959], [2.7138, 3.8322], [2.2830, 0.3791]]) + """ + + check_variable_and_dtype(x, 'x', ('float32', 'float64'), 'cdist') + check_variable_and_dtype(y, 'y', ('float32', 'float64'), 'cdist') + check_type(p, 'p', (float, int), 'cdist') + + if compute_mode not in [ + 'use_mm_for_euclid_dist_if_necessary', + 'use_mm_for_euclid_dist', + 'donot_use_mm_for_euclid_dist', + ]: + raise ValueError( + "The compute_mode should be 'use_mm_for_euclid_dist_if_necessary', " + "'use_mm_for_euclid_dist' or 'donot_use_mm_for_euclid_dist', " + "but received compute_mode is %s." % compute_mode + ) + + mode = 0 + if compute_mode == 'use_mm_for_euclid_dist_if_necessary': + mode = 0 + elif compute_mode == 'use_mm_for_euclid_dist': + mode = 1 + elif compute_mode == 'donot_use_mm_for_euclid_dist': + mode = 2 + + x_shape = list(x.shape) + assert len(x_shape) >= 2, ( + "The x must be at least 2-dimensional, " + "But received Input x's dimensional is %s.\n" % len(x_shape) + ) + y_shape = list(y.shape) + assert len(y_shape) >= 2, ( + "The y must be at least 2-dimensional, " + "But received Input y's dimensional is %s.\n" % len(y_shape) + ) + assert x_shape[-1] == y_shape[-1], ( + "The x and y must have same last dimension, " + "But received Input x's last dimension is {}, " + "Input y's last dimension is {}.\n".format(x_shape[-1], y_shape[-1]) + ) + assert p >= 0, ( + "The p must be greater than or equal to 0, " + "But received p is %s.\n" % p + ) + + r1 = x.shape[-2] + r2 = y.shape[-2] + c1 = x.shape[-1] + + p = float(p) + + if r1 == 0 or r2 == 0: + return paddle.empty((r1, r2), dtype=x.dtype) + + if c1 == 0: + return paddle.zeros((r1, r2), dtype=x.dtype) + + if p == 2.0 and (mode == 1 or (mode == 0 and (r1 > 25 or r2 > 25))): + x_norm = paddle.sum(x.pow(2), axis=-1, keepdim=True) + y_norm = paddle.sum(y.pow(2), axis=-1, keepdim=True) + y_transposed = paddle.transpose( + y, perm=[*range(y.ndim - 2), y.ndim - 1, y.ndim - 2] + ) + y_norm_transposed = paddle.transpose( + y_norm, + perm=[*range(y_norm.ndim - 2), y_norm.ndim - 1, y_norm.ndim - 2], + ) + res = paddle.matmul(x, y_transposed) * -2 + y_norm_transposed + x_norm + res = paddle.clip(res, min=0.0).sqrt() + return res + + return paddle.linalg.norm( + x[..., None, :] - y[..., None, :, :], p=p, axis=-1 + ) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 22457958f7bfd4dce6f8bbbd6cdc6b48a01da215..3e4984832ec0c476971ad91108b51aba63003180 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -940,9 +940,11 @@ set_tests_properties(test_paddle_save_load_binary PROPERTIES TIMEOUT 120) if(WIN32) set_tests_properties(test_static_save_load_large PROPERTIES TIMEOUT 900) set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 250) + set_tests_properties(test_cdist PROPERTIES TIMEOUT 250) else() set_tests_properties(test_static_save_load_large PROPERTIES TIMEOUT 600) set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 250) + set_tests_properties(test_cdist PROPERTIES TIMEOUT 120) endif() if(WITH_NV_JETSON) set_tests_properties(test_concat_op PROPERTIES TIMEOUT 1200) diff --git a/test/legacy_test/test_cdist.py b/test/legacy_test/test_cdist.py new file mode 100644 index 0000000000000000000000000000000000000000..eb8460870e99fe5a58bcaf6109003b7c7b2330e3 --- /dev/null +++ b/test/legacy_test/test_cdist.py @@ -0,0 +1,190 @@ +# # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. + +import unittest + +import numpy as np + +import paddle + + +def ref_cdist(x, y, p=2.0): + r1 = x.shape[-2] + r2 = y.shape[-2] + if r1 == 0 or r2 == 0: + return np.empty((r1, r2), x.dtype) + return np.linalg.norm(x[..., None, :] - y[..., None, :, :], ord=p, axis=-1) + + +class TestCdistAPI(unittest.TestCase): + def setUp(self): + np.random.seed(1024) + self.x = np.random.rand(10, 20).astype('float32') + self.y = np.random.rand(11, 20).astype('float32') + self.p = 2.0 + self.compute_mode = "use_mm_for_euclid_dist_if_necessary" + self.init_input() + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def init_input(self): + pass + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('x', self.x.shape, dtype=self.x.dtype) + y = paddle.static.data('y', self.y.shape, dtype=self.y.dtype) + out = paddle.cdist(x, y, self.p, self.compute_mode) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out]) + out_ref = ref_cdist(self.x, self.y, self.p) + np.testing.assert_allclose(out_ref, res[0], rtol=1e-5, atol=1e-5) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) + out = paddle.cdist(x, y, self.p, self.compute_mode) + out_ref = ref_cdist(self.x, self.y, self.p) + np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5, atol=1e-5) + paddle.enable_static() + + +class TestCdistAPICase1(TestCdistAPI): + def init_input(self): + self.p = 0 + + +class TestCdistAPICase2(TestCdistAPI): + def init_input(self): + self.p = 1.0 + + +class TestCdistAPICase3(TestCdistAPI): + def init_input(self): + self.p = 3.0 + + +class TestCdistAPICase4(TestCdistAPI): + def init_input(self): + self.p = 1.5 + + +class TestCdistAPICase5(TestCdistAPI): + def init_input(self): + self.p = 2.5 + + +class TestCdistAPICase6(TestCdistAPI): + def init_input(self): + self.p = float('inf') + + +class TestCdistAPICase7(TestCdistAPI): + def init_input(self): + self.x = np.random.rand(50, 20).astype('float64') + self.y = np.random.rand(40, 20).astype('float64') + self.compute_mode = "use_mm_for_euclid_dist" + + +class TestCdistAPICase8(TestCdistAPI): + def init_input(self): + self.x = np.random.rand(50, 20).astype('float64') + self.y = np.random.rand(40, 20).astype('float64') + self.compute_mode = "donot_use_mm_for_euclid_dist" + + +class TestCdistAPICase9(TestCdistAPI): + def init_input(self): + self.x = np.random.rand(500, 100).astype('float64') + self.y = np.random.rand(400, 100).astype('float64') + + +class TestCdistAPICase10(TestCdistAPI): + def init_input(self): + self.x = np.random.rand(3, 500, 100).astype('float64') + self.y = np.random.rand(3, 400, 100).astype('float64') + + +class TestCdistAPICase11(TestCdistAPI): + def init_input(self): + self.x = np.random.rand(3, 4, 500, 100).astype('float64') + self.y = np.random.rand(3, 4, 400, 100).astype('float64') + + +class TestCdistAPICase12(TestCdistAPI): + def init_input(self): + self.x = np.random.rand(3, 4, 500, 100).astype('float64') + self.y = np.random.rand(3, 4, 400, 100).astype('float64') + self.p = 3.0 + + +# test for different compute mode output same result +class TestCdistAPICase13(TestCdistAPI): + def init_input(self): + self.x = np.random.rand(3, 4, 500, 100).astype('float64') + self.y = np.random.rand(3, 4, 400, 100).astype('float64') + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('x', self.x.shape, dtype=self.x.dtype) + y = paddle.static.data('y', self.y.shape, dtype=self.y.dtype) + out0 = paddle.cdist(x, y, self.p, self.compute_mode) + out1 = paddle.cdist(x, y, self.p, "donot_use_mm_for_euclid_dist") + out2 = paddle.cdist(x, y, self.p, "use_mm_for_euclid_dist") + exe = paddle.static.Executor(self.place) + res = exe.run( + feed={'x': self.x, 'y': self.y}, fetch_list=[out0, out1, out2] + ) + out_ref = ref_cdist(self.x, self.y, self.p) + np.testing.assert_allclose(out_ref, res[0]) + np.testing.assert_allclose(out_ref, res[2]) + np.testing.assert_allclose(out_ref, res[2]) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) + out0 = paddle.cdist(x, y, self.p, self.compute_mode) + out1 = paddle.cdist(x, y, self.p, "donot_use_mm_for_euclid_dist") + out2 = paddle.cdist(x, y, self.p, "use_mm_for_euclid_dist") + out_ref = ref_cdist(self.x, self.y, self.p) + np.testing.assert_allclose(out_ref, out0.numpy()) + np.testing.assert_allclose(out_ref, out1.numpy()) + np.testing.assert_allclose(out_ref, out2.numpy()) + paddle.enable_static() + + +# test for broadcast +class TestCdistAPICase14(TestCdistAPI): + def init_input(self): + self.x = np.random.rand(3, 4, 500, 100).astype('float64') + self.y = np.random.rand(1, 4, 400, 100).astype('float64') + + +# test for broadcast +class TestCdistAPICase15(TestCdistAPI): + def init_input(self): + self.x = np.random.rand(3, 4, 500, 100).astype('float64') + self.y = np.random.rand(4, 400, 100).astype('float64') + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main()