From 361363c321b4b71ed8ac2b785df534ee0aa317eb Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Wed, 12 Aug 2020 17:08:55 +0800 Subject: [PATCH] add pairewise distance for the paddlepaddle api 2.0 add pairewise distance for the paddlepaddle api 2.0 --- .../tests/unittests/test_pairwise_distance.py | 109 ++++++++++++++++++ python/paddle/nn/__init__.py | 1 + python/paddle/nn/layer/__init__.py | 2 + python/paddle/nn/layer/distance.py | 103 +++++++++++++++++ 4 files changed, 215 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_pairwise_distance.py create mode 100644 python/paddle/nn/layer/distance.py diff --git a/python/paddle/fluid/tests/unittests/test_pairwise_distance.py b/python/paddle/fluid/tests/unittests/test_pairwise_distance.py new file mode 100644 index 00000000000..085a717e659 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pairwise_distance.py @@ -0,0 +1,109 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import paddle.fluid as fluid +import numpy as np +import unittest + + +def pairwise_distance(x, y, p=2.0, eps=1e-6, keepdim=False): + return np.linalg.norm(x - y, ord=p, axis=1, keepdims=keepdim) + + +def test_static(x_np, y_np, p=2.0, eps=1e-6, keepdim=False): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + + place = fluid.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + + with paddle.static.program_guard(prog, startup_prog): + x = paddle.data(name='x', shape=x_np.shape, dtype=x_np.dtype) + y = paddle.data(name='y', shape=y_np.shape, dtype=x_np.dtype) + dist = paddle.nn.layer.distance.PairwiseDistance( + p=p, eps=eps, keepdim=keepdim) + distance = dist(x, y) + exe = paddle.static.Executor(place) + static_ret = exe.run(prog, + feed={'x': x_np, + 'y': y_np}, + fetch_list=[distance]) + static_ret = static_ret[0] + return static_ret + + +def test_dygraph(x_np, y_np, p=2.0, eps=1e-6, keepdim=False): + paddle.disable_static() + x = paddle.to_variable(x_np) + y = paddle.to_variable(y_np) + dist = paddle.nn.layer.distance.PairwiseDistance( + p=p, eps=eps, keepdim=keepdim) + distance = dist(x, y) + dygraph_ret = distance.numpy() + paddle.enable_static() + return dygraph_ret + + +class TestPairwiseDistance(unittest.TestCase): + def test_pairwise_distance(self): + all_shape = [[100, 100], [4, 5, 6, 7]] + dtypes = ['float32', 'float64'] + keeps = [False, True] + for shape in all_shape: + for dtype in dtypes: + for keepdim in keeps: + x_np = np.random.random(shape).astype(dtype) + y_np = np.random.random(shape).astype(dtype) + + static_ret = test_static(x_np, y_np, keepdim=keepdim) + dygraph_ret = test_dygraph(x_np, y_np, keepdim=keepdim) + excepted_value = pairwise_distance( + x_np, y_np, keepdim=keepdim) + + self.assertTrue(np.allclose(static_ret, dygraph_ret)) + self.assertTrue(np.allclose(static_ret, excepted_value)) + self.assertTrue(np.allclose(dygraph_ret, excepted_value)) + + def test_pairwise_distance_broadcast(self): + shape_x = [100, 100] + shape_y = [100, 1] + keepdim = False + x_np = np.random.random(shape_x).astype('float32') + y_np = np.random.random(shape_y).astype('float32') + static_ret = test_static(x_np, y_np, keepdim=keepdim) + dygraph_ret = test_dygraph(x_np, y_np, keepdim=keepdim) + excepted_value = pairwise_distance(x_np, y_np, keepdim=keepdim) + self.assertTrue(np.allclose(static_ret, dygraph_ret)) + self.assertTrue(np.allclose(static_ret, excepted_value)) + self.assertTrue(np.allclose(dygraph_ret, excepted_value)) + + def test_pairwise_distance_different_p(self): + shape = [100, 100] + keepdim = False + p = 3.0 + x_np = np.random.random(shape).astype('float32') + y_np = np.random.random(shape).astype('float32') + static_ret = test_static(x_np, y_np, p=p, keepdim=keepdim) + dygraph_ret = test_dygraph(x_np, y_np, p=p, keepdim=keepdim) + excepted_value = pairwise_distance(x_np, y_np, p=p, keepdim=keepdim) + self.assertTrue(np.allclose(static_ret, dygraph_ret)) + self.assertTrue(np.allclose(static_ret, excepted_value)) + self.assertTrue(np.allclose(dygraph_ret, excepted_value)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 98948fa91e2..aac6b401685 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -93,6 +93,7 @@ from .layer.norm import InstanceNorm #DEFINE_ALIAS # from .layer.rnn import RNNCell #DEFINE_ALIAS # from .layer.rnn import GRUCell #DEFINE_ALIAS # from .layer.rnn import LSTMCell #DEFINE_ALIAS +from .layer.distance import PairwiseDistance #DEFINE_ALIAS from .layer import loss #DEFINE_ALIAS from .layer import conv #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 7173c5b5877..560314788a1 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -20,6 +20,7 @@ from . import conv from . import extension from . import activation from . import norm +from . import distance from .activation import * from .loss import * @@ -69,3 +70,4 @@ from .norm import InstanceNorm #DEFINE_ALIAS # from .rnn import RNNCell #DEFINE_ALIAS # from .rnn import GRUCell #DEFINE_ALIAS # from .rnn import LSTMCell #DEFINE_ALIAS +from .distance import PairwiseDistance #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/distance.py b/python/paddle/nn/layer/distance.py new file mode 100644 index 00000000000..73ad60b9796 --- /dev/null +++ b/python/paddle/nn/layer/distance.py @@ -0,0 +1,103 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['PairwiseDistance'] + +import numpy as np + +import paddle +from ...fluid.dygraph import layers +from ...fluid.framework import core, in_dygraph_mode +from ...fluid.data_feeder import check_variable_and_dtype, check_type +from ...fluid.layer_helper import LayerHelper + + +class PairwiseDistance(layers.Layer): + """ + This operator computes the pairwise distance between two vectors. The + distance is calculated by p-oreder norm: + + .. math:: + + \Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}. + + Parameters: + p (float): The order of norm. The default value is 2. + eps (float, optional): Add small value to avoid division by zero, + default value is 1e-6. + keepdim (bool, optional): Whether to reserve the reduced dimension + in the output Tensor. The result tensor is one dimension less than + the result of ``'x-y'`` unless :attr:`keepdim` is True, default + value is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + x: :math:`(N, D)` where `D` is the dimension of vector, available dtype + is float32, float64. + y: :math:`(N, D)`, y have the same shape and dtype as x. + out: :math:`(N)`. If :attr:`keepdim` is ``True``, the out shape is :math:`(N, 1)`. + The same dtype as input tensor. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + paddle.disable_static() + x_np = np.array([[1., 3.], [3., 5.]]).astype(np.float64) + y_np = np.array([[5., 6.], [7., 8.]]).astype(np.float64) + x = paddle.to_variable(x_np) + y = paddle.to_variable(y_np) + dist = paddle.nn.PairwiseDistance() + distance = dist(x, y) + print(distance.numpy()) # [5. 5.] + + """ + + def __init__(self, p=2., eps=1e-6, keepdim=False, name=None): + super(PairwiseDistance, self).__init__() + self.p = p + self.eps = eps + self.keepdim = keepdim + self.name = name + check_type(self.p, 'porder', (float, int), 'PairwiseDistance') + check_type(self.eps, 'epsilon', (float), 'PairwiseDistance') + check_type(self.keepdim, 'keepdim', (bool), 'PairwiseDistance') + + def forward(self, x, y): + if in_dygraph_mode(): + sub = core.ops.elementwise_sub(x, y) + return core.ops.p_norm(sub, 'axis', 1, 'porder', self.p, 'keepdim', + self.keepdim, 'epsilon', self.eps) + + check_variable_and_dtype(x, 'x', ['float32', 'float64'], + 'PairwiseDistance') + check_variable_and_dtype(y, 'y', ['float32', 'float64'], + 'PairwiseDistance') + sub = paddle.elementwise_sub(x, y) + + helper = LayerHelper("p_norm", name=self.name) + attrs = { + 'axis': 1, + 'porder': self.p, + 'keepdim': self.keepdim, + 'epsilon': self.eps, + } + out = helper.create_variable_for_type_inference( + dtype=self._helper.input_dtype(x)) + helper.append_op( + type='p_norm', inputs={'X': sub}, outputs={'Out': out}, attrs=attrs) + + return out -- GitLab