未验证 提交 5fd115f3 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] add static graph gradient test method for 0D Tensor input (#49755)

上级 8d512b8f
...@@ -17,9 +17,11 @@ import unittest ...@@ -17,9 +17,11 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid
import paddle.nn.functional as F import paddle.nn.functional as F
paddle.set_device('xpu') paddle.set_device('xpu')
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
unary_api_list = [ unary_api_list = [
...@@ -100,9 +102,7 @@ class TestUnaryAPI(unittest.TestCase): ...@@ -100,9 +102,7 @@ class TestUnaryAPI(unittest.TestCase):
for api in unary_api_list: for api in unary_api_list:
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
x.retain_grads()
out = api(x) out = api(x)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.shape, []) self.assertEqual(x.shape, [])
...@@ -138,25 +138,22 @@ reduce_api_list = [ ...@@ -138,25 +138,22 @@ reduce_api_list = [
# Use to test zero-dim of reduce API # Use to test zero-dim of reduce API
class TestReduceAPI(unittest.TestCase): class TestReduceAPI(unittest.TestCase):
def test_dygraph(self): def test_dygraph_reduce(self):
paddle.disable_static() paddle.disable_static()
for api in reduce_api_list: for api in reduce_api_list:
# 1) x is 0D
if api in [paddle.all, paddle.any]: if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, []).astype('bool') x = paddle.randint(0, 2, []).astype('bool')
out = api(x, None)
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
else: else:
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
x.retain_grads() out = api(x, None)
out = api(x, None) out.backward()
out.retain_grads()
out.backward()
self.assertEqual(x.shape, []) self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
if x.grad is not None:
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, []) self.assertEqual(out.grad.shape, [])
paddle.enable_static() paddle.enable_static()
...@@ -196,29 +193,28 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -196,29 +193,28 @@ class TestBinaryAPI(unittest.TestCase):
def test_dygraph_binary(self): def test_dygraph_binary(self):
paddle.disable_static() paddle.disable_static()
for api in binary_api_list: for api in binary_api_list:
# 1) x/y is 0D # 1) x is 0D, y is 0D
x = paddle.rand([]) x = paddle.rand([])
y = paddle.rand([]) y = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
y.stop_gradient = False y.stop_gradient = False
x.retain_grads()
y.retain_grads()
if isinstance(api, dict): if isinstance(api, dict):
out = api['func'](x, y) out = api['func'](x, y)
out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y) out_cls = getattr(paddle.Tensor, api['cls_method'])(x, y)
np.testing.assert_array_equal(out_cls.numpy(), out.numpy()) np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else: else:
out = api(x, y) out = api(x, y)
out.retain_grads()
self.assertEqual(out.shape, [])
out.backward() out.backward()
self.assertEqual(x.shape, [])
self.assertEqual(y.shape, [])
self.assertEqual(out.shape, [])
if x.grad is not None: if x.grad is not None:
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(y.grad.shape, []) self.assertEqual(y.grad.shape, [])
self.assertEqual(out.grad.shape, []) self.assertEqual(out.grad.shape, [])
# 2) x is not 0D , y is 0D # 2) x is ND, y is 0D
x = paddle.rand([2, 3, 4]) x = paddle.rand([2, 3, 4])
y = paddle.rand([]) y = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
...@@ -229,16 +225,17 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -229,16 +225,17 @@ class TestBinaryAPI(unittest.TestCase):
np.testing.assert_array_equal(out_cls.numpy(), out.numpy()) np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else: else:
out = api(x, y) out = api(x, y)
out.retain_grads()
self.assertEqual(out.shape, [2, 3, 4])
out.backward() out.backward()
self.assertEqual(x.shape, [2, 3, 4])
self.assertEqual(y.shape, [])
self.assertEqual(out.shape, [2, 3, 4])
if x.grad is not None: if x.grad is not None:
self.assertEqual(x.grad.shape, [2, 3, 4]) self.assertEqual(x.grad.shape, [2, 3, 4])
self.assertEqual(y.grad.shape, []) self.assertEqual(y.grad.shape, [])
self.assertEqual(out.grad.shape, [2, 3, 4]) self.assertEqual(out.grad.shape, [2, 3, 4])
# 3) x is 0D , y is not 0D # 3) x is 0D , y is ND
x = paddle.rand([]) x = paddle.rand([])
y = paddle.rand([2, 3, 4]) y = paddle.rand([2, 3, 4])
x.stop_gradient = False x.stop_gradient = False
...@@ -249,10 +246,11 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -249,10 +246,11 @@ class TestBinaryAPI(unittest.TestCase):
np.testing.assert_array_equal(out_cls.numpy(), out.numpy()) np.testing.assert_array_equal(out_cls.numpy(), out.numpy())
else: else:
out = api(x, y) out = api(x, y)
out.retain_grads()
self.assertEqual(out.shape, [2, 3, 4])
out.backward() out.backward()
self.assertEqual(x.shape, [])
self.assertEqual(y.shape, [2, 3, 4])
self.assertEqual(out.shape, [2, 3, 4])
if x.grad is not None: if x.grad is not None:
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(y.grad.shape, [2, 3, 4]) self.assertEqual(y.grad.shape, [2, 3, 4])
...@@ -260,26 +258,32 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -260,26 +258,32 @@ class TestBinaryAPI(unittest.TestCase):
# 4) x is 0D , y is scalar # 4) x is 0D , y is scalar
x = paddle.rand([]) x = paddle.rand([])
y = 0.5
x.stop_gradient = False x.stop_gradient = False
y = 0.5
if isinstance(api, dict): if isinstance(api, dict):
out = getattr(paddle.Tensor, api['cls_method'])(x, y) out = getattr(paddle.Tensor, api['cls_method'])(x, y)
out.backward()
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
if x.grad is not None:
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [])
for api in binary_int_api_list: for api in binary_int_api_list:
# 1) x/y is 0D # 1) x is 0D, y is 0D
x = paddle.randint(-10, 10, []) x = paddle.randint(-10, 10, [])
y = paddle.randint(-10, 10, []) y = paddle.randint(-10, 10, [])
out = api(x, y) out = api(x, y)
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
# 2) x is not 0D , y is 0D # 2) x is ND, y is 0D
x = paddle.randint(-10, 10, [3, 5]) x = paddle.randint(-10, 10, [3, 5])
y = paddle.randint(-10, 10, []) y = paddle.randint(-10, 10, [])
out = api(x, y) out = api(x, y)
self.assertEqual(out.shape, [3, 5]) self.assertEqual(out.shape, [3, 5])
# 3) x is 0D , y is not 0D # 3) x is 0D , y is ND
x = paddle.randint(-10, 10, []) x = paddle.randint(-10, 10, [])
y = paddle.randint(-10, 10, [3, 5]) y = paddle.randint(-10, 10, [3, 5])
out = api(x, y) out = api(x, y)
...@@ -374,9 +378,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -374,9 +378,7 @@ class TestSundryAPI(unittest.TestCase):
def test_pow_factor(self): def test_pow_factor(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
x.retain_grads()
out = paddle.pow(x, 2.0) out = paddle.pow(x, 2.0)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
...@@ -386,9 +388,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -386,9 +388,7 @@ class TestSundryAPI(unittest.TestCase):
def test_cast(self): def test_cast(self):
x = paddle.full([], 1.0, 'float32') x = paddle.full([], 1.0, 'float32')
x.stop_gradient = False x.stop_gradient = False
x.retain_grads()
out = paddle.cast(x, 'int32') out = paddle.cast(x, 'int32')
out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
...@@ -399,7 +399,6 @@ class TestSundryAPI(unittest.TestCase): ...@@ -399,7 +399,6 @@ class TestSundryAPI(unittest.TestCase):
x = paddle.uniform([], None, -10, 10) x = paddle.uniform([], None, -10, 10)
x.stop_gradient = False x.stop_gradient = False
out = paddle.clip(x, -5, 5) out = paddle.clip(x, -5, 5)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
...@@ -444,11 +443,11 @@ class TestSundryAPI(unittest.TestCase): ...@@ -444,11 +443,11 @@ class TestSundryAPI(unittest.TestCase):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False) x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64') index = paddle.full([], 2, 'int64')
out = paddle.gather(x, index) out = paddle.gather(x, index)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 5) self.assertEqual(out.numpy(), 5)
self.assertEqual(x.grad.shape, [5])
self.assertEqual(out.grad.shape, []) self.assertEqual(out.grad.shape, [])
def test_gather_xD_axis_0(self): def test_gather_xD_axis_0(self):
...@@ -457,61 +456,62 @@ class TestSundryAPI(unittest.TestCase): ...@@ -457,61 +456,62 @@ class TestSundryAPI(unittest.TestCase):
) )
index = paddle.full([], 1, 'int64') index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index) out = paddle.gather(x, index)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, [3]) self.assertEqual(out.shape, [3])
for i in range(3): np.testing.assert_array_equal(out.numpy(), x.numpy()[1, :])
self.assertEqual(out.numpy()[i], x.numpy()[1][i]) self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.grad.shape, [3]) self.assertEqual(out.grad.shape, [3])
def test_gather_xD_axis_1(self): def test_gather_xD_axis_1(self):
x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
index = paddle.full([], 1, 'int64') index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index, axis=1) out = paddle.gather(x, index, axis=1)
out.backward()
self.assertEqual(out.shape, [2]) self.assertEqual(out.shape, [2])
for i in range(2): np.testing.assert_array_equal(out.numpy(), [2.0, 5.0])
self.assertEqual(out.numpy()[i], x.numpy()[i][1]) self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.grad.shape, [2])
def test_scatter_1D(self): def test_scatter_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0]) x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64') index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4.0) updates = paddle.full([], 4.0)
out = paddle.scatter(x, index, updates) out = paddle.scatter(x, index, updates)
out.backward()
self.assertEqual(out.shape, [5])
self.assertEqual(out.numpy()[2], 4) self.assertEqual(out.numpy()[2], 4)
self.assertEqual(out.grad.shape, [5])
def test_scatter_XD(self): def test_scatter_XD(self):
x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
index = paddle.full([], 1, 'int64') index = paddle.full([], 1, 'int64')
updates = paddle.to_tensor([1.0, 2.0, 3.0]) updates = paddle.to_tensor([1.0, 2.0, 3.0])
out = paddle.scatter(x, index, updates) out = paddle.scatter(x, index, updates)
out.backward()
for i in range(3): self.assertEqual(out.shape, [2, 3])
self.assertEqual(out.numpy()[1][i], updates.numpy()[i]) np.testing.assert_array_equal(out.numpy()[1], [1.0, 2.0, 3.0])
self.assertEqual(out.grad.shape, [2, 3])
def test_diagflat(self): def test_diagflat(self):
x1 = paddle.rand([]) x1 = paddle.rand([])
x2 = paddle.rand([]) x2 = paddle.rand([])
x3 = paddle.rand([]) x3 = paddle.rand([])
x1.stop_gradient = False x1.stop_gradient = False
x2.stop_gradient = False x2.stop_gradient = False
x3.stop_gradient = False x3.stop_gradient = False
x1.retain_grads()
x2.retain_grads()
x3.retain_grads()
out1 = paddle.diagflat(x1, 1) out1 = paddle.diagflat(x1, 1)
out2 = paddle.diagflat(x2, -1) out2 = paddle.diagflat(x2, -1)
out3 = paddle.diagflat(x3, 0) out3 = paddle.diagflat(x3, 0)
out1.retain_grads()
out2.retain_grads()
out3.retain_grads()
out1.backward() out1.backward()
out2.backward() out2.backward()
out3.backward() out3.backward()
...@@ -541,9 +541,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -541,9 +541,7 @@ class TestSundryAPI(unittest.TestCase):
index = paddle.full([], 1, 'int64') index = paddle.full([], 1, 'int64')
updates = paddle.to_tensor([1.0, 2.0, 3.0]) updates = paddle.to_tensor([1.0, 2.0, 3.0])
out = paddle.scatter_(x, index, updates) out = paddle.scatter_(x, index, updates)
np.testing.assert_array_equal(out.numpy()[1], [1.0, 2.0, 3.0])
for i in range(3):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i])
def test_flatten(self): def test_flatten(self):
x = paddle.full([], 1, 'float32') x = paddle.full([], 1, 'float32')
...@@ -561,9 +559,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -561,9 +559,7 @@ class TestSundryAPI(unittest.TestCase):
def test_scale(self): def test_scale(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
x.retain_grads()
out = paddle.scale(x, scale=2.0, bias=1.0) out = paddle.scale(x, scale=2.0, bias=1.0)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
...@@ -598,31 +594,26 @@ class TestSundryAPI(unittest.TestCase): ...@@ -598,31 +594,26 @@ class TestSundryAPI(unittest.TestCase):
def test_reshape_list(self): def test_reshape_list(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
x.retain_grads()
out = paddle.reshape(x, []) out = paddle.reshape(x, [])
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, []) self.assertEqual(out.grad.shape, [])
out = paddle.reshape(x, [1]) out = paddle.reshape(x, [1])
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [1]) self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1]) self.assertEqual(out.grad.shape, [1])
out = paddle.reshape(x, [-1]) out = paddle.reshape(x, [-1])
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [1]) self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1]) self.assertEqual(out.grad.shape, [1])
out = paddle.reshape(x, [-1, 1]) out = paddle.reshape(x, [-1, 1])
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [1, 1]) self.assertEqual(out.shape, [1, 1])
...@@ -631,26 +622,22 @@ class TestSundryAPI(unittest.TestCase): ...@@ -631,26 +622,22 @@ class TestSundryAPI(unittest.TestCase):
def test_reshape_tensor(self): def test_reshape_tensor(self):
x = paddle.rand([1, 1]) x = paddle.rand([1, 1])
x.stop_gradient = False x.stop_gradient = False
x.retain_grads()
out = paddle.reshape(x, []) out = paddle.reshape(x, [])
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, [1, 1]) self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, []) self.assertEqual(out.grad.shape, [])
new_shape = paddle.full([], 1, "int32") new_shape = paddle.to_tensor([1, 1, 1], "int32")
out = paddle.reshape(x, new_shape) out = paddle.reshape(x, new_shape)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, [1, 1]) self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1]) self.assertEqual(out.shape, [1, 1, 1])
self.assertEqual(out.grad.shape, [1]) self.assertEqual(out.grad.shape, [1, 1, 1])
new_shape = paddle.full([], -1, "int32") new_shape = paddle.to_tensor([-1], "int32")
out = paddle.reshape(x, new_shape) out = paddle.reshape(x, new_shape)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, [1, 1]) self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1]) self.assertEqual(out.shape, [1])
...@@ -658,7 +645,6 @@ class TestSundryAPI(unittest.TestCase): ...@@ -658,7 +645,6 @@ class TestSundryAPI(unittest.TestCase):
new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")] new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")]
out = paddle.reshape(x, new_shape) out = paddle.reshape(x, new_shape)
out.retain_grads()
out.backward() out.backward()
self.assertEqual(x.grad.shape, [1, 1]) self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1, 1]) self.assertEqual(out.shape, [1, 1])
...@@ -700,13 +686,9 @@ class TestSundryAPI(unittest.TestCase): ...@@ -700,13 +686,9 @@ class TestSundryAPI(unittest.TestCase):
x2 = paddle.rand([]) x2 = paddle.rand([])
x1.stop_gradient = False x1.stop_gradient = False
x2.stop_gradient = False x2.stop_gradient = False
x1.retain_grads()
x2.retain_grads()
out1 = paddle.sort(x1, axis=-1) out1 = paddle.sort(x1, axis=-1)
out2 = paddle.sort(x2, axis=0) out2 = paddle.sort(x2, axis=0)
out1.retain_grads()
out2.retain_grads()
out1.backward() out1.backward()
out2.backward() out2.backward()
...@@ -727,13 +709,9 @@ class TestSundryAPI(unittest.TestCase): ...@@ -727,13 +709,9 @@ class TestSundryAPI(unittest.TestCase):
x2 = paddle.rand([]) x2 = paddle.rand([])
x1.stop_gradient = False x1.stop_gradient = False
x2.stop_gradient = False x2.stop_gradient = False
x1.retain_grads()
x2.retain_grads()
out1 = paddle.argsort(x1, axis=-1) out1 = paddle.argsort(x1, axis=-1)
out2 = paddle.argsort(x2, axis=0) out2 = paddle.argsort(x2, axis=0)
out1.retain_grads()
out2.retain_grads()
out1.backward() out1.backward()
out2.backward() out2.backward()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册