未验证 提交 b813c948 编写于 作者: S songyouwei 提交者: GitHub

support control flow cond in dygraph mode (#22693)

* dygraph support cond op
test=develop

* unittest coverage
test=develop

* fix coverage
test=develop

* fix for coverage
test=develop

* refine TypeError msg
test=develop

* remove restrict
test=develop
上级 b2c1be85
......@@ -2035,13 +2035,35 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
# [ True True True]]
"""
if in_dygraph_mode():
assert isinstance(pred, Variable), "The pred in cond must be Variable"
assert pred.numpy().size == 1, "condition input's numel should be 1"
pred = pred.numpy()[0]
if pred:
if true_fn is not None:
if not callable(true_fn):
raise TypeError(
"The true_fn in cond must be callable, but received {}".
format(type(true_fn).__name__))
return true_fn()
else:
if false_fn is not None:
if not callable(false_fn):
raise TypeError(
"The false_fn in cond must be callable, but received {}".
format(type(false_fn).__name__))
return false_fn()
return None
helper = LayerHelper('cond', **locals())
true_output = None
false_output = None
copy_to_parent_func = lambda var: copy_var_to_parent_block(var, helper)
if true_fn is not None:
if not callable(true_fn):
raise TypeError("The true_fn in cond must be callable")
raise TypeError(
"The true_fn in cond must be callable, but received {}".format(
type(true_fn).__name__))
true_cond_block = ConditionalBlock([pred], is_scalar_condition=True)
with true_cond_block.block():
origin_true_output = true_fn()
......@@ -2050,7 +2072,9 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
origin_true_output)
if false_fn is not None:
if not callable(false_fn):
raise TypeError("The false_fn in cond must be callable")
raise TypeError(
"The false_fn in cond must be callable, but received {}".format(
type(false_fn).__name__))
false_cond_block = ConditionalBlock(
[logical_not(pred)], is_scalar_condition=True)
with false_cond_block.block():
......
......@@ -192,15 +192,11 @@ class TestCondInputOutput(unittest.TestCase):
with program_guard(main_program, startup_program):
i = fluid.data(name="i", shape=[1], dtype='int32')
pred = ((i % 2) == 0)
with self.assertRaises(Exception) as e:
with self.assertRaises(TypeError):
out = layers.cond(pred, i, func_return_one_tensor)
self.assertEqual("The true_fn in cond must be callable",
str(e.exception))
with self.assertRaises(Exception) as e:
with self.assertRaises(TypeError):
out = layers.cond(pred, func_return_one_tensor, np.asarray([3]))
self.assertEqual("The false_fn in cond must be callable",
str(e.exception))
with self.assertRaises(Exception) as e:
out = layers.cond(pred, func_return_none,
......
......@@ -1527,6 +1527,43 @@ class TestLayer(LayerTest):
for i in range(len(static_ret5)):
self.assertTrue(dcond5.numpy()[i] == static_ret5[i])
def test_cond(self):
def less_than_branch(a, b):
return fluid.layers.elementwise_add(a, b)
def greater_equal_branch(a, b):
return fluid.layers.elementwise_sub(a, b)
with self.static_graph():
a = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=0.1)
b = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=0.23)
out = fluid.layers.cond(a >= b, lambda: greater_equal_branch(a, b),
lambda: less_than_branch(a, b))
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
ret = exe.run(fetch_list=[out])
static_res = ret[0]
with self.dynamic_graph():
a = fluid.dygraph.to_variable(np.array([0.1]).astype('float32'))
b = fluid.dygraph.to_variable(np.array([0.23]).astype('float32'))
out = layers.cond(a < b, lambda: less_than_branch(a, b),
lambda: greater_equal_branch(a, b))
out2 = layers.cond(a >= b, lambda: greater_equal_branch(a, b),
lambda: less_than_branch(a, b))
dynamic_res = out.numpy()
dynamic_res2 = out2.numpy()
self.assertTrue(np.array_equal(dynamic_res, dynamic_res2))
with self.assertRaises(TypeError):
layers.cond(a < b, 'str', 'str')
with self.assertRaises(TypeError):
layers.cond(a >= b, 'str', 'str')
self.assertTrue(np.array_equal(static_res, dynamic_res))
def test_crop_tensor(self):
with self.static_graph():
x = fluid.layers.data(name="x1", shape=[6, 5, 8])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册