From c77eb1fd06d944cbd373936c4f2a454b114dc1db Mon Sep 17 00:00:00 2001 From: Tao Luo Date: Tue, 28 Feb 2023 11:15:46 +0800 Subject: [PATCH] zero-dim support for gcd and lcm (#50950) --- .../paddle/fluid/tests/unittests/test_zero_dim_tensor.py | 2 ++ python/paddle/tensor/math.py | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 9a01f01e10d..4c6f26f0e55 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -282,6 +282,8 @@ binary_int_api_list = [ paddle.bitwise_and, paddle.bitwise_or, paddle.bitwise_xor, + paddle.gcd, + paddle.lcm, ] diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 046ab64aa4f..7ca8b4f6a0c 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4449,15 +4449,19 @@ def gcd(x, y, name=None): y = paddle.broadcast_to(y, shape) x = paddle.abs(x) y = paddle.abs(y) + # TODO(zhouwei25): Support 0D for not_equal tensor with scalar + zero = paddle.full([], 0) def _gcd_cond_fn(x, y): - return paddle.any(y != 0) + # return paddle.any(y != 0) + return paddle.any(y != zero) def _gcd_body_fn(x, y): # paddle.mod will raise an error when any element of y is 0. To avoid # that, we change those zeros to ones. Their values don't matter because # they won't be used. - y_not_equal_0 = y != 0 + # y_not_equal_0 = y != 0 + y_not_equal_0 = y != zero y_safe = paddle.where(y_not_equal_0, y, paddle.ones(y.shape, y.dtype)) x, y = ( paddle.where(y_not_equal_0, y, x), -- GitLab