未验证 提交 587120ec 编写于 作者: A Ainavo 提交者: GitHub

[fp16] fix fp16 support for nn.PairwiseDistance (#50849)

上级 ebea0885
......@@ -286,6 +286,97 @@ class TestPairwiseDistance(unittest.TestCase):
dygraph_functional_ret, excepted_value, rtol=1e-05
)
def test_pairwise_distance_fp16(self):
epsilon = 1e-6
all_shape = [[5], [100, 100]]
dtypes = ['float16']
p_list = [-1, 0, 1, 2, np.inf, -np.inf]
places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
keeps = [False, True]
for place in places:
for shape in all_shape:
for dtype in dtypes:
for p in p_list:
for keepdim in keeps:
x_np = np.random.random(shape).astype(dtype)
y_np = np.random.random(shape).astype(dtype)
# Currently, the CPU does not support float16
if dtype == "float16" and isinstance(
place, paddle.CPUPlace
):
continue
static_ret = test_static(
place,
x_np,
y_np,
p,
epsilon=epsilon,
keepdim=keepdim,
)
dygraph_ret = test_dygraph(
place,
x_np,
y_np,
p,
epsilon=epsilon,
keepdim=keepdim,
)
excepted_value = np_pairwise_distance(
x_np, y_np, p, epsilon=epsilon, keepdim=keepdim
)
self.assertEqual(
static_ret.shape, excepted_value.shape
)
self.assertEqual(
dygraph_ret.shape, excepted_value.shape
)
np.testing.assert_allclose(
static_ret, excepted_value, atol=1e-03
)
np.testing.assert_allclose(
dygraph_ret, excepted_value, atol=1e-03
)
static_functional_ret = test_static(
place,
x_np,
y_np,
p,
epsilon=epsilon,
keepdim=keepdim,
)
dygraph_functional_ret = test_dygraph(
place,
x_np,
y_np,
p,
epsilon=epsilon,
keepdim=keepdim,
)
self.assertEqual(
static_functional_ret.shape,
excepted_value.shape,
)
self.assertEqual(
dygraph_functional_ret.shape,
excepted_value.shape,
)
np.testing.assert_allclose(
static_functional_ret,
excepted_value,
atol=1e-03,
)
np.testing.assert_allclose(
dygraph_functional_ret,
excepted_value,
atol=1e-03,
)
if __name__ == "__main__":
unittest.main()
......@@ -35,10 +35,10 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None):
Parameters:
x (Tensor): Tensor, shape is :math:`[N, D]` or :math:`[D]`, where :math:`N`
is batch size, :math:`D` is the dimension of vector. Available dtype is
float32, float64.
float16, float32, float64.
y (Tensor): Tensor, shape is :math:`[N, D]` or :math:`[D]`, where :math:`N`
is batch size, :math:`D` is the dimension of vector. Available dtype is
float32, float64.
float16, float32, float64.
p (float, optional): The order of norm. Default: :math:`2.0`.
epsilon (float, optional): Add small value to avoid division by zero.
Default: :math:`1e-6`.
......@@ -84,10 +84,10 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None):
check_type(keepdim, 'keepdim', (bool), 'PairwiseDistance')
check_variable_and_dtype(
x, 'x', ['float32', 'float64'], 'PairwiseDistance'
x, 'x', ['float16', 'float32', 'float64'], 'PairwiseDistance'
)
check_variable_and_dtype(
y, 'y', ['float32', 'float64'], 'PairwiseDistance'
y, 'y', ['float16', 'float32', 'float64'], 'PairwiseDistance'
)
sub = paddle.subtract(x, y)
if epsilon != 0.0:
......
......@@ -40,7 +40,7 @@ class PairwiseDistance(Layer):
Shape:
- x: :math:`[N, D]` or :math:`[D]`, where :math:`N` is batch size, :math:`D`
is the dimension of the data. Available data type is float32, float64.
is the dimension of the data. Available data type is float16, float32, float64.
- y: :math:`[N, D]` or :math:`[D]`, y have the same dtype as x.
- output: The same dtype as input tensor.
- If :attr:`keepdim` is True, the output shape is :math:`[N, 1]` or :math:`[1]`,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册