未验证 提交 792531b6 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] fix create_scalar to create 0D (#51024)

上级 db47dec5
......@@ -539,8 +539,8 @@ static PyObject* tensor__mul__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__mul__", 0);
if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard;
other_tensor =
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place());
other_tensor = full_ad_func(
self_tensor.shape(), value, DataType::COMPLEX64, self_tensor.place());
} else {
eager_gil_scoped_release guard;
other_tensor = full_ad_func(
......
......@@ -98,8 +98,7 @@ def monkey_patch_variable():
return var
def create_scalar(block, value, dtype):
# TODO(zhouwei): will change to [] which is 0-D Tensor
return create_tensor(block, value, dtype, shape=[1])
return create_tensor(block, value, dtype, shape=[])
def create_tensor_with_batchsize(ref_var, value, dtype):
assert isinstance(ref_var, Variable)
......
......@@ -139,12 +139,7 @@ class TestUnaryAPI(unittest.TestCase):
x = paddle.rand([])
x.stop_gradient = False
out = api(x)
# TODO(zhouwei):
# ScaleLossGradOp / append_backward set grad shape to [1]
# after output 0D, may change it to []
# use out.sum() to avoid this two problem now
loss = out.sum()
paddle.static.append_backward(loss)
paddle.static.append_backward(out)
fetch_list = [x, out]
if block.has_var(x.grad_name):
......@@ -156,12 +151,10 @@ class TestUnaryAPI(unittest.TestCase):
self.assertEqual(item.shape, ())
# 2) Test CompiledProgram Program
expect_shape = ()
compile_prog = paddle.static.CompiledProgram(main_prog)
res = exe.run(compile_prog, fetch_list=fetch_list)
for item in res:
self.assertEqual(item.shape, expect_shape)
self.assertEqual(item.shape, ())
paddle.disable_static()
......@@ -229,7 +222,7 @@ class TestReduceAPI(unittest.TestCase):
x = paddle.rand([])
x.stop_gradient = False
out = api(x, None)
paddle.static.append_backward(out.sum())
paddle.static.append_backward(out)
out_empty_list = api(x, None)
self.assertEqual(out_empty_list.shape, ())
......@@ -437,7 +430,7 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(out.shape, out_cls.shape)
else:
out = api(x, y)
paddle.static.append_backward(out.sum())
paddle.static.append_backward(out)
self.assertEqual(x.shape, ())
self.assertEqual(y.shape, ())
......@@ -464,7 +457,7 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(out.shape, out_cls.shape)
else:
out = api(x, y)
paddle.static.append_backward(out.sum())
paddle.static.append_backward(out)
self.assertEqual(x.shape, ())
self.assertEqual(y.shape, (2, 3, 4))
......@@ -491,7 +484,7 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(out.shape, out_cls.shape)
else:
out = api(x, y)
paddle.static.append_backward(out.sum())
paddle.static.append_backward(out)
self.assertEqual(x.shape, (2, 3, 4))
self.assertEqual(y.shape, ())
......@@ -505,9 +498,6 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(y_grad.shape, ())
self.assertEqual(out_grad.shape, (2, 3, 4))
# TODO(zhouwei25):
# will open this UT after fix create_scalar in static graph
'''
# 4) x is 0D , y is scalar
x = paddle.rand([])
x.stop_gradient = False
......@@ -516,17 +506,16 @@ class TestBinaryAPI(unittest.TestCase):
out = getattr(paddle.static.Variable, api['cls_method'])(
x, y
)
paddle.static.append_backward(out.sum())
paddle.static.append_backward(out)
self.assertEqual(x.shape, ())
self.assertEqual(out.shape, ())
if block.has_var(x.name):
if block.has_var(x.grad_name):
out_grad = block.var(out.grad_name)
x_grad = block.var(x.grad_name)
self.assertEqual(out_grad.shape, ())
self.assertEqual(x_grad.shape, ())
'''
for api in binary_int_api_list:
main_prog = paddle.static.Program()
......@@ -2154,10 +2143,10 @@ class TestSundryAPIStatic(unittest.TestCase):
@prog_scope()
def test_std(self):
x = paddle.rand([])
x.stop_gradient = False
out1 = paddle.std(x)
out2 = paddle.std(x, [])
paddle.static.append_backward(out1)
paddle.static.append_backward(out2)
prog = paddle.static.default_main_program()
res = self.exe.run(
......@@ -2166,19 +2155,23 @@ class TestSundryAPIStatic(unittest.TestCase):
x,
out1,
out2,
x.grad_name,
out1.grad_name,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
self.assertEqual(res[4].shape, ())
@prog_scope()
def test_var(self):
x = paddle.rand([])
x.stop_gradient = False
out1 = paddle.var(x)
out2 = paddle.var(x, [])
paddle.static.append_backward(out1)
paddle.static.append_backward(out2)
prog = paddle.static.default_main_program()
res = self.exe.run(
......@@ -2187,11 +2180,15 @@ class TestSundryAPIStatic(unittest.TestCase):
x,
out1,
out2,
x.grad_name,
out1.grad_name,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
self.assertEqual(res[4].shape, ())
@prog_scope()
def test_quantile(self):
......@@ -3651,7 +3648,7 @@ class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase):
def test_dygraph_unary(self):
paddle.disable_static()
for api in unary_apis_with_complex_input:
x = paddle.to_tensor(2.0 + 3.0j).squeeze()
x = paddle.rand([]) + 1j * paddle.rand([])
x.stop_gradient = False
x.retain_grads()
out = api(x)
......@@ -3668,7 +3665,6 @@ class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase):
def test_static_unary(self):
paddle.enable_static()
for api in unary_apis_with_complex_input:
main_prog = paddle.static.Program()
block = main_prog.global_block()
......@@ -3676,18 +3672,10 @@ class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase):
with paddle.static.program_guard(
main_prog, paddle.static.Program()
):
# before full support for complex, we cannot create complex tensor with the same code as in dynamic graph
x = paddle.complex(
paddle.to_tensor(2.0), paddle.to_tensor(2.0)
).squeeze()
x = paddle.complex(paddle.rand([]), paddle.rand([]))
x.stop_gradient = False
out = api(x)
# TODO(zhouwei):
# ScaleLossGradOp / append_backward set grad shape to [1]
# after output 0D, may change it to []
# use out.sum() to avoid this two problem now
loss = out.sum()
paddle.static.append_backward(loss)
paddle.static.append_backward(out)
fetch_list = [x, out]
if block.has_var(x.grad_name):
......@@ -3699,12 +3687,10 @@ class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase):
self.assertEqual(item.shape, ())
# 2) Test CompiledProgram Program
expect_shape = ()
compile_prog = paddle.static.CompiledProgram(main_prog)
res = exe.run(compile_prog, fetch_list=fetch_list)
for item in res:
self.assertEqual(item.shape, expect_shape)
self.assertEqual(item.shape, ())
paddle.disable_static()
......@@ -3712,66 +3698,88 @@ class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase):
class TestAsReal(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
for api in unary_apis_with_complex_input:
x = paddle.to_tensor(2.0 + 3.0j).squeeze()
x.stop_gradient = False
x.retain_grads()
out = paddle.as_real(x)
out.retain_grads()
out.backward()
x = paddle.rand([]) + 1j * paddle.rand([])
x.stop_gradient = False
x.retain_grads()
out = paddle.as_real(x)
out.retain_grads()
out.backward()
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [2])
if x.grad is not None:
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [2])
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [2])
if x.grad is not None:
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [2])
paddle.enable_static()
def test_static(self):
paddle.enable_static()
for api in unary_apis_with_complex_input:
main_prog = paddle.static.Program()
block = main_prog.global_block()
exe = paddle.static.Executor()
with paddle.static.program_guard(
main_prog, paddle.static.Program()
):
# before full support for complex, we cannot create complex tensor with the same code as in dynamic graph
x = paddle.complex(
paddle.to_tensor(2.0), paddle.to_tensor(2.0)
).squeeze()
x.stop_gradient = False
out = paddle.as_real(x)
self.assertEqual(x.shape, ())
self.assertEqual(out.shape, (2,))
# TODO(zhouwei):
# ScaleLossGradOp / append_backward set grad shape to [1]
# after output 0D, may change it to []
# use out.sum() to avoid this two problem now
loss = out.abs().sum()
paddle.static.append_backward(loss)
main_prog = paddle.static.Program()
block = main_prog.global_block()
exe = paddle.static.Executor()
with paddle.static.program_guard(main_prog, paddle.static.Program()):
x = paddle.complex(paddle.rand([]), paddle.rand([]))
x.stop_gradient = False
out = paddle.as_real(x)
self.assertEqual(x.shape, ())
self.assertEqual(out.shape, (2,))
paddle.static.append_backward(out.sum())
fetch_list = [x, out]
if block.has_var(x.grad_name):
fetch_list.extend([x.grad_name, out.grad_name])
fetch_list = [x, out]
if block.has_var(x.grad_name):
fetch_list.extend([x.grad_name, out.grad_name])
# 1) Test Program
res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2,))
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, (2,))
res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2,))
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, (2,))
# 2) Test CompiledProgram Program
expect_shapes = (), (2,), (), (2,)
compile_prog = paddle.static.CompiledProgram(main_prog)
paddle.disable_static()
res = exe.run(compile_prog, fetch_list=fetch_list)
print(res)
for actual, expect in zip(res, expect_shapes):
self.assertEqual(actual.shape, expect)
class TestAsComplex(unittest.TestCase):
def test_dygraph(self):
paddle.disable_static()
x = paddle.rand([2])
x.stop_gradient = False
x.retain_grads()
out = paddle.as_complex(x)
out.retain_grads()
out.backward()
self.assertEqual(x.shape, [2])
self.assertEqual(out.shape, [])
if x.grad is not None:
self.assertEqual(x.grad.shape, [2])
self.assertEqual(out.grad.shape, [])
paddle.enable_static()
def test_static(self):
paddle.enable_static()
main_prog = paddle.static.Program()
block = main_prog.global_block()
exe = paddle.static.Executor()
with paddle.static.program_guard(main_prog, paddle.static.Program()):
x = paddle.rand([2])
x.stop_gradient = False
out = paddle.as_complex(x)
self.assertEqual(x.shape, (2,))
self.assertEqual(out.shape, ())
paddle.static.append_backward(out.sum())
fetch_list = [x, out]
if block.has_var(x.grad_name):
fetch_list.extend([x.grad_name, out.grad_name])
res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, (2,))
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2,))
self.assertEqual(res[3].shape, ())
paddle.disable_static()
......
......@@ -4469,19 +4469,15 @@ def gcd(x, y, name=None):
y = paddle.broadcast_to(y, shape)
x = paddle.abs(x)
y = paddle.abs(y)
# TODO(zhouwei25): Support 0D for not_equal tensor with scalar
zero = paddle.full([], 0)
def _gcd_cond_fn(x, y):
# return paddle.any(y != 0)
return paddle.any(y != zero)
return paddle.any(y != 0)
def _gcd_body_fn(x, y):
# paddle.mod will raise an error when any element of y is 0. To avoid
# that, we change those zeros to ones. Their values don't matter because
# they won't be used.
# y_not_equal_0 = y != 0
y_not_equal_0 = y != zero
y_not_equal_0 = y != 0
y_safe = paddle.where(y_not_equal_0, y, paddle.ones(y.shape, y.dtype))
x, y = (
paddle.where(y_not_equal_0, y, x),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册