diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 9a8586e3761cceac36b2fde48ae8d4a0161f509a..a8b372c28ce4bc54711a287fc54e60a770feebad 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -145,21 +145,6 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0): need_check_feed=False) -def create_undefined_var_like(variable): - """ create a undefined var with the same shape and dtype like varaible. - """ - from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM - var = data_layer_not_check(unique_name.generate("undefined_var"), - variable.shape, variable.dtype) - var.stop_gradient = False - helper = LayerHelper('create_undefined_var_like', **locals()) - saved_block_ids = helper.main_program.current_block_idx - helper.main_program.current_block_idx = 0 - assign(RETURN_NO_VALUE_MAGIC_NUM, var) - helper.main_program.current_block_idx = saved_block_ids - return var - - def create_undefined_variable(): from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM var = data_layer_not_check(unique_name.generate("undefined_var"), [1], diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index a9f2eaa40e2a54fbdeeb806b5ba986da1306afaa..6275dff31ad49ed2bbd7b852826048972714e2b5 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -70,6 +70,25 @@ def select_output(input, outputs, mask): return outputs +def _select_input_infer_shape(first_shape, second_shape): + """ + This function infer the output shape by following algorithm: + 1. if the dims is different, raise a error. + 2. compare axis one by one: + if a == b: we set axis to a + if a != b: we set axis to -1 + for compatibility,non declarative mode, we just return second_shape. + """ + if len(first_shape) != len(second_shape): + warnings.warn( + f"the input shapes of select_input should have the same rank, but get {first_shape}, {second_shape}" + ) + return second_shape + out_shape = list( + map(lambda a, b: a if a == b else -1, first_shape, second_shape)) + return out_shape + + def select_input(inputs, mask): """ **select_input** @@ -89,13 +108,15 @@ def select_input(inputs, mask): check_type(inputs, 'inputs', (list, tuple), 'select_input') check_variable_and_dtype(mask, 'mask', ['int32'], 'select_input') - input_dtype = inputs[1].dtype - input_shape = inputs[1].shape - input_type = inputs[1].type + # Select input should expand the shape. If it is - 1 and valid number, use - 1 first. If the dim is different, an error will be reported directly + #assert inputs[0].dtype == inputs[1].dtype, f"Expect the inputs should have the same dtype, but get {inputs[0].dtype} and {inputs[1].dtype}" + output_shape = _select_input_infer_shape(inputs[0].shape, inputs[1].shape) + output_dtype = inputs[1].dtype + output_type = inputs[1].type - out = helper.create_variable(dtype=input_dtype, - shape=input_shape, - type=input_type) + out = helper.create_variable(dtype=output_dtype, + shape=output_shape, + type=output_type) helper.append_op(type='select_input', inputs={ 'X': inputs, @@ -105,9 +126,9 @@ def select_input(inputs, mask): return out -def select_input_with_buildin_type(inputs, mask): +def select_input_with_buildin_type(inputs, mask, name): from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable - from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_var_like + from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar false_var, true_var = inputs if isinstance(false_var, UndefinedVar) and isinstance( @@ -117,7 +138,11 @@ def select_input_with_buildin_type(inputs, mask): return None if isinstance(false_var, Variable) and isinstance(true_var, Variable): - return select_input(inputs, mask) + try: + return select_input(inputs, mask) + except Exception as e: + raise RuntimeError( + f"Exceptions throwed while doing select_input on {name}:\n{e}") elif (isinstance(false_var, (support_ret_buildin_type)) and isinstance(false_var, type(true_var))): @@ -148,24 +173,19 @@ def select_input_with_buildin_type(inputs, mask): if isinstance(a, UndefinedVar): return a return to_static_variable(a) - def create_like_if_undefined_var(a, b): - if isinstance(a, UndefinedVar): return create_undefined_var_like(b) - return a - - # TODO(xiongkun): add warning here. - true_var, false_var = create_var_if_not_undefined_var( - true_var), create_var_if_not_undefined_var(false_var) - inputs = [ - create_like_if_undefined_var(false_var, true_var), - create_like_if_undefined_var(true_var, false_var) - ] + true_var, false_var = to_static_variable(true_var), to_static_variable( + false_var) + inputs = [false_var, true_var] else: raise TypeError( "Unsupported return type of true_fn and false_fn in cond: false_var " "returned by fasle_fn is '{}' and true_var of true_fn is '{}'". format(type(false_var), type(true_var))) - - return select_input(inputs, mask) + try: + return select_input(inputs, mask) + except Exception as e: + raise RuntimeError( + f"Exceptions throwed while doing select_input on {name}:\n{e}") def split_lod_tensor(input, mask, level=0): @@ -2658,9 +2678,16 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): .format(return_name, e)) mask = cast(pred, dtype='int32') - merge_func = lambda false_var, true_var: select_input_with_buildin_type( - [false_var, true_var], mask) - merged_output = map_structure(merge_func, false_output, true_output) + merge_func = lambda name, false_var, true_var: select_input_with_buildin_type( + [false_var, true_var], mask, name) + + def merge_every_var_list(false_vars, true_vars, name): + return map_structure(partial(merge_func, name), false_vars, true_vars) + + merged_output = list( + map(merge_every_var_list, to_sequence(false_output), + to_sequence(true_output), to_sequence(return_names))) + merged_output = pack_sequence_as(false_output, flatten(merged_output)) return merged_output diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py index b37accce9d1b84a6378aa38ff7f85deebfbc4275..482206b906abd682f9dfcae79367972fdff77220 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py @@ -164,7 +164,7 @@ def nested_if_else(x_v): if paddle.mean(y).numpy()[0] < batch_size: y = fluid.layers.abs(y) else: - tmp = fluid.layers.fill_constant([feat_size], + tmp = fluid.layers.fill_constant(y.shape, dtype='float32', value=-1) y = y - tmp @@ -273,7 +273,7 @@ class NetWithControlFlowIf(fluid.dygraph.Layer): [hidden_dim], dtype='float32', value=9) y = fluid.layers.abs(y) else: - tmp = fluid.layers.fill_constant([5], + tmp = fluid.layers.fill_constant(y.shape, dtype='float32', value=-1) y = y - tmp