未验证 提交 46be6854 编写于 作者: A Ainavo 提交者: GitHub

【PaddlePaddle Hackathon 3 No.12】为 Paddle 新增 pairwise_distance (#44161)

* add paddle.nn.functional.pairwise_distance (cattidea/Paddle#273)
* remove the test case for undefined behavior
Co-authored-by: NSigureMo <sigure.qaq@gmail.com>
上级 88490567
...@@ -20,24 +20,64 @@ import numpy as np ...@@ -20,24 +20,64 @@ import numpy as np
import unittest import unittest
def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False): def np_pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False):
return np.linalg.norm(x - y, ord=p, axis=1, keepdims=keepdim) distance = np.linalg.norm(x - y + epsilon, ord=p, axis=-1, keepdims=keepdim)
# Paddle currently has not supported for 0-d Tensors, so even if keep_dim is False,
# and neither x nor y is batched, a Tensor of shape (1, ) is returned
if distance.ndim == 0:
distance = np.expand_dims(distance, axis=0)
return distance
def test_static(x_np, y_np, p=2.0, epsilon=1e-6, keepdim=False): def call_pairwise_distance_layer(x, y, p=2., epsilon=1e-6, keepdim='False'):
pairwise_distance = paddle.nn.PairwiseDistance(p=p,
epsilon=epsilon,
keepdim=keepdim)
distance = pairwise_distance(x=x, y=y)
return distance
def call_pairwise_distance_functional(x,
y,
p=2.,
epsilon=1e-6,
keepdim='False'):
distance = paddle.nn.functional.pairwise_distance(x=x,
y=y,
p=p,
epsilon=epsilon,
keepdim=keepdim)
return distance
def test_static(place,
x_np,
y_np,
p=2.0,
epsilon=1e-6,
keepdim=False,
functional=False):
prog = paddle.static.Program() prog = paddle.static.Program()
startup_prog = paddle.static.Program() startup_prog = paddle.static.Program()
place = fluid.CUDAPlace( place = fluid.CUDAPlace(
0) if paddle.fluid.core.is_compiled_with_cuda() else fluid.CPUPlace() 0) if paddle.fluid.core.is_compiled_with_cuda() else fluid.CPUPlace()
paddle.enable_static()
with paddle.static.program_guard(prog, startup_prog): with paddle.static.program_guard(prog, startup_prog):
x = paddle.fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype) x = paddle.fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype)
y = paddle.fluid.data(name='y', shape=y_np.shape, dtype=x_np.dtype) y = paddle.fluid.data(name='y', shape=y_np.shape, dtype=x_np.dtype)
dist = paddle.nn.layer.distance.PairwiseDistance(p=p,
if functional:
distance = call_pairwise_distance_functional(x=x,
y=y,
p=p,
epsilon=epsilon, epsilon=epsilon,
keepdim=keepdim) keepdim=keepdim)
distance = dist(x, y) else:
distance = call_pairwise_distance_layer(x=x,
y=y,
p=p,
epsilon=epsilon,
keepdim=keepdim)
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
static_ret = exe.run(prog, static_ret = exe.run(prog,
feed={ feed={
...@@ -46,69 +86,279 @@ def test_static(x_np, y_np, p=2.0, epsilon=1e-6, keepdim=False): ...@@ -46,69 +86,279 @@ def test_static(x_np, y_np, p=2.0, epsilon=1e-6, keepdim=False):
}, },
fetch_list=[distance]) fetch_list=[distance])
static_ret = static_ret[0] static_ret = static_ret[0]
paddle.disable_static()
return static_ret return static_ret
def test_dygraph(x_np, y_np, p=2.0, epsilon=1e-6, keepdim=False): def test_dygraph(place,
paddle.disable_static() x_np,
y_np,
p=2.0,
epsilon=1e-6,
keepdim=False,
functional=False):
x = paddle.to_tensor(x_np) x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np) y = paddle.to_tensor(y_np)
dist = paddle.nn.layer.distance.PairwiseDistance(p=p, if functional:
epsilon=epsilon, dy_distance = call_pairwise_distance_functional(x=x,
keepdim=keepdim) y=y,
distance = dist(x, y) p=p,
dygraph_ret = distance.numpy() epsilon=epsilon,
paddle.enable_static() keepdim=keepdim)
else:
dy_distance = call_pairwise_distance_layer(x=x,
y=y,
p=p,
epsilon=epsilon,
keepdim=keepdim)
dygraph_ret = dy_distance.numpy()
return dygraph_ret return dygraph_ret
def test_legacy_dygraph(place,
x_np,
y_np,
p=2.0,
epsilon=1e-6,
keepdim=False,
functional=False):
paddle.fluid.framework._enable_legacy_dygraph()
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
if functional:
legacy_distance = call_pairwise_distance_functional(x=x,
y=y,
p=p,
epsilon=epsilon,
keepdim=keepdim)
else:
legacy_distance = call_pairwise_distance_layer(x=x,
y=y,
p=p,
epsilon=epsilon,
keepdim=keepdim)
legacy_ret = legacy_distance.numpy()
paddle.fluid.framework._disable_legacy_dygraph()
return legacy_ret
class TestPairwiseDistance(unittest.TestCase): class TestPairwiseDistance(unittest.TestCase):
def test_pairwise_distance(self): def test_pairwise_distance(self):
all_shape = [[100, 100], [4, 5, 6, 7]] epsilon = 1e-6
all_shape = [[5], [100, 100]]
dtypes = ['float32', 'float64'] dtypes = ['float32', 'float64']
p_list = [-1, 0, 1, 2, np.inf, -np.inf]
places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
keeps = [False, True] keeps = [False, True]
for shape in all_shape: for place in places:
for dtype in dtypes: for shape in all_shape:
for keepdim in keeps: for dtype in dtypes:
x_np = np.random.random(shape).astype(dtype) for p in p_list:
y_np = np.random.random(shape).astype(dtype) for keepdim in keeps:
x_np = np.random.random(shape).astype(dtype)
static_ret = test_static(x_np, y_np, keepdim=keepdim) y_np = np.random.random(shape).astype(dtype)
dygraph_ret = test_dygraph(x_np, y_np, keepdim=keepdim)
excepted_value = pairwise_distance(x_np, static_ret = test_static(place,
x_np,
y_np,
p,
epsilon=epsilon,
keepdim=keepdim)
dygraph_ret = test_dygraph(place,
x_np,
y_np, y_np,
p,
epsilon=epsilon,
keepdim=keepdim) keepdim=keepdim)
legacy_ret = test_legacy_dygraph(place,
x_np,
y_np,
p,
epsilon=epsilon,
keepdim=keepdim)
excepted_value = np_pairwise_distance(
x_np, y_np, p, epsilon=epsilon, keepdim=keepdim)
self.assertEqual(static_ret.shape,
excepted_value.shape)
self.assertEqual(dygraph_ret.shape,
excepted_value.shape)
self.assertEqual(legacy_ret.shape,
excepted_value.shape)
self.assertTrue(
np.allclose(static_ret, excepted_value))
self.assertTrue(
np.allclose(dygraph_ret, excepted_value))
self.assertTrue(
np.allclose(legacy_ret, excepted_value))
static_functional_ret = test_static(place,
x_np,
y_np,
p,
epsilon=epsilon,
keepdim=keepdim)
dygraph_functional_ret = test_dygraph(
place,
x_np,
y_np,
p,
epsilon=epsilon,
keepdim=keepdim)
legacy_functional_ret = test_legacy_dygraph(
place,
x_np,
y_np,
p,
epsilon=epsilon,
keepdim=keepdim)
self.assertTrue(np.allclose(static_ret, dygraph_ret)) self.assertEqual(static_functional_ret.shape,
self.assertTrue(np.allclose(static_ret, excepted_value)) excepted_value.shape)
self.assertTrue(np.allclose(dygraph_ret, excepted_value)) self.assertEqual(dygraph_functional_ret.shape,
excepted_value.shape)
self.assertEqual(legacy_functional_ret.shape,
excepted_value.shape)
def test_pairwise_distance_broadcast(self): self.assertTrue(
np.allclose(static_functional_ret,
excepted_value))
self.assertTrue(
np.allclose(dygraph_functional_ret,
excepted_value))
self.assertTrue(
np.allclose(legacy_functional_ret,
excepted_value))
def test_pairwise_distance_broadcast_1(self):
shape_x = [100, 100] shape_x = [100, 100]
shape_y = [100, 1] shape_y = [100, 1]
epsilon = 1e-6
keepdim = False keepdim = False
place = paddle.CPUPlace()
x_np = np.random.random(shape_x).astype('float32') x_np = np.random.random(shape_x).astype('float32')
y_np = np.random.random(shape_y).astype('float32') y_np = np.random.random(shape_y).astype('float32')
static_ret = test_static(x_np, y_np, keepdim=keepdim) static_ret = test_static(place=place,
dygraph_ret = test_dygraph(x_np, y_np, keepdim=keepdim) x_np=x_np,
excepted_value = pairwise_distance(x_np, y_np, keepdim=keepdim) y_np=y_np,
self.assertTrue(np.allclose(static_ret, dygraph_ret)) epsilon=epsilon,
keepdim=keepdim)
dygraph_ret = test_dygraph(place=place,
x_np=x_np,
y_np=y_np,
epsilon=epsilon,
keepdim=keepdim)
legacy_ret = test_legacy_dygraph(place=place,
x_np=x_np,
y_np=y_np,
epsilon=epsilon,
keepdim=keepdim)
excepted_value = np_pairwise_distance(x_np,
y_np,
epsilon=epsilon,
keepdim=keepdim)
self.assertEqual(static_ret.shape, excepted_value.shape)
self.assertEqual(dygraph_ret.shape, excepted_value.shape)
self.assertEqual(legacy_ret.shape, excepted_value.shape)
self.assertTrue(np.allclose(static_ret, excepted_value)) self.assertTrue(np.allclose(static_ret, excepted_value))
self.assertTrue(np.allclose(dygraph_ret, excepted_value)) self.assertTrue(np.allclose(dygraph_ret, excepted_value))
self.assertTrue(np.allclose(legacy_ret, excepted_value))
static_functional_ret = test_static(place=place,
x_np=x_np,
y_np=y_np,
epsilon=epsilon,
keepdim=keepdim,
functional=True)
dygraph_functional_ret = test_dygraph(place=place,
x_np=x_np,
y_np=y_np,
epsilon=epsilon,
keepdim=keepdim,
functional=True)
legacy_functional_ret = test_legacy_dygraph(place=place,
x_np=x_np,
y_np=y_np,
epsilon=epsilon,
keepdim=keepdim,
functional=True)
self.assertEqual(static_functional_ret.shape, excepted_value.shape)
self.assertEqual(dygraph_functional_ret.shape, excepted_value.shape)
self.assertEqual(legacy_functional_ret.shape, excepted_value.shape)
def test_pairwise_distance_different_p(self): self.assertTrue(np.allclose(static_functional_ret, excepted_value))
shape = [100, 100] self.assertTrue(np.allclose(dygraph_functional_ret, excepted_value))
self.assertTrue(np.allclose(legacy_functional_ret, excepted_value))
def test_pairwise_distance_broadcast_2(self):
shape_x = [100, 100]
shape_y = [100]
epsilon = 1e-6
keepdim = False keepdim = False
p = 3.0 place = paddle.CPUPlace()
x_np = np.random.random(shape).astype('float32') x_np = np.random.random(shape_x).astype('float32')
y_np = np.random.random(shape).astype('float32') y_np = np.random.random(shape_y).astype('float32')
static_ret = test_static(x_np, y_np, p=p, keepdim=keepdim) static_ret = test_static(place=place,
dygraph_ret = test_dygraph(x_np, y_np, p=p, keepdim=keepdim) x_np=x_np,
excepted_value = pairwise_distance(x_np, y_np, p=p, keepdim=keepdim) y_np=y_np,
self.assertTrue(np.allclose(static_ret, dygraph_ret)) epsilon=epsilon,
keepdim=keepdim)
dygraph_ret = test_dygraph(place=place,
x_np=x_np,
y_np=y_np,
epsilon=epsilon,
keepdim=keepdim)
legacy_ret = test_legacy_dygraph(place=place,
x_np=x_np,
y_np=y_np,
epsilon=epsilon,
keepdim=keepdim)
excepted_value = np_pairwise_distance(x_np,
y_np,
epsilon=epsilon,
keepdim=keepdim)
self.assertEqual(static_ret.shape, excepted_value.shape)
self.assertEqual(dygraph_ret.shape, excepted_value.shape)
self.assertEqual(legacy_ret.shape, excepted_value.shape)
self.assertTrue(np.allclose(static_ret, excepted_value)) self.assertTrue(np.allclose(static_ret, excepted_value))
self.assertTrue(np.allclose(dygraph_ret, excepted_value)) self.assertTrue(np.allclose(dygraph_ret, excepted_value))
self.assertTrue(np.allclose(legacy_ret, excepted_value))
static_functional_ret = test_static(place=place,
x_np=x_np,
y_np=y_np,
epsilon=epsilon,
keepdim=keepdim,
functional=True)
dygraph_functional_ret = test_dygraph(place=place,
x_np=x_np,
y_np=y_np,
epsilon=epsilon,
keepdim=keepdim,
functional=True)
legacy_functional_ret = test_legacy_dygraph(place=place,
x_np=x_np,
y_np=y_np,
epsilon=epsilon,
keepdim=keepdim,
functional=True)
self.assertEqual(static_functional_ret.shape, excepted_value.shape)
self.assertEqual(dygraph_functional_ret.shape, excepted_value.shape)
self.assertEqual(legacy_functional_ret.shape, excepted_value.shape)
self.assertTrue(np.allclose(static_functional_ret, excepted_value))
self.assertTrue(np.allclose(dygraph_functional_ret, excepted_value))
self.assertTrue(np.allclose(legacy_functional_ret, excepted_value))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -69,6 +69,7 @@ from .conv import conv2d # noqa: F401 ...@@ -69,6 +69,7 @@ from .conv import conv2d # noqa: F401
from .conv import conv2d_transpose # noqa: F401 from .conv import conv2d_transpose # noqa: F401
from .conv import conv3d # noqa: F401 from .conv import conv3d # noqa: F401
from .conv import conv3d_transpose # noqa: F401 from .conv import conv3d_transpose # noqa: F401
from .distance import pairwise_distance # noqa: F401
from .extension import diag_embed # noqa: F401 from .extension import diag_embed # noqa: F401
from .extension import sequence_mask from .extension import sequence_mask
from .loss import binary_cross_entropy # noqa: F401 from .loss import binary_cross_entropy # noqa: F401
...@@ -137,6 +138,7 @@ __all__ = [ # noqa ...@@ -137,6 +138,7 @@ __all__ = [ # noqa
'conv2d_transpose', 'conv2d_transpose',
'conv3d', 'conv3d',
'conv3d_transpose', 'conv3d_transpose',
'pairwise_distance',
'elu', 'elu',
'elu_', 'elu_',
'gelu', 'gelu',
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from ...fluid.data_feeder import check_variable_and_dtype, check_type
from ...fluid.layer_helper import LayerHelper
from paddle import _C_ops
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
__all__ = []
def pairwise_distance(x, y, p=2., epsilon=1e-6, keepdim=False, name=None):
r"""
It computes the pairwise distance between two vectors. The
distance is calculated by p-oreder norm:
.. math::
\Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}.
Parameters:
x (Tensor): Tensor, shape is :math:`[N, D]` or :math:`[D]`, where :math:`N`
is batch size, :math:`D` is the dimension of vector. Available dtype is
float32, float64.
y (Tensor): Tensor, shape is :math:`[N, D]` or :math:`[D]`, where :math:`N`
is batch size, :math:`D` is the dimension of vector. Available dtype is
float32, float64.
p (float, optional): The order of norm. Default: :math:`2.0`.
epsilon (float, optional): Add small value to avoid division by zero.
Default: :math:`1e-6`.
keepdim (bool, optional): Whether to reserve the reduced dimension
in the output Tensor. The result tensor is one dimension less than
the result of ``|x-y|`` unless :attr:`keepdim` is True. Default: False.
name (str, optional): For details, please refer to :ref:`api_guide_Name`.
Generally, no setting is required. Default: None.
Returns:
Tensor, the dtype is same as input tensor.
- If :attr:`keepdim` is True, the output shape is :math:`[N, 1]` or :math:`[1]`,
depending on whether the input has data shaped as :math:`[N, D]`.
- If :attr:`keepdim` is False, the output shape is :math:`[N]` or :math:`[]`,
depending on whether the input has data shaped as :math:`[N, D]`.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([[1., 3.], [3., 5.]], dtype=paddle.float64)
y = paddle.to_tensor([[5., 6.], [7., 8.]], dtype=paddle.float64)
distance = paddle.nn.functional.pairwise_distance(x, y)
print(distance.numpy()) # [5. 5.]
"""
check_type(p, 'porder', (float, int), 'PairwiseDistance')
check_type(epsilon, 'epsilon', (float), 'PairwiseDistance')
check_type(keepdim, 'keepdim', (bool), 'PairwiseDistance')
if in_dygraph_mode():
sub = _C_ops.elementwise_sub(x, y)
# p_norm op has not uesd epsilon, so change it to the following.
if epsilon != 0.0:
epsilon = paddle.fluid.dygraph.base.to_variable([epsilon],
dtype=sub.dtype)
sub = _C_ops.elementwise_add(sub, epsilon)
return _C_ops.final_state_p_norm(sub, p, -1, 0., keepdim, False)
if _in_legacy_dygraph():
sub = _C_ops.elementwise_sub(x, y)
if epsilon != 0.0:
epsilon = paddle.fluid.dygraph.base.to_variable([epsilon],
dtype=sub.dtype)
sub = _C_ops.elementwise_add(sub, epsilon)
return _C_ops.p_norm(sub, 'axis', -1, 'porder', p, 'keepdim', keepdim,
'epsilon', 0.)
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'PairwiseDistance')
check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'PairwiseDistance')
sub = paddle.subtract(x, y)
if epsilon != 0.0:
epsilon_var = sub.block.create_var(dtype=sub.dtype)
epsilon_var = paddle.full(shape=[1],
fill_value=epsilon,
dtype=sub.dtype)
sub = paddle.add(sub, epsilon_var)
helper = LayerHelper("PairwiseDistance", name=name)
attrs = {
'axis': -1,
'porder': p,
'keepdim': keepdim,
'epsilon': 0.,
}
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='p_norm',
inputs={'X': sub},
outputs={'Out': out},
attrs=attrs)
return out
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,22 +12,15 @@ ...@@ -12,22 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import paddle
from .. import Layer from .. import Layer
from ...fluid.data_feeder import check_variable_and_dtype, check_type from .. import functional as F
from ...fluid.layer_helper import LayerHelper
from paddle import _C_ops
from paddle import in_dynamic_mode
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
__all__ = [] __all__ = []
class PairwiseDistance(Layer): class PairwiseDistance(Layer):
r""" r"""
This operator computes the pairwise distance between two vectors. The It computes the pairwise distance between two vectors. The
distance is calculated by p-oreder norm: distance is calculated by p-oreder norm:
.. math:: .. math::
...@@ -35,33 +28,31 @@ class PairwiseDistance(Layer): ...@@ -35,33 +28,31 @@ class PairwiseDistance(Layer):
\Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}. \Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}.
Parameters: Parameters:
p (float): The order of norm. The default value is 2. p (float, optional): The order of norm. Default: :math:`2.0`.
epsilon (float, optional): Add small value to avoid division by zero, epsilon (float, optional): Add small value to avoid division by zero.
default value is 1e-6. Default: :math:`1e-6`.
keepdim (bool, optional): Whether to reserve the reduced dimension keepdim (bool, optional): Whether to reserve the reduced dimension
in the output Tensor. The result tensor is one dimension less than in the output Tensor. The result tensor is one dimension less than
the result of ``'x-y'`` unless :attr:`keepdim` is True, default the result of ``|x-y|`` unless :attr:`keepdim` is True. Default: False.
value is False. name (str, optional): For details, please refer to :ref:`api_guide_Name`.
name (str, optional): Name for the operation (optional, default is None). Generally, no setting is required. Default: None.
For more information, please refer to :ref:`api_guide_Name`.
Shape: Shape:
x: :math:`[N, D]` where `D` is the dimension of vector, available dtype x: :math:`[N, D]` or :math:`[D]`, where :math:`N` is batch size, :math:`D`
is float32, float64. is the dimension of the data. Available data type is float32, float64.
y: :math:`[N, D]`, y have the same shape and dtype as x. y: :math:`[N, D]` or :math:`[D]`, y have the same dtype as x.
out: :math:`[N]`. If :attr:`keepdim` is ``True``, the out shape is :math:`[N, 1]`. output: The same dtype as input tensor.
The same dtype as input tensor. - If :attr:`keepdim` is True, the output shape is :math:`[N, 1]` or :math:`[1]`,
depending on whether the input has data shaped as :math:`[N, D]`.
- If :attr:`keepdim` is False, the output shape is :math:`[N]` or :math:`[]`,
depending on whether the input has data shaped as :math:`[N, D]`.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import numpy as np x = paddle.to_tensor([[1., 3.], [3., 5.]], dtype=paddle.float64)
paddle.disable_static() y = paddle.to_tensor([[5., 6.], [7., 8.]], dtype=paddle.float64)
x_np = np.array([[1., 3.], [3., 5.]]).astype(np.float64)
y_np = np.array([[5., 6.], [7., 8.]]).astype(np.float64)
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
dist = paddle.nn.PairwiseDistance() dist = paddle.nn.PairwiseDistance()
distance = dist(x, y) distance = dist(x, y)
print(distance.numpy()) # [5. 5.] print(distance.numpy()) # [5. 5.]
...@@ -74,41 +65,11 @@ class PairwiseDistance(Layer): ...@@ -74,41 +65,11 @@ class PairwiseDistance(Layer):
self.epsilon = epsilon self.epsilon = epsilon
self.keepdim = keepdim self.keepdim = keepdim
self.name = name self.name = name
check_type(self.p, 'porder', (float, int), 'PairwiseDistance')
check_type(self.epsilon, 'epsilon', (float), 'PairwiseDistance')
check_type(self.keepdim, 'keepdim', (bool), 'PairwiseDistance')
def forward(self, x, y): def forward(self, x, y):
if in_dygraph_mode():
sub = _C_ops.elementwise_sub(x, y)
return _C_ops.final_state_p_norm(sub, self.p, 1, self.epsilon,
self.keepdim, False)
if _in_legacy_dygraph():
sub = _C_ops.elementwise_sub(x, y)
return _C_ops.p_norm(sub, 'axis', 1, 'porder', self.p, 'keepdim',
self.keepdim, 'epsilon', self.epsilon)
check_variable_and_dtype(x, 'x', ['float32', 'float64'],
'PairwiseDistance')
check_variable_and_dtype(y, 'y', ['float32', 'float64'],
'PairwiseDistance')
sub = paddle.subtract(x, y)
helper = LayerHelper("PairwiseDistance", name=self.name)
attrs = {
'axis': 1,
'porder': self.p,
'keepdim': self.keepdim,
'epsilon': self.epsilon,
}
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='p_norm',
inputs={'X': sub},
outputs={'Out': out},
attrs=attrs)
return out return F.pairwise_distance(x, y, self.p, self.epsilon, self.keepdim,
self.name)
def extra_repr(self): def extra_repr(self):
main_str = 'p={p}' main_str = 'p={p}'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册