From 46be6854dd67a0db65fc77aeb57cbdd522e9eef3 Mon Sep 17 00:00:00 2001 From: Ainavo <57820731+Ainavo@users.noreply.github.com> Date: Fri, 29 Jul 2022 23:35:12 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PaddlePaddle=20Hackathon=203=20No.12?= =?UTF-8?q?=E3=80=91=E4=B8=BA=20Paddle=20=E6=96=B0=E5=A2=9E=20pairwise=5Fd?= =?UTF-8?q?istance=20(#44161)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add paddle.nn.functional.pairwise_distance (cattidea/Paddle#273) * remove the test case for undefined behavior Co-authored-by: SigureMo --- .../tests/unittests/test_pairwise_distance.py | 334 +++++++++++++++--- python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/distance.py | 109 ++++++ python/paddle/nn/layer/distance.py | 81 ++--- 4 files changed, 424 insertions(+), 102 deletions(-) create mode 100644 python/paddle/nn/functional/distance.py diff --git a/python/paddle/fluid/tests/unittests/test_pairwise_distance.py b/python/paddle/fluid/tests/unittests/test_pairwise_distance.py index 651d9b5ea6..bdcc302de5 100644 --- a/python/paddle/fluid/tests/unittests/test_pairwise_distance.py +++ b/python/paddle/fluid/tests/unittests/test_pairwise_distance.py @@ -20,24 +20,64 @@ import numpy as np import unittest -def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False): - return np.linalg.norm(x - y, ord=p, axis=1, keepdims=keepdim) +def np_pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False): + distance = np.linalg.norm(x - y + epsilon, ord=p, axis=-1, keepdims=keepdim) + # Paddle currently has not supported for 0-d Tensors, so even if keep_dim is False, + # and neither x nor y is batched, a Tensor of shape (1, ) is returned + if distance.ndim == 0: + distance = np.expand_dims(distance, axis=0) + return distance -def test_static(x_np, y_np, p=2.0, epsilon=1e-6, keepdim=False): +def call_pairwise_distance_layer(x, y, p=2., epsilon=1e-6, keepdim='False'): + pairwise_distance = paddle.nn.PairwiseDistance(p=p, + epsilon=epsilon, + keepdim=keepdim) + distance = pairwise_distance(x=x, y=y) + return distance + + +def call_pairwise_distance_functional(x, + y, + p=2., + epsilon=1e-6, + keepdim='False'): + distance = paddle.nn.functional.pairwise_distance(x=x, + y=y, + p=p, + epsilon=epsilon, + keepdim=keepdim) + return distance + + +def test_static(place, + x_np, + y_np, + p=2.0, + epsilon=1e-6, + keepdim=False, + functional=False): prog = paddle.static.Program() startup_prog = paddle.static.Program() - place = fluid.CUDAPlace( 0) if paddle.fluid.core.is_compiled_with_cuda() else fluid.CPUPlace() - + paddle.enable_static() with paddle.static.program_guard(prog, startup_prog): x = paddle.fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype) y = paddle.fluid.data(name='y', shape=y_np.shape, dtype=x_np.dtype) - dist = paddle.nn.layer.distance.PairwiseDistance(p=p, + + if functional: + distance = call_pairwise_distance_functional(x=x, + y=y, + p=p, epsilon=epsilon, keepdim=keepdim) - distance = dist(x, y) + else: + distance = call_pairwise_distance_layer(x=x, + y=y, + p=p, + epsilon=epsilon, + keepdim=keepdim) exe = paddle.static.Executor(place) static_ret = exe.run(prog, feed={ @@ -46,69 +86,279 @@ def test_static(x_np, y_np, p=2.0, epsilon=1e-6, keepdim=False): }, fetch_list=[distance]) static_ret = static_ret[0] + paddle.disable_static() return static_ret -def test_dygraph(x_np, y_np, p=2.0, epsilon=1e-6, keepdim=False): - paddle.disable_static() +def test_dygraph(place, + x_np, + y_np, + p=2.0, + epsilon=1e-6, + keepdim=False, + functional=False): x = paddle.to_tensor(x_np) y = paddle.to_tensor(y_np) - dist = paddle.nn.layer.distance.PairwiseDistance(p=p, - epsilon=epsilon, - keepdim=keepdim) - distance = dist(x, y) - dygraph_ret = distance.numpy() - paddle.enable_static() + if functional: + dy_distance = call_pairwise_distance_functional(x=x, + y=y, + p=p, + epsilon=epsilon, + keepdim=keepdim) + else: + dy_distance = call_pairwise_distance_layer(x=x, + y=y, + p=p, + epsilon=epsilon, + keepdim=keepdim) + dygraph_ret = dy_distance.numpy() return dygraph_ret +def test_legacy_dygraph(place, + x_np, + y_np, + p=2.0, + epsilon=1e-6, + keepdim=False, + functional=False): + paddle.fluid.framework._enable_legacy_dygraph() + x = paddle.to_tensor(x_np) + y = paddle.to_tensor(y_np) + if functional: + legacy_distance = call_pairwise_distance_functional(x=x, + y=y, + p=p, + epsilon=epsilon, + keepdim=keepdim) + else: + legacy_distance = call_pairwise_distance_layer(x=x, + y=y, + p=p, + epsilon=epsilon, + keepdim=keepdim) + legacy_ret = legacy_distance.numpy() + paddle.fluid.framework._disable_legacy_dygraph() + return legacy_ret + + class TestPairwiseDistance(unittest.TestCase): def test_pairwise_distance(self): - all_shape = [[100, 100], [4, 5, 6, 7]] + epsilon = 1e-6 + all_shape = [[5], [100, 100]] dtypes = ['float32', 'float64'] + p_list = [-1, 0, 1, 2, np.inf, -np.inf] + places = [paddle.CPUPlace()] + if paddle.device.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) keeps = [False, True] - for shape in all_shape: - for dtype in dtypes: - for keepdim in keeps: - x_np = np.random.random(shape).astype(dtype) - y_np = np.random.random(shape).astype(dtype) - - static_ret = test_static(x_np, y_np, keepdim=keepdim) - dygraph_ret = test_dygraph(x_np, y_np, keepdim=keepdim) - excepted_value = pairwise_distance(x_np, + for place in places: + for shape in all_shape: + for dtype in dtypes: + for p in p_list: + for keepdim in keeps: + x_np = np.random.random(shape).astype(dtype) + y_np = np.random.random(shape).astype(dtype) + + static_ret = test_static(place, + x_np, + y_np, + p, + epsilon=epsilon, + keepdim=keepdim) + dygraph_ret = test_dygraph(place, + x_np, y_np, + p, + epsilon=epsilon, keepdim=keepdim) + legacy_ret = test_legacy_dygraph(place, + x_np, + y_np, + p, + epsilon=epsilon, + keepdim=keepdim) + excepted_value = np_pairwise_distance( + x_np, y_np, p, epsilon=epsilon, keepdim=keepdim) + + self.assertEqual(static_ret.shape, + excepted_value.shape) + self.assertEqual(dygraph_ret.shape, + excepted_value.shape) + self.assertEqual(legacy_ret.shape, + excepted_value.shape) + + self.assertTrue( + np.allclose(static_ret, excepted_value)) + self.assertTrue( + np.allclose(dygraph_ret, excepted_value)) + self.assertTrue( + np.allclose(legacy_ret, excepted_value)) + + static_functional_ret = test_static(place, + x_np, + y_np, + p, + epsilon=epsilon, + keepdim=keepdim) + dygraph_functional_ret = test_dygraph( + place, + x_np, + y_np, + p, + epsilon=epsilon, + keepdim=keepdim) + legacy_functional_ret = test_legacy_dygraph( + place, + x_np, + y_np, + p, + epsilon=epsilon, + keepdim=keepdim) - self.assertTrue(np.allclose(static_ret, dygraph_ret)) - self.assertTrue(np.allclose(static_ret, excepted_value)) - self.assertTrue(np.allclose(dygraph_ret, excepted_value)) + self.assertEqual(static_functional_ret.shape, + excepted_value.shape) + self.assertEqual(dygraph_functional_ret.shape, + excepted_value.shape) + self.assertEqual(legacy_functional_ret.shape, + excepted_value.shape) - def test_pairwise_distance_broadcast(self): + self.assertTrue( + np.allclose(static_functional_ret, + excepted_value)) + self.assertTrue( + np.allclose(dygraph_functional_ret, + excepted_value)) + self.assertTrue( + np.allclose(legacy_functional_ret, + excepted_value)) + + def test_pairwise_distance_broadcast_1(self): shape_x = [100, 100] shape_y = [100, 1] + epsilon = 1e-6 keepdim = False + place = paddle.CPUPlace() x_np = np.random.random(shape_x).astype('float32') y_np = np.random.random(shape_y).astype('float32') - static_ret = test_static(x_np, y_np, keepdim=keepdim) - dygraph_ret = test_dygraph(x_np, y_np, keepdim=keepdim) - excepted_value = pairwise_distance(x_np, y_np, keepdim=keepdim) - self.assertTrue(np.allclose(static_ret, dygraph_ret)) + static_ret = test_static(place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim) + dygraph_ret = test_dygraph(place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim) + legacy_ret = test_legacy_dygraph(place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim) + excepted_value = np_pairwise_distance(x_np, + y_np, + epsilon=epsilon, + keepdim=keepdim) + + self.assertEqual(static_ret.shape, excepted_value.shape) + self.assertEqual(dygraph_ret.shape, excepted_value.shape) + self.assertEqual(legacy_ret.shape, excepted_value.shape) + self.assertTrue(np.allclose(static_ret, excepted_value)) self.assertTrue(np.allclose(dygraph_ret, excepted_value)) + self.assertTrue(np.allclose(legacy_ret, excepted_value)) + + static_functional_ret = test_static(place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim, + functional=True) + dygraph_functional_ret = test_dygraph(place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim, + functional=True) + legacy_functional_ret = test_legacy_dygraph(place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim, + functional=True) + + self.assertEqual(static_functional_ret.shape, excepted_value.shape) + self.assertEqual(dygraph_functional_ret.shape, excepted_value.shape) + self.assertEqual(legacy_functional_ret.shape, excepted_value.shape) - def test_pairwise_distance_different_p(self): - shape = [100, 100] + self.assertTrue(np.allclose(static_functional_ret, excepted_value)) + self.assertTrue(np.allclose(dygraph_functional_ret, excepted_value)) + self.assertTrue(np.allclose(legacy_functional_ret, excepted_value)) + + def test_pairwise_distance_broadcast_2(self): + shape_x = [100, 100] + shape_y = [100] + epsilon = 1e-6 keepdim = False - p = 3.0 - x_np = np.random.random(shape).astype('float32') - y_np = np.random.random(shape).astype('float32') - static_ret = test_static(x_np, y_np, p=p, keepdim=keepdim) - dygraph_ret = test_dygraph(x_np, y_np, p=p, keepdim=keepdim) - excepted_value = pairwise_distance(x_np, y_np, p=p, keepdim=keepdim) - self.assertTrue(np.allclose(static_ret, dygraph_ret)) + place = paddle.CPUPlace() + x_np = np.random.random(shape_x).astype('float32') + y_np = np.random.random(shape_y).astype('float32') + static_ret = test_static(place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim) + dygraph_ret = test_dygraph(place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim) + legacy_ret = test_legacy_dygraph(place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim) + excepted_value = np_pairwise_distance(x_np, + y_np, + epsilon=epsilon, + keepdim=keepdim) + + self.assertEqual(static_ret.shape, excepted_value.shape) + self.assertEqual(dygraph_ret.shape, excepted_value.shape) + self.assertEqual(legacy_ret.shape, excepted_value.shape) + self.assertTrue(np.allclose(static_ret, excepted_value)) self.assertTrue(np.allclose(dygraph_ret, excepted_value)) + self.assertTrue(np.allclose(legacy_ret, excepted_value)) + + static_functional_ret = test_static(place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim, + functional=True) + dygraph_functional_ret = test_dygraph(place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim, + functional=True) + legacy_functional_ret = test_legacy_dygraph(place=place, + x_np=x_np, + y_np=y_np, + epsilon=epsilon, + keepdim=keepdim, + functional=True) + + self.assertEqual(static_functional_ret.shape, excepted_value.shape) + self.assertEqual(dygraph_functional_ret.shape, excepted_value.shape) + self.assertEqual(legacy_functional_ret.shape, excepted_value.shape) + + self.assertTrue(np.allclose(static_functional_ret, excepted_value)) + self.assertTrue(np.allclose(dygraph_functional_ret, excepted_value)) + self.assertTrue(np.allclose(legacy_functional_ret, excepted_value)) if __name__ == "__main__": diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index b5d2d6f5be..701997e0d0 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -69,6 +69,7 @@ from .conv import conv2d # noqa: F401 from .conv import conv2d_transpose # noqa: F401 from .conv import conv3d # noqa: F401 from .conv import conv3d_transpose # noqa: F401 +from .distance import pairwise_distance # noqa: F401 from .extension import diag_embed # noqa: F401 from .extension import sequence_mask from .loss import binary_cross_entropy # noqa: F401 @@ -137,6 +138,7 @@ __all__ = [ # noqa 'conv2d_transpose', 'conv3d', 'conv3d_transpose', + 'pairwise_distance', 'elu', 'elu_', 'gelu', diff --git a/python/paddle/nn/functional/distance.py b/python/paddle/nn/functional/distance.py new file mode 100644 index 0000000000..8c672ffc69 --- /dev/null +++ b/python/paddle/nn/functional/distance.py @@ -0,0 +1,109 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from ...fluid.data_feeder import check_variable_and_dtype, check_type +from ...fluid.layer_helper import LayerHelper +from paddle import _C_ops +from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph + +__all__ = [] + + +def pairwise_distance(x, y, p=2., epsilon=1e-6, keepdim=False, name=None): + r""" + It computes the pairwise distance between two vectors. The + distance is calculated by p-oreder norm: + + .. math:: + + \Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}. + + Parameters: + x (Tensor): Tensor, shape is :math:`[N, D]` or :math:`[D]`, where :math:`N` + is batch size, :math:`D` is the dimension of vector. Available dtype is + float32, float64. + y (Tensor): Tensor, shape is :math:`[N, D]` or :math:`[D]`, where :math:`N` + is batch size, :math:`D` is the dimension of vector. Available dtype is + float32, float64. + p (float, optional): The order of norm. Default: :math:`2.0`. + epsilon (float, optional): Add small value to avoid division by zero. + Default: :math:`1e-6`. + keepdim (bool, optional): Whether to reserve the reduced dimension + in the output Tensor. The result tensor is one dimension less than + the result of ``|x-y|`` unless :attr:`keepdim` is True. Default: False. + name (str, optional): For details, please refer to :ref:`api_guide_Name`. + Generally, no setting is required. Default: None. + + Returns: + Tensor, the dtype is same as input tensor. + - If :attr:`keepdim` is True, the output shape is :math:`[N, 1]` or :math:`[1]`, + depending on whether the input has data shaped as :math:`[N, D]`. + - If :attr:`keepdim` is False, the output shape is :math:`[N]` or :math:`[]`, + depending on whether the input has data shaped as :math:`[N, D]`. + + Examples: + .. code-block:: python + + import paddle + x = paddle.to_tensor([[1., 3.], [3., 5.]], dtype=paddle.float64) + y = paddle.to_tensor([[5., 6.], [7., 8.]], dtype=paddle.float64) + distance = paddle.nn.functional.pairwise_distance(x, y) + print(distance.numpy()) # [5. 5.] + + """ + check_type(p, 'porder', (float, int), 'PairwiseDistance') + check_type(epsilon, 'epsilon', (float), 'PairwiseDistance') + check_type(keepdim, 'keepdim', (bool), 'PairwiseDistance') + if in_dygraph_mode(): + sub = _C_ops.elementwise_sub(x, y) + # p_norm op has not uesd epsilon, so change it to the following. + if epsilon != 0.0: + epsilon = paddle.fluid.dygraph.base.to_variable([epsilon], + dtype=sub.dtype) + sub = _C_ops.elementwise_add(sub, epsilon) + return _C_ops.final_state_p_norm(sub, p, -1, 0., keepdim, False) + + if _in_legacy_dygraph(): + sub = _C_ops.elementwise_sub(x, y) + if epsilon != 0.0: + epsilon = paddle.fluid.dygraph.base.to_variable([epsilon], + dtype=sub.dtype) + sub = _C_ops.elementwise_add(sub, epsilon) + return _C_ops.p_norm(sub, 'axis', -1, 'porder', p, 'keepdim', keepdim, + 'epsilon', 0.) + + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'PairwiseDistance') + check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'PairwiseDistance') + sub = paddle.subtract(x, y) + if epsilon != 0.0: + epsilon_var = sub.block.create_var(dtype=sub.dtype) + epsilon_var = paddle.full(shape=[1], + fill_value=epsilon, + dtype=sub.dtype) + sub = paddle.add(sub, epsilon_var) + helper = LayerHelper("PairwiseDistance", name=name) + attrs = { + 'axis': -1, + 'porder': p, + 'keepdim': keepdim, + 'epsilon': 0., + } + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op(type='p_norm', + inputs={'X': sub}, + outputs={'Out': out}, + attrs=attrs) + + return out diff --git a/python/paddle/nn/layer/distance.py b/python/paddle/nn/layer/distance.py index 7c08e358fc..a7a488c833 100644 --- a/python/paddle/nn/layer/distance.py +++ b/python/paddle/nn/layer/distance.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,22 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np - -import paddle from .. import Layer -from ...fluid.data_feeder import check_variable_and_dtype, check_type -from ...fluid.layer_helper import LayerHelper -from paddle import _C_ops -from paddle import in_dynamic_mode -from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph +from .. import functional as F __all__ = [] class PairwiseDistance(Layer): r""" - This operator computes the pairwise distance between two vectors. The + It computes the pairwise distance between two vectors. The distance is calculated by p-oreder norm: .. math:: @@ -35,33 +28,31 @@ class PairwiseDistance(Layer): \Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}. Parameters: - p (float): The order of norm. The default value is 2. - epsilon (float, optional): Add small value to avoid division by zero, - default value is 1e-6. + p (float, optional): The order of norm. Default: :math:`2.0`. + epsilon (float, optional): Add small value to avoid division by zero. + Default: :math:`1e-6`. keepdim (bool, optional): Whether to reserve the reduced dimension in the output Tensor. The result tensor is one dimension less than - the result of ``'x-y'`` unless :attr:`keepdim` is True, default - value is False. - name (str, optional): Name for the operation (optional, default is None). - For more information, please refer to :ref:`api_guide_Name`. + the result of ``|x-y|`` unless :attr:`keepdim` is True. Default: False. + name (str, optional): For details, please refer to :ref:`api_guide_Name`. + Generally, no setting is required. Default: None. Shape: - x: :math:`[N, D]` where `D` is the dimension of vector, available dtype - is float32, float64. - y: :math:`[N, D]`, y have the same shape and dtype as x. - out: :math:`[N]`. If :attr:`keepdim` is ``True``, the out shape is :math:`[N, 1]`. - The same dtype as input tensor. + x: :math:`[N, D]` or :math:`[D]`, where :math:`N` is batch size, :math:`D` + is the dimension of the data. Available data type is float32, float64. + y: :math:`[N, D]` or :math:`[D]`, y have the same dtype as x. + output: The same dtype as input tensor. + - If :attr:`keepdim` is True, the output shape is :math:`[N, 1]` or :math:`[1]`, + depending on whether the input has data shaped as :math:`[N, D]`. + - If :attr:`keepdim` is False, the output shape is :math:`[N]` or :math:`[]`, + depending on whether the input has data shaped as :math:`[N, D]`. Examples: .. code-block:: python import paddle - import numpy as np - paddle.disable_static() - x_np = np.array([[1., 3.], [3., 5.]]).astype(np.float64) - y_np = np.array([[5., 6.], [7., 8.]]).astype(np.float64) - x = paddle.to_tensor(x_np) - y = paddle.to_tensor(y_np) + x = paddle.to_tensor([[1., 3.], [3., 5.]], dtype=paddle.float64) + y = paddle.to_tensor([[5., 6.], [7., 8.]], dtype=paddle.float64) dist = paddle.nn.PairwiseDistance() distance = dist(x, y) print(distance.numpy()) # [5. 5.] @@ -74,41 +65,11 @@ class PairwiseDistance(Layer): self.epsilon = epsilon self.keepdim = keepdim self.name = name - check_type(self.p, 'porder', (float, int), 'PairwiseDistance') - check_type(self.epsilon, 'epsilon', (float), 'PairwiseDistance') - check_type(self.keepdim, 'keepdim', (bool), 'PairwiseDistance') def forward(self, x, y): - if in_dygraph_mode(): - sub = _C_ops.elementwise_sub(x, y) - return _C_ops.final_state_p_norm(sub, self.p, 1, self.epsilon, - self.keepdim, False) - - if _in_legacy_dygraph(): - sub = _C_ops.elementwise_sub(x, y) - return _C_ops.p_norm(sub, 'axis', 1, 'porder', self.p, 'keepdim', - self.keepdim, 'epsilon', self.epsilon) - - check_variable_and_dtype(x, 'x', ['float32', 'float64'], - 'PairwiseDistance') - check_variable_and_dtype(y, 'y', ['float32', 'float64'], - 'PairwiseDistance') - sub = paddle.subtract(x, y) - - helper = LayerHelper("PairwiseDistance", name=self.name) - attrs = { - 'axis': 1, - 'porder': self.p, - 'keepdim': self.keepdim, - 'epsilon': self.epsilon, - } - out = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op(type='p_norm', - inputs={'X': sub}, - outputs={'Out': out}, - attrs=attrs) - return out + return F.pairwise_distance(x, y, self.p, self.epsilon, self.keepdim, + self.name) def extra_repr(self): main_str = 'p={p}' -- GitLab