如何实现tf.where(cond, x, y)的等效操作,下面的代码无法运行
Created by: stuhailiu
目标:网络计算过程中存在某一个张量需要进行如下的操作,要求将该张量中大于等于0的值设置为c1,小于0的值设置为c2
import numpy as np import paddle.fluid as fluid
np_inp = np.array([[100, -100], [-100, 100]], dtype=np.float32) with fluid.dygraph.guard(): var_inp = fluid.dygraph.to_variable(np_inp) if_cond = fluid.layers.greater_equal(c, fluid.layers.fill_constant(shape=var_inp .shape, dtype='int64', value=0))
ie = fluid.layers.IfElse(if_cond)
with ie.true_block():
ie.input(var_inp)
ie.output(c1)
with ie.false_block():
ie.input(var_inp)
ie.output(c2)