diff --git a/tests/test_prune_walker.py b/tests/test_prune_walker.py index 46290633a80ce5ae5040161914de583eba3e4e87..e0375c8b6038bfd9a2db1b16f0cf7d03cc98be8a 100644 --- a/tests/test_prune_walker.py +++ b/tests/test_prune_walker.py @@ -61,14 +61,14 @@ class TestPrune(StaticCase): def cond_block1(): cond_conv = conv_bn_layer(conv5, 8, 3, "conv_cond1_1") - fluid.layers.assign(input=cond_conv, output=cond_output) + return cond_conv def cond_block2(): cond_conv1 = conv_bn_layer(conv5, 8, 3, "conv_cond2_1") cond_conv2 = conv_bn_layer(cond_conv1, 8, 3, "conv_cond2_2") - fluid.layers.assign(input=cond_conv2, output=cond_output) + return cond_conv2 - fluid.layers.cond(cond, cond_block1, cond_block2) + cond_output = fluid.layers.cond(cond, cond_block1, cond_block2) sum3 = fluid.layers.sum([sum2, cond_output]) conv6 = conv_bn_layer(sum3, 8, 3, "conv6")