未验证 提交 791637cf 编写于 作者: H Huihuang Zheng 提交者: GitHub

Support 0d Tensor in ConditionalBlockOp (#49842)

Support 0d Tensor in ConditionalBlockOp

1. Add dygraph 0d tensor support for ConditionalBlockOp
2. Set scalar loss shape when `append_backward`
上级 73f97de0
...@@ -385,12 +385,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): ...@@ -385,12 +385,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
def _create_loss_op_desc_(loss): def _create_loss_op_desc_(loss):
create_shape = [] if len(loss.shape) == 0 else [1]
op_desc = _create_op_desc_( op_desc = _create_op_desc_(
"fill_constant", "fill_constant",
{}, {},
{"Out": [_append_grad_suffix_(loss.name)]}, {"Out": [_append_grad_suffix_(loss.name)]},
{ {
"shape": [1], "shape": create_shape,
"value": 1.0, "value": 1.0,
"dtype": loss.dtype, "dtype": loss.dtype,
"force_cpu": False, "force_cpu": False,
......
...@@ -103,6 +103,7 @@ class TestCondInputOutput(unittest.TestCase): ...@@ -103,6 +103,7 @@ class TestCondInputOutput(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
(ret,) = exe.run(main_program, fetch_list=[out.name]) (ret,) = exe.run(main_program, fetch_list=[out.name])
np.testing.assert_allclose(np.asarray(ret), np.array(2), rtol=1e-05) np.testing.assert_allclose(np.asarray(ret), np.array(2), rtol=1e-05)
self.assertEqual(ret.shape, ())
def test_0d_tensor_as_cond(self): def test_0d_tensor_as_cond(self):
""" """
...@@ -129,7 +130,7 @@ class TestCondInputOutput(unittest.TestCase): ...@@ -129,7 +130,7 @@ class TestCondInputOutput(unittest.TestCase):
y = paddle.full(shape=[], dtype='float32', fill_value=0.23) y = paddle.full(shape=[], dtype='float32', fill_value=0.23)
pred = paddle.greater_equal(y, x) pred = paddle.greater_equal(y, x)
out = paddle.static.nn.cond(pred, true_func, false_func) out = paddle.static.nn.cond(pred, true_func, false_func)
# out is one tensor # out is a tensor
place = ( place = (
fluid.CUDAPlace(0) fluid.CUDAPlace(0)
...@@ -168,14 +169,41 @@ class TestCondInputOutput(unittest.TestCase): ...@@ -168,14 +169,41 @@ class TestCondInputOutput(unittest.TestCase):
if core.is_compiled_with_cuda() if core.is_compiled_with_cuda()
else fluid.CPUPlace() else fluid.CPUPlace()
) )
exe = fluid.Executor(place) exe = fluid.Executor(place)
ret = exe.run(main_program, fetch_list=[out.name, a.grad_name]) ret = exe.run(main_program, fetch_list=[out.name, a.grad_name])
np.testing.assert_allclose( np.testing.assert_allclose(
np.asarray(ret[0]), np.array(2.0), rtol=1e-05 np.asarray(ret[0]), np.array(2.0), rtol=1e-05
) )
self.assertEqual(ret[0].shape, ())
np.testing.assert_allclose( np.testing.assert_allclose(
np.asarray(ret[1]), np.array(-1.0), rtol=1e-05 np.asarray(ret[1]), np.array(-1.0), rtol=1e-05
) )
self.assertEqual(ret[1].shape, ())
def test_0d_tensor_dygraph(self):
"""
pseudocode:
a = -2.0
if a >= 0:
return a
else:
return -a
"""
paddle.disable_static()
a = paddle.full(shape=[], dtype='float32', fill_value=-2.0)
a.stop_gradient = False
out = paddle.static.nn.cond(a >= 0, lambda: a, lambda: -a)
out.backward()
np.testing.assert_allclose(np.asarray(out), np.array(2.0), rtol=1e-05)
self.assertEqual(out.shape, [])
np.testing.assert_allclose(
np.asarray(a.grad), np.array(-1.0), rtol=1e-05
)
self.assertEqual(a.grad.shape, [])
def test_return_var_tuple(self): def test_return_var_tuple(self):
""" """
...@@ -527,9 +555,11 @@ class TestCondNestedControlFlow(unittest.TestCase): ...@@ -527,9 +555,11 @@ class TestCondNestedControlFlow(unittest.TestCase):
np.testing.assert_allclose( np.testing.assert_allclose(
np.asarray(ret[0]), np.array(7.0), rtol=1e-05 np.asarray(ret[0]), np.array(7.0), rtol=1e-05
) )
self.assertEqual(ret[0].shape, ())
np.testing.assert_allclose( np.testing.assert_allclose(
np.asarray(ret[1]), np.array(2.0), rtol=1e-05 np.asarray(ret[1]), np.array(2.0), rtol=1e-05
) )
self.assertEqual(ret[1].shape, ())
def test_cond_op_in_condition(self): def test_cond_op_in_condition(self):
paddle.enable_static() paddle.enable_static()
......
...@@ -969,7 +969,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): ...@@ -969,7 +969,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
if _non_static_mode(): if _non_static_mode():
assert isinstance(pred, Variable), "The pred in cond must be Variable" assert isinstance(pred, Variable), "The pred in cond must be Variable"
assert pred.size == 1, "condition input's numel should be 1" assert pred.size == 1, "condition input's numel should be 1"
pred = pred.numpy()[0] pred = pred.numpy().item()
if pred: if pred:
if true_fn is not None: if true_fn is not None:
if not callable(true_fn): if not callable(true_fn):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册