From e588f2d93fca03fde1f8542d4f5d027f028b165b Mon Sep 17 00:00:00 2001 From: zhwesky2010 <1183042833@qq.com> Date: Tue, 9 May 2023 17:31:18 +0800 Subject: [PATCH] [Zero-Dim] add 0D Tensor UT case for XPU and expand kernel support 0D (#53555) * [Zero-Dim] add 0D Tensor UT case for XPU * fix comment * remove some unnecessary UT --- paddle/phi/kernels/xpu/expand_grad_kernel.cc | 3 +- paddle/phi/kernels/xpu/expand_kernel.cc | 17 +- test/xpu/test_zero_dim_tensor_xpu.py | 1458 ++++++++++++++++-- 3 files changed, 1317 insertions(+), 161 deletions(-) diff --git a/paddle/phi/kernels/xpu/expand_grad_kernel.cc b/paddle/phi/kernels/xpu/expand_grad_kernel.cc index 52fc0fd38f1..1665b8e3192 100644 --- a/paddle/phi/kernels/xpu/expand_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/expand_grad_kernel.cc @@ -37,7 +37,8 @@ void ExpandGradKernel(const Context& ctx, // Two zero if (out_grad_dims.size() == 0 && in_grad_dims.size() == 0) { - return; + out_grad_dims = {1}; + in_grad_dims = {1}; } int r = xpu::expand_grad( diff --git a/paddle/phi/kernels/xpu/expand_kernel.cc b/paddle/phi/kernels/xpu/expand_kernel.cc index 10b8c18e9d3..d8808d3c3aa 100644 --- a/paddle/phi/kernels/xpu/expand_kernel.cc +++ b/paddle/phi/kernels/xpu/expand_kernel.cc @@ -94,26 +94,17 @@ void ExpandKernel(const Context& ctx, shape_size, rank)); - if (shape_size == 0) { - phi::DDim out_dims = phi::make_ddim(final_expand_shape); - out->Resize(out_dims); - ctx.template Alloc(out); - - int r = xpu::copy(ctx.x_context(), - reinterpret_cast(x.data()), - reinterpret_cast(out->data()), - x.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); - return; - } DDim out_dims = phi::make_ddim(final_expand_shape); out->Resize(out_dims); ctx.template Alloc(out); auto& x_shape = vec_in_dims; auto out_shape = phi::vectorize(out_dims); + if (shape_size == 0) { + x_shape = {1}; + out_shape = {1}; + } int r = XPU_SUCCESS; - if (std::is_same::value) { auto x_data = reinterpret_cast(x.data()); auto out_data = reinterpret_cast(out->data()); diff --git a/test/xpu/test_zero_dim_tensor_xpu.py b/test/xpu/test_zero_dim_tensor_xpu.py index 9369a9b0ed3..7591b3a402f 100644 --- a/test/xpu/test_zero_dim_tensor_xpu.py +++ b/test/xpu/test_zero_dim_tensor_xpu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Note: +# 0D Tensor indicates that the tensor's dimension is 0 +# 0D Tensor's shape is always [], numel is 1 +# which can be created by paddle.rand([]) + +import os import unittest import numpy as np @@ -20,6 +26,7 @@ import paddle import paddle.nn.functional as F paddle.set_device('xpu') +paddle.disable_static() unary_api_list = [ paddle.nn.functional.elu, @@ -86,6 +93,8 @@ unary_api_list = [ paddle.bernoulli, paddle.nn.functional.softmax, paddle.nn.functional.log_softmax, + paddle.nn.functional.gumbel_softmax, + paddle.nn.functional.alpha_dropout, ] inplace_api_list = [ @@ -97,11 +106,11 @@ inplace_api_list = [ # Use to test zero-dim in unary API. class TestUnaryAPI(unittest.TestCase): def test_dygraph_unary(self): - paddle.disable_static() for api in unary_api_list: x = paddle.rand([]) x.stop_gradient = False out = api(x) + out.retain_grads() out.backward() @@ -117,8 +126,6 @@ class TestUnaryAPI(unittest.TestCase): self.assertEqual(x.shape, []) self.assertEqual(out.shape, []) - paddle.enable_static() - reduce_api_list = [ paddle.sum, @@ -139,7 +146,6 @@ reduce_api_list = [ # Use to test zero-dim of reduce API class TestReduceAPI(unittest.TestCase): def test_dygraph_reduce(self): - paddle.disable_static() for api in reduce_api_list: # 1) x is 0D if api in [paddle.all, paddle.any]: @@ -148,13 +154,18 @@ class TestReduceAPI(unittest.TestCase): x = paddle.rand([]) x.stop_gradient = False out = api(x, None) - out.retain_grads() + out.retain_grads() out.backward() self.assertEqual(x.shape, []) self.assertEqual(out.shape, []) np.testing.assert_allclose(out.numpy(), x.numpy()) + + out_empty_list = api(x, []) + self.assertEqual(out_empty_list, out) + self.assertEqual(out_empty_list.shape, []) + if x.grad is not None: self.assertEqual(x.grad.shape, []) self.assertEqual(out.grad.shape, []) @@ -175,7 +186,35 @@ class TestReduceAPI(unittest.TestCase): self.assertEqual(x.grad.shape, []) np.testing.assert_allclose(x.grad.numpy(), np.array(3.0)) - paddle.enable_static() + # 2) x is ND, reduce to 0D + if api in [paddle.all, paddle.any]: + x = paddle.randint(0, 2, [3, 5]).astype('bool') + else: + x = paddle.rand([3, 5]) + x.stop_gradient = False + out = api(x, None) + out.retain_grads() + out.backward() + + self.assertEqual(out.shape, []) + if x.grad is not None: + self.assertEqual(out.grad.shape, []) + self.assertEqual(x.grad.shape, [3, 5]) + + # 3) x is 1D, axis=0, reduce to 0D + if api in [paddle.all, paddle.any]: + x = paddle.randint(0, 2, [5]).astype('bool') + else: + x = paddle.rand([5]) + x.stop_gradient = False + out = api(x, 0) + out.retain_grads() + out.backward() + + self.assertEqual(out.shape, []) + if x.grad is not None: + self.assertEqual(out.grad.shape, []) + self.assertEqual(x.grad.shape, [5]) binary_api_list = [ @@ -198,33 +237,37 @@ binary_api_list = [ paddle.logical_xor, paddle.maximum, paddle.minimum, + paddle.fmax, + paddle.fmin, + paddle.complex, + paddle.kron, ] binary_int_api_list = [ paddle.bitwise_and, paddle.bitwise_or, paddle.bitwise_xor, + paddle.gcd, + paddle.lcm, ] # Use to test zero-dim of binary API class TestBinaryAPI(unittest.TestCase): def test_dygraph_binary(self): - paddle.disable_static() for api in binary_api_list: # 1) x is 0D, y is 0D x = paddle.rand([]) y = paddle.rand([]) x.stop_gradient = False y.stop_gradient = False - x.retain_grads() - y.retain_grads() if isinstance(api, dict): out = api['func'](x, y) out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y) np.testing.assert_array_equal(out_cls.numpy(), out.numpy()) else: out = api(x, y) + out.retain_grads() out.backward() @@ -247,6 +290,7 @@ class TestBinaryAPI(unittest.TestCase): np.testing.assert_array_equal(out_cls.numpy(), out.numpy()) else: out = api(x, y) + out.retain_grads() out.backward() @@ -263,14 +307,13 @@ class TestBinaryAPI(unittest.TestCase): y = paddle.rand([2, 3, 4]) x.stop_gradient = False y.stop_gradient = False - x.retain_grads() - y.retain_grads() if isinstance(api, dict): out = api['func'](x, y) out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y) np.testing.assert_array_equal(out_cls.numpy(), out.numpy()) else: out = api(x, y) + out.retain_grads() out.backward() @@ -288,6 +331,7 @@ class TestBinaryAPI(unittest.TestCase): y = 0.5 if isinstance(api, dict): out = getattr(paddle.Tensor, api['cls_method'])(x, y) + out.retain_grads() out.backward() @@ -334,14 +378,11 @@ class TestBinaryAPI(unittest.TestCase): self.assertEqual(out.shape, [3, 5]) np.testing.assert_array_equal(out.numpy(), out_np) - paddle.enable_static() - # Use to test zero-dim of Sundry API, which is unique and can not be classified # with others. It can be implemented here flexibly. class TestSundryAPI(unittest.TestCase): def setUp(self): - paddle.disable_static() self.x = paddle.rand([]) def test_getitem(self): @@ -610,6 +651,100 @@ class TestSundryAPI(unittest.TestCase): with self.assertRaises(ValueError): tmp = paddle.topk(x1, k=1, axis=2) + def test_broadcast_to(self): + x = paddle.full([], 1, 'float32') + x.stop_gradient = False + out = paddle.broadcast_to(x, shape=[1]) + out.retain_grads() + out.backward() + + self.assertEqual(out.shape, [1]) + np.testing.assert_allclose(out, 1.0) + self.assertEqual(x.grad.shape, []) + np.testing.assert_allclose(x.grad, 1.0) + self.assertEqual(out.grad.shape, [1]) + np.testing.assert_allclose(out.grad, 1.0) + + # case2 + x1 = paddle.full([], 1, 'float32') + x1.stop_gradient = False + out1 = paddle.broadcast_to(x1, shape=[]) + out1.retain_grads() + out1.backward() + + self.assertEqual(out1.shape, []) + np.testing.assert_allclose(out1, 1.0) + self.assertEqual(x1.grad.shape, []) + np.testing.assert_allclose(x1.grad, 1.0) + self.assertEqual(out1.grad.shape, []) + np.testing.assert_allclose(out1.grad, 1.0) + + # case3 + x2 = paddle.full([], 1, 'float32') + x2.stop_gradient = False + out2 = paddle.broadcast_to(x2, shape=[1, 1]) + out2.retain_grads() + out2.backward() + + self.assertEqual(out2.shape, [1, 1]) + np.testing.assert_allclose(out2, 1.0) + self.assertEqual(x2.grad.shape, []) + np.testing.assert_allclose(x2.grad, 1.0) + self.assertEqual(out2.grad.shape, [1, 1]) + np.testing.assert_allclose(out2.grad, 1.0) + + # case4 + x3 = paddle.full([], 1, 'float32') + x3.stop_gradient = False + out3 = paddle.broadcast_to(x3, shape=[3, 3]) + out3.retain_grads() + out3.backward() + + self.assertEqual(out3.shape, [3, 3]) + np.testing.assert_allclose(out3, 1.0) + self.assertEqual(x3.grad.shape, []) + np.testing.assert_allclose(x3.grad, 9.0) + self.assertEqual(out3.grad.shape, [3, 3]) + np.testing.assert_allclose(out3.grad, 1.0) + + def test_broadcast_tensors(self): + # 1) x is 0D, y is 0D + x1 = paddle.full([], 2.0) + x1.stop_gradient = False + x2 = paddle.full([], 2.0) + x2.stop_gradient = False + out1, out2 = paddle.broadcast_tensors([x1, x2]) + # backward has bug now + # out1.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out2.shape, []) + # self.assertEqual(x1.grad.shape, []) + + # 2) x is ND , y is 0D + x1 = paddle.full([2, 3], 2.0) + x1.stop_gradient = False + x2 = paddle.full([], 2.0) + x2.stop_gradient = False + out1, out2 = paddle.broadcast_tensors([x1, x2]) + # out1.backward() + + self.assertEqual(out1.shape, [2, 3]) + self.assertEqual(out2.shape, [2, 3]) + # self.assertEqual(x1.grad.shape, [2, 3]) + + # 3) x is 0D , y is ND + x1 = paddle.full([], 2.0) + x1.stop_gradient = False + x2 = paddle.full([2, 3], 2.0) + x2.stop_gradient = False + out1, out2 = paddle.broadcast_tensors([x1, x2]) + # out1.backward() + + self.assertEqual(out1.shape, [2, 3]) + self.assertEqual(out2.shape, [2, 3]) + # self.assertEqual(x1.grad.shape, [2, 3]) + def test_argmin(self): # 1) x is 0D x = paddle.rand([]) @@ -679,6 +814,7 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out.shape, [1, 1]) def test_median(self): + # 1) x is 0D x = paddle.rand([]) x.stop_gradient = False out1 = paddle.median(x, 0) @@ -701,149 +837,505 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(x.grad.shape, []) np.testing.assert_allclose(x.grad, 3.0) - def test_linear(self): - x = paddle.randn([3, 2]) - w = paddle.full(shape=[2, 4], fill_value=0.5) - b = paddle.zeros([]) - - np.testing.assert_array_equal( - F.linear(x, w, b).numpy(), F.linear(x, w).numpy() - ) + # 2) x is 1D + x = paddle.rand([5]) + x.stop_gradient = False + out = paddle.median(x, 0) + out.backward() + self.assertEqual(out.shape, []) + self.assertEqual(x.grad.shape, [5]) - def test_is_floating_point(self): - self.assertTrue(paddle.is_floating_point(self.x)) + # 3) x is ND + x = paddle.rand([3, 5]) + x.stop_gradient = False + out = paddle.median(x, None) + out.backward() + self.assertEqual(out.shape, []) + self.assertEqual(x.grad.shape, [3, 5]) - def test_is_integer(self): - x = paddle.randint(0, 10, []) - self.assertTrue(paddle.is_integer(x)) + # 4) x is ND, keepdim=True + x = paddle.rand([3, 5]) + x.stop_gradient = False + out = paddle.median(x, keepdim=True) + out.backward() + self.assertEqual(out.shape, [1, 1]) + self.assertEqual(x.grad.shape, [3, 5]) - def test_is_tensor(self): - self.assertTrue(paddle.is_tensor(self.x)) + def test_kthvalue(self): + # 1) x is 0D + x = paddle.randn([]) + x.stop_gradient = False + out, index = paddle.kthvalue(x, 1) + out.backward() - def test_is_empty(self): - x = paddle.rand([3, 0, 5]) - self.assertTrue(paddle.is_empty(x)) + self.assertEqual(out.shape, []) + self.assertEqual(out, x) + self.assertEqual(index.shape, []) + self.assertEqual(index, 0) - def test_isfinite(self): - out = paddle.isfinite(self.x) - np.testing.assert_array_equal(out.numpy(), np.array(True)) + self.assertEqual(x.grad.shape, []) + self.assertEqual(x.grad, 1.0) - def test_isinf(self): - x = paddle.to_tensor(np.array(float('-inf'))) - out = paddle.isinf(x) - np.testing.assert_array_equal(out.numpy(), np.array(True)) + # 2) x is 1D + x1 = paddle.randn([5]) + x1.stop_gradient = False + out1, index1 = paddle.kthvalue(x1, 1) + out1.backward() - def test_isnan(self): - x = paddle.to_tensor(np.array(float('nan'))) - out = paddle.isnan(x) - np.testing.assert_array_equal(out.numpy(), np.array(True)) + self.assertEqual(out1.shape, []) + self.assertEqual(index1.shape, []) + self.assertEqual(x1.grad.shape, [5]) - def test_isclose(self): - out = paddle.isclose(self.x, self.x) - np.testing.assert_array_equal(out.numpy(), np.array(True)) + def test_mode(self): + # 1) x is 0D + x = paddle.randn([]) + x.stop_gradient = False + out, index = paddle.mode(x) + out.backward() - def test_clone(self): - out = paddle.clone(self.x) - np.testing.assert_array_equal(out.numpy(), self.x.numpy()) + self.assertEqual(out.shape, []) + self.assertEqual(out, x) + self.assertEqual(index.shape, []) + self.assertEqual(index, 0) - def test_assign(self): - out = paddle.assign(self.x) - np.testing.assert_array_equal(out.numpy(), self.x.numpy()) + self.assertEqual(x.grad.shape, []) + self.assertEqual(x.grad, 1.0) - def test_item(self): - x = paddle.full([], 0.5) - self.assertEqual(x.item(), 0.5) + # 2) x is 1D + x1 = paddle.randn([5]) + x1.stop_gradient = False + out1, index1 = paddle.mode(x1) + out1.backward() - def test_tolist(self): - x = paddle.full([], 0.5) - self.assertEqual(x.tolist(), 0.5) + self.assertEqual(out1.shape, []) + self.assertEqual(index1.shape, []) - def test_numpy(self): - x = paddle.full([], 0.5) - np.testing.assert_array_equal(x.numpy(), np.array(0.5)) + self.assertEqual(x1.grad.shape, [5]) - def test_numel(self): + def test_is_empty(self): # 1) x is 0D - out = paddle.numel(self.x) + x = paddle.rand([]) + out = paddle.is_empty(x) + self.assertFalse(out) self.assertEqual(out.shape, []) - np.testing.assert_array_equal(out.numpy(), np.array(1)) - # 2) x is ND - x = paddle.full([3, 5], 0.5) - out = paddle.numel(x) + # 2) x is 1D + x = paddle.rand([5]) + out = paddle.is_empty(x) + self.assertFalse(out) self.assertEqual(out.shape, []) - np.testing.assert_array_equal(out.numpy(), np.array(15)) - def test_rank(self): - # 1) x is 0D - out = paddle.rank(self.x) + # 3) x is ND + x = paddle.rand([3, 5]) + out = paddle.is_empty(x) + self.assertFalse(out) self.assertEqual(out.shape, []) - np.testing.assert_array_equal(out.numpy(), np.array(0)) - # 1) x is ND - x = paddle.full([3, 5], 0.5) - out = paddle.rank(x) + x = paddle.rand([3, 0, 5]) + out = paddle.is_empty(x) + self.assertTrue(out) self.assertEqual(out.shape, []) - np.testing.assert_array_equal(out.numpy(), np.array(2)) - - def test_shape(self): - out = paddle.shape(self.x) - self.assertEqual(out.shape, [0]) - np.testing.assert_array_equal(out.numpy(), np.array([])) - def test_pow_factor(self): + def test_squeeze_(self): + # 1) x is 0D x = paddle.rand([]) + x.squeeze_(0) + self.assertEqual(x.shape, []) + + # 2) x is 1D + x = paddle.rand([1]) + x.squeeze_(0) + self.assertEqual(x.shape, []) + + # 3)x is ND + x = paddle.rand([2, 1]) + x.squeeze_(1) + self.assertEqual(x.shape, [2]) + + def test_as_complex(self): + x = paddle.rand([2]) x.stop_gradient = False - x.retain_grads() - out = paddle.pow(x, 2.0) + out = paddle.as_complex(x) out.retain_grads() out.backward() + self.assertEqual(x.shape, [2]) self.assertEqual(out.shape, []) + self.assertEqual(x.grad.shape, [2]) self.assertEqual(out.grad.shape, []) - self.assertEqual(x.grad.shape, []) - def test_cast(self): - x = paddle.full([], 1.0, 'float32') + def test_dot(self): + # 1) x is 1D + x = paddle.rand([2]) x.stop_gradient = False - x.retain_grads() - out = paddle.cast(x, 'int32') + y = paddle.rand([2]) + y.stop_gradient = False + out = paddle.dot(x, y) out.retain_grads() out.backward() + self.assertEqual(x.grad.shape, [2]) self.assertEqual(out.shape, []) self.assertEqual(out.grad.shape, []) - self.assertEqual(x.grad.shape, []) - def test_clip(self): - x = paddle.uniform([], None, -10, 10) + # 2) x is 2D + x1 = paddle.rand([2, 2]) + x1.stop_gradient = False + y1 = paddle.rand([2, 2]) + y1.stop_gradient = False + out1 = paddle.dot(x1, y1) + out1.retain_grads() + out1.backward() + + self.assertEqual(x1.grad.shape, [2, 2]) + self.assertEqual(out1.shape, [2]) + self.assertEqual(out1.grad.shape, [2]) + + def test_inner(self): + # 0) input is 0D + x = paddle.rand([]) x.stop_gradient = False - x.retain_grads() - out = paddle.clip(x, -5, 5) + y = paddle.rand([]) + y.stop_gradient = False + out = paddle.inner(x, y) out.retain_grads() out.backward() + self.assertEqual(x.grad.shape, []) self.assertEqual(out.shape, []) self.assertEqual(out.grad.shape, []) - self.assertEqual(x.grad.shape, []) - def test_increment(self): - x = paddle.rand([]) + # 1) input is 1D + x = paddle.rand([2]) x.stop_gradient = False - out = paddle.increment(x, 1.0) + y = paddle.rand([2]) + y.stop_gradient = False + out = paddle.inner(x, y) + out.retain_grads() out.backward() + self.assertEqual(x.grad.shape, [2]) self.assertEqual(out.shape, []) self.assertEqual(out.grad.shape, []) - self.assertEqual(x.grad.shape, []) - def test_bitwise_not(self): - x = paddle.randint(-1, 1, []) - out1 = ~x - out2 = paddle.bitwise_not(x) + # 2) input is 2D + x = paddle.rand([2, 3]) + x.stop_gradient = False + y = paddle.rand([3, 3]) + y.stop_gradient = False + out = paddle.inner(x, y) + out.retain_grads() + out.backward() - self.assertEqual(out1.shape, []) - self.assertEqual(out2.shape, []) + self.assertEqual(x.grad.shape, [2, 3]) + self.assertEqual(out.shape, [2, 3]) + self.assertEqual(out.grad.shape, [2, 3]) + + def test_tensordot(self): + # 1) input is 1D + x = paddle.arange(10, dtype='float64') + x.stop_gradient = False + y = paddle.arange(10, dtype='float64') + y.stop_gradient = False + out = paddle.tensordot(x, y, axes=1) + out.retain_grads() + out.backward() + + self.assertEqual(x.grad.shape, [10]) + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + + # 2) input is 2D + x = paddle.arange(6, dtype='float64').reshape([2, 3]) + y = paddle.arange(6, dtype='float64').reshape([2, 3]) + x.stop_gradient = False + out = paddle.tensordot(x, y, axes=2) + out.retain_grads() + out.backward() + + self.assertEqual(x.grad.shape, [2, 3]) + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + + def test_metric_accuracy(self): + x = paddle.full(shape=[2, 4], fill_value=0.25) + y = paddle.full(shape=[2, 1], fill_value=1, dtype="int64") + out = paddle.metric.accuracy(input=x, label=y, k=1) + self.assertEqual(out.shape, []) + + def test_std(self): + # 1) x is 0D + x = paddle.rand([]) + x.stop_gradient = False + out1 = paddle.std(x) + out2 = paddle.std(x, []) + out1.backward() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out2.shape, []) + self.assertEqual(out1, 0) + self.assertEqual(out2, 0) + + self.assertEqual(x.grad.shape, []) + + # 2) x is ND + x = paddle.rand([3, 5]) + x.stop_gradient = False + out = paddle.std(x) + out.backward() + + self.assertEqual(out.shape, []) + self.assertEqual(x.grad.shape, [3, 5]) + + def test_var(self): + # 1) x is 0D + x = paddle.rand([]) + x.stop_gradient = False + out1 = paddle.var(x) + out2 = paddle.var(x, []) + out1.backward() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out2.shape, []) + self.assertEqual(out1, 0) + self.assertEqual(out2, 0) + + self.assertEqual(x.grad.shape, []) + np.testing.assert_allclose(x.grad, 0) + + # 2) x is ND + x = paddle.rand([3, 5]) + x.stop_gradient = False + out = paddle.std(x) + out.backward() + + self.assertEqual(out.shape, []) + self.assertEqual(x.grad.shape, [3, 5]) + + def test_quantile(self): + # 1) x is 0D + x = paddle.rand([]) + x.stop_gradient = False + out = paddle.quantile(x, 0.5, axis=None) + + out.retain_grads() + out.backward() + + out_empty_list = paddle.quantile(x, 0.5, axis=[]) + self.assertEqual(out_empty_list, out) + + self.assertEqual(x.shape, []) + self.assertEqual(out.shape, []) + self.assertEqual(out, x) + + self.assertEqual(x.grad.shape, []) + self.assertEqual(x.grad, 1.0) + self.assertEqual(out.grad.shape, []) + self.assertEqual(out.grad, 1.0) + + # 2) x is ND + x = paddle.rand([2, 3]) + x.stop_gradient = False + out = paddle.quantile(x, 0.5, axis=None) + + out.retain_grads() + out.backward() + + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + self.assertEqual(out.grad, 1.0) + self.assertEqual(x.grad.shape, [2, 3]) + + def test_flip(self): + x = paddle.rand([]) + x.stop_gradient = False + out = paddle.flip(x, axis=[]) + out.retain_grads() + out.backward() + self.assertEqual(x.shape, []) + self.assertEqual(out.shape, []) + self.assertEqual(x.grad.shape, []) + self.assertEqual(out.grad.shape, []) + + def test_linear(self): + x = paddle.randn([3, 2]) + w = paddle.full(shape=[2, 4], fill_value=0.5) + b = paddle.zeros([]) + + np.testing.assert_array_equal( + F.linear(x, w, b).numpy(), F.linear(x, w).numpy() + ) + + def test_is_floating_point(self): + self.assertTrue(paddle.is_floating_point(self.x)) + + def test_is_integer(self): + x = paddle.randint(0, 10, []) + self.assertTrue(paddle.is_integer(x)) + + def test_is_tensor(self): + self.assertTrue(paddle.is_tensor(self.x)) + + def test_isfinite(self): + out = paddle.isfinite(self.x) + np.testing.assert_array_equal(out.numpy(), np.array(True)) + + def test_isinf(self): + x = paddle.to_tensor(np.array(float('-inf'))) + out = paddle.isinf(x) + np.testing.assert_array_equal(out.numpy(), np.array(True)) + + def test_isnan(self): + x = paddle.to_tensor(np.array(float('nan'))) + out = paddle.isnan(x) + np.testing.assert_array_equal(out.numpy(), np.array(True)) + + def test_isclose(self): + out = paddle.isclose(self.x, self.x) + np.testing.assert_array_equal(out.numpy(), np.array(True)) + + def test_clone(self): + out = paddle.clone(self.x) + np.testing.assert_array_equal(out.numpy(), self.x.numpy()) + + def test_assign(self): + out = paddle.assign(self.x) + np.testing.assert_array_equal(out.numpy(), self.x.numpy()) + + def test_item(self): + x = paddle.full([], 0.5) + self.assertEqual(x.item(), 0.5) + + def test_tolist(self): + x = paddle.full([], 0.5) + self.assertEqual(x.tolist(), 0.5) + + def test_numpy(self): + x = paddle.full([], 0.5) + x_np = x.numpy() + np.testing.assert_array_equal(x_np.shape, ()) + np.testing.assert_array_equal(x_np, np.array(0.5)) + + x_np = x.numpy(False) + np.testing.assert_array_equal(x_np.shape, ()) + np.testing.assert_array_equal(x_np, np.array(0.5)) + + def test_numel(self): + # 1) x is 0D + out = paddle.numel(self.x) + self.assertEqual(out.shape, []) + np.testing.assert_array_equal(out.numpy(), np.array(1)) + + # 2) x is ND + x = paddle.full([3, 5], 0.5) + out = paddle.numel(x) + self.assertEqual(out.shape, []) + np.testing.assert_array_equal(out.numpy(), np.array(15)) + + def test_rank(self): + # 1) x is 0D + x = paddle.rand([]) + out = paddle.rank(x) + self.assertEqual(out.shape, []) + np.testing.assert_array_equal(out.numpy(), np.array(0)) + + # 1) x is ND + x = paddle.full([3, 5], 0.5) + out = paddle.rank(x) + self.assertEqual(out.shape, []) + np.testing.assert_array_equal(out.numpy(), np.array(2)) + + def test_shape(self): + out = paddle.shape(self.x) + np.testing.assert_array_equal(out.numpy(), np.array([])) + self.assertEqual(out.shape, [0]) + + def test_equal_scalar(self): + x = paddle.rand([]) + out = paddle.equal(x, 2.0) + self.assertEqual(out.shape, []) + self.assertEqual(out, False) + + x1 = paddle.full([], 2.0) + out1 = paddle.equal(x1, 2.0) + self.assertEqual(out1.shape, []) + self.assertEqual(out1, True) + + def test_pow_scalar(self): + x = paddle.rand([]) + x.stop_gradient = False + out = paddle.pow(x, 2.0) + out.retain_grads() + out.backward() + + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + self.assertEqual(x.grad.shape, []) + + def test_cast(self): + x = paddle.full([], 1.0, 'float32') + x.stop_gradient = False + out = paddle.cast(x, 'int32') + out.retain_grads() + out.backward() + + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + self.assertEqual(x.grad.shape, []) + + def test_cumprod(self): + x = paddle.full([], 1.0, 'float32') + x.stop_gradient = False + out = paddle.cumprod(x, 0) + out.retain_grads() + out.backward() + + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + self.assertEqual(x.grad.shape, []) + + with self.assertRaises(ValueError): + tmp = paddle.cumprod(x, 2) + + def test_clip(self): + x = paddle.uniform([], None, -10, 10) + x.stop_gradient = False + out = paddle.clip(x, -5, 5) + out.retain_grads() + out.backward() + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + self.assertEqual(x.grad.shape, []) + + x1 = paddle.uniform([], None, -10, 10) + x1.stop_gradient = False + out1 = paddle.clip(x1, paddle.full([], 5.0), paddle.full([], 5.0)) + out1.retain_grads() + out1.backward() + self.assertEqual(out1.shape, []) + self.assertEqual(out1.grad.shape, []) + self.assertEqual(x1.grad.shape, []) + + def test_increment(self): + x = paddle.rand([]) + x.stop_gradient = False + out = paddle.increment(x, 1.0) + out.retain_grads() + out.backward() + + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + self.assertEqual(x.grad.shape, []) + + def test_bitwise_not(self): + x = paddle.randint(-1, 1, []) + out1 = ~x + out2 = paddle.bitwise_not(x) + + self.assertEqual(out1.shape, []) + self.assertEqual(out2.shape, []) def test_logical_not(self): x = paddle.randint(0, 1, []) @@ -852,10 +1344,10 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out.shape, []) def test_searchsorted(self): + # have no backward x = paddle.to_tensor([1, 3, 5, 7, 9]) y = paddle.rand([]) - # only has forward kernel out = paddle.searchsorted(x, y) self.assertEqual(out.shape, []) @@ -925,11 +1417,85 @@ class TestSundryAPI(unittest.TestCase): ) index = paddle.full([], 1, 'int64') out = paddle.gather(x, index, axis=1) + out.retain_grads() + out.backward() self.assertEqual(out.shape, [2]) np.testing.assert_array_equal(out.numpy(), [2.0, 5.0]) + self.assertEqual(x.grad.shape, [2, 3]) + self.assertEqual(out.grad.shape, [2]) + + def test_gather_nd(self): + x1 = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False) + x2 = paddle.to_tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False + ) + + index1 = paddle.full([1], 1, 'int64') + index2 = paddle.full([2], 1, 'int64') + + out1 = paddle.gather_nd(x1, index1) + out2 = paddle.gather_nd(x2, index2) + + out1.retain_grads() + out2.retain_grads() + + out1.backward() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out2.shape, []) + np.testing.assert_array_equal(out1, np.array(3.0)) + np.testing.assert_array_equal(out2, np.array(5.0)) + self.assertEqual(x1.grad.shape, [5]) + self.assertEqual(x2.grad.shape, [2, 3]) + self.assertEqual(out1.grad.shape, []) + self.assertEqual(out2.grad.shape, []) + + def test_einsum(self): + os.environ['FLAGS_new_einsum'] = "0" + x = paddle.rand([5]) + # sum + out1 = paddle.einsum('i->', x) + expect1 = np.einsum('i->', x) + # dot + out2 = paddle.einsum('i,i->', x, x) + expect2 = np.einsum('i,i->', x, x) + + out1.retain_grads() + out2.retain_grads() + + out1.backward() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out2.shape, []) + np.testing.assert_allclose(out1, expect1, rtol=1e-03) + np.testing.assert_allclose(out2, expect2, rtol=1e-03) + + def test_einsum_V2(self): + os.environ['FLAGS_new_einsum'] = "1" + x = paddle.rand([5]) + # sum + out1 = paddle.einsum('i->', x) + expect1 = np.einsum('i->', x) + # dot + out2 = paddle.einsum('i,i->', x, x) + expect2 = np.einsum('i,i->', x, x) + + out1.retain_grads() + out2.retain_grads() + + out1.backward() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out2.shape, []) + np.testing.assert_allclose(out1, expect1, rtol=1e-03) + np.testing.assert_allclose(out2, expect2, rtol=1e-03) def test_scatter_1D(self): + # have no backward now x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False) index = paddle.full([], 2, 'int64') updates = paddle.full([], 4.0) @@ -939,6 +1505,7 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out.numpy()[2], 4) def test_scatter_XD(self): + # have no backward now x = paddle.to_tensor( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False ) @@ -1000,24 +1567,39 @@ class TestSundryAPI(unittest.TestCase): out = paddle.scatter_(x, index, updates) np.testing.assert_array_equal(out.numpy()[1], [1.0, 2.0, 3.0]) + def test_scatter_nd(self): + index = paddle.to_tensor([3], dtype="int64") + updates = paddle.full([], 2, dtype='float32') + + out = paddle.scatter_nd(index, updates, [5]) + self.assertEqual(out.shape, [5]) + self.assertEqual(out.numpy()[3], 2) + def test_flatten(self): - x = paddle.full([], 1, 'float32') + x = paddle.rand([]) x.stop_gradient = False start_axis = 0 stop_axis = -1 out = paddle.flatten(x, start_axis=start_axis, stop_axis=stop_axis) + out.retain_grads() out.backward() self.assertEqual(out.shape, [1]) + self.assertEqual(out.grad.shape, [1]) self.assertEqual(x.grad.shape, []) + def test_histogram(self): + x = paddle.rand([]) + out = paddle.histogram(x, bins=5, min=1, max=5) + self.assertEqual(out.shape, [5]) + def test_scale(self): x = paddle.rand([]) x.stop_gradient = False - x.retain_grads() out = paddle.scale(x, scale=2.0, bias=1.0) + out.retain_grads() out.backward() @@ -1025,6 +1607,11 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out.grad.shape, []) self.assertEqual(x.grad.shape, []) + def test_scale_(self): + x = paddle.rand([]) + out = x.scale_(scale=2.0, bias=1.0) + self.assertEqual(out.shape, []) + def test_floor_divide(self): # 1-d // 0-d x = paddle.to_tensor([1, -2, 3], dtype="int64") @@ -1066,12 +1653,17 @@ class TestSundryAPI(unittest.TestCase): out2.backward() out3.backward() + self.assertEqual(x1.grad.shape, []) + self.assertTrue(x1.grad.numpy() == 3) self.assertEqual(out1.shape, [1]) self.assertEqual(out1.grad.shape, [1]) + self.assertTrue(out1.grad.numpy() == 1) self.assertEqual(out2.shape, []) self.assertEqual(out2.grad.shape, []) + self.assertTrue(out2.grad.numpy() == 1) self.assertEqual(out3.shape, []) self.assertEqual(out3.grad.shape, []) + self.assertTrue(out3.grad.numpy() == 1) def test_add_n(self): x1 = paddle.rand([]) @@ -1090,6 +1682,12 @@ class TestSundryAPI(unittest.TestCase): out1.backward() out2.backward() + self.assertEqual(x1.grad.shape, []) + self.assertTrue(x1.grad.numpy() == 1) + self.assertEqual(x2.grad.shape, []) + self.assertTrue(x2.grad.numpy() == 1) + self.assertEqual(x3.grad.shape, []) + self.assertTrue(x3.grad.numpy() == 1) self.assertEqual(out1.shape, []) self.assertEqual(out1.grad.shape, []) self.assertEqual(out2.shape, []) @@ -1098,8 +1696,8 @@ class TestSundryAPI(unittest.TestCase): def test_reshape_list(self): x = paddle.rand([]) x.stop_gradient = False - out = paddle.reshape(x, []) + out.retain_grads() out.backward() self.assertEqual(x.grad.shape, []) @@ -1130,8 +1728,8 @@ class TestSundryAPI(unittest.TestCase): def test_reshape_tensor(self): x = paddle.rand([1, 1]) x.stop_gradient = False - out = paddle.reshape(x, []) + out.retain_grads() out.backward() self.assertEqual(x.grad.shape, [1, 1]) @@ -1193,15 +1791,23 @@ class TestSundryAPI(unittest.TestCase): out = paddle.reshape_(x, new_shape) self.assertEqual(out.shape, [1, 1]) + def test_reverse(self): + x = paddle.rand([]) + x.stop_gradient = False + out = paddle.reverse(x, axis=[]) + out.retain_grads() + out.backward() + self.assertEqual(x.shape, []) + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + def test_sort(self): x1 = paddle.rand([]) x2 = paddle.rand([]) x1.stop_gradient = False x2.stop_gradient = False - x1.retain_grads() x2.retain_grads() - out1 = paddle.sort(x1, axis=-1) out2 = paddle.sort(x2, axis=0) @@ -1250,36 +1856,153 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(x1.grad.numpy(), 0) self.assertEqual(x2.grad.numpy(), 0) - def test_sigmoid_focal_loss(self): - logit = paddle.to_tensor( - [[0.97, 0.91, 0.03], [0.55, 0.43, 0.71]], - dtype='float32', - stop_gradient=False, - ) - label = paddle.to_tensor( - [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype='float32' - ) - fg_num_0 = paddle.full([], 2.0) - fg_num_1 = paddle.full([1], 2.0) + def test_lerp(self): + # 0D + 0D, weight is float scalar + x = paddle.rand([]) + y = paddle.rand([]) + x.stop_gradient = False + y.stop_gradient = False + out = paddle.lerp(x, y, 0.5) + out.backward() - out0 = F.sigmoid_focal_loss(logit, label, normalizer=fg_num_0) - out1 = F.sigmoid_focal_loss(logit, label, normalizer=fg_num_1) + self.assertEqual(out.shape, []) + self.assertEqual(x.grad.shape, []) + self.assertEqual(y.grad.shape, []) - np.testing.assert_array_equal( - out0.numpy(), - out1.numpy(), - ) - self.assertEqual(out0.shape, []) + # 0D + 0D, weigh is 0D + x0 = paddle.rand([]) + y0 = paddle.rand([]) + w0 = paddle.rand([]) + x0.stop_gradient = False + y0.stop_gradient = False + y0.retain_grads() - out0.retain_grads() + out0 = paddle.lerp(x0, y0, w0) out0.backward() - self.assertEqual(out0.grad.shape, []) - self.assertEqual(logit.grad.shape, [2, 3]) + + self.assertEqual(out0.shape, []) + self.assertEqual(x0.grad.shape, []) + self.assertEqual(y0.grad.shape, []) + + # 0D + ND + x1 = paddle.rand([]) + y1 = paddle.rand([64, 64]) + w1 = paddle.rand([]) + x1.stop_gradient = False + y1.stop_gradient = False + x1.retain_grads() + y1.retain_grads() + + out1 = paddle.lerp(x1, y1, w1) + out1.backward() + + self.assertEqual(out1.shape, [64, 64]) + self.assertEqual(x1.grad.shape, []) + self.assertEqual(y1.grad.shape, [64, 64]) + + # ND + 0D + x2 = paddle.rand([64, 64]) + y2 = paddle.rand([]) + w2 = paddle.rand([]) + x2.stop_gradient = False + y2.stop_gradient = False + x2.retain_grads() + y2.retain_grads() + + out2 = paddle.lerp(x2, y2, w2) + out2.backward() + + self.assertEqual(out2.shape, [64, 64]) + self.assertEqual(x2.grad.shape, [64, 64]) + self.assertEqual(y2.grad.shape, []) + + def test_repeat_interleave(self): + x = paddle.randn(()) + x.stop_gradient = False + + out = paddle.repeat_interleave(x, 2, None) + out.backward() + + # check shape of output + self.assertEqual(out.shape, [2]) + + # check grad shape + self.assertEqual(x.grad.shape, []) + + repeats = paddle.to_tensor([3], dtype='int32') + out = paddle.repeat_interleave(x, repeats, None) + + # check shape of output with 1D repeats + self.assertEqual(out.shape, [3]) + + # check grad shape with 1D repeats + self.assertEqual(x.grad.shape, []) def test_allclose(self): + # 1) x is 0D + x = paddle.full([], 0.5) + y = paddle.full([], 0.6) + out = paddle.allclose(x, y) + self.assertEqual(out.shape, []) + self.assertFalse(out) + + # 2) x is ND + x = paddle.full([2, 3], 0.5) + y = paddle.full([2, 3], 0.6) + out = paddle.allclose(x, y) + self.assertEqual(out.shape, []) + self.assertFalse(out) + + def test_equal_all(self): + # 1) x is 0D x = paddle.full([], 0.5) y = paddle.full([], 0.6) - self.assertFalse(paddle.allclose(x, y)) + out = paddle.equal_all(x, y) + self.assertEqual(out.shape, []) + self.assertFalse(out) + + # 2) x is ND + x = paddle.full([2, 3], 0.5) + y = paddle.full([2, 3], 0.6) + out = paddle.equal_all(x, y) + self.assertEqual(out.shape, []) + self.assertFalse(out) + + def test_where(self): + x1 = paddle.full([], 1) + x2 = paddle.full([], 2) + x1.stop_gradient = False + x2.stop_gradient = False + x1.retain_grads() + x2.retain_grads() + out = paddle.where(x1 > x2, x1, x2) + out.retain_grads() + out.backward() + self.assertEqual(out.shape, []) + self.assertEqual(out.numpy(), 2) + self.assertEqual(out.grad.shape, []) + self.assertEqual(x1.grad.shape, []) + self.assertEqual(x2.grad.shape, []) + self.assertEqual(x1.grad.numpy(), 0) + self.assertEqual(x2.grad.numpy(), 1) + + def test_atan2(self): + x1 = paddle.full([], 0) + x2 = paddle.full([], 2) + x1.retain_grads() + x2.retain_grads() + x1.stop_gradient = False + x2.stop_gradient = False + out = paddle.atan2(x1, x2) + out.retain_grads() + out.backward() + self.assertEqual(out.shape, []) + self.assertEqual(out.numpy(), 0) + self.assertEqual(out.grad.shape, []) + self.assertEqual(x1.grad.shape, []) + self.assertEqual(x2.grad.shape, []) + self.assertEqual(x1.grad.numpy(), 0.5) + self.assertEqual(x2.grad.numpy(), 0) def test_interpolate(self): from paddle.nn.functional import interpolate @@ -1336,12 +2059,73 @@ class TestSundryAPI(unittest.TestCase): origin_result.numpy(), out3.numpy(), rtol=1e-05 ) - def test_equalall(self): - x = paddle.full([], 0.5) - y = paddle.full([], 0.6) - out = paddle.equal_all(x, y) - self.assertEqual(out.shape, []) - self.assertFalse(out) + def test_upsample(self): + from paddle.nn.functional import upsample + + input_x = paddle.rand([2, 3, 6, 6]) + input_x.stop_gradient = False + + output_size = [ + paddle.full([], 12, dtype="int32"), + paddle.full([], 12, dtype="int32"), + ] + out1 = upsample( + x=input_x, size=output_size, mode="bilinear", align_corners=False + ) + out1.backward() + + self.assertEqual(out1.shape, [2, 3, 12, 12]) + self.assertEqual(input_x.grad.shape, [2, 3, 6, 6]) + + def test_unstack(self): + x1 = paddle.full([1], 0) + x2 = paddle.full([2], 2) + x1.retain_grads() + x2.retain_grads() + x1.stop_gradient = False + x2.stop_gradient = False + + [out1] = paddle.unstack(x1, 0) + out1.retain_grads() + out1.backward() + [out2_1, out2_2] = paddle.unstack(x2, 0) + out2 = paddle.add_n([out2_1, out2_2]) + out2.retain_grads() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out1.numpy(), 0) + + self.assertEqual(out2_1.shape, []) + self.assertEqual(out2_1.numpy(), 2) + self.assertEqual(out2_2.shape, []) + self.assertEqual(out2_2.numpy(), 2) + self.assertEqual(x2.grad.shape, [2]) + + def test_unbind(self): + x1 = paddle.full([1], 0) + x2 = paddle.full([2], 2) + x1.retain_grads() + x2.retain_grads() + x1.stop_gradient = False + x2.stop_gradient = False + + [out1] = paddle.unbind(x1, 0) + out1.retain_grads() + out1.backward() + [out2_1, out2_2] = paddle.unbind(x2, 0) + out2 = paddle.add_n([out2_1, out2_2]) + out2.retain_grads() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out1.numpy(), 0) + + self.assertEqual(out2_1.shape, []) + self.assertEqual(out2_1.numpy(), 2) + self.assertEqual(out2_2.shape, []) + self.assertEqual(out2_2.numpy(), 2) + self.assertEqual(x2.grad.shape, [2]) def test_maseked_select(self): x = paddle.rand([]) @@ -1357,18 +2141,53 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.numpy(), 1) + def test_squeeze(self): + x1 = paddle.full([], 2) + x1.stop_gradient = False + x1.retain_grads() + out1 = paddle.squeeze(x1, axis=0) + out1.retain_grads() + out1.backward() + self.assertEqual(out1.shape, []) + self.assertEqual(x1.grad.shape, []) + + x2 = paddle.full([], 3) + x3 = paddle.full([1], 0, dtype='int32') + x2.stop_gradient = False + x2.retain_grads() + out2 = paddle.squeeze(x2, axis=x3) + out2.retain_grads() + out2.backward() + self.assertEqual(out2.shape, []) + self.assertEqual(x2.grad.shape, []) + def test_unsqueeze(self): x1 = paddle.full([], 2) x1.stop_gradient = False + x1.retain_grads() out1 = paddle.unsqueeze(x1, axis=0) + out1.retain_grads() out1.backward() self.assertEqual(out1.shape, [1]) self.assertEqual(x1.grad.shape, []) x2 = paddle.full([], 0, dtype='int32') out2 = paddle.unsqueeze(x1, axis=x2) + out2.retain_grads() out2.backward() self.assertEqual(out2.shape, [1]) + self.assertEqual(x1.grad.shape, []) + + def test_t(self): + x = paddle.full([], 2.0) + x.stop_gradient = False + x.retain_grads() + out = paddle.t(x) + out.retain_grads() + out.backward() + self.assertEqual(out.shape, []) + self.assertEqual(out.grad.shape, []) + self.assertEqual(x.grad.shape, []) def test_prelu(self): x1 = paddle.full([], 1.0, 'float32') @@ -1397,11 +2216,298 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(x2.grad.shape, []) self.assertEqual(x2.grad.numpy(), 0.25) + def test_while_loop(self): + def cond(i, x): + return paddle.less_than(i, eleven) + + def body(i, x): + x = x + i + i = i + 1 + return [i, x] + + i = paddle.full([], 1.0, dtype='float32') + i.stop_gradient = False + eleven = paddle.full([], 11, dtype='float32') + x = paddle.full([], 0.0, dtype='float32') + x.stop_gradient = False + out_i, out_x = paddle.static.nn.while_loop(cond, body, [i, x]) + out_x.backward() + + self.assertEqual(out_i.shape, []) + np.testing.assert_allclose(out_i, np.array(11)) + self.assertEqual(out_x.shape, []) + np.testing.assert_allclose(out_x, np.array(55)) + self.assertEqual(i.grad.shape, []) + np.testing.assert_allclose(i.grad, np.array(10)) + self.assertEqual(x.grad.shape, []) + np.testing.assert_allclose(x.grad, np.array(1.0)) + + def test_to_tensor(self): + out1 = paddle.to_tensor(1) + out2 = paddle.to_tensor(2.5) + + out1.retain_grads() + out1.backward() + out2.retain_grads() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out1, 1) + self.assertEqual(out2.shape, []) + self.assertEqual(out2, 2.5) + + def test_linalg_slogdet(self): + # 2-D input + x = paddle.randn([3, 3]) + x.stop_gradient = False + out = paddle.linalg.slogdet(x) + out.retain_grads() + out.backward() + + self.assertTrue(out.shape, [2]) + self.assertTrue(x.grad.shape, [3, 3]) + + # 3-D input + x1 = paddle.randn([3, 3, 3]) + x1.stop_gradient = False + out1 = paddle.linalg.slogdet(x1) + out1.retain_grads() + out1.backward() + + self.assertTrue(out1.shape, [2, 3]) + self.assertTrue(x1.grad.shape, [3, 3, 3]) + + def test_multi_dot(self): + a = paddle.randn([4]) + a.stop_gradient = False + b = paddle.randn([4, 5]) + b.stop_gradient = False + c = paddle.randn([5]) + c.stop_gradient = False + + out = paddle.linalg.multi_dot([a, b, c]) + out.retain_grads() + out.backward() + + self.assertEqual(out.shape, []) + self.assertEqual(a.grad.shape, [4]) + self.assertEqual(b.grad.shape, [4, 5]) + self.assertEqual(c.grad.shape, [5]) + + def test_cov(self): + xt = paddle.randn((3, 4)) + xt.stop_gradient = False + xt_1 = paddle.randn((12,)) + xt_1.stop_gradient = False + + xt_out = paddle.linalg.cov(xt) + xt_out.retain_grads() + xt_out.backward() + self.assertEqual(xt_out.shape, [3, 3]) + self.assertEqual(xt.grad.shape, [3, 4]) + + xt_1_out = paddle.linalg.cov(xt_1) + xt_1.retain_grads() + xt_1_out.backward() + self.assertEqual(xt_1_out.shape, []) + self.assertEqual(xt_1.grad.shape, [12]) + + def test_det(self): + xt = paddle.randn([3, 3, 3]) + xt.stop_gradient = False + xt_1 = paddle.randn([3, 3]) + xt_1.stop_gradient = False + + xt_out = paddle.linalg.det(xt) + xt.retain_grads() + xt_out.backward() + self.assertEqual(xt_out.shape, [3]) + self.assertEqual(xt.grad.shape, [3, 3, 3]) + + xt_1_out = paddle.linalg.det(xt_1) + xt_1.retain_grads() + xt_1_out.backward() + self.assertEqual(xt_1_out.shape, []) + self.assertEqual(xt_1.grad.shape, [3, 3]) + + def test_dist(self): + x = paddle.to_tensor([[3, 3], [3, 3]], dtype="float32") + y = paddle.to_tensor([[3, 3], [3, 1]], dtype="float32") + x.stop_gradient = False + y.stop_gradient = False + out = paddle.dist(x, y, 0) + out.backward() + + self.assertEqual(out.shape, []) + np.testing.assert_allclose(out, np.array(1)) + self.assertEqual(x.grad.shape, [2, 2]) + self.assertEqual(y.grad.shape, [2, 2]) + + def test_linalg_norm(self): + # 1D input, p = fro ,axis = None, using reduceInferMeta + x_1 = paddle.arange(24, dtype="float32") - 12 + x_1.stop_gradient = False + out_1 = paddle.linalg.norm(x_1) + out_1.retain_grads() + out_1.backward() + + self.assertEqual(out_1.shape, []) + self.assertTrue(x_1.grad.shape, [24]) + + # 1D input, p = 1 ,axis = None, + # using p_nrom, as_vector = True + x_2 = paddle.arange(24, dtype="float32") - 12 + x_2.stop_gradient = False + out_2 = paddle.linalg.norm(x_2, p=1) + out_2.retain_grads() + out_2.backward() + + self.assertEqual(out_2.shape, []) + self.assertEqual(x_2.grad.shape, [24]) + + # 1D input, p = 1 ,axis = 0, + # using p_nrom, as_vector = False + x_2_p = paddle.arange(24, dtype="float32") - 12 + x_2_p.stop_gradient = False + out_2_p = paddle.linalg.norm(x_2_p, p=1, axis=0) + out_2_p.retain_grads() + out_2_p.backward() + + self.assertEqual(out_2_p.shape, []) + self.assertEqual(x_2_p.grad.shape, [24]) + + # 1D input, p = fro ,axis = 0, + # using p_nrom, as_vector = False + x_2_fro = paddle.arange(24, dtype="float32") - 12 + x_2_fro.stop_gradient = False + out_2_fro = paddle.linalg.norm(x_2_fro, p="fro", axis=0) + out_2_fro.retain_grads() + out_2_fro.backward() + + self.assertEqual(out_2_fro.shape, []) + self.assertEqual(x_2_fro.grad.shape, [24]) + + # 2D input, p = 1, axis = [0, 1] + # using p_matrix_norm ,depends on paddle.sum + x_3 = paddle.arange(24, dtype="float32").reshape([4, 6]) + x_3.stop_gradient = False + out_3 = paddle.linalg.norm(x_3, p=1, axis=[0, 1]) + out_3.retain_grads() + out_3.backward() + self.assertEqual(out_3.shape, []) + self.assertEqual(x_3.grad.shape, [4, 6]) + + # 2D input, p = 1, axis = None + # using p_matrix_norm, depends on paddle.sum + x_4 = paddle.arange(24, dtype="float32").reshape([4, 6]) + x_4.stop_gradient = False + out_4 = paddle.linalg.norm(x_4) + out_4.retain_grads() + out_4.backward() + self.assertEqual(out_4.shape, []) + self.assertEqual(x_4.grad.shape, [4, 6]) + + # 2D input, p = inf, axis = [0, 1] + # using p_matrix_norm, depends on paddle.sum + x_5 = paddle.arange(24, dtype="float32").reshape([4, 6]) + x_5.stop_gradient = False + out_5 = paddle.linalg.norm(x_5, p=2, axis=[0, 1]) + out_5.retain_grads() + out_5.backward() + + self.assertEqual(out_5.shape, []) + self.assertEqual(x_5.grad.shape, [4, 6]) + + # 2D input, p = -inf, axis = [0, 1] + x_6 = paddle.arange(24, dtype="float32").reshape([4, 6]) + x_6.stop_gradient = False + out_6 = paddle.linalg.norm(x_6, p=-float("inf"), axis=[0, 1]) + out_6.retain_grads() + out_6.backward() + + self.assertEqual(out_6.shape, []) + self.assertEqual(x_6.grad.shape, [4, 6]) + + def test_linalg_cond(self): + def assert_shape(out): + self.assertEqual(out.shape, []) + + x1 = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]]) + x1.stop_gradient = False + # p = 2 : use paddle.sum + out = paddle.linalg.cond(x1) + out.backward() + assert_shape(out) + self.assertEqual(x1.grad.shape, [3, 3]) + + # p = fro : use paddle.sum + x2 = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]]) + x2.stop_gradient = False + out_fro = paddle.linalg.cond(x2, p='fro') + out_fro.backward() + assert_shape(out_fro) + self.assertEqual(x2.grad.shape, [3, 3]) + + # p = nuc : use paddle.sum + x3 = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]]) + x3.stop_gradient = False + out_nuc = paddle.linalg.cond(x3, p='nuc') + out_nuc.backward() + assert_shape(out_nuc) + self.assertEqual(x3.grad.shape, [3, 3]) + + # p in (-1, 1) : use paddle.sum + x4 = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]]) + x4.stop_gradient = False + out_1 = paddle.linalg.cond(x4, p=1) + out_1.backward() + assert_shape(out_1) + self.assertEqual(x4.grad.shape, [3, 3]) + + x5 = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]]) + x5.stop_gradient = False + out_minus_1 = paddle.linalg.cond(x5, p=-1) + out_minus_1.backward() + assert_shape(out_minus_1) + self.assertEqual(x5.grad.shape, [3, 3]) + + # p in (-2, 2) depends on paddle.sum + x6 = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]]) + x6.stop_gradient = False + out_2 = paddle.linalg.cond(x6, p=2) + out_2.backward() + assert_shape(out_2) + self.assertEqual(x6.grad.shape, [3, 3]) + + # p in (-inf, inf):use paddle.sum + x8 = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]]) + x8.stop_gradient = False + out_inf = paddle.linalg.cond(x8, p=float("inf")) + out_inf.backward() + assert_shape(out_inf) + self.assertEqual(x8.grad.shape, [3, 3]) + + a = paddle.randn([2, 4, 4]) + a.stop_gradient = False + a_cond_fro = paddle.linalg.cond(a, p='fro') + a_cond_fro.backward() + self.assertEqual(len(a_cond_fro.shape), 1) + self.assertEqual(a.grad.shape, [2, 4, 4]) + + def test_trace(self): + x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32") + x.stop_gradient = False + out = paddle.trace(x) + out.backward() + + self.assertEqual(out.shape, []) + np.testing.assert_allclose(out, np.array(12)) + self.assertEqual(x.grad.shape, [2, 2]) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): def setUp(self): - paddle.disable_static() self.shape = [ paddle.full([], 2, 'int32'), paddle.full([], 3, 'int32'), @@ -1545,12 +2651,70 @@ class TestNoBackwardAPI(unittest.TestCase): self.assertEqual(one_hot_label.shape, [4]) self.assertEqual(one_hot_label.numpy()[2], 1) - def test_where(self): - x1 = paddle.full([], 1) - x2 = paddle.full([], 2) - out = paddle.where(x1 > x2, x1, x2) + def test_unique_consecutive(self): + x = paddle.rand([]) + y, inverse, counts = paddle.unique_consecutive( + x, + return_inverse=True, + return_counts=True, + ) + + self.assertEqual(y, x) + self.assertEqual(inverse, 0) + self.assertEqual(counts, 1) + self.assertEqual(y.shape, [1]) + self.assertEqual(inverse.shape, [1]) + self.assertEqual(counts.shape, [1]) + + def test_unique(self): + x = paddle.rand([]) + y, index, inverse, counts = paddle.unique( + x, + return_index=True, + return_inverse=True, + return_counts=True, + ) + + self.assertEqual(y, x) + self.assertEqual(index, 0) + self.assertEqual(inverse, 0) + self.assertEqual(counts, 1) + self.assertEqual(y.shape, [1]) + self.assertEqual(index.shape, [1]) + self.assertEqual(inverse.shape, [1]) + self.assertEqual(counts.shape, [1]) + + def test_matrix_rank(self): + x = paddle.eye(10) + x.stop_gradient = False + out = paddle.linalg.matrix_rank(x) + self.assertEqual(out.shape, []) - self.assertEqual(out.numpy(), 2) + np.testing.assert_equal(out, np.array(10)) + + c = paddle.ones(shape=[3, 4, 5]) + c.stop_gradient = False + out_c = paddle.linalg.matrix_rank(c) + self.assertEqual(out_c.shape, [3]) + np.testing.assert_equal(out_c, np.array([1, 1, 1])) + + # 2D, tol->float : OUTPUT 0D + x_tol = paddle.eye(10) + x_tol.stop_gradient = False + out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1) + self.assertEqual(out_tol.shape, []) + + # 3D, tol->float : OUTPUT 1D + c_tol = paddle.ones(shape=[3, 4, 5]) + c_tol.stop_gradient = False + out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1) + self.assertEqual(out_c_tol.shape, [3]) + + tol_2 = paddle.randn([2]) + # 2D, tol->Tensor[1,2] : OUTPUT 1D + d = paddle.eye(10) + out_d = paddle.linalg.matrix_rank(d, tol=tol_2) + self.assertEqual(out_d.shape, [2]) if __name__ == "__main__": -- GitLab