未验证 提交 792531b6 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] fix create_scalar to create 0D (#51024)

上级 db47dec5
...@@ -539,8 +539,8 @@ static PyObject* tensor__mul__method(TensorObject* self, ...@@ -539,8 +539,8 @@ static PyObject* tensor__mul__method(TensorObject* self,
CastPyArg2Scalar(other_obj, "__mul__", 0); CastPyArg2Scalar(other_obj, "__mul__", 0);
if (PyComplex_Check(other_obj)) { if (PyComplex_Check(other_obj)) {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = other_tensor = full_ad_func(
full_ad_func({1}, value, DataType::COMPLEX64, self_tensor.place()); self_tensor.shape(), value, DataType::COMPLEX64, self_tensor.place());
} else { } else {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
other_tensor = full_ad_func( other_tensor = full_ad_func(
......
...@@ -98,8 +98,7 @@ def monkey_patch_variable(): ...@@ -98,8 +98,7 @@ def monkey_patch_variable():
return var return var
def create_scalar(block, value, dtype): def create_scalar(block, value, dtype):
# TODO(zhouwei): will change to [] which is 0-D Tensor return create_tensor(block, value, dtype, shape=[])
return create_tensor(block, value, dtype, shape=[1])
def create_tensor_with_batchsize(ref_var, value, dtype): def create_tensor_with_batchsize(ref_var, value, dtype):
assert isinstance(ref_var, Variable) assert isinstance(ref_var, Variable)
......
...@@ -139,12 +139,7 @@ class TestUnaryAPI(unittest.TestCase): ...@@ -139,12 +139,7 @@ class TestUnaryAPI(unittest.TestCase):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
out = api(x) out = api(x)
# TODO(zhouwei): paddle.static.append_backward(out)
# ScaleLossGradOp / append_backward set grad shape to [1]
# after output 0D, may change it to []
# use out.sum() to avoid this two problem now
loss = out.sum()
paddle.static.append_backward(loss)
fetch_list = [x, out] fetch_list = [x, out]
if block.has_var(x.grad_name): if block.has_var(x.grad_name):
...@@ -156,12 +151,10 @@ class TestUnaryAPI(unittest.TestCase): ...@@ -156,12 +151,10 @@ class TestUnaryAPI(unittest.TestCase):
self.assertEqual(item.shape, ()) self.assertEqual(item.shape, ())
# 2) Test CompiledProgram Program # 2) Test CompiledProgram Program
expect_shape = ()
compile_prog = paddle.static.CompiledProgram(main_prog) compile_prog = paddle.static.CompiledProgram(main_prog)
res = exe.run(compile_prog, fetch_list=fetch_list) res = exe.run(compile_prog, fetch_list=fetch_list)
for item in res: for item in res:
self.assertEqual(item.shape, expect_shape) self.assertEqual(item.shape, ())
paddle.disable_static() paddle.disable_static()
...@@ -229,7 +222,7 @@ class TestReduceAPI(unittest.TestCase): ...@@ -229,7 +222,7 @@ class TestReduceAPI(unittest.TestCase):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
out = api(x, None) out = api(x, None)
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out)
out_empty_list = api(x, None) out_empty_list = api(x, None)
self.assertEqual(out_empty_list.shape, ()) self.assertEqual(out_empty_list.shape, ())
...@@ -437,7 +430,7 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -437,7 +430,7 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(out.shape, out_cls.shape) self.assertEqual(out.shape, out_cls.shape)
else: else:
out = api(x, y) out = api(x, y)
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out)
self.assertEqual(x.shape, ()) self.assertEqual(x.shape, ())
self.assertEqual(y.shape, ()) self.assertEqual(y.shape, ())
...@@ -464,7 +457,7 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -464,7 +457,7 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(out.shape, out_cls.shape) self.assertEqual(out.shape, out_cls.shape)
else: else:
out = api(x, y) out = api(x, y)
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out)
self.assertEqual(x.shape, ()) self.assertEqual(x.shape, ())
self.assertEqual(y.shape, (2, 3, 4)) self.assertEqual(y.shape, (2, 3, 4))
...@@ -491,7 +484,7 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -491,7 +484,7 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(out.shape, out_cls.shape) self.assertEqual(out.shape, out_cls.shape)
else: else:
out = api(x, y) out = api(x, y)
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out)
self.assertEqual(x.shape, (2, 3, 4)) self.assertEqual(x.shape, (2, 3, 4))
self.assertEqual(y.shape, ()) self.assertEqual(y.shape, ())
...@@ -505,9 +498,6 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -505,9 +498,6 @@ class TestBinaryAPI(unittest.TestCase):
self.assertEqual(y_grad.shape, ()) self.assertEqual(y_grad.shape, ())
self.assertEqual(out_grad.shape, (2, 3, 4)) self.assertEqual(out_grad.shape, (2, 3, 4))
# TODO(zhouwei25):
# will open this UT after fix create_scalar in static graph
'''
# 4) x is 0D , y is scalar # 4) x is 0D , y is scalar
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
...@@ -516,17 +506,16 @@ class TestBinaryAPI(unittest.TestCase): ...@@ -516,17 +506,16 @@ class TestBinaryAPI(unittest.TestCase):
out = getattr(paddle.static.Variable, api['cls_method'])( out = getattr(paddle.static.Variable, api['cls_method'])(
x, y x, y
) )
paddle.static.append_backward(out.sum()) paddle.static.append_backward(out)
self.assertEqual(x.shape, ()) self.assertEqual(x.shape, ())
self.assertEqual(out.shape, ()) self.assertEqual(out.shape, ())
if block.has_var(x.name): if block.has_var(x.grad_name):
out_grad = block.var(out.grad_name) out_grad = block.var(out.grad_name)
x_grad = block.var(x.grad_name) x_grad = block.var(x.grad_name)
self.assertEqual(out_grad.shape, ()) self.assertEqual(out_grad.shape, ())
self.assertEqual(x_grad.shape, ()) self.assertEqual(x_grad.shape, ())
'''
for api in binary_int_api_list: for api in binary_int_api_list:
main_prog = paddle.static.Program() main_prog = paddle.static.Program()
...@@ -2154,10 +2143,10 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -2154,10 +2143,10 @@ class TestSundryAPIStatic(unittest.TestCase):
@prog_scope() @prog_scope()
def test_std(self): def test_std(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False
out1 = paddle.std(x) out1 = paddle.std(x)
out2 = paddle.std(x, []) out2 = paddle.std(x, [])
paddle.static.append_backward(out1) paddle.static.append_backward(out1)
paddle.static.append_backward(out2)
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
res = self.exe.run( res = self.exe.run(
...@@ -2166,19 +2155,23 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -2166,19 +2155,23 @@ class TestSundryAPIStatic(unittest.TestCase):
x, x,
out1, out1,
out2, out2,
x.grad_name,
out1.grad_name,
], ],
) )
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ()) self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
self.assertEqual(res[4].shape, ())
@prog_scope() @prog_scope()
def test_var(self): def test_var(self):
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False
out1 = paddle.var(x) out1 = paddle.var(x)
out2 = paddle.var(x, []) out2 = paddle.var(x, [])
paddle.static.append_backward(out1) paddle.static.append_backward(out1)
paddle.static.append_backward(out2)
prog = paddle.static.default_main_program() prog = paddle.static.default_main_program()
res = self.exe.run( res = self.exe.run(
...@@ -2187,11 +2180,15 @@ class TestSundryAPIStatic(unittest.TestCase): ...@@ -2187,11 +2180,15 @@ class TestSundryAPIStatic(unittest.TestCase):
x, x,
out1, out1,
out2, out2,
x.grad_name,
out1.grad_name,
], ],
) )
self.assertEqual(res[0].shape, ()) self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ()) self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ()) self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
self.assertEqual(res[4].shape, ())
@prog_scope() @prog_scope()
def test_quantile(self): def test_quantile(self):
...@@ -3651,7 +3648,7 @@ class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase): ...@@ -3651,7 +3648,7 @@ class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase):
def test_dygraph_unary(self): def test_dygraph_unary(self):
paddle.disable_static() paddle.disable_static()
for api in unary_apis_with_complex_input: for api in unary_apis_with_complex_input:
x = paddle.to_tensor(2.0 + 3.0j).squeeze() x = paddle.rand([]) + 1j * paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
x.retain_grads() x.retain_grads()
out = api(x) out = api(x)
...@@ -3668,7 +3665,6 @@ class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase): ...@@ -3668,7 +3665,6 @@ class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase):
def test_static_unary(self): def test_static_unary(self):
paddle.enable_static() paddle.enable_static()
for api in unary_apis_with_complex_input: for api in unary_apis_with_complex_input:
main_prog = paddle.static.Program() main_prog = paddle.static.Program()
block = main_prog.global_block() block = main_prog.global_block()
...@@ -3676,18 +3672,10 @@ class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase): ...@@ -3676,18 +3672,10 @@ class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase):
with paddle.static.program_guard( with paddle.static.program_guard(
main_prog, paddle.static.Program() main_prog, paddle.static.Program()
): ):
# before full support for complex, we cannot create complex tensor with the same code as in dynamic graph x = paddle.complex(paddle.rand([]), paddle.rand([]))
x = paddle.complex(
paddle.to_tensor(2.0), paddle.to_tensor(2.0)
).squeeze()
x.stop_gradient = False x.stop_gradient = False
out = api(x) out = api(x)
# TODO(zhouwei): paddle.static.append_backward(out)
# ScaleLossGradOp / append_backward set grad shape to [1]
# after output 0D, may change it to []
# use out.sum() to avoid this two problem now
loss = out.sum()
paddle.static.append_backward(loss)
fetch_list = [x, out] fetch_list = [x, out]
if block.has_var(x.grad_name): if block.has_var(x.grad_name):
...@@ -3699,12 +3687,10 @@ class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase): ...@@ -3699,12 +3687,10 @@ class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase):
self.assertEqual(item.shape, ()) self.assertEqual(item.shape, ())
# 2) Test CompiledProgram Program # 2) Test CompiledProgram Program
expect_shape = ()
compile_prog = paddle.static.CompiledProgram(main_prog) compile_prog = paddle.static.CompiledProgram(main_prog)
res = exe.run(compile_prog, fetch_list=fetch_list) res = exe.run(compile_prog, fetch_list=fetch_list)
for item in res: for item in res:
self.assertEqual(item.shape, expect_shape) self.assertEqual(item.shape, ())
paddle.disable_static() paddle.disable_static()
...@@ -3712,66 +3698,88 @@ class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase): ...@@ -3712,66 +3698,88 @@ class TestUnaryElementwiseAPIWithComplexInput(unittest.TestCase):
class TestAsReal(unittest.TestCase): class TestAsReal(unittest.TestCase):
def test_dygraph(self): def test_dygraph(self):
paddle.disable_static() paddle.disable_static()
for api in unary_apis_with_complex_input: x = paddle.rand([]) + 1j * paddle.rand([])
x = paddle.to_tensor(2.0 + 3.0j).squeeze() x.stop_gradient = False
x.stop_gradient = False x.retain_grads()
x.retain_grads() out = paddle.as_real(x)
out = paddle.as_real(x) out.retain_grads()
out.retain_grads() out.backward()
out.backward()
self.assertEqual(x.shape, []) self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [2]) self.assertEqual(out.shape, [2])
if x.grad is not None: if x.grad is not None:
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [2]) self.assertEqual(out.grad.shape, [2])
paddle.enable_static() paddle.enable_static()
def test_static(self): def test_static(self):
paddle.enable_static() paddle.enable_static()
for api in unary_apis_with_complex_input: main_prog = paddle.static.Program()
main_prog = paddle.static.Program() block = main_prog.global_block()
block = main_prog.global_block() exe = paddle.static.Executor()
exe = paddle.static.Executor() with paddle.static.program_guard(main_prog, paddle.static.Program()):
with paddle.static.program_guard( x = paddle.complex(paddle.rand([]), paddle.rand([]))
main_prog, paddle.static.Program() x.stop_gradient = False
): out = paddle.as_real(x)
# before full support for complex, we cannot create complex tensor with the same code as in dynamic graph self.assertEqual(x.shape, ())
x = paddle.complex( self.assertEqual(out.shape, (2,))
paddle.to_tensor(2.0), paddle.to_tensor(2.0) paddle.static.append_backward(out.sum())
).squeeze()
x.stop_gradient = False
out = paddle.as_real(x)
self.assertEqual(x.shape, ())
self.assertEqual(out.shape, (2,))
# TODO(zhouwei):
# ScaleLossGradOp / append_backward set grad shape to [1]
# after output 0D, may change it to []
# use out.sum() to avoid this two problem now
loss = out.abs().sum()
paddle.static.append_backward(loss)
fetch_list = [x, out] fetch_list = [x, out]
if block.has_var(x.grad_name): if block.has_var(x.grad_name):
fetch_list.extend([x.grad_name, out.grad_name]) fetch_list.extend([x.grad_name, out.grad_name])
# 1) Test Program res = exe.run(main_prog, fetch_list=fetch_list)
res = exe.run(main_prog, fetch_list=fetch_list) self.assertEqual(res[0].shape, ())
self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, (2,))
self.assertEqual(res[1].shape, (2,)) self.assertEqual(res[2].shape, ())
self.assertEqual(res[2].shape, ()) self.assertEqual(res[3].shape, (2,))
self.assertEqual(res[3].shape, (2,))
# 2) Test CompiledProgram Program paddle.disable_static()
expect_shapes = (), (2,), (), (2,)
compile_prog = paddle.static.CompiledProgram(main_prog)
res = exe.run(compile_prog, fetch_list=fetch_list)
print(res) class TestAsComplex(unittest.TestCase):
for actual, expect in zip(res, expect_shapes): def test_dygraph(self):
self.assertEqual(actual.shape, expect) paddle.disable_static()
x = paddle.rand([2])
x.stop_gradient = False
x.retain_grads()
out = paddle.as_complex(x)
out.retain_grads()
out.backward()
self.assertEqual(x.shape, [2])
self.assertEqual(out.shape, [])
if x.grad is not None:
self.assertEqual(x.grad.shape, [2])
self.assertEqual(out.grad.shape, [])
paddle.enable_static()
def test_static(self):
paddle.enable_static()
main_prog = paddle.static.Program()
block = main_prog.global_block()
exe = paddle.static.Executor()
with paddle.static.program_guard(main_prog, paddle.static.Program()):
x = paddle.rand([2])
x.stop_gradient = False
out = paddle.as_complex(x)
self.assertEqual(x.shape, (2,))
self.assertEqual(out.shape, ())
paddle.static.append_backward(out.sum())
fetch_list = [x, out]
if block.has_var(x.grad_name):
fetch_list.extend([x.grad_name, out.grad_name])
res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, (2,))
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2,))
self.assertEqual(res[3].shape, ())
paddle.disable_static() paddle.disable_static()
......
...@@ -4469,19 +4469,15 @@ def gcd(x, y, name=None): ...@@ -4469,19 +4469,15 @@ def gcd(x, y, name=None):
y = paddle.broadcast_to(y, shape) y = paddle.broadcast_to(y, shape)
x = paddle.abs(x) x = paddle.abs(x)
y = paddle.abs(y) y = paddle.abs(y)
# TODO(zhouwei25): Support 0D for not_equal tensor with scalar
zero = paddle.full([], 0)
def _gcd_cond_fn(x, y): def _gcd_cond_fn(x, y):
# return paddle.any(y != 0) return paddle.any(y != 0)
return paddle.any(y != zero)
def _gcd_body_fn(x, y): def _gcd_body_fn(x, y):
# paddle.mod will raise an error when any element of y is 0. To avoid # paddle.mod will raise an error when any element of y is 0. To avoid
# that, we change those zeros to ones. Their values don't matter because # that, we change those zeros to ones. Their values don't matter because
# they won't be used. # they won't be used.
# y_not_equal_0 = y != 0 y_not_equal_0 = y != 0
y_not_equal_0 = y != zero
y_safe = paddle.where(y_not_equal_0, y, paddle.ones(y.shape, y.dtype)) y_safe = paddle.where(y_not_equal_0, y, paddle.ones(y.shape, y.dtype))
x, y = ( x, y = (
paddle.where(y_not_equal_0, y, x), paddle.where(y_not_equal_0, y, x),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册