未验证 提交 0c617377 编写于 作者: S songyouwei 提交者: GitHub

add case and switch_case unittests for dygraph mode (#22790)

test=develop
上级 ddb9b46f
......@@ -1564,6 +1564,110 @@ class TestLayer(LayerTest):
self.assertTrue(np.array_equal(static_res, dynamic_res))
def test_case(self):
def fn_1():
return layers.fill_constant(shape=[1, 2], dtype='float32', value=1)
def fn_2():
return layers.fill_constant(shape=[2, 2], dtype='int32', value=2)
def fn_3():
return layers.fill_constant(shape=[3], dtype='int32', value=3)
with self.static_graph():
x = layers.fill_constant(shape=[1], dtype='float32', value=0.3)
y = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
z = layers.fill_constant(shape=[1], dtype='float32', value=0.2)
pred_1 = layers.less_than(z, x) # true: 0.2 < 0.3
pred_2 = layers.less_than(x, y) # false: 0.3 < 0.1
pred_3 = layers.equal(x, y) # false: 0.3 == 0.1
out_1 = layers.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3)
out_2 = layers.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
static_res1, static_res2 = exe.run(fetch_list=[out_1, out_2])
with self.dynamic_graph():
x = layers.fill_constant(shape=[1], dtype='float32', value=0.3)
y = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
z = layers.fill_constant(shape=[1], dtype='float32', value=0.2)
pred_1 = layers.less_than(z, x) # true: 0.2 < 0.3
pred_2 = layers.less_than(x, y) # false: 0.3 < 0.1
pred_3 = layers.equal(x, y) # false: 0.3 == 0.1
out_1 = layers.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3)
out_2 = layers.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])
dynamic_res1 = out_1.numpy()
dynamic_res2 = out_2.numpy()
self.assertTrue(np.array_equal(static_res1, dynamic_res1))
self.assertTrue(np.array_equal(static_res2, dynamic_res2))
def test_switch_case(self):
def fn_1():
return layers.fill_constant(shape=[1, 2], dtype='float32', value=1)
def fn_2():
return layers.fill_constant(shape=[2, 2], dtype='int32', value=2)
def fn_3():
return layers.fill_constant(shape=[3], dtype='int32', value=3)
with self.static_graph():
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
out_1 = layers.switch_case(
branch_index=index_1,
branch_fns={1: fn_1,
2: fn_2},
default=fn_3)
out_2 = layers.switch_case(
branch_index=index_2,
branch_fns=[(1, fn_1), (2, fn_2)],
default=fn_3)
out_3 = layers.switch_case(
branch_index=index_2,
branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)])
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
static_res1, static_res2, static_res3 = exe.run(
fetch_list=[out_1, out_2, out_3])
with self.dynamic_graph():
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
out_1 = layers.switch_case(
branch_index=index_1,
branch_fns={1: fn_1,
2: fn_2},
default=fn_3)
out_2 = layers.switch_case(
branch_index=index_2,
branch_fns=[(1, fn_1), (2, fn_2)],
default=fn_3)
out_3 = layers.switch_case(
branch_index=index_2,
branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)])
dynamic_res1 = out_1.numpy()
dynamic_res2 = out_2.numpy()
dynamic_res3 = out_3.numpy()
self.assertTrue(np.array_equal(static_res1, dynamic_res1))
self.assertTrue(np.array_equal(static_res2, dynamic_res2))
self.assertTrue(np.array_equal(static_res3, dynamic_res3))
def test_crop_tensor(self):
with self.static_graph():
x = fluid.layers.data(name="x1", shape=[6, 5, 8])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册