Created by: songyouwei
动态图下支持使用控制流op cond
。依赖此op的case
、switch_case
op也一并支持。
cond
sample:
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import numpy as np
#
# pseudocode:
# if 0.1 < 0.23:
# return 1, True
# else:
# return 3, 2
#
def true_func():
return layers.fill_constant(
shape=[1, 2], dtype='int32', value=1), layers.fill_constant(
shape=[2, 3], dtype='bool', value=True)
def false_func():
return layers.fill_constant(
shape=[3, 4], dtype='float32', value=3), layers.fill_constant(
shape=[4, 5], dtype='int64', value=2)
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(np.array([0.1]))
y = fluid.dygraph.to_variable(np.array([0.23]))
pred = layers.less_than(x, y)
out = layers.cond(pred, true_func, false_func)
# out is a tuple containing 2 tensors
# out[0] = [[1 1]]
# out[1] = [[ True True True]
# [ True True True]]
switch_case
sample:
import paddle.fluid as fluid
import paddle.fluid.layers as layers
def fn_1():
return layers.fill_constant(shape=[1, 2], dtype='float32', value=1)
def fn_2():
return layers.fill_constant(shape=[2, 2], dtype='int32', value=2)
def fn_3():
return layers.fill_constant(shape=[3], dtype='int32', value=3)
with fluid.dygraph.guard():
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
out_1 = layers.switch_case(
branch_index=index_1,
branch_fns={1: fn_1, 2: fn_2},
default=fn_3)
out_2 = layers.switch_case(
branch_index=index_2,
branch_fns=[(1, fn_1), (2, fn_2)],
default=fn_3)
# Argument default is None and no index matches. fn_3 will be called because of the max index 7.
out_3 = layers.switch_case(
branch_index=index_2,
branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)])
print(out_1.numpy()) # [[1. 1.]]
print(out_2.numpy()) # [[2 2] [2 2]]
print(out_3.numpy()) # [3 3 3]
case
sample:
import paddle.fluid as fluid
import paddle.fluid.layers as layers
def fn_1():
return layers.fill_constant(shape=[1, 2], dtype='float32', value=1)
def fn_2():
return layers.fill_constant(shape=[2, 2], dtype='int32', value=2)
def fn_3():
return layers.fill_constant(shape=[3], dtype='int32', value=3)
with fluid.dygraph.guard():
x = layers.fill_constant(shape=[1], dtype='float32', value=0.3)
y = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
z = layers.fill_constant(shape=[1], dtype='float32', value=0.2)
pred_1 = layers.less_than(z, x) # true: 0.2 < 0.3
pred_2 = layers.less_than(x, y) # false: 0.3 < 0.1
pred_3 = layers.equal(x, y) # false: 0.3 == 0.1
# Call fn_1 because pred_1 is True
out_1 = layers.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3)
# Argument default is None and no pred in pred_fn_pairs is True. fn_3 will be called.
# because fn_3 is the last callable in pred_fn_pairs.
out_2 = layers.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])
print(out_1.numpy()) # [[1. 1.]]
print(out_2.numpy()) # [3 3 3]