未验证 提交 b85c9b56 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static ] Fix bugs when select inputs meeting different shape or undefined-var (#45916)

* fix select_input with different shape errors:
1. select_input_with_buildin_type directly return non-undefinedvar branch when meeting undefined var
2. the output shape of select_input is inferred from inputs.

* reverse the logic in select_input
上级 6833ecfe
......@@ -145,21 +145,6 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
need_check_feed=False)
def create_undefined_var_like(variable):
""" create a undefined var with the same shape and dtype like varaible.
"""
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
var = data_layer_not_check(unique_name.generate("undefined_var"),
variable.shape, variable.dtype)
var.stop_gradient = False
helper = LayerHelper('create_undefined_var_like', **locals())
saved_block_ids = helper.main_program.current_block_idx
helper.main_program.current_block_idx = 0
assign(RETURN_NO_VALUE_MAGIC_NUM, var)
helper.main_program.current_block_idx = saved_block_ids
return var
def create_undefined_variable():
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
var = data_layer_not_check(unique_name.generate("undefined_var"), [1],
......
......@@ -70,6 +70,25 @@ def select_output(input, outputs, mask):
return outputs
def _select_input_infer_shape(first_shape, second_shape):
"""
This function infer the output shape by following algorithm:
1. if the dims is different, raise a error.
2. compare axis one by one:
if a == b: we set axis to a
if a != b: we set axis to -1
for compatibility,non declarative mode, we just return second_shape.
"""
if len(first_shape) != len(second_shape):
warnings.warn(
f"the input shapes of select_input should have the same rank, but get {first_shape}, {second_shape}"
)
return second_shape
out_shape = list(
map(lambda a, b: a if a == b else -1, first_shape, second_shape))
return out_shape
def select_input(inputs, mask):
"""
**select_input**
......@@ -89,13 +108,15 @@ def select_input(inputs, mask):
check_type(inputs, 'inputs', (list, tuple), 'select_input')
check_variable_and_dtype(mask, 'mask', ['int32'], 'select_input')
input_dtype = inputs[1].dtype
input_shape = inputs[1].shape
input_type = inputs[1].type
# Select input should expand the shape. If it is - 1 and valid number, use - 1 first. If the dim is different, an error will be reported directly
#assert inputs[0].dtype == inputs[1].dtype, f"Expect the inputs should have the same dtype, but get {inputs[0].dtype} and {inputs[1].dtype}"
output_shape = _select_input_infer_shape(inputs[0].shape, inputs[1].shape)
output_dtype = inputs[1].dtype
output_type = inputs[1].type
out = helper.create_variable(dtype=input_dtype,
shape=input_shape,
type=input_type)
out = helper.create_variable(dtype=output_dtype,
shape=output_shape,
type=output_type)
helper.append_op(type='select_input',
inputs={
'X': inputs,
......@@ -105,9 +126,9 @@ def select_input(inputs, mask):
return out
def select_input_with_buildin_type(inputs, mask):
def select_input_with_buildin_type(inputs, mask, name):
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_var_like
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
false_var, true_var = inputs
if isinstance(false_var, UndefinedVar) and isinstance(
......@@ -117,7 +138,11 @@ def select_input_with_buildin_type(inputs, mask):
return None
if isinstance(false_var, Variable) and isinstance(true_var, Variable):
try:
return select_input(inputs, mask)
except Exception as e:
raise RuntimeError(
f"Exceptions throwed while doing select_input on {name}:\n{e}")
elif (isinstance(false_var, (support_ret_buildin_type))
and isinstance(false_var, type(true_var))):
......@@ -148,24 +173,19 @@ def select_input_with_buildin_type(inputs, mask):
if isinstance(a, UndefinedVar): return a
return to_static_variable(a)
def create_like_if_undefined_var(a, b):
if isinstance(a, UndefinedVar): return create_undefined_var_like(b)
return a
# TODO(xiongkun): add warning here.
true_var, false_var = create_var_if_not_undefined_var(
true_var), create_var_if_not_undefined_var(false_var)
inputs = [
create_like_if_undefined_var(false_var, true_var),
create_like_if_undefined_var(true_var, false_var)
]
true_var, false_var = to_static_variable(true_var), to_static_variable(
false_var)
inputs = [false_var, true_var]
else:
raise TypeError(
"Unsupported return type of true_fn and false_fn in cond: false_var "
"returned by fasle_fn is '{}' and true_var of true_fn is '{}'".
format(type(false_var), type(true_var)))
try:
return select_input(inputs, mask)
except Exception as e:
raise RuntimeError(
f"Exceptions throwed while doing select_input on {name}:\n{e}")
def split_lod_tensor(input, mask, level=0):
......@@ -2658,9 +2678,16 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
.format(return_name, e))
mask = cast(pred, dtype='int32')
merge_func = lambda false_var, true_var: select_input_with_buildin_type(
[false_var, true_var], mask)
merged_output = map_structure(merge_func, false_output, true_output)
merge_func = lambda name, false_var, true_var: select_input_with_buildin_type(
[false_var, true_var], mask, name)
def merge_every_var_list(false_vars, true_vars, name):
return map_structure(partial(merge_func, name), false_vars, true_vars)
merged_output = list(
map(merge_every_var_list, to_sequence(false_output),
to_sequence(true_output), to_sequence(return_names)))
merged_output = pack_sequence_as(false_output, flatten(merged_output))
return merged_output
......
......@@ -164,7 +164,7 @@ def nested_if_else(x_v):
if paddle.mean(y).numpy()[0] < batch_size:
y = fluid.layers.abs(y)
else:
tmp = fluid.layers.fill_constant([feat_size],
tmp = fluid.layers.fill_constant(y.shape,
dtype='float32',
value=-1)
y = y - tmp
......@@ -273,7 +273,7 @@ class NetWithControlFlowIf(fluid.dygraph.Layer):
[hidden_dim], dtype='float32', value=9)
y = fluid.layers.abs(y)
else:
tmp = fluid.layers.fill_constant([5],
tmp = fluid.layers.fill_constant(y.shape,
dtype='float32',
value=-1)
y = y - tmp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册