未验证 提交 e588f2d9 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] add 0D Tensor UT case for XPU and expand kernel support 0D (#53555)

* [Zero-Dim] add 0D Tensor UT case for XPU

* fix comment

* remove some unnecessary UT
上级 a37ef769
......@@ -37,7 +37,8 @@ void ExpandGradKernel(const Context& ctx,
// Two zero
if (out_grad_dims.size() == 0 && in_grad_dims.size() == 0) {
return;
out_grad_dims = {1};
in_grad_dims = {1};
}
int r = xpu::expand_grad<XPUType>(
......
......@@ -94,26 +94,17 @@ void ExpandKernel(const Context& ctx,
shape_size,
rank));
if (shape_size == 0) {
phi::DDim out_dims = phi::make_ddim(final_expand_shape);
out->Resize(out_dims);
ctx.template Alloc<T>(out);
int r = xpu::copy<XPUType>(ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
return;
}
DDim out_dims = phi::make_ddim(final_expand_shape);
out->Resize(out_dims);
ctx.template Alloc<T>(out);
auto& x_shape = vec_in_dims;
auto out_shape = phi::vectorize<int>(out_dims);
if (shape_size == 0) {
x_shape = {1};
out_shape = {1};
}
int r = XPU_SUCCESS;
if (std::is_same<T, bool>::value) {
auto x_data = reinterpret_cast<const int8_t*>(x.data<T>());
auto out_data = reinterpret_cast<int8_t*>(out->data<T>());
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,6 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Note:
# 0D Tensor indicates that the tensor's dimension is 0
# 0D Tensor's shape is always [], numel is 1
# which can be created by paddle.rand([])
import os
import unittest
import numpy as np
......@@ -20,6 +26,7 @@ import paddle
import paddle.nn.functional as F
paddle.set_device('xpu')
paddle.disable_static()
unary_api_list = [
paddle.nn.functional.elu,
......@@ -86,6 +93,8 @@ unary_api_list = [
paddle.bernoulli,
paddle.nn.functional.softmax,
paddle.nn.functional.log_softmax,
paddle.nn.functional.gumbel_softmax,
paddle.nn.functional.alpha_dropout,
]
inplace_api_list = [
......@@ -97,11 +106,11 @@ inplace_api_list = [
# Use to test zero-dim in unary API.
class TestUnaryAPI(unittest.TestCase):
def test_dygraph_unary(self):
paddle.disable_static()
for api in unary_api_list:
x = paddle.rand([])
x.stop_gradient = False
out = api(x)
out.retain_grads()
out.backward()
......@@ -117,8 +126,6 @@ class TestUnaryAPI(unittest.TestCase):
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
paddle.enable_static()
reduce_api_list = [
paddle.sum,
......@@ -139,7 +146,6 @@ reduce_api_list = [
# Use to test zero-dim of reduce API
class TestReduceAPI(unittest.TestCase):
def test_dygraph_reduce(self):
paddle.disable_static()
for api in reduce_api_list:
# 1) x is 0D
if api in [paddle.all, paddle.any]:
......@@ -148,13 +154,18 @@ class TestReduceAPI(unittest.TestCase):
x = paddle.rand([])
x.stop_gradient = False
out = api(x, None)
out.retain_grads()
out.retain_grads()
out.backward()
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
np.testing.assert_allclose(out.numpy(), x.numpy())
out_empty_list = api(x, [])
self.assertEqual(out_empty_list, out)
self.assertEqual(out_empty_list.shape, [])
if x.grad is not None:
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [])
......@@ -175,7 +186,35 @@ class TestReduceAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad.numpy(), np.array(3.0))
paddle.enable_static()
# 2) x is ND, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [3, 5]).astype('bool')
else:
x = paddle.rand([3, 5])
x.stop_gradient = False
out = api(x, None)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
if x.grad is not None:
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [3, 5])
# 3) x is 1D, axis=0, reduce to 0D
if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [5]).astype('bool')
else:
x = paddle.rand([5])
x.stop_gradient = False
out = api(x, 0)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
if x.grad is not None:
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [5])
binary_api_list = [
......@@ -198,33 +237,37 @@ binary_api_list = [
paddle.logical_xor,
paddle.maximum,
paddle.minimum,
paddle.fmax,
paddle.fmin,
paddle.complex,
paddle.kron,
]
binary_int_api_list = [
paddle.bitwise_and,
paddle.bitwise_or,
paddle.bitwise_xor,
paddle.gcd,
paddle.lcm,
]
# Use to test zero-dim of binary API
class TestBinaryAPI(unittest.TestCase):
def test_dygraph_binary(self):
paddle.disable_static()
for api in binary_api_list:
# 1) x is 0D, y is 0D
x = paddle.rand([])
y = paddle.rand([])
x.stop_gradient = False
y.stop_gradient = False
x.retain_grads()
y.retain_grads()
if isinstance(api, dict):
out = api['func'](x, y)
out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y)
np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else:
out = api(x, y)
out.retain_grads()
out.backward()
......@@ -247,6 +290,7 @@ class TestBinaryAPI(unittest.TestCase):
np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else:
out = api(x, y)
out.retain_grads()
out.backward()
......@@ -263,14 +307,13 @@ class TestBinaryAPI(unittest.TestCase):
y = paddle.rand([2, 3, 4])
x.stop_gradient = False
y.stop_gradient = False
x.retain_grads()
y.retain_grads()
if isinstance(api, dict):
out = api['func'](x, y)
out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y)
np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else:
out = api(x, y)
out.retain_grads()
out.backward()
......@@ -288,6 +331,7 @@ class TestBinaryAPI(unittest.TestCase):
y = 0.5
if isinstance(api, dict):
out = getattr(paddle.Tensor, api['cls_method'])(x, y)
out.retain_grads()
out.backward()
......@@ -334,14 +378,11 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(out.shape, [3, 5])
np.testing.assert_array_equal(out.numpy(), out_np)
paddle.enable_static()
# Use to test zero-dim of Sundry API, which is unique and can not be classified
# with others. It can be implemented here flexibly.
class TestSundryAPI(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.x = paddle.rand([])
def test_getitem(self):
......@@ -610,6 +651,100 @@ class TestSundryAPI(unittest.TestCase):
with self.assertRaises(ValueError):
tmp = paddle.topk(x1, k=1, axis=2)
def test_broadcast_to(self):
x = paddle.full([], 1, 'float32')
x.stop_gradient = False
out = paddle.broadcast_to(x, shape=[1])
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [1])
np.testing.assert_allclose(out, 1.0)
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, 1.0)
self.assertEqual(out.grad.shape, [1])
np.testing.assert_allclose(out.grad, 1.0)
# case2
x1 = paddle.full([], 1, 'float32')
x1.stop_gradient = False
out1 = paddle.broadcast_to(x1, shape=[])
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [])
np.testing.assert_allclose(out1, 1.0)
self.assertEqual(x1.grad.shape, [])
np.testing.assert_allclose(x1.grad, 1.0)
self.assertEqual(out1.grad.shape, [])
np.testing.assert_allclose(out1.grad, 1.0)
# case3
x2 = paddle.full([], 1, 'float32')
x2.stop_gradient = False
out2 = paddle.broadcast_to(x2, shape=[1, 1])
out2.retain_grads()
out2.backward()
self.assertEqual(out2.shape, [1, 1])
np.testing.assert_allclose(out2, 1.0)
self.assertEqual(x2.grad.shape, [])
np.testing.assert_allclose(x2.grad, 1.0)
self.assertEqual(out2.grad.shape, [1, 1])
np.testing.assert_allclose(out2.grad, 1.0)
# case4
x3 = paddle.full([], 1, 'float32')
x3.stop_gradient = False
out3 = paddle.broadcast_to(x3, shape=[3, 3])
out3.retain_grads()
out3.backward()
self.assertEqual(out3.shape, [3, 3])
np.testing.assert_allclose(out3, 1.0)
self.assertEqual(x3.grad.shape, [])
np.testing.assert_allclose(x3.grad, 9.0)
self.assertEqual(out3.grad.shape, [3, 3])
np.testing.assert_allclose(out3.grad, 1.0)
def test_broadcast_tensors(self):
# 1) x is 0D, y is 0D
x1 = paddle.full([], 2.0)
x1.stop_gradient = False
x2 = paddle.full([], 2.0)
x2.stop_gradient = False
out1, out2 = paddle.broadcast_tensors([x1, x2])
# backward has bug now
# out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
# self.assertEqual(x1.grad.shape, [])
# 2) x is ND , y is 0D
x1 = paddle.full([2, 3], 2.0)
x1.stop_gradient = False
x2 = paddle.full([], 2.0)
x2.stop_gradient = False
out1, out2 = paddle.broadcast_tensors([x1, x2])
# out1.backward()
self.assertEqual(out1.shape, [2, 3])
self.assertEqual(out2.shape, [2, 3])
# self.assertEqual(x1.grad.shape, [2, 3])
# 3) x is 0D , y is ND
x1 = paddle.full([], 2.0)
x1.stop_gradient = False
x2 = paddle.full([2, 3], 2.0)
x2.stop_gradient = False
out1, out2 = paddle.broadcast_tensors([x1, x2])
# out1.backward()
self.assertEqual(out1.shape, [2, 3])
self.assertEqual(out2.shape, [2, 3])
# self.assertEqual(x1.grad.shape, [2, 3])
def test_argmin(self):
# 1) x is 0D
x = paddle.rand([])
......@@ -679,6 +814,7 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.shape, [1, 1])
def test_median(self):
# 1) x is 0D
x = paddle.rand([])
x.stop_gradient = False
out1 = paddle.median(x, 0)
......@@ -701,149 +837,505 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, 3.0)
def test_linear(self):
x = paddle.randn([3, 2])
w = paddle.full(shape=[2, 4], fill_value=0.5)
b = paddle.zeros([])
np.testing.assert_array_equal(
F.linear(x, w, b).numpy(), F.linear(x, w).numpy()
)
# 2) x is 1D
x = paddle.rand([5])
x.stop_gradient = False
out = paddle.median(x, 0)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [5])
def test_is_floating_point(self):
self.assertTrue(paddle.is_floating_point(self.x))
# 3) x is ND
x = paddle.rand([3, 5])
x.stop_gradient = False
out = paddle.median(x, None)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [3, 5])
def test_is_integer(self):
x = paddle.randint(0, 10, [])
self.assertTrue(paddle.is_integer(x))
# 4) x is ND, keepdim=True
x = paddle.rand([3, 5])
x.stop_gradient = False
out = paddle.median(x, keepdim=True)
out.backward()
self.assertEqual(out.shape, [1, 1])
self.assertEqual(x.grad.shape, [3, 5])
def test_is_tensor(self):
self.assertTrue(paddle.is_tensor(self.x))
def test_kthvalue(self):
# 1) x is 0D
x = paddle.randn([])
x.stop_gradient = False
out, index = paddle.kthvalue(x, 1)
out.backward()
def test_is_empty(self):
x = paddle.rand([3, 0, 5])
self.assertTrue(paddle.is_empty(x))
self.assertEqual(out.shape, [])
self.assertEqual(out, x)
self.assertEqual(index.shape, [])
self.assertEqual(index, 0)
def test_isfinite(self):
out = paddle.isfinite(self.x)
np.testing.assert_array_equal(out.numpy(), np.array(True))
self.assertEqual(x.grad.shape, [])
self.assertEqual(x.grad, 1.0)
def test_isinf(self):
x = paddle.to_tensor(np.array(float('-inf')))
out = paddle.isinf(x)
np.testing.assert_array_equal(out.numpy(), np.array(True))
# 2) x is 1D
x1 = paddle.randn([5])
x1.stop_gradient = False
out1, index1 = paddle.kthvalue(x1, 1)
out1.backward()
def test_isnan(self):
x = paddle.to_tensor(np.array(float('nan')))
out = paddle.isnan(x)
np.testing.assert_array_equal(out.numpy(), np.array(True))
self.assertEqual(out1.shape, [])
self.assertEqual(index1.shape, [])
self.assertEqual(x1.grad.shape, [5])
def test_isclose(self):
out = paddle.isclose(self.x, self.x)
np.testing.assert_array_equal(out.numpy(), np.array(True))
def test_mode(self):
# 1) x is 0D
x = paddle.randn([])
x.stop_gradient = False
out, index = paddle.mode(x)
out.backward()
def test_clone(self):
out = paddle.clone(self.x)
np.testing.assert_array_equal(out.numpy(), self.x.numpy())
self.assertEqual(out.shape, [])
self.assertEqual(out, x)
self.assertEqual(index.shape, [])
self.assertEqual(index, 0)
def test_assign(self):
out = paddle.assign(self.x)
np.testing.assert_array_equal(out.numpy(), self.x.numpy())
self.assertEqual(x.grad.shape, [])
self.assertEqual(x.grad, 1.0)
def test_item(self):
x = paddle.full([], 0.5)
self.assertEqual(x.item(), 0.5)
# 2) x is 1D
x1 = paddle.randn([5])
x1.stop_gradient = False
out1, index1 = paddle.mode(x1)
out1.backward()
def test_tolist(self):
x = paddle.full([], 0.5)
self.assertEqual(x.tolist(), 0.5)
self.assertEqual(out1.shape, [])
self.assertEqual(index1.shape, [])
def test_numpy(self):
x = paddle.full([], 0.5)
np.testing.assert_array_equal(x.numpy(), np.array(0.5))
self.assertEqual(x1.grad.shape, [5])
def test_numel(self):
def test_is_empty(self):
# 1) x is 0D
out = paddle.numel(self.x)
x = paddle.rand([])
out = paddle.is_empty(x)
self.assertFalse(out)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(1))
# 2) x is ND
x = paddle.full([3, 5], 0.5)
out = paddle.numel(x)
# 2) x is 1D
x = paddle.rand([5])
out = paddle.is_empty(x)
self.assertFalse(out)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(15))
def test_rank(self):
# 1) x is 0D
out = paddle.rank(self.x)
# 3) x is ND
x = paddle.rand([3, 5])
out = paddle.is_empty(x)
self.assertFalse(out)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(0))
# 1) x is ND
x = paddle.full([3, 5], 0.5)
out = paddle.rank(x)
x = paddle.rand([3, 0, 5])
out = paddle.is_empty(x)
self.assertTrue(out)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(2))
def test_shape(self):
out = paddle.shape(self.x)
self.assertEqual(out.shape, [0])
np.testing.assert_array_equal(out.numpy(), np.array([]))
def test_pow_factor(self):
def test_squeeze_(self):
# 1) x is 0D
x = paddle.rand([])
x.squeeze_(0)
self.assertEqual(x.shape, [])
# 2) x is 1D
x = paddle.rand([1])
x.squeeze_(0)
self.assertEqual(x.shape, [])
# 3)x is ND
x = paddle.rand([2, 1])
x.squeeze_(1)
self.assertEqual(x.shape, [2])
def test_as_complex(self):
x = paddle.rand([2])
x.stop_gradient = False
x.retain_grads()
out = paddle.pow(x, 2.0)
out = paddle.as_complex(x)
out.retain_grads()
out.backward()
self.assertEqual(x.shape, [2])
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [2])
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
def test_cast(self):
x = paddle.full([], 1.0, 'float32')
def test_dot(self):
# 1) x is 1D
x = paddle.rand([2])
x.stop_gradient = False
x.retain_grads()
out = paddle.cast(x, 'int32')
y = paddle.rand([2])
y.stop_gradient = False
out = paddle.dot(x, y)
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [2])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
def test_clip(self):
x = paddle.uniform([], None, -10, 10)
# 2) x is 2D
x1 = paddle.rand([2, 2])
x1.stop_gradient = False
y1 = paddle.rand([2, 2])
y1.stop_gradient = False
out1 = paddle.dot(x1, y1)
out1.retain_grads()
out1.backward()
self.assertEqual(x1.grad.shape, [2, 2])
self.assertEqual(out1.shape, [2])
self.assertEqual(out1.grad.shape, [2])
def test_inner(self):
# 0) input is 0D
x = paddle.rand([])
x.stop_gradient = False
x.retain_grads()
out = paddle.clip(x, -5, 5)
y = paddle.rand([])
y.stop_gradient = False
out = paddle.inner(x, y)
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
def test_increment(self):
x = paddle.rand([])
# 1) input is 1D
x = paddle.rand([2])
x.stop_gradient = False
out = paddle.increment(x, 1.0)
y = paddle.rand([2])
y.stop_gradient = False
out = paddle.inner(x, y)
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [2])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
def test_bitwise_not(self):
x = paddle.randint(-1, 1, [])
out1 = ~x
out2 = paddle.bitwise_not(x)
# 2) input is 2D
x = paddle.rand([2, 3])
x.stop_gradient = False
y = paddle.rand([3, 3])
y.stop_gradient = False
out = paddle.inner(x, y)
out.retain_grads()
out.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.shape, [2, 3])
self.assertEqual(out.grad.shape, [2, 3])
def test_tensordot(self):
# 1) input is 1D
x = paddle.arange(10, dtype='float64')
x.stop_gradient = False
y = paddle.arange(10, dtype='float64')
y.stop_gradient = False
out = paddle.tensordot(x, y, axes=1)
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [10])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
# 2) input is 2D
x = paddle.arange(6, dtype='float64').reshape([2, 3])
y = paddle.arange(6, dtype='float64').reshape([2, 3])
x.stop_gradient = False
out = paddle.tensordot(x, y, axes=2)
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
def test_metric_accuracy(self):
x = paddle.full(shape=[2, 4], fill_value=0.25)
y = paddle.full(shape=[2, 1], fill_value=1, dtype="int64")
out = paddle.metric.accuracy(input=x, label=y, k=1)
self.assertEqual(out.shape, [])
def test_std(self):
# 1) x is 0D
x = paddle.rand([])
x.stop_gradient = False
out1 = paddle.std(x)
out2 = paddle.std(x, [])
out1.backward()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
self.assertEqual(out1, 0)
self.assertEqual(out2, 0)
self.assertEqual(x.grad.shape, [])
# 2) x is ND
x = paddle.rand([3, 5])
x.stop_gradient = False
out = paddle.std(x)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [3, 5])
def test_var(self):
# 1) x is 0D
x = paddle.rand([])
x.stop_gradient = False
out1 = paddle.var(x)
out2 = paddle.var(x, [])
out1.backward()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
self.assertEqual(out1, 0)
self.assertEqual(out2, 0)
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, 0)
# 2) x is ND
x = paddle.rand([3, 5])
x.stop_gradient = False
out = paddle.std(x)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [3, 5])
def test_quantile(self):
# 1) x is 0D
x = paddle.rand([])
x.stop_gradient = False
out = paddle.quantile(x, 0.5, axis=None)
out.retain_grads()
out.backward()
out_empty_list = paddle.quantile(x, 0.5, axis=[])
self.assertEqual(out_empty_list, out)
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out, x)
self.assertEqual(x.grad.shape, [])
self.assertEqual(x.grad, 1.0)
self.assertEqual(out.grad.shape, [])
self.assertEqual(out.grad, 1.0)
# 2) x is ND
x = paddle.rand([2, 3])
x.stop_gradient = False
out = paddle.quantile(x, 0.5, axis=None)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
self.assertEqual(out.grad, 1.0)
self.assertEqual(x.grad.shape, [2, 3])
def test_flip(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.flip(x, axis=[])
out.retain_grads()
out.backward()
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [])
def test_linear(self):
x = paddle.randn([3, 2])
w = paddle.full(shape=[2, 4], fill_value=0.5)
b = paddle.zeros([])
np.testing.assert_array_equal(
F.linear(x, w, b).numpy(), F.linear(x, w).numpy()
)
def test_is_floating_point(self):
self.assertTrue(paddle.is_floating_point(self.x))
def test_is_integer(self):
x = paddle.randint(0, 10, [])
self.assertTrue(paddle.is_integer(x))
def test_is_tensor(self):
self.assertTrue(paddle.is_tensor(self.x))
def test_isfinite(self):
out = paddle.isfinite(self.x)
np.testing.assert_array_equal(out.numpy(), np.array(True))
def test_isinf(self):
x = paddle.to_tensor(np.array(float('-inf')))
out = paddle.isinf(x)
np.testing.assert_array_equal(out.numpy(), np.array(True))
def test_isnan(self):
x = paddle.to_tensor(np.array(float('nan')))
out = paddle.isnan(x)
np.testing.assert_array_equal(out.numpy(), np.array(True))
def test_isclose(self):
out = paddle.isclose(self.x, self.x)
np.testing.assert_array_equal(out.numpy(), np.array(True))
def test_clone(self):
out = paddle.clone(self.x)
np.testing.assert_array_equal(out.numpy(), self.x.numpy())
def test_assign(self):
out = paddle.assign(self.x)
np.testing.assert_array_equal(out.numpy(), self.x.numpy())
def test_item(self):
x = paddle.full([], 0.5)
self.assertEqual(x.item(), 0.5)
def test_tolist(self):
x = paddle.full([], 0.5)
self.assertEqual(x.tolist(), 0.5)
def test_numpy(self):
x = paddle.full([], 0.5)
x_np = x.numpy()
np.testing.assert_array_equal(x_np.shape, ())
np.testing.assert_array_equal(x_np, np.array(0.5))
x_np = x.numpy(False)
np.testing.assert_array_equal(x_np.shape, ())
np.testing.assert_array_equal(x_np, np.array(0.5))
def test_numel(self):
# 1) x is 0D
out = paddle.numel(self.x)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(1))
# 2) x is ND
x = paddle.full([3, 5], 0.5)
out = paddle.numel(x)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(15))
def test_rank(self):
# 1) x is 0D
x = paddle.rand([])
out = paddle.rank(x)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(0))
# 1) x is ND
x = paddle.full([3, 5], 0.5)
out = paddle.rank(x)
self.assertEqual(out.shape, [])
np.testing.assert_array_equal(out.numpy(), np.array(2))
def test_shape(self):
out = paddle.shape(self.x)
np.testing.assert_array_equal(out.numpy(), np.array([]))
self.assertEqual(out.shape, [0])
def test_equal_scalar(self):
x = paddle.rand([])
out = paddle.equal(x, 2.0)
self.assertEqual(out.shape, [])
self.assertEqual(out, False)
x1 = paddle.full([], 2.0)
out1 = paddle.equal(x1, 2.0)
self.assertEqual(out1.shape, [])
self.assertEqual(out1, True)
def test_pow_scalar(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.pow(x, 2.0)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
def test_cast(self):
x = paddle.full([], 1.0, 'float32')
x.stop_gradient = False
out = paddle.cast(x, 'int32')
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
def test_cumprod(self):
x = paddle.full([], 1.0, 'float32')
x.stop_gradient = False
out = paddle.cumprod(x, 0)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
with self.assertRaises(ValueError):
tmp = paddle.cumprod(x, 2)
def test_clip(self):
x = paddle.uniform([], None, -10, 10)
x.stop_gradient = False
out = paddle.clip(x, -5, 5)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
x1 = paddle.uniform([], None, -10, 10)
x1.stop_gradient = False
out1 = paddle.clip(x1, paddle.full([], 5.0), paddle.full([], 5.0))
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out1.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
def test_increment(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.increment(x, 1.0)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
def test_bitwise_not(self):
x = paddle.randint(-1, 1, [])
out1 = ~x
out2 = paddle.bitwise_not(x)
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
def test_logical_not(self):
x = paddle.randint(0, 1, [])
......@@ -852,10 +1344,10 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.shape, [])
def test_searchsorted(self):
# have no backward
x = paddle.to_tensor([1, 3, 5, 7, 9])
y = paddle.rand([])
# only has forward kernel
out = paddle.searchsorted(x, y)
self.assertEqual(out.shape, [])
......@@ -925,11 +1417,85 @@ class TestSundryAPI(unittest.TestCase):
)
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index, axis=1)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [2])
np.testing.assert_array_equal(out.numpy(), [2.0, 5.0])
self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.grad.shape, [2])
def test_gather_nd(self):
x1 = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
x2 = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
index1 = paddle.full([1], 1, 'int64')
index2 = paddle.full([2], 1, 'int64')
out1 = paddle.gather_nd(x1, index1)
out2 = paddle.gather_nd(x2, index2)
out1.retain_grads()
out2.retain_grads()
out1.backward()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
np.testing.assert_array_equal(out1, np.array(3.0))
np.testing.assert_array_equal(out2, np.array(5.0))
self.assertEqual(x1.grad.shape, [5])
self.assertEqual(x2.grad.shape, [2, 3])
self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.grad.shape, [])
def test_einsum(self):
os.environ['FLAGS_new_einsum'] = "0"
x = paddle.rand([5])
# sum
out1 = paddle.einsum('i->', x)
expect1 = np.einsum('i->', x)
# dot
out2 = paddle.einsum('i,i->', x, x)
expect2 = np.einsum('i,i->', x, x)
out1.retain_grads()
out2.retain_grads()
out1.backward()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
np.testing.assert_allclose(out1, expect1, rtol=1e-03)
np.testing.assert_allclose(out2, expect2, rtol=1e-03)
def test_einsum_V2(self):
os.environ['FLAGS_new_einsum'] = "1"
x = paddle.rand([5])
# sum
out1 = paddle.einsum('i->', x)
expect1 = np.einsum('i->', x)
# dot
out2 = paddle.einsum('i,i->', x, x)
expect2 = np.einsum('i,i->', x, x)
out1.retain_grads()
out2.retain_grads()
out1.backward()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, [])
np.testing.assert_allclose(out1, expect1, rtol=1e-03)
np.testing.assert_allclose(out2, expect2, rtol=1e-03)
def test_scatter_1D(self):
# have no backward now
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4.0)
......@@ -939,6 +1505,7 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.numpy()[2], 4)
def test_scatter_XD(self):
# have no backward now
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
......@@ -1000,24 +1567,39 @@ class TestSundryAPI(unittest.TestCase):
out = paddle.scatter_(x, index, updates)
np.testing.assert_array_equal(out.numpy()[1], [1.0, 2.0, 3.0])
def test_scatter_nd(self):
index = paddle.to_tensor([3], dtype="int64")
updates = paddle.full([], 2, dtype='float32')
out = paddle.scatter_nd(index, updates, [5])
self.assertEqual(out.shape, [5])
self.assertEqual(out.numpy()[3], 2)
def test_flatten(self):
x = paddle.full([], 1, 'float32')
x = paddle.rand([])
x.stop_gradient = False
start_axis = 0
stop_axis = -1
out = paddle.flatten(x, start_axis=start_axis, stop_axis=stop_axis)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1])
self.assertEqual(x.grad.shape, [])
def test_histogram(self):
x = paddle.rand([])
out = paddle.histogram(x, bins=5, min=1, max=5)
self.assertEqual(out.shape, [5])
def test_scale(self):
x = paddle.rand([])
x.stop_gradient = False
x.retain_grads()
out = paddle.scale(x, scale=2.0, bias=1.0)
out.retain_grads()
out.backward()
......@@ -1025,6 +1607,11 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
def test_scale_(self):
x = paddle.rand([])
out = x.scale_(scale=2.0, bias=1.0)
self.assertEqual(out.shape, [])
def test_floor_divide(self):
# 1-d // 0-d
x = paddle.to_tensor([1, -2, 3], dtype="int64")
......@@ -1066,12 +1653,17 @@ class TestSundryAPI(unittest.TestCase):
out2.backward()
out3.backward()
self.assertEqual(x1.grad.shape, [])
self.assertTrue(x1.grad.numpy() == 3)
self.assertEqual(out1.shape, [1])
self.assertEqual(out1.grad.shape, [1])
self.assertTrue(out1.grad.numpy() == 1)
self.assertEqual(out2.shape, [])
self.assertEqual(out2.grad.shape, [])
self.assertTrue(out2.grad.numpy() == 1)
self.assertEqual(out3.shape, [])
self.assertEqual(out3.grad.shape, [])
self.assertTrue(out3.grad.numpy() == 1)
def test_add_n(self):
x1 = paddle.rand([])
......@@ -1090,6 +1682,12 @@ class TestSundryAPI(unittest.TestCase):
out1.backward()
out2.backward()
self.assertEqual(x1.grad.shape, [])
self.assertTrue(x1.grad.numpy() == 1)
self.assertEqual(x2.grad.shape, [])
self.assertTrue(x2.grad.numpy() == 1)
self.assertEqual(x3.grad.shape, [])
self.assertTrue(x3.grad.numpy() == 1)
self.assertEqual(out1.shape, [])
self.assertEqual(out1.grad.shape, [])
self.assertEqual(out2.shape, [])
......@@ -1098,8 +1696,8 @@ class TestSundryAPI(unittest.TestCase):
def test_reshape_list(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.reshape(x, [])
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [])
......@@ -1130,8 +1728,8 @@ class TestSundryAPI(unittest.TestCase):
def test_reshape_tensor(self):
x = paddle.rand([1, 1])
x.stop_gradient = False
out = paddle.reshape(x, [])
out.retain_grads()
out.backward()
self.assertEqual(x.grad.shape, [1, 1])
......@@ -1193,15 +1791,23 @@ class TestSundryAPI(unittest.TestCase):
out = paddle.reshape_(x, new_shape)
self.assertEqual(out.shape, [1, 1])
def test_reverse(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.reverse(x, axis=[])
out.retain_grads()
out.backward()
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
def test_sort(self):
x1 = paddle.rand([])
x2 = paddle.rand([])
x1.stop_gradient = False
x2.stop_gradient = False
x1.retain_grads()
x2.retain_grads()
out1 = paddle.sort(x1, axis=-1)
out2 = paddle.sort(x2, axis=0)
......@@ -1250,36 +1856,153 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x1.grad.numpy(), 0)
self.assertEqual(x2.grad.numpy(), 0)
def test_sigmoid_focal_loss(self):
logit = paddle.to_tensor(
[[0.97, 0.91, 0.03], [0.55, 0.43, 0.71]],
dtype='float32',
stop_gradient=False,
)
label = paddle.to_tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype='float32'
)
fg_num_0 = paddle.full([], 2.0)
fg_num_1 = paddle.full([1], 2.0)
def test_lerp(self):
# 0D + 0D, weight is float scalar
x = paddle.rand([])
y = paddle.rand([])
x.stop_gradient = False
y.stop_gradient = False
out = paddle.lerp(x, y, 0.5)
out.backward()
out0 = F.sigmoid_focal_loss(logit, label, normalizer=fg_num_0)
out1 = F.sigmoid_focal_loss(logit, label, normalizer=fg_num_1)
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [])
self.assertEqual(y.grad.shape, [])
np.testing.assert_array_equal(
out0.numpy(),
out1.numpy(),
)
self.assertEqual(out0.shape, [])
# 0D + 0D, weigh is 0D
x0 = paddle.rand([])
y0 = paddle.rand([])
w0 = paddle.rand([])
x0.stop_gradient = False
y0.stop_gradient = False
y0.retain_grads()
out0.retain_grads()
out0 = paddle.lerp(x0, y0, w0)
out0.backward()
self.assertEqual(out0.grad.shape, [])
self.assertEqual(logit.grad.shape, [2, 3])
self.assertEqual(out0.shape, [])
self.assertEqual(x0.grad.shape, [])
self.assertEqual(y0.grad.shape, [])
# 0D + ND
x1 = paddle.rand([])
y1 = paddle.rand([64, 64])
w1 = paddle.rand([])
x1.stop_gradient = False
y1.stop_gradient = False
x1.retain_grads()
y1.retain_grads()
out1 = paddle.lerp(x1, y1, w1)
out1.backward()
self.assertEqual(out1.shape, [64, 64])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(y1.grad.shape, [64, 64])
# ND + 0D
x2 = paddle.rand([64, 64])
y2 = paddle.rand([])
w2 = paddle.rand([])
x2.stop_gradient = False
y2.stop_gradient = False
x2.retain_grads()
y2.retain_grads()
out2 = paddle.lerp(x2, y2, w2)
out2.backward()
self.assertEqual(out2.shape, [64, 64])
self.assertEqual(x2.grad.shape, [64, 64])
self.assertEqual(y2.grad.shape, [])
def test_repeat_interleave(self):
x = paddle.randn(())
x.stop_gradient = False
out = paddle.repeat_interleave(x, 2, None)
out.backward()
# check shape of output
self.assertEqual(out.shape, [2])
# check grad shape
self.assertEqual(x.grad.shape, [])
repeats = paddle.to_tensor([3], dtype='int32')
out = paddle.repeat_interleave(x, repeats, None)
# check shape of output with 1D repeats
self.assertEqual(out.shape, [3])
# check grad shape with 1D repeats
self.assertEqual(x.grad.shape, [])
def test_allclose(self):
# 1) x is 0D
x = paddle.full([], 0.5)
y = paddle.full([], 0.6)
out = paddle.allclose(x, y)
self.assertEqual(out.shape, [])
self.assertFalse(out)
# 2) x is ND
x = paddle.full([2, 3], 0.5)
y = paddle.full([2, 3], 0.6)
out = paddle.allclose(x, y)
self.assertEqual(out.shape, [])
self.assertFalse(out)
def test_equal_all(self):
# 1) x is 0D
x = paddle.full([], 0.5)
y = paddle.full([], 0.6)
self.assertFalse(paddle.allclose(x, y))
out = paddle.equal_all(x, y)
self.assertEqual(out.shape, [])
self.assertFalse(out)
# 2) x is ND
x = paddle.full([2, 3], 0.5)
y = paddle.full([2, 3], 0.6)
out = paddle.equal_all(x, y)
self.assertEqual(out.shape, [])
self.assertFalse(out)
def test_where(self):
x1 = paddle.full([], 1)
x2 = paddle.full([], 2)
x1.stop_gradient = False
x2.stop_gradient = False
x1.retain_grads()
x2.retain_grads()
out = paddle.where(x1 > x2, x1, x2)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 2)
self.assertEqual(out.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
self.assertEqual(x1.grad.numpy(), 0)
self.assertEqual(x2.grad.numpy(), 1)
def test_atan2(self):
x1 = paddle.full([], 0)
x2 = paddle.full([], 2)
x1.retain_grads()
x2.retain_grads()
x1.stop_gradient = False
x2.stop_gradient = False
out = paddle.atan2(x1, x2)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 0)
self.assertEqual(out.grad.shape, [])
self.assertEqual(x1.grad.shape, [])
self.assertEqual(x2.grad.shape, [])
self.assertEqual(x1.grad.numpy(), 0.5)
self.assertEqual(x2.grad.numpy(), 0)
def test_interpolate(self):
from paddle.nn.functional import interpolate
......@@ -1336,12 +2059,73 @@ class TestSundryAPI(unittest.TestCase):
origin_result.numpy(), out3.numpy(), rtol=1e-05
)
def test_equalall(self):
x = paddle.full([], 0.5)
y = paddle.full([], 0.6)
out = paddle.equal_all(x, y)
self.assertEqual(out.shape, [])
self.assertFalse(out)
def test_upsample(self):
from paddle.nn.functional import upsample
input_x = paddle.rand([2, 3, 6, 6])
input_x.stop_gradient = False
output_size = [
paddle.full([], 12, dtype="int32"),
paddle.full([], 12, dtype="int32"),
]
out1 = upsample(
x=input_x, size=output_size, mode="bilinear", align_corners=False
)
out1.backward()
self.assertEqual(out1.shape, [2, 3, 12, 12])
self.assertEqual(input_x.grad.shape, [2, 3, 6, 6])
def test_unstack(self):
x1 = paddle.full([1], 0)
x2 = paddle.full([2], 2)
x1.retain_grads()
x2.retain_grads()
x1.stop_gradient = False
x2.stop_gradient = False
[out1] = paddle.unstack(x1, 0)
out1.retain_grads()
out1.backward()
[out2_1, out2_2] = paddle.unstack(x2, 0)
out2 = paddle.add_n([out2_1, out2_2])
out2.retain_grads()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out1.numpy(), 0)
self.assertEqual(out2_1.shape, [])
self.assertEqual(out2_1.numpy(), 2)
self.assertEqual(out2_2.shape, [])
self.assertEqual(out2_2.numpy(), 2)
self.assertEqual(x2.grad.shape, [2])
def test_unbind(self):
x1 = paddle.full([1], 0)
x2 = paddle.full([2], 2)
x1.retain_grads()
x2.retain_grads()
x1.stop_gradient = False
x2.stop_gradient = False
[out1] = paddle.unbind(x1, 0)
out1.retain_grads()
out1.backward()
[out2_1, out2_2] = paddle.unbind(x2, 0)
out2 = paddle.add_n([out2_1, out2_2])
out2.retain_grads()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out1.numpy(), 0)
self.assertEqual(out2_1.shape, [])
self.assertEqual(out2_1.numpy(), 2)
self.assertEqual(out2_2.shape, [])
self.assertEqual(out2_2.numpy(), 2)
self.assertEqual(x2.grad.shape, [2])
def test_maseked_select(self):
x = paddle.rand([])
......@@ -1357,18 +2141,53 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, [])
self.assertEqual(x.grad.numpy(), 1)
def test_squeeze(self):
x1 = paddle.full([], 2)
x1.stop_gradient = False
x1.retain_grads()
out1 = paddle.squeeze(x1, axis=0)
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(x1.grad.shape, [])
x2 = paddle.full([], 3)
x3 = paddle.full([1], 0, dtype='int32')
x2.stop_gradient = False
x2.retain_grads()
out2 = paddle.squeeze(x2, axis=x3)
out2.retain_grads()
out2.backward()
self.assertEqual(out2.shape, [])
self.assertEqual(x2.grad.shape, [])
def test_unsqueeze(self):
x1 = paddle.full([], 2)
x1.stop_gradient = False
x1.retain_grads()
out1 = paddle.unsqueeze(x1, axis=0)
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [1])
self.assertEqual(x1.grad.shape, [])
x2 = paddle.full([], 0, dtype='int32')
out2 = paddle.unsqueeze(x1, axis=x2)
out2.retain_grads()
out2.backward()
self.assertEqual(out2.shape, [1])
self.assertEqual(x1.grad.shape, [])
def test_t(self):
x = paddle.full([], 2.0)
x.stop_gradient = False
x.retain_grads()
out = paddle.t(x)
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, [])
def test_prelu(self):
x1 = paddle.full([], 1.0, 'float32')
......@@ -1397,11 +2216,298 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(x2.grad.shape, [])
self.assertEqual(x2.grad.numpy(), 0.25)
def test_while_loop(self):
def cond(i, x):
return paddle.less_than(i, eleven)
def body(i, x):
x = x + i
i = i + 1
return [i, x]
i = paddle.full([], 1.0, dtype='float32')
i.stop_gradient = False
eleven = paddle.full([], 11, dtype='float32')
x = paddle.full([], 0.0, dtype='float32')
x.stop_gradient = False
out_i, out_x = paddle.static.nn.while_loop(cond, body, [i, x])
out_x.backward()
self.assertEqual(out_i.shape, [])
np.testing.assert_allclose(out_i, np.array(11))
self.assertEqual(out_x.shape, [])
np.testing.assert_allclose(out_x, np.array(55))
self.assertEqual(i.grad.shape, [])
np.testing.assert_allclose(i.grad, np.array(10))
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, np.array(1.0))
def test_to_tensor(self):
out1 = paddle.to_tensor(1)
out2 = paddle.to_tensor(2.5)
out1.retain_grads()
out1.backward()
out2.retain_grads()
out2.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(out1, 1)
self.assertEqual(out2.shape, [])
self.assertEqual(out2, 2.5)
def test_linalg_slogdet(self):
# 2-D input
x = paddle.randn([3, 3])
x.stop_gradient = False
out = paddle.linalg.slogdet(x)
out.retain_grads()
out.backward()
self.assertTrue(out.shape, [2])
self.assertTrue(x.grad.shape, [3, 3])
# 3-D input
x1 = paddle.randn([3, 3, 3])
x1.stop_gradient = False
out1 = paddle.linalg.slogdet(x1)
out1.retain_grads()
out1.backward()
self.assertTrue(out1.shape, [2, 3])
self.assertTrue(x1.grad.shape, [3, 3, 3])
def test_multi_dot(self):
a = paddle.randn([4])
a.stop_gradient = False
b = paddle.randn([4, 5])
b.stop_gradient = False
c = paddle.randn([5])
c.stop_gradient = False
out = paddle.linalg.multi_dot([a, b, c])
out.retain_grads()
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(a.grad.shape, [4])
self.assertEqual(b.grad.shape, [4, 5])
self.assertEqual(c.grad.shape, [5])
def test_cov(self):
xt = paddle.randn((3, 4))
xt.stop_gradient = False
xt_1 = paddle.randn((12,))
xt_1.stop_gradient = False
xt_out = paddle.linalg.cov(xt)
xt_out.retain_grads()
xt_out.backward()
self.assertEqual(xt_out.shape, [3, 3])
self.assertEqual(xt.grad.shape, [3, 4])
xt_1_out = paddle.linalg.cov(xt_1)
xt_1.retain_grads()
xt_1_out.backward()
self.assertEqual(xt_1_out.shape, [])
self.assertEqual(xt_1.grad.shape, [12])
def test_det(self):
xt = paddle.randn([3, 3, 3])
xt.stop_gradient = False
xt_1 = paddle.randn([3, 3])
xt_1.stop_gradient = False
xt_out = paddle.linalg.det(xt)
xt.retain_grads()
xt_out.backward()
self.assertEqual(xt_out.shape, [3])
self.assertEqual(xt.grad.shape, [3, 3, 3])
xt_1_out = paddle.linalg.det(xt_1)
xt_1.retain_grads()
xt_1_out.backward()
self.assertEqual(xt_1_out.shape, [])
self.assertEqual(xt_1.grad.shape, [3, 3])
def test_dist(self):
x = paddle.to_tensor([[3, 3], [3, 3]], dtype="float32")
y = paddle.to_tensor([[3, 3], [3, 1]], dtype="float32")
x.stop_gradient = False
y.stop_gradient = False
out = paddle.dist(x, y, 0)
out.backward()
self.assertEqual(out.shape, [])
np.testing.assert_allclose(out, np.array(1))
self.assertEqual(x.grad.shape, [2, 2])
self.assertEqual(y.grad.shape, [2, 2])
def test_linalg_norm(self):
# 1D input, p = fro ,axis = None, using reduceInferMeta
x_1 = paddle.arange(24, dtype="float32") - 12
x_1.stop_gradient = False
out_1 = paddle.linalg.norm(x_1)
out_1.retain_grads()
out_1.backward()
self.assertEqual(out_1.shape, [])
self.assertTrue(x_1.grad.shape, [24])
# 1D input, p = 1 ,axis = None,
# using p_nrom, as_vector = True
x_2 = paddle.arange(24, dtype="float32") - 12
x_2.stop_gradient = False
out_2 = paddle.linalg.norm(x_2, p=1)
out_2.retain_grads()
out_2.backward()
self.assertEqual(out_2.shape, [])
self.assertEqual(x_2.grad.shape, [24])
# 1D input, p = 1 ,axis = 0,
# using p_nrom, as_vector = False
x_2_p = paddle.arange(24, dtype="float32") - 12
x_2_p.stop_gradient = False
out_2_p = paddle.linalg.norm(x_2_p, p=1, axis=0)
out_2_p.retain_grads()
out_2_p.backward()
self.assertEqual(out_2_p.shape, [])
self.assertEqual(x_2_p.grad.shape, [24])
# 1D input, p = fro ,axis = 0,
# using p_nrom, as_vector = False
x_2_fro = paddle.arange(24, dtype="float32") - 12
x_2_fro.stop_gradient = False
out_2_fro = paddle.linalg.norm(x_2_fro, p="fro", axis=0)
out_2_fro.retain_grads()
out_2_fro.backward()
self.assertEqual(out_2_fro.shape, [])
self.assertEqual(x_2_fro.grad.shape, [24])
# 2D input, p = 1, axis = [0, 1]
# using p_matrix_norm ,depends on paddle.sum
x_3 = paddle.arange(24, dtype="float32").reshape([4, 6])
x_3.stop_gradient = False
out_3 = paddle.linalg.norm(x_3, p=1, axis=[0, 1])
out_3.retain_grads()
out_3.backward()
self.assertEqual(out_3.shape, [])
self.assertEqual(x_3.grad.shape, [4, 6])
# 2D input, p = 1, axis = None
# using p_matrix_norm, depends on paddle.sum
x_4 = paddle.arange(24, dtype="float32").reshape([4, 6])
x_4.stop_gradient = False
out_4 = paddle.linalg.norm(x_4)
out_4.retain_grads()
out_4.backward()
self.assertEqual(out_4.shape, [])
self.assertEqual(x_4.grad.shape, [4, 6])
# 2D input, p = inf, axis = [0, 1]
# using p_matrix_norm, depends on paddle.sum
x_5 = paddle.arange(24, dtype="float32").reshape([4, 6])
x_5.stop_gradient = False
out_5 = paddle.linalg.norm(x_5, p=2, axis=[0, 1])
out_5.retain_grads()
out_5.backward()
self.assertEqual(out_5.shape, [])
self.assertEqual(x_5.grad.shape, [4, 6])
# 2D input, p = -inf, axis = [0, 1]
x_6 = paddle.arange(24, dtype="float32").reshape([4, 6])
x_6.stop_gradient = False
out_6 = paddle.linalg.norm(x_6, p=-float("inf"), axis=[0, 1])
out_6.retain_grads()
out_6.backward()
self.assertEqual(out_6.shape, [])
self.assertEqual(x_6.grad.shape, [4, 6])
def test_linalg_cond(self):
def assert_shape(out):
self.assertEqual(out.shape, [])
x1 = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]])
x1.stop_gradient = False
# p = 2 : use paddle.sum
out = paddle.linalg.cond(x1)
out.backward()
assert_shape(out)
self.assertEqual(x1.grad.shape, [3, 3])
# p = fro : use paddle.sum
x2 = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]])
x2.stop_gradient = False
out_fro = paddle.linalg.cond(x2, p='fro')
out_fro.backward()
assert_shape(out_fro)
self.assertEqual(x2.grad.shape, [3, 3])
# p = nuc : use paddle.sum
x3 = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]])
x3.stop_gradient = False
out_nuc = paddle.linalg.cond(x3, p='nuc')
out_nuc.backward()
assert_shape(out_nuc)
self.assertEqual(x3.grad.shape, [3, 3])
# p in (-1, 1) : use paddle.sum
x4 = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]])
x4.stop_gradient = False
out_1 = paddle.linalg.cond(x4, p=1)
out_1.backward()
assert_shape(out_1)
self.assertEqual(x4.grad.shape, [3, 3])
x5 = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]])
x5.stop_gradient = False
out_minus_1 = paddle.linalg.cond(x5, p=-1)
out_minus_1.backward()
assert_shape(out_minus_1)
self.assertEqual(x5.grad.shape, [3, 3])
# p in (-2, 2) depends on paddle.sum
x6 = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]])
x6.stop_gradient = False
out_2 = paddle.linalg.cond(x6, p=2)
out_2.backward()
assert_shape(out_2)
self.assertEqual(x6.grad.shape, [3, 3])
# p in (-inf, inf):use paddle.sum
x8 = paddle.to_tensor([[1.0, 0, -1], [0, 1, 0], [1, 0, 1]])
x8.stop_gradient = False
out_inf = paddle.linalg.cond(x8, p=float("inf"))
out_inf.backward()
assert_shape(out_inf)
self.assertEqual(x8.grad.shape, [3, 3])
a = paddle.randn([2, 4, 4])
a.stop_gradient = False
a_cond_fro = paddle.linalg.cond(a, p='fro')
a_cond_fro.backward()
self.assertEqual(len(a_cond_fro.shape), 1)
self.assertEqual(a.grad.shape, [2, 4, 4])
def test_trace(self):
x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32")
x.stop_gradient = False
out = paddle.trace(x)
out.backward()
self.assertEqual(out.shape, [])
np.testing.assert_allclose(out, np.array(12))
self.assertEqual(x.grad.shape, [2, 2])
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.shape = [
paddle.full([], 2, 'int32'),
paddle.full([], 3, 'int32'),
......@@ -1545,12 +2651,70 @@ class TestNoBackwardAPI(unittest.TestCase):
self.assertEqual(one_hot_label.shape, [4])
self.assertEqual(one_hot_label.numpy()[2], 1)
def test_where(self):
x1 = paddle.full([], 1)
x2 = paddle.full([], 2)
out = paddle.where(x1 > x2, x1, x2)
def test_unique_consecutive(self):
x = paddle.rand([])
y, inverse, counts = paddle.unique_consecutive(
x,
return_inverse=True,
return_counts=True,
)
self.assertEqual(y, x)
self.assertEqual(inverse, 0)
self.assertEqual(counts, 1)
self.assertEqual(y.shape, [1])
self.assertEqual(inverse.shape, [1])
self.assertEqual(counts.shape, [1])
def test_unique(self):
x = paddle.rand([])
y, index, inverse, counts = paddle.unique(
x,
return_index=True,
return_inverse=True,
return_counts=True,
)
self.assertEqual(y, x)
self.assertEqual(index, 0)
self.assertEqual(inverse, 0)
self.assertEqual(counts, 1)
self.assertEqual(y.shape, [1])
self.assertEqual(index.shape, [1])
self.assertEqual(inverse.shape, [1])
self.assertEqual(counts.shape, [1])
def test_matrix_rank(self):
x = paddle.eye(10)
x.stop_gradient = False
out = paddle.linalg.matrix_rank(x)
self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 2)
np.testing.assert_equal(out, np.array(10))
c = paddle.ones(shape=[3, 4, 5])
c.stop_gradient = False
out_c = paddle.linalg.matrix_rank(c)
self.assertEqual(out_c.shape, [3])
np.testing.assert_equal(out_c, np.array([1, 1, 1]))
# 2D, tol->float : OUTPUT 0D
x_tol = paddle.eye(10)
x_tol.stop_gradient = False
out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1)
self.assertEqual(out_tol.shape, [])
# 3D, tol->float : OUTPUT 1D
c_tol = paddle.ones(shape=[3, 4, 5])
c_tol.stop_gradient = False
out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1)
self.assertEqual(out_c_tol.shape, [3])
tol_2 = paddle.randn([2])
# 2D, tol->Tensor[1,2] : OUTPUT 1D
d = paddle.eye(10)
out_d = paddle.linalg.matrix_rank(d, tol=tol_2)
self.assertEqual(out_d.shape, [2])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册