IfElse 不支持嵌套?
Created by: Annnnnnnnnnnnn
简单的代码如下:
encIdx = fluid.layers.data(name='encIdx', shape=[1], dtype='float32')
temp = fluid.layers.fill_constant([1], dtype='float32', value=2.)
cond = fluid.layers.equal(encIdx, temp)
branch_1st = fluid.layers.IfElse(cond)
with branch_1st.true_block():
words_a = branch_1st.input(encIdx)
words_b = fluid.layers.exp(words_a)
fluid.layers.Print(words_b)
branch_1st.output(words_a, words_b)
with branch_1st.false_block():
cond = fluid.layers.equal(encIdx, temp)
branch_2nd = fluid.layers.IfElse(cond)
with branch_2nd.true_block():
words_a = branch_2nd.input(encIdx)
words_b = fluid.layers.sqrt(words_a)
fluid.layers.Print(words_b)
branch_2nd.output(words_a, words_b)
with branch_2nd.false_block():
words_a = branch_2nd.input(encIdx)
words_b = fluid.layers.square(words_a)
fluid.layers.Print(words_b)
branch_2nd.output(words_a, words_b)
words_a, words_b = branch_2nd()
branch_1st.output(words_a, words_b)
a, b = branch_1st()
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
a_, b_ = exe.run(fluid.default_main_program(),
feed={"encIdx": np.array([2.]).astype(np.float32)},
fetch_list=[a.name, b.name])
print a_, a_.shape
print b_, b_.shape
结果如下, 第二个ifelse块里的false也被执行了,而且结果也多不上。 1558618182 Tensor[exp_0.tmp_0] shape: [1,] dtype: f data: 7.38906,
1558618182 Tensor[sqrt_0.tmp_0] shape: [1,] dtype: f data: 1.41421,
[ 2. 0.] (2,) [ 7.38905621 0. ] (2,)